diff --git a/pyproject.toml b/pyproject.toml index 0bf5e95..54d279f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dev = [ "jax[cpu]==0.6.2", "jax_tpu_embedding==0.1.0.dev20250618", "keras-hub", + "tokenizers", "keras-rs>=0.2.1", "nltk>=3.9.1", "pytest>=8.4.1", @@ -52,6 +53,7 @@ dev = [ "protobuf>=5.29.5", "rouge-score>=0.1.2", "scikit-learn>=1.7.1", + "sentencepiece", "tensorflow", "torchmetrics>=1.8.1", ] diff --git a/src/metrax/audio_metrics.py b/src/metrax/audio_metrics.py index 2ab88f8..c8769c4 100644 --- a/src/metrax/audio_metrics.py +++ b/src/metrax/audio_metrics.py @@ -20,6 +20,8 @@ from metrax import base +# TODO(jiwonshin): Move SNR class out of audio metrics since now it can be used +# for image data as well. @flax.struct.dataclass class SNR(base.Average): r"""SNR (Signal-to-Noise Ratio) Metric for audio. @@ -55,6 +57,13 @@ def _calculate_snr( ) -> jax.Array: """Computes SNR (Signal-to-Noise Ratio) values for a batch of audio signals. + If the input has more than 2 dimensions, it is assumed that the first + dimension is the batch dimension and all others are signal dimensions. The + input is then reshaped to (batch, signal_dimensions) to compute the SNR over + all signal dimensions for each example in the batch. E.g. image data of + shape (batch, H, W, C) is reshaped to (batch, H * W * C) to compute the SNR + for each image in the batch. + Args: preds: The estimated or predicted audio signal. JAX Array. target: The ground truth audio signal. JAX Array. @@ -71,6 +80,10 @@ def _calculate_snr( f' {target.shape}' ) + if preds.ndim > 2: + target = jnp.reshape(target, (target.shape[0], -1)) + preds = jnp.reshape(preds, (preds.shape[0], -1)) + target_processed, preds_processed = jax.lax.cond( zero_mean, lambda t, p: ( diff --git a/src/metrax/audio_metrics_test.py b/src/metrax/audio_metrics_test.py index 2356840..f965ddc 100644 --- a/src/metrax/audio_metrics_test.py +++ b/src/metrax/audio_metrics_test.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for metrax image metrics.""" +"""Tests for metrax audio metrics.""" import os os.environ['KERAS_BACKEND'] = 'jax' -from absl.testing import absltest +from absl.testing import absltest # pylint: disable=g-import-not-at-top from absl.testing import parameterized import jax.numpy as jnp import metrax @@ -43,6 +43,13 @@ AUDIO_PREDS_2D_NOISY = ( AUDIO_TARGET_2D + 0.5 * np.random.randn(*AUDIO_SHAPE_2D) ).astype(np.float32) +# 3D batch of signals +AUDIO_SHAPE_3D = (2, 3, 100) +AUDIO_TARGET_3D = (np.random.randn(*AUDIO_SHAPE_3D) * 5.0).astype(np.float32) +AUDIO_PREDS_3D_NOISY = ( + AUDIO_TARGET_3D + 0.5 * np.random.randn(*AUDIO_SHAPE_3D) +).astype(np.float32) + # Target and preds are all zeros. AUDIO_SHAPE_ZEROS = (100,) AUDIO_TARGET_ZEROS = np.zeros(AUDIO_SHAPE_ZEROS).astype(np.float32) @@ -88,6 +95,18 @@ class AudioMetricsTest(parameterized.TestCase): AUDIO_PREDS_2D_NOISY, True, ), + ( + 'snr_3d_noisy_false_zero_mean', + AUDIO_TARGET_3D, + AUDIO_PREDS_3D_NOISY, + False, + ), + ( + 'snr_3d_noisy_true_zero_mean', + AUDIO_TARGET_3D, + AUDIO_PREDS_3D_NOISY, + True, + ), ( 'snr_zeros_false_zero_mean', AUDIO_TARGET_ZEROS, @@ -110,6 +129,10 @@ def test_snr( ) metrax_snr_result = metrax_snr_metric.compute() + if preds_np.ndim > 2: + preds_np = preds_np.reshape(preds_np.shape[0], -1) + target_np = target_np.reshape(target_np.shape[0], -1) + torchmetrics_snr_result = ( tm_snr.signal_noise_ratio( preds=torch.from_numpy(preds_np),