Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ dev = [
"jax[cpu]==0.6.2",
"jax_tpu_embedding==0.1.0.dev20250618",
"keras-hub",
"tokenizers",
"keras-rs>=0.2.1",
"nltk>=3.9.1",
"pytest>=8.4.1",
"Pillow>=9.0.0",
"protobuf>=5.29.5",
"rouge-score>=0.1.2",
"scikit-learn>=1.7.1",
"sentencepiece",
"tensorflow",
"torchmetrics>=1.8.1",
]
Expand Down
13 changes: 13 additions & 0 deletions src/metrax/audio_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from metrax import base


# TODO(jiwonshin): Move SNR class out of audio metrics since now it can be used
# for image data as well.
@flax.struct.dataclass
class SNR(base.Average):
r"""SNR (Signal-to-Noise Ratio) Metric for audio.
Expand Down Expand Up @@ -55,6 +57,13 @@ def _calculate_snr(
) -> jax.Array:
"""Computes SNR (Signal-to-Noise Ratio) values for a batch of audio signals.

If the input has more than 2 dimensions, it is assumed that the first
dimension is the batch dimension and all others are signal dimensions. The
input is then reshaped to (batch, signal_dimensions) to compute the SNR over
all signal dimensions for each example in the batch. E.g. image data of
shape (batch, H, W, C) is reshaped to (batch, H * W * C) to compute the SNR
for each image in the batch.

Args:
preds: The estimated or predicted audio signal. JAX Array.
target: The ground truth audio signal. JAX Array.
Expand All @@ -71,6 +80,10 @@ def _calculate_snr(
f' {target.shape}'
)

if preds.ndim > 2:
target = jnp.reshape(target, (target.shape[0], -1))
preds = jnp.reshape(preds, (preds.shape[0], -1))

target_processed, preds_processed = jax.lax.cond(
zero_mean,
lambda t, p: (
Expand Down
27 changes: 25 additions & 2 deletions src/metrax/audio_metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for metrax image metrics."""
"""Tests for metrax audio metrics."""

import os

os.environ['KERAS_BACKEND'] = 'jax'

from absl.testing import absltest
from absl.testing import absltest # pylint: disable=g-import-not-at-top
from absl.testing import parameterized
import jax.numpy as jnp
import metrax
Expand All @@ -43,6 +43,13 @@
AUDIO_PREDS_2D_NOISY = (
AUDIO_TARGET_2D + 0.5 * np.random.randn(*AUDIO_SHAPE_2D)
).astype(np.float32)
# 3D batch of signals
AUDIO_SHAPE_3D = (2, 3, 100)
AUDIO_TARGET_3D = (np.random.randn(*AUDIO_SHAPE_3D) * 5.0).astype(np.float32)
AUDIO_PREDS_3D_NOISY = (
AUDIO_TARGET_3D + 0.5 * np.random.randn(*AUDIO_SHAPE_3D)
).astype(np.float32)

# Target and preds are all zeros.
AUDIO_SHAPE_ZEROS = (100,)
AUDIO_TARGET_ZEROS = np.zeros(AUDIO_SHAPE_ZEROS).astype(np.float32)
Expand Down Expand Up @@ -88,6 +95,18 @@ class AudioMetricsTest(parameterized.TestCase):
AUDIO_PREDS_2D_NOISY,
True,
),
(
'snr_3d_noisy_false_zero_mean',
AUDIO_TARGET_3D,
AUDIO_PREDS_3D_NOISY,
False,
),
(
'snr_3d_noisy_true_zero_mean',
AUDIO_TARGET_3D,
AUDIO_PREDS_3D_NOISY,
True,
),
(
'snr_zeros_false_zero_mean',
AUDIO_TARGET_ZEROS,
Expand All @@ -110,6 +129,10 @@ def test_snr(
)
metrax_snr_result = metrax_snr_metric.compute()

if preds_np.ndim > 2:
preds_np = preds_np.reshape(preds_np.shape[0], -1)
target_np = target_np.reshape(target_np.shape[0], -1)

torchmetrics_snr_result = (
tm_snr.signal_noise_ratio(
preds=torch.from_numpy(preds_np),
Expand Down
Loading