Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions realbook/callbacks/spectrogram_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class SpectrogramVisualizationCallback(tf.keras.callbacks.Callback):
- transpose: whether to transpose the spectrogram before plotting it.
- add_colorbar: whether to add a colorbar to the spectrogram plot.
- Any remaining keyword arguments are passed through to librosa.display.specshow.
If `hop_length` is not set, it is inferred.
"""

def __init__(
Expand Down Expand Up @@ -135,11 +136,13 @@ def on_train_begin(self, logs: Any = None) -> None:
# Ignore the single channel dimension, if it exists.
spectrograms = spectrograms[:, :, :, 0]

# We can infer the hop length, as we know the input audio length
# and sample rate used in the spectrogram
length_in_samples = data.shape[-1]
length_in_frames = spectrograms.shape[-2]
hop_length = int(tf.math.ceil(length_in_samples / length_in_frames))
if "hop_length" not in self.specshow_arguments:
# We can infer the hop length, as we know the input audio length
# and sample rate used in the spectrogram
length_in_samples = data.shape[-1]
length_in_frames = spectrograms.shape[-2] if self.transpose else spectrograms.shape[-1]
hop_length = int(tf.math.ceil(length_in_samples / length_in_frames))
self.specshow_arguments["hop_length"] = hop_length

figs = []
for spectrogram in spectrograms:
Expand All @@ -157,7 +160,6 @@ def on_train_begin(self, logs: Any = None) -> None:
img = librosa.display.specshow(
spectrogram,
sr=self.sample_rate_hz,
hop_length=hop_length,
ax=ax,
**self.specshow_arguments,
)
Expand Down
33 changes: 33 additions & 0 deletions tests/callbacks/test_spectrogram_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,36 @@ def test_enable_colorbar() -> None:

model.fit(fake_data, callbacks=[cb])
assert True


@pytest.mark.skipif(
SpectrogramVisualizationCallback is None,
reason="SpectrogramVisualizationCallback import fails on this platform",
)
def test_set_hop_length() -> None:
fake_data = tf.data.Dataset.zip(
(
tf.data.Dataset.from_tensor_slices([TEST_AUDIO]),
tf.data.Dataset.from_tensor_slices([1]),
)
).batch(1)

model = tf.keras.Sequential(
[
tf.keras.Input(shape=(None,)),
Spectrogram(),
tf.keras.layers.Dense(1),
]
)
model.compile(loss="binary_crossentropy")

cb = SpectrogramVisualizationCallback(
FakeWriter(),
fake_data,
sample_rate=DEFAULT_SAMPLE_RATE,
raise_on_error=True,
hop_length=1024,
)

model.fit(fake_data, callbacks=[cb])
assert True
Loading