Skip to content

Use psnr to compare frames #662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dev = [
"numpy",
"pytest",
"pillow",
"torcheval",
]

[tool.usort]
Expand Down
8 changes: 3 additions & 5 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,13 @@ def test_get_frame_at_pts(self, device):
frame6, _, _ = get_frame_at_pts(decoder, 6.02)
assert_frames_equal(frame6, reference_frame6.to(device))
frame6, _, _ = get_frame_at_pts(decoder, 6.039366)
assert_frames_equal(frame6, reference_frame6.to(device))
prev_frame_psnr = assert_frames_equal(frame6, reference_frame6.to(device))
# Note that this timestamp is exactly on a frame boundary, so it should
# return the next frame since the right boundary of the interval is
# open.
next_frame, _, _ = get_frame_at_pts(decoder, 6.039367)
if device == "cpu":
# We can only compare exact equality on CPU.
with pytest.raises(AssertionError):
assert_frames_equal(next_frame, reference_frame6.to(device))
with pytest.raises(AssertionError):
assert_frames_equal(next_frame, reference_frame6.to(device), psnr=prev_frame_psnr)

@pytest.mark.parametrize("device", cpu_and_cuda())
def test_get_frame_at_index(self, device):
Expand Down
6 changes: 3 additions & 3 deletions test/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,11 @@ def test_sampling_range(
cm = (
contextlib.nullcontext()
if assert_all_equal
else pytest.raises(AssertionError, match="Tensor-likes are not")
else pytest.raises(AssertionError, match="low psnr")
)
with cm:
for clip in clips:
assert_frames_equal(clip.data, clips[0].data)
assert_frames_equal(clip.data, clips[0].data, psnr=float("inf"))


@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices))
Expand Down Expand Up @@ -447,7 +447,7 @@ def test_random_sampler_randomness(sampler):
# Call with a different seed, expect different results
torch.manual_seed(1)
clips_3 = sampler(decoder, num_clips=num_clips)
with pytest.raises(AssertionError, match="Tensor-likes are not"):
with pytest.raises(AssertionError, match="low psnr"):
assert_frames_equal(clips_1[0].data, clips_3[0].data)

# Make sure we didn't alter the builtin Python RNG
Expand Down
34 changes: 16 additions & 18 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,23 @@ def get_ffmpeg_major_version():
return int(get_ffmpeg_library_versions()["ffmpeg_version"].split(".")[0])


# For use with decoded data frames. On CPU Linux, we expect exact, bit-for-bit
# equality. On CUDA Linux, we expect a small tolerance.
# On other platforms (e.g. MacOS), we also allow a small tolerance. FFmpeg does
# not guarantee bit-for-bit equality across systems and architectures, so we
# also cannot. We currently use Linux on x86_64 as our reference system.
def assert_frames_equal(*args, **kwargs):
if sys.platform == "linux":
if args[0].device.type == "cuda":
atol = 2
if get_ffmpeg_major_version() == 4:
assert_tensor_close_on_at_least(
args[0], args[1], percentage=95, atol=atol
)
else:
torch.testing.assert_close(*args, **kwargs, atol=atol, rtol=0)
else:
torch.testing.assert_close(*args, **kwargs, atol=0, rtol=0)
# For use with decoded data frames. `psnr` sets the PSNR threshold when
# frames are considered equal. `float("inf")` correspond to bit-to-bit
# identical frames. Function returns calculated psnr value.
def assert_frames_equal(input, other, psnr=40, msg=None):
if torch.allclose(input, other, atol=0, rtol=0):
return float("inf")
else:
torch.testing.assert_close(*args, **kwargs, atol=3, rtol=0)
from torcheval.metrics import PeakSignalNoiseRatio

metric = PeakSignalNoiseRatio()
metric.update(input, other)
m = metric.compute()
message = f"low psnr: {m} < {psnr}"
if (msg):
message += f" ({msg})"
assert m >= psnr, message
return m


# Asserts that at least `percentage`% of the values are within the absolute tolerance.
Expand Down