From dca93c26fd29635c39cc1dd3ffce35c5fd16435d Mon Sep 17 00:00:00 2001 From: Maria Lyubimtseva Date: Mon, 13 Apr 2026 10:46:52 -0700 Subject: [PATCH] Add support for multi-dimensional signals in SNR computation The SNR function now reshapes inputs with more than two dimensions, assuming the first dimension is the batch dimension, to compute SNR across all signal dimensions. This allows handling inputs like images or multi-channel audio. PiperOrigin-RevId: 899078362 --- pyproject.toml | 2 ++ src/metrax/audio_metrics.py | 13 +++++++++++++ src/metrax/audio_metrics_test.py | 27 +++++++++++++++++++++++++-- 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0bf5e95..54d279f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dev = [ "jax[cpu]==0.6.2", "jax_tpu_embedding==0.1.0.dev20250618", "keras-hub", + "tokenizers", "keras-rs>=0.2.1", "nltk>=3.9.1", "pytest>=8.4.1", @@ -52,6 +53,7 @@ dev = [ "protobuf>=5.29.5", "rouge-score>=0.1.2", "scikit-learn>=1.7.1", + "sentencepiece", "tensorflow", "torchmetrics>=1.8.1", ] diff --git a/src/metrax/audio_metrics.py b/src/metrax/audio_metrics.py index 2ab88f8..c8769c4 100644 --- a/src/metrax/audio_metrics.py +++ b/src/metrax/audio_metrics.py @@ -20,6 +20,8 @@ from metrax import base +# TODO(jiwonshin): Move SNR class out of audio metrics since now it can be used +# for image data as well. @flax.struct.dataclass class SNR(base.Average): r"""SNR (Signal-to-Noise Ratio) Metric for audio. @@ -55,6 +57,13 @@ def _calculate_snr( ) -> jax.Array: """Computes SNR (Signal-to-Noise Ratio) values for a batch of audio signals. + If the input has more than 2 dimensions, it is assumed that the first + dimension is the batch dimension and all others are signal dimensions. The + input is then reshaped to (batch, signal_dimensions) to compute the SNR over + all signal dimensions for each example in the batch. E.g. image data of + shape (batch, H, W, C) is reshaped to (batch, H * W * C) to compute the SNR + for each image in the batch. + Args: preds: The estimated or predicted audio signal. JAX Array. target: The ground truth audio signal. JAX Array. @@ -71,6 +80,10 @@ def _calculate_snr( f' {target.shape}' ) + if preds.ndim > 2: + target = jnp.reshape(target, (target.shape[0], -1)) + preds = jnp.reshape(preds, (preds.shape[0], -1)) + target_processed, preds_processed = jax.lax.cond( zero_mean, lambda t, p: ( diff --git a/src/metrax/audio_metrics_test.py b/src/metrax/audio_metrics_test.py index 2356840..f965ddc 100644 --- a/src/metrax/audio_metrics_test.py +++ b/src/metrax/audio_metrics_test.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for metrax image metrics.""" +"""Tests for metrax audio metrics.""" import os os.environ['KERAS_BACKEND'] = 'jax' -from absl.testing import absltest +from absl.testing import absltest # pylint: disable=g-import-not-at-top from absl.testing import parameterized import jax.numpy as jnp import metrax @@ -43,6 +43,13 @@ AUDIO_PREDS_2D_NOISY = ( AUDIO_TARGET_2D + 0.5 * np.random.randn(*AUDIO_SHAPE_2D) ).astype(np.float32) +# 3D batch of signals +AUDIO_SHAPE_3D = (2, 3, 100) +AUDIO_TARGET_3D = (np.random.randn(*AUDIO_SHAPE_3D) * 5.0).astype(np.float32) +AUDIO_PREDS_3D_NOISY = ( + AUDIO_TARGET_3D + 0.5 * np.random.randn(*AUDIO_SHAPE_3D) +).astype(np.float32) + # Target and preds are all zeros. AUDIO_SHAPE_ZEROS = (100,) AUDIO_TARGET_ZEROS = np.zeros(AUDIO_SHAPE_ZEROS).astype(np.float32) @@ -88,6 +95,18 @@ class AudioMetricsTest(parameterized.TestCase): AUDIO_PREDS_2D_NOISY, True, ), + ( + 'snr_3d_noisy_false_zero_mean', + AUDIO_TARGET_3D, + AUDIO_PREDS_3D_NOISY, + False, + ), + ( + 'snr_3d_noisy_true_zero_mean', + AUDIO_TARGET_3D, + AUDIO_PREDS_3D_NOISY, + True, + ), ( 'snr_zeros_false_zero_mean', AUDIO_TARGET_ZEROS, @@ -110,6 +129,10 @@ def test_snr( ) metrax_snr_result = metrax_snr_metric.compute() + if preds_np.ndim > 2: + preds_np = preds_np.reshape(preds_np.shape[0], -1) + target_np = target_np.reshape(target_np.shape[0], -1) + torchmetrics_snr_result = ( tm_snr.signal_noise_ratio( preds=torch.from_numpy(preds_np),