Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
urls = { repository = "https://github.com/m-bain/whisperx" }
authors = [{ name = "Max Bain" }]
name = "whisperx"
version = "3.4.2"
version = "3.4.3"
description = "Time-Accurate Automatic Speech Recognition using Whisper."
readme = "README.md"
requires-python = ">=3.9, <3.13"
Expand All @@ -15,7 +15,7 @@ dependencies = [
"numpy>=2.0.2",
"onnxruntime>=1.19",
"pandas>=2.2.3",
"pyannote-audio>=3.3.2",
"pyannote-audio>=3.3.2,<4.0.0",
"torch>=2.5.1",
"torchaudio>=2.5.1",
"transformers>=4.48.0",
Expand Down
6 changes: 3 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 13 additions & 13 deletions whisperx/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters

PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof']
PUNKT_ABBREVIATIONS = ['dr', 'vs', 'mr', 'mrs', 'prof', 'jr', 'sr', 'ph.d']

LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]

Expand Down Expand Up @@ -124,14 +124,14 @@ def align(
"""
Align phoneme recognition predictions to known transcription.
"""

if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)

MAX_DURATION = audio.shape[1] / SAMPLE_RATE

model_dictionary = align_model_metadata["dictionary"]
Expand All @@ -148,7 +148,7 @@ def align(
base_progress = ((sdx + 1) / total_segments) * 100
percent_complete = (50 + base_progress / 2) if combined_progress else base_progress
print(f"Progress: {percent_complete:.2f}%...")

num_leading = len(segment["text"]) - len(segment["text"].lstrip())
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
text = segment["text"]
Expand All @@ -165,7 +165,7 @@ def align(
# wav2vec2 models use "|" character to represent spaces
if model_lang not in LANGUAGES_WITHOUT_SPACES:
char_ = char_.replace(" ", "|")

# ignore whitespace at beginning and end of transcript
if cdx < num_leading:
pass
Expand All @@ -187,7 +187,7 @@ def align(
# index for placeholder
clean_wdx.append(wdx)


punkt_param = PunktParameters()
punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS)
sentence_splitter = PunktSentenceTokenizer(punkt_param)
Expand All @@ -199,12 +199,12 @@ def align(
"clean_wdx": clean_wdx,
"sentence_spans": sentence_spans
}

aligned_segments: List[SingleAlignedSegment] = []

# 2. Get prediction matrix from alignment model & align
for sdx, segment in enumerate(transcript):

t1 = segment["start"]
t2 = segment["end"]
text = segment["text"]
Expand Down Expand Up @@ -247,7 +247,7 @@ def align(
)
else:
lengths = None

with torch.inference_mode():
if model_type == "torchaudio":
emissions, _ = model(waveform_segment.to(device), lengths=lengths)
Expand Down Expand Up @@ -304,7 +304,7 @@ def align(
word_idx += 1
elif cdx == len(text) - 1 or text[cdx+1] == " ":
word_idx += 1

char_segments_arr = pd.DataFrame(char_segments_arr)

aligned_subsegments = []
Expand Down Expand Up @@ -333,7 +333,7 @@ def align(
word_end = word_chars["end"].max()
word_score = round(word_chars["score"].mean(), 3)

# -1 indicates unalignable
# -1 indicates unalignable
word_segment = {"word": word_text}

if not np.isnan(word_start):
Expand All @@ -344,7 +344,7 @@ def align(
word_segment["score"] = word_score

sentence_words.append(word_segment)

aligned_subsegments.append({
"text": sentence_text,
"start": sentence_start,
Expand Down