Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions ASRData.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import re
from typing import List
from typing import List, Optional, Dict, Any, Union


class ASRDataSeg:
def __init__(self, text, start_time, end_time):
def __init__(self, text: str, start_time: Union[int, float], end_time: Union[int, float]):
self.text = text
self.start_time = start_time
self.end_time = end_time
Expand All @@ -13,7 +13,7 @@ def to_srt_ts(self) -> str:
return f"{self._ms_to_srt_time(self.start_time)} --> {self._ms_to_srt_time(self.end_time)}"

@staticmethod
def _ms_to_srt_time(ms) -> str:
def _ms_to_srt_time(ms: Union[int, float]) -> str:
"""Convert milliseconds to SRT time format (HH:MM:SS,mmm)"""
total_seconds, milliseconds = divmod(ms, 1000)
minutes, seconds = divmod(total_seconds, 60)
Expand All @@ -24,7 +24,7 @@ def to_lrc_ts(self) -> str:
"""Convert to LRC timestamp format"""
return f"[{self._ms_to_lrc_time(self.start_time)}]"

def _ms_to_lrc_time(self, ms) -> str:
def _ms_to_lrc_time(self, ms: Union[int, float]) -> str:
seconds = ms / 1000
minutes, seconds = divmod(seconds, 60)
return f"{int(minutes):02}:{seconds:.2f}"
Expand Down Expand Up @@ -53,7 +53,7 @@ def to_txt(self) -> str:
"""Convert to plain text subtitle format (without timestamps)"""
return "\n".join(seg.transcript for seg in self.segments)

def to_srt(self, save_path=None) -> str:
def to_srt(self, save_path: Optional[str] = None) -> str:
"""Convert to SRT subtitle format"""
srt_text = "\n".join(
f"{n}\n{seg.to_srt_ts()}\n{seg.transcript}\n"
Expand All @@ -73,7 +73,7 @@ def to_ass(self) -> str:
"""Convert to ASS subtitle format"""
raise NotImplementedError("ASS format conversion not implemented yet")

def to_json(self) -> dict:
def to_json(self) -> Dict[str, Any]:
result_json = {}
for i, segment in enumerate(self.segments, 1):
# 检查是否有换行符
Expand All @@ -90,18 +90,18 @@ def to_json(self) -> dict:
}
return result_json

def merge_segments(self, start_index: int, end_index: int, merged_text: str = None):
"""合并从 start_index 到 end_index 的段(包含)。"""
if start_index < 0 or end_index >= len(self.segments) or start_index > end_index:
raise IndexError("无效的段索引。")
merged_start_time = self.segments[start_index].start_time
merged_end_time = self.segments[end_index].end_time
if merged_text is None:
merged_text = ''.join(seg.text for seg in self.segments[start_index:end_index+1])
merged_seg = ASRDataSeg(merged_text, merged_start_time, merged_end_time)
# 替换 segments[start_index:end_index+1] 为 merged_seg
# self.segments[start_index:end_index+1] = [merged_seg]
return merged_seg
def merge_segments(self, start_index: int, end_index: int, merged_text: Optional[str] = None) -> ASRDataSeg:
"""合并从 start_index 到 end_index 的段(包含)。"""
if start_index < 0 or end_index >= len(self.segments) or start_index > end_index:
raise IndexError("无效的段索引。")
merged_start_time = self.segments[start_index].start_time
merged_end_time = self.segments[end_index].end_time
if merged_text is None:
merged_text = ''.join(seg.text for seg in self.segments[start_index:end_index+1])
merged_seg = ASRDataSeg(merged_text, merged_start_time, merged_end_time)
# 替换 segments[start_index:end_index+1] 为 merged_seg
# self.segments[start_index:end_index+1] = [merged_seg]
return merged_seg

def merge_with_next_segment(self, index: int) -> None:
"""合并指定索引的段与下一个段。"""
Expand All @@ -121,11 +121,11 @@ def merge_with_next_segment(self, index: int) -> None:
# 删除下一个段
del self.segments[index + 1]

def __str__(self):
def __str__(self) -> str:
return self.to_txt()


def from_srt(srt_str: str) -> 'ASRData':
def from_srt(srt_str: str) -> ASRData:
"""
从SRT格式的字符串创建ASRData实例。

Expand Down Expand Up @@ -165,7 +165,7 @@ def from_srt(srt_str: str) -> 'ASRData':

return ASRData(segments)

def from_vtt(vtt_str: str) -> 'ASRData':
def from_vtt(vtt_str: str) -> ASRData:
"""
从VTT格式的字符串创建ASRData实例, 去除不必要的样式和HTML信息。

Expand Down
173 changes: 114 additions & 59 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,36 @@
import os
import re
import time
import logging
from dataclasses import dataclass
from typing import List, Dict, Any, Optional
from ASRData import ASRData, from_srt, ASRDataSeg

import difflib
from typing import List
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed

from split_by_llm import split_by_llm
from split_by_llm import split_by_llm, batch_split_by_llm, get_cache_stats

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 性能优化常量
MAX_WORD_COUNT = 16 # 英文单词或中文字符的最大数量
SEGMENT_THRESHOLD = 1000 # 每个分段的最大字数
FIXED_NUM_THREADS = 4 # 固定的线程数量
SPLIT_RANGE = 30 # 在分割点前后寻找最大时间间隔的范围
CHUNK_SIZE = 100 # 流式处理的块大小

@dataclass
class ProcessingStats:
"""处理统计信息"""
total_segments: int = 0
merged_segments: int = 0
split_segments: int = 0
processing_time: float = 0.0
cache_hits: int = 0
cache_misses: int = 0
api_calls: int = 0


def is_pure_punctuation(s: str) -> bool:
Expand Down Expand Up @@ -227,72 +245,109 @@ def determine_num_segments(word_count: int, threshold: int = 1000) -> int:
return max(1, num_segments)


def main(srt_path: str, save_path: str, num_threads: int = FIXED_NUM_THREADS):
# 从SRT文件加载ASR数据
with open(srt_path, encoding="utf-8") as f:
asr_data = from_srt(f.read())

# 预处理ASR数据,去除标点并转换为小写
new_segments = []
for seg in asr_data.segments:
if not is_pure_punctuation(seg.text):
if re.match(r'^[a-zA-Z\']+$', seg.text.strip()):
seg.text = seg.text.lower() + " "
new_segments.append(seg)
asr_data.segments = new_segments

# 获取连接后的文本
txt = asr_data.to_txt().replace("\n", "")
total_word_count = count_words(txt)
print(f"[+] 合并后的文本长度: {total_word_count} 字")

# 确定分段数
num_segments = determine_num_segments(total_word_count, threshold=SEGMENT_THRESHOLD)
print(f"[+] 根据字数 {total_word_count},确定分段数: {num_segments}")

# 分割ASRData
asr_data_segments = split_asr_data(asr_data, num_segments)

# for i in asr_data_segments:
# print(len(i.segments))
# print(i.to_txt().split("\n"))

# 多线程执行 split_by_llm 获取句子列表
print("[+] 正在并行请求LLM将每个分段的文本拆分为句子...")
with ThreadPoolExecutor(max_workers=num_threads) as executor:
def process_segment(asr_data_part):
txt = asr_data_part.to_txt().replace("\n", "")
sentences = split_by_llm(txt, use_cache=True)
print(f"[+] 分段的句子提取完成,共 {len(sentences)} 句")
return sentences
all_sentences = list(executor.map(process_segment, asr_data_segments))
all_sentences = [item for sublist in all_sentences for item in sublist]
def main(srt_path: str, save_path: str, num_threads: int = FIXED_NUM_THREADS) -> ProcessingStats:
"""
优化的主处理函数,带性能统计
"""
start_time = time.time()
stats = ProcessingStats()

print(f"[+] 总共提取到 {len(all_sentences)} 句")

# 基于LLM已经分段的句子,对ASR分段进行合并
print("[+] 正在合并ASR分段基于句子列表...")
merged_asr_data = merge_segments_based_on_sentences(asr_data, all_sentences)

# 按开始时间排序合并后的分段(其实好像不需要)
merged_asr_data.segments.sort(key=lambda seg: seg.start_time)
final_asr_data = ASRData(merged_asr_data.segments)

# 保存到SRT文件
final_asr_data.to_srt(save_path=save_path)
print(f"[+] 已保存合并后的SRT文件: {save_path}")
try:
# 从SRT文件加载ASR数据
logger.info(f"[+] 开始处理文件: {srt_path}")
with open(srt_path, encoding="utf-8") as f:
asr_data = from_srt(f.read())

# 预处理ASR数据,去除标点并转换为小写
new_segments = []
for seg in asr_data.segments:
if not is_pure_punctuation(seg.text):
if re.match(r'^[a-zA-Z\']+$', seg.text.strip()):
seg.text = seg.text.lower() + " "
new_segments.append(seg)
asr_data.segments = new_segments
stats.total_segments = len(asr_data.segments)

# 获取连接后的文本
txt = asr_data.to_txt().replace("\n", "")
total_word_count = count_words(txt)
logger.info(f"[+] 合并后的文本长度: {total_word_count} 字")

# 确定分段数
num_segments = determine_num_segments(total_word_count, threshold=SEGMENT_THRESHOLD)
logger.info(f"[+] 根据字数 {total_word_count},确定分段数: {num_segments}")

# 分割ASRData
asr_data_segments = split_asr_data(asr_data, num_segments)

# 多线程执行 split_by_llm 获取句子列表
logger.info("[+] 正在并行请求LLM将每个分段的文本拆分为句子...")
with ThreadPoolExecutor(max_workers=num_threads) as executor:
def process_segment(asr_data_part):
txt = asr_data_part.to_txt().replace("\n", "")
sentences = split_by_llm(txt, use_cache=True)
logger.info(f"[+] 分段的句子提取完成,共 {len(sentences)} 句")
return sentences
all_sentences = list(executor.map(process_segment, asr_data_segments))
all_sentences = [item for sublist in all_sentences for item in sublist]

logger.info(f"[+] 总共提取到 {len(all_sentences)} 句")

# 基于LLM已经分段的句子,对ASR分段进行合并
logger.info("[+] 正在合并ASR分段基于句子列表...")
merged_asr_data = merge_segments_based_on_sentences(asr_data, all_sentences)

# 按开始时间排序合并后的分段
merged_asr_data.segments.sort(key=lambda seg: seg.start_time)
final_asr_data = ASRData(merged_asr_data.segments)

# 保存到SRT文件
final_asr_data.to_srt(save_path=save_path)
logger.info(f"[+] 已保存合并后的SRT文件: {save_path}")

# 更新统计信息
stats.merged_segments = len(final_asr_data.segments)
stats.processing_time = time.time() - start_time

# 获取缓存统计
cache_stats_data = get_cache_stats()
stats.cache_hits = cache_stats_data["hits"]
stats.cache_misses = cache_stats_data["misses"]
stats.api_calls = cache_stats_data["misses"] # 缓存未命中时才调用API

return stats

except Exception as e:
logger.error(f"[!] 处理文件时发生错误: {e}")
stats.processing_time = time.time() - start_time
return stats


if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser(description="优化ASR分段处理脚本")
parser.add_argument('--srt_path', type=str, required=True, help='输入的SRT文件路径')
parser.add_argument('--save_path', type=str, required=True, help='输入的SRT文件路径')
parser.add_argument('--save_path', type=str, required=True, help='输出的SRT文件路径')
parser.add_argument('--num_threads', type=int, default=FIXED_NUM_THREADS, help='线程数量')
args = parser.parse_args()

# args.srt_path = "test_data/java.srt"
# args.save_path = args.srt_path.replace(".srt", "_merged.srt")

main(srt_path=args.srt_path, save_path=args.save_path, num_threads=args.num_threads)
# 运行主处理函数并获取统计信息
stats = main(srt_path=args.srt_path, save_path=args.save_path, num_threads=args.num_threads)

# 打印处理统计
print("\n=== 处理统计 ===")
print(f"总分段数: {stats.total_segments}")
print(f"合并后分段数: {stats.merged_segments}")
print(f"处理时间: {stats.processing_time:.2f}秒")
print(f"缓存命中: {stats.cache_hits}")
print(f"缓存未命中: {stats.cache_misses}")
print(f"API调用次数: {stats.api_calls}")
if stats.processing_time > 0:
print(f"每秒处理分段数: {stats.total_segments / stats.processing_time:.2f}")
if stats.cache_hits + stats.cache_misses > 0:
hit_rate = stats.cache_hits / (stats.cache_hits + stats.cache_misses) * 100
print(f"缓存命中率: {hit_rate:.1f}%")
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
openai>=1.0.0
python-dotenv>=0.19.0
Loading