diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index 525c7ac8..50eb1292 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -125,7 +125,7 @@ class AudioSamples(Iterable): pts_seconds: float """The :term:`pts` of the first sample, in seconds.""" duration_seconds: float - """The duration of the sampleas, in seconds.""" + """The duration of the samples, in seconds.""" sample_rate: int """The sample rate of the samples, in Hz.""" diff --git a/test/test_encoders.py b/test/test_encoders.py index bf5f9cc6..da2a185d 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -1,5 +1,7 @@ +import json import re import subprocess +from pathlib import Path import pytest import torch @@ -16,12 +18,55 @@ ) +def validate_frames_properties(*, actual: Path, expected: Path): + + frames_actual, frames_expected = ( + json.loads( + subprocess.run( + [ + "ffprobe", + "-v", + "error", + "-hide_banner", + "-select_streams", + "a:0", + "-show_frames", + "-of", + "json", + f"{f}", + ], + check=True, + capture_output=True, + text=True, + ).stdout + )["frames"] + for f in (actual, expected) + ) + + # frames_actual and frames_expected are both a list of dicts, each dict + # corresponds to a frame and each key-value pair corresponds to a frame + # property like pts, nb_samples, etc., similar to the AVFrame fields. + assert isinstance(frames_actual, list) + assert all(isinstance(d, dict) for d in frames_actual) + + assert len(frames_actual) == len(frames_expected) + for frame_index, (d_actual, d_expected) in enumerate( + zip(frames_actual, frames_expected) + ): + for prop in d_expected: + if prop == "pkt_pos": + continue # TODO this probably matters + assert ( + d_actual[prop] == d_expected[prop] + ), f"{prop} value is different for frame {frame_index}:" + + class TestAudioEncoder: def decode(self, source) -> torch.Tensor: if isinstance(source, TestContainerFile): source = str(source.path) - return AudioDecoder(source).get_all_samples().data + return AudioDecoder(source).get_all_samples() def test_bad_input(self): with pytest.raises(ValueError, match="Expected samples to be a Tensor"): @@ -63,12 +108,12 @@ def test_bad_input_parametrized(self, method, tmp_path): else dict(format="mp3") ) - decoder = AudioEncoder(self.decode(NASA_AUDIO_MP3), sample_rate=10) + decoder = AudioEncoder(self.decode(NASA_AUDIO_MP3).data, sample_rate=10) with pytest.raises(RuntimeError, match="invalid sample rate=10"): getattr(decoder, method)(**valid_params) decoder = AudioEncoder( - self.decode(NASA_AUDIO_MP3), sample_rate=NASA_AUDIO_MP3.sample_rate + self.decode(NASA_AUDIO_MP3).data, sample_rate=NASA_AUDIO_MP3.sample_rate ) with pytest.raises(RuntimeError, match="bit_rate=-1 must be >= 0"): getattr(decoder, method)(**valid_params, bit_rate=-1) @@ -81,7 +126,7 @@ def test_bad_input_parametrized(self, method, tmp_path): getattr(decoder, method)(**valid_params) decoder = AudioEncoder( - self.decode(NASA_AUDIO_MP3), sample_rate=NASA_AUDIO_MP3.sample_rate + self.decode(NASA_AUDIO_MP3).data, sample_rate=NASA_AUDIO_MP3.sample_rate ) for num_channels in (0, 3): with pytest.raises( @@ -101,7 +146,7 @@ def test_round_trip(self, method, format, tmp_path): pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") asset = NASA_AUDIO_MP3 - source_samples = self.decode(asset) + source_samples = self.decode(asset).data encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) @@ -116,7 +161,7 @@ def test_round_trip(self, method, format, tmp_path): rtol, atol = (0, 1e-4) if format == "wav" else (None, None) torch.testing.assert_close( - self.decode(encoded_source), source_samples, rtol=rtol, atol=atol + self.decode(encoded_source).data, source_samples, rtol=rtol, atol=atol ) @pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI") @@ -144,7 +189,7 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa check=True, ) - encoder = AudioEncoder(self.decode(asset), sample_rate=asset.sample_rate) + encoder = AudioEncoder(self.decode(asset).data, sample_rate=asset.sample_rate) params = dict(bit_rate=bit_rate, num_channels=num_channels) if method == "to_file": encoded_by_us = tmp_path / f"output.{format}" @@ -162,12 +207,22 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa rtol, atol = 0, 1e-3 else: rtol, atol = None, None + samples_by_us = self.decode(encoded_by_us) + samples_by_ffmpeg = self.decode(encoded_by_ffmpeg) torch.testing.assert_close( - self.decode(encoded_by_ffmpeg), - self.decode(encoded_by_us), + samples_by_us.data, + samples_by_ffmpeg.data, rtol=rtol, atol=atol, ) + assert samples_by_us.pts_seconds == samples_by_ffmpeg.pts_seconds + assert samples_by_us.duration_seconds == samples_by_ffmpeg.duration_seconds + assert samples_by_us.sample_rate == samples_by_ffmpeg.sample_rate + + if method == "to_file": + validate_frames_properties(actual=encoded_by_us, expected=encoded_by_ffmpeg) + else: + assert method == "to_tensor", "wrong test parametrization!" @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) @@ -179,7 +234,7 @@ def test_to_tensor_against_to_file( if get_ffmpeg_major_version() == 4 and format == "wav": pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") - encoder = AudioEncoder(self.decode(asset), sample_rate=asset.sample_rate) + encoder = AudioEncoder(self.decode(asset).data, sample_rate=asset.sample_rate) params = dict(bit_rate=bit_rate, num_channels=num_channels) encoded_file = tmp_path / f"output.{format}" @@ -189,7 +244,7 @@ def test_to_tensor_against_to_file( ) torch.testing.assert_close( - self.decode(encoded_file), self.decode(encoded_tensor) + self.decode(encoded_file).data, self.decode(encoded_tensor).data ) def test_encode_to_tensor_long_output(self): @@ -205,7 +260,7 @@ def test_encode_to_tensor_long_output(self): INITIAL_TENSOR_SIZE = 10_000_000 assert encoded_tensor.numel() > INITIAL_TENSOR_SIZE - torch.testing.assert_close(self.decode(encoded_tensor), samples) + torch.testing.assert_close(self.decode(encoded_tensor).data, samples) def test_contiguity(self): # Ensure that 2 waveforms with the same values are encoded in the same @@ -262,4 +317,4 @@ def test_num_channels( if num_channels_output is None: num_channels_output = num_channels_input - assert self.decode(encoded_source).shape[0] == num_channels_output + assert self.decode(encoded_source).data.shape[0] == num_channels_output