diff --git a/realbook/callbacks/spectrogram_visualization.py b/realbook/callbacks/spectrogram_visualization.py index 405538a..d3037a1 100644 --- a/realbook/callbacks/spectrogram_visualization.py +++ b/realbook/callbacks/spectrogram_visualization.py @@ -63,6 +63,7 @@ class SpectrogramVisualizationCallback(tf.keras.callbacks.Callback): - transpose: whether to transpose the spectrogram before plotting it. - add_colorbar: whether to add a colorbar to the spectrogram plot. - Any remaining keyword arguments are passed through to librosa.display.specshow. + If `hop_length` is not set, it is inferred. """ def __init__( @@ -135,11 +136,13 @@ def on_train_begin(self, logs: Any = None) -> None: # Ignore the single channel dimension, if it exists. spectrograms = spectrograms[:, :, :, 0] - # We can infer the hop length, as we know the input audio length - # and sample rate used in the spectrogram - length_in_samples = data.shape[-1] - length_in_frames = spectrograms.shape[-2] - hop_length = int(tf.math.ceil(length_in_samples / length_in_frames)) + if "hop_length" not in self.specshow_arguments: + # We can infer the hop length, as we know the input audio length + # and sample rate used in the spectrogram + length_in_samples = data.shape[-1] + length_in_frames = spectrograms.shape[-2] if self.transpose else spectrograms.shape[-1] + hop_length = int(tf.math.ceil(length_in_samples / length_in_frames)) + self.specshow_arguments["hop_length"] = hop_length figs = [] for spectrogram in spectrograms: @@ -157,7 +160,6 @@ def on_train_begin(self, logs: Any = None) -> None: img = librosa.display.specshow( spectrogram, sr=self.sample_rate_hz, - hop_length=hop_length, ax=ax, **self.specshow_arguments, ) diff --git a/tests/callbacks/test_spectrogram_visualization.py b/tests/callbacks/test_spectrogram_visualization.py index 6e7ac9c..2ceab77 100644 --- a/tests/callbacks/test_spectrogram_visualization.py +++ b/tests/callbacks/test_spectrogram_visualization.py @@ -350,3 +350,36 @@ def test_enable_colorbar() -> None: model.fit(fake_data, callbacks=[cb]) assert True + + +@pytest.mark.skipif( + SpectrogramVisualizationCallback is None, + reason="SpectrogramVisualizationCallback import fails on this platform", +) +def test_set_hop_length() -> None: + fake_data = tf.data.Dataset.zip( + ( + tf.data.Dataset.from_tensor_slices([TEST_AUDIO]), + tf.data.Dataset.from_tensor_slices([1]), + ) + ).batch(1) + + model = tf.keras.Sequential( + [ + tf.keras.Input(shape=(None,)), + Spectrogram(), + tf.keras.layers.Dense(1), + ] + ) + model.compile(loss="binary_crossentropy") + + cb = SpectrogramVisualizationCallback( + FakeWriter(), + fake_data, + sample_rate=DEFAULT_SAMPLE_RATE, + raise_on_error=True, + hop_length=1024, + ) + + model.fit(fake_data, callbacks=[cb]) + assert True