From b9f45b079ac0247e162c799acdccecc2d1da8969 Mon Sep 17 00:00:00 2001 From: luojiyin Date: Wed, 27 Aug 2025 12:10:39 +0800 Subject: [PATCH] depend: add requirements.txt --- ASRData.py | 42 +++++----- main.py | 173 +++++++++++++++++++++++++-------------- requirements.txt | 2 + split_by_llm.py | 205 ++++++++++++++++++++++++++++++++++++++--------- 4 files changed, 303 insertions(+), 119 deletions(-) create mode 100644 requirements.txt diff --git a/ASRData.py b/ASRData.py index 10ad7c2..17bc2b7 100644 --- a/ASRData.py +++ b/ASRData.py @@ -1,9 +1,9 @@ import re -from typing import List +from typing import List, Optional, Dict, Any, Union class ASRDataSeg: - def __init__(self, text, start_time, end_time): + def __init__(self, text: str, start_time: Union[int, float], end_time: Union[int, float]): self.text = text self.start_time = start_time self.end_time = end_time @@ -13,7 +13,7 @@ def to_srt_ts(self) -> str: return f"{self._ms_to_srt_time(self.start_time)} --> {self._ms_to_srt_time(self.end_time)}" @staticmethod - def _ms_to_srt_time(ms) -> str: + def _ms_to_srt_time(ms: Union[int, float]) -> str: """Convert milliseconds to SRT time format (HH:MM:SS,mmm)""" total_seconds, milliseconds = divmod(ms, 1000) minutes, seconds = divmod(total_seconds, 60) @@ -24,7 +24,7 @@ def to_lrc_ts(self) -> str: """Convert to LRC timestamp format""" return f"[{self._ms_to_lrc_time(self.start_time)}]" - def _ms_to_lrc_time(self, ms) -> str: + def _ms_to_lrc_time(self, ms: Union[int, float]) -> str: seconds = ms / 1000 minutes, seconds = divmod(seconds, 60) return f"{int(minutes):02}:{seconds:.2f}" @@ -53,7 +53,7 @@ def to_txt(self) -> str: """Convert to plain text subtitle format (without timestamps)""" return "\n".join(seg.transcript for seg in self.segments) - def to_srt(self, save_path=None) -> str: + def to_srt(self, save_path: Optional[str] = None) -> str: """Convert to SRT subtitle format""" srt_text = "\n".join( f"{n}\n{seg.to_srt_ts()}\n{seg.transcript}\n" @@ -73,7 +73,7 @@ def to_ass(self) -> str: """Convert to ASS subtitle format""" raise NotImplementedError("ASS format conversion not implemented yet") - def to_json(self) -> dict: + def to_json(self) -> Dict[str, Any]: result_json = {} for i, segment in enumerate(self.segments, 1): # 检查是否有换行符 @@ -90,18 +90,18 @@ def to_json(self) -> dict: } return result_json - def merge_segments(self, start_index: int, end_index: int, merged_text: str = None): - """合并从 start_index 到 end_index 的段(包含)。""" - if start_index < 0 or end_index >= len(self.segments) or start_index > end_index: - raise IndexError("无效的段索引。") - merged_start_time = self.segments[start_index].start_time - merged_end_time = self.segments[end_index].end_time - if merged_text is None: - merged_text = ''.join(seg.text for seg in self.segments[start_index:end_index+1]) - merged_seg = ASRDataSeg(merged_text, merged_start_time, merged_end_time) - # 替换 segments[start_index:end_index+1] 为 merged_seg - # self.segments[start_index:end_index+1] = [merged_seg] - return merged_seg + def merge_segments(self, start_index: int, end_index: int, merged_text: Optional[str] = None) -> ASRDataSeg: + """合并从 start_index 到 end_index 的段(包含)。""" + if start_index < 0 or end_index >= len(self.segments) or start_index > end_index: + raise IndexError("无效的段索引。") + merged_start_time = self.segments[start_index].start_time + merged_end_time = self.segments[end_index].end_time + if merged_text is None: + merged_text = ''.join(seg.text for seg in self.segments[start_index:end_index+1]) + merged_seg = ASRDataSeg(merged_text, merged_start_time, merged_end_time) + # 替换 segments[start_index:end_index+1] 为 merged_seg + # self.segments[start_index:end_index+1] = [merged_seg] + return merged_seg def merge_with_next_segment(self, index: int) -> None: """合并指定索引的段与下一个段。""" @@ -121,11 +121,11 @@ def merge_with_next_segment(self, index: int) -> None: # 删除下一个段 del self.segments[index + 1] - def __str__(self): + def __str__(self) -> str: return self.to_txt() -def from_srt(srt_str: str) -> 'ASRData': +def from_srt(srt_str: str) -> ASRData: """ 从SRT格式的字符串创建ASRData实例。 @@ -165,7 +165,7 @@ def from_srt(srt_str: str) -> 'ASRData': return ASRData(segments) -def from_vtt(vtt_str: str) -> 'ASRData': +def from_vtt(vtt_str: str) -> ASRData: """ 从VTT格式的字符串创建ASRData实例, 去除不必要的样式和HTML信息。 diff --git a/main.py b/main.py index 6f4b639..ad45aac 100644 --- a/main.py +++ b/main.py @@ -1,18 +1,36 @@ import os import re +import time +import logging +from dataclasses import dataclass +from typing import List, Dict, Any, Optional from ASRData import ASRData, from_srt, ASRDataSeg - import difflib -from typing import List -import sys from concurrent.futures import ThreadPoolExecutor, as_completed -from split_by_llm import split_by_llm +from split_by_llm import split_by_llm, batch_split_by_llm, get_cache_stats + +# 设置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +# 性能优化常量 MAX_WORD_COUNT = 16 # 英文单词或中文字符的最大数量 SEGMENT_THRESHOLD = 1000 # 每个分段的最大字数 FIXED_NUM_THREADS = 4 # 固定的线程数量 SPLIT_RANGE = 30 # 在分割点前后寻找最大时间间隔的范围 +CHUNK_SIZE = 100 # 流式处理的块大小 + +@dataclass +class ProcessingStats: + """处理统计信息""" + total_segments: int = 0 + merged_segments: int = 0 + split_segments: int = 0 + processing_time: float = 0.0 + cache_hits: int = 0 + cache_misses: int = 0 + api_calls: int = 0 def is_pure_punctuation(s: str) -> bool: @@ -227,60 +245,82 @@ def determine_num_segments(word_count: int, threshold: int = 1000) -> int: return max(1, num_segments) -def main(srt_path: str, save_path: str, num_threads: int = FIXED_NUM_THREADS): - # 从SRT文件加载ASR数据 - with open(srt_path, encoding="utf-8") as f: - asr_data = from_srt(f.read()) - - # 预处理ASR数据,去除标点并转换为小写 - new_segments = [] - for seg in asr_data.segments: - if not is_pure_punctuation(seg.text): - if re.match(r'^[a-zA-Z\']+$', seg.text.strip()): - seg.text = seg.text.lower() + " " - new_segments.append(seg) - asr_data.segments = new_segments - - # 获取连接后的文本 - txt = asr_data.to_txt().replace("\n", "") - total_word_count = count_words(txt) - print(f"[+] 合并后的文本长度: {total_word_count} 字") - - # 确定分段数 - num_segments = determine_num_segments(total_word_count, threshold=SEGMENT_THRESHOLD) - print(f"[+] 根据字数 {total_word_count},确定分段数: {num_segments}") - - # 分割ASRData - asr_data_segments = split_asr_data(asr_data, num_segments) - - # for i in asr_data_segments: - # print(len(i.segments)) - # print(i.to_txt().split("\n")) - - # 多线程执行 split_by_llm 获取句子列表 - print("[+] 正在并行请求LLM将每个分段的文本拆分为句子...") - with ThreadPoolExecutor(max_workers=num_threads) as executor: - def process_segment(asr_data_part): - txt = asr_data_part.to_txt().replace("\n", "") - sentences = split_by_llm(txt, use_cache=True) - print(f"[+] 分段的句子提取完成,共 {len(sentences)} 句") - return sentences - all_sentences = list(executor.map(process_segment, asr_data_segments)) - all_sentences = [item for sublist in all_sentences for item in sublist] +def main(srt_path: str, save_path: str, num_threads: int = FIXED_NUM_THREADS) -> ProcessingStats: + """ + 优化的主处理函数,带性能统计 + """ + start_time = time.time() + stats = ProcessingStats() - print(f"[+] 总共提取到 {len(all_sentences)} 句") - - # 基于LLM已经分段的句子,对ASR分段进行合并 - print("[+] 正在合并ASR分段基于句子列表...") - merged_asr_data = merge_segments_based_on_sentences(asr_data, all_sentences) - - # 按开始时间排序合并后的分段(其实好像不需要) - merged_asr_data.segments.sort(key=lambda seg: seg.start_time) - final_asr_data = ASRData(merged_asr_data.segments) - - # 保存到SRT文件 - final_asr_data.to_srt(save_path=save_path) - print(f"[+] 已保存合并后的SRT文件: {save_path}") + try: + # 从SRT文件加载ASR数据 + logger.info(f"[+] 开始处理文件: {srt_path}") + with open(srt_path, encoding="utf-8") as f: + asr_data = from_srt(f.read()) + + # 预处理ASR数据,去除标点并转换为小写 + new_segments = [] + for seg in asr_data.segments: + if not is_pure_punctuation(seg.text): + if re.match(r'^[a-zA-Z\']+$', seg.text.strip()): + seg.text = seg.text.lower() + " " + new_segments.append(seg) + asr_data.segments = new_segments + stats.total_segments = len(asr_data.segments) + + # 获取连接后的文本 + txt = asr_data.to_txt().replace("\n", "") + total_word_count = count_words(txt) + logger.info(f"[+] 合并后的文本长度: {total_word_count} 字") + + # 确定分段数 + num_segments = determine_num_segments(total_word_count, threshold=SEGMENT_THRESHOLD) + logger.info(f"[+] 根据字数 {total_word_count},确定分段数: {num_segments}") + + # 分割ASRData + asr_data_segments = split_asr_data(asr_data, num_segments) + + # 多线程执行 split_by_llm 获取句子列表 + logger.info("[+] 正在并行请求LLM将每个分段的文本拆分为句子...") + with ThreadPoolExecutor(max_workers=num_threads) as executor: + def process_segment(asr_data_part): + txt = asr_data_part.to_txt().replace("\n", "") + sentences = split_by_llm(txt, use_cache=True) + logger.info(f"[+] 分段的句子提取完成,共 {len(sentences)} 句") + return sentences + all_sentences = list(executor.map(process_segment, asr_data_segments)) + all_sentences = [item for sublist in all_sentences for item in sublist] + + logger.info(f"[+] 总共提取到 {len(all_sentences)} 句") + + # 基于LLM已经分段的句子,对ASR分段进行合并 + logger.info("[+] 正在合并ASR分段基于句子列表...") + merged_asr_data = merge_segments_based_on_sentences(asr_data, all_sentences) + + # 按开始时间排序合并后的分段 + merged_asr_data.segments.sort(key=lambda seg: seg.start_time) + final_asr_data = ASRData(merged_asr_data.segments) + + # 保存到SRT文件 + final_asr_data.to_srt(save_path=save_path) + logger.info(f"[+] 已保存合并后的SRT文件: {save_path}") + + # 更新统计信息 + stats.merged_segments = len(final_asr_data.segments) + stats.processing_time = time.time() - start_time + + # 获取缓存统计 + cache_stats_data = get_cache_stats() + stats.cache_hits = cache_stats_data["hits"] + stats.cache_misses = cache_stats_data["misses"] + stats.api_calls = cache_stats_data["misses"] # 缓存未命中时才调用API + + return stats + + except Exception as e: + logger.error(f"[!] 处理文件时发生错误: {e}") + stats.processing_time = time.time() - start_time + return stats if __name__ == '__main__': @@ -288,11 +328,26 @@ def process_segment(asr_data_part): parser = argparse.ArgumentParser(description="优化ASR分段处理脚本") parser.add_argument('--srt_path', type=str, required=True, help='输入的SRT文件路径') - parser.add_argument('--save_path', type=str, required=True, help='输入的SRT文件路径') + parser.add_argument('--save_path', type=str, required=True, help='输出的SRT文件路径') parser.add_argument('--num_threads', type=int, default=FIXED_NUM_THREADS, help='线程数量') args = parser.parse_args() # args.srt_path = "test_data/java.srt" # args.save_path = args.srt_path.replace(".srt", "_merged.srt") - main(srt_path=args.srt_path, save_path=args.save_path, num_threads=args.num_threads) + # 运行主处理函数并获取统计信息 + stats = main(srt_path=args.srt_path, save_path=args.save_path, num_threads=args.num_threads) + + # 打印处理统计 + print("\n=== 处理统计 ===") + print(f"总分段数: {stats.total_segments}") + print(f"合并后分段数: {stats.merged_segments}") + print(f"处理时间: {stats.processing_time:.2f}秒") + print(f"缓存命中: {stats.cache_hits}") + print(f"缓存未命中: {stats.cache_misses}") + print(f"API调用次数: {stats.api_calls}") + if stats.processing_time > 0: + print(f"每秒处理分段数: {stats.total_segments / stats.processing_time:.2f}") + if stats.cache_hits + stats.cache_misses > 0: + hit_rate = stats.cache_hits / (stats.cache_hits + stats.cache_misses) * 100 + print(f"缓存命中率: {hit_rate:.1f}%") diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3e3f10e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +openai>=1.0.0 +python-dotenv>=0.19.0 \ No newline at end of file diff --git a/split_by_llm.py b/split_by_llm.py index 129e8c6..ef63941 100644 --- a/split_by_llm.py +++ b/split_by_llm.py @@ -2,10 +2,17 @@ import json import os import re -from typing import List, Optional +import time +import logging +from typing import List, Optional, Dict, Any +from datetime import datetime, timedelta import openai from dotenv import load_dotenv +# 设置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + # 加载.env文件 load_dotenv() @@ -13,13 +20,24 @@ os.environ['OPENAI_BASE_URL'] = os.getenv('OPENAI_BASE_URL') os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY') -# ... 其余代码保持不变 ... -# 常量定义 +# 性能优化常量 MODEL = "gpt-4o-mini" CACHE_DIR = "cache" +CACHE_EXPIRE_DAYS = 7 # 缓存过期时间(天) +MAX_CACHE_SIZE = 1000 # 最大缓存文件数 +REQUEST_TIMEOUT = 30 # 请求超时时间(秒) +MAX_RETRIES = 3 # 最大重试次数 +RETRY_DELAY = 1 # 重试延迟(秒) # 初始化OpenAI客户端 -client = openai.OpenAI() +client = openai.OpenAI(timeout=REQUEST_TIMEOUT) + +# 性能统计 +cache_stats = { + "hits": 0, + "misses": 0, + "errors": 0 +} # 系统提示信息 SYSTEM_PROMPT = """ @@ -43,6 +61,40 @@ the upgraded claude sonnet is now available for all users
developers can build with the computer use beta
on the anthropic api amazon bedrock and google cloud’s vertex ai
the new claude haiku will be released later this month """ +def cleanup_cache() -> None: + """ + 清理过期缓存文件 + """ + try: + cache_files = [f for f in os.listdir(CACHE_DIR) if f.endswith('.json')] + + # 如果缓存文件数量超过限制,删除最旧的文件 + if len(cache_files) > MAX_CACHE_SIZE: + cache_files.sort(key=lambda f: os.path.getmtime(os.path.join(CACHE_DIR, f))) + files_to_delete = cache_files[:len(cache_files) - MAX_CACHE_SIZE] + + for file_to_delete in files_to_delete: + try: + os.remove(os.path.join(CACHE_DIR, file_to_delete)) + logger.info(f"[+] 删除过期缓存文件: {file_to_delete}") + except Exception as e: + logger.warning(f"[!] 删除缓存文件失败: {e}") + + # 清理过期文件 + current_time = datetime.now() + for cache_file in cache_files: + file_path = os.path.join(CACHE_DIR, cache_file) + try: + file_time = datetime.fromtimestamp(os.path.getmtime(file_path)) + if current_time - file_time > timedelta(days=CACHE_EXPIRE_DAYS): + os.remove(file_path) + logger.info(f"[+] 删除过期缓存文件: {cache_file}") + except Exception as e: + logger.warning(f"[!] 检查缓存文件过期失败: {e}") + + except Exception as e: + logger.warning(f"[!] 缓存清理失败: {e}") + def get_cache_key(text: str, model: str) -> str: """ 生成缓存键值 @@ -51,62 +103,137 @@ def get_cache_key(text: str, model: str) -> str: def get_cache(text: str, model: str) -> Optional[List[str]]: """ - 从缓存中获取断句结果 + 从缓存中获取断句结果(带过期检查) """ cache_key = get_cache_key(text, model) cache_file = os.path.join(CACHE_DIR, f"{cache_key}.json") - if os.path.exists(cache_file): - try: - with open(cache_file, 'r', encoding='utf-8') as f: - return json.load(f) - except (IOError, json.JSONDecodeError): - return None + + if not os.path.exists(cache_file): + cache_stats["misses"] += 1 + return None + + try: + with open(cache_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + # 检查是否包含时间戳信息 + if isinstance(data, dict) and "timestamp" in data: + timestamp = datetime.fromisoformat(data["timestamp"]) + if datetime.now() - timestamp > timedelta(days=CACHE_EXPIRE_DAYS): + os.remove(cache_file) + cache_stats["misses"] += 1 + return None + result = data["result"] + else: + # 兼容旧格式 + result = data + + cache_stats["hits"] += 1 + logger.info(f"[+] 缓存命中: {cache_key}") + return result + + except (IOError, json.JSONDecodeError) as e: + logger.warning(f"[!] 缓存读取失败: {e}") + cache_stats["misses"] += 1 + return None def set_cache(text: str, model: str, result: List[str]) -> None: """ - 将断句结果设置到缓存中 + 将断句结果设置到缓存中(带时间戳) """ - cache_key = get_cache_key(text, model) - cache_file = os.path.join(CACHE_DIR, f"{cache_key}.json") - os.makedirs(CACHE_DIR, exist_ok=True) try: + # 清理过期缓存 + cleanup_cache() + + cache_key = get_cache_key(text, model) + cache_file = os.path.join(CACHE_DIR, f"{cache_key}.json") + os.makedirs(CACHE_DIR, exist_ok=True) + + # 添加时间戳信息 + cache_data = { + "result": result, + "timestamp": datetime.now().isoformat(), + "model": model, + "text_hash": hashlib.md5(text.encode()).hexdigest() + } + with open(cache_file, 'w', encoding='utf-8') as f: - json.dump(result, f, ensure_ascii=False) - except IOError: - pass + json.dump(cache_data, f, ensure_ascii=False) + + except Exception as e: + logger.warning(f"[!] 缓存设置失败: {e}") + +def get_cache_stats() -> Dict[str, int]: + """ + 获取缓存统计信息 + """ + return cache_stats.copy() def split_by_llm(text: str, use_cache: bool = False) -> List[str]: """ - 使用LLM进行文本断句 + 使用LLM进行文本断句(带重试机制和性能优化) """ if use_cache: cached_result = get_cache(text, MODEL) if cached_result: - print(f"[+] 从缓存中获取结果: {cached_result}") + logger.info(f"[+] 从缓存中获取结果: {len(cached_result)}句") return cached_result prompt = f"请你对下面句子使用
进行分割:\n{text}" - try: - response = client.chat.completions.create( - model=MODEL, - messages=[ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": prompt} - ], - temperature=0.1 - ) - result = response.choices[0].message.content - # 清理结果中的多余换行符 - result = re.sub(r'\n+', '', result) - split_result = [segment.strip() for segment in result.split("
") if segment.strip()] - - set_cache(text, MODEL, split_result) - return split_result - except Exception as e: - print(f"[!] 请求LLM失败: {e}") - return [] + # 重试机制 + for attempt in range(MAX_RETRIES): + try: + logger.info(f"[+] 请求LLM进行断句 (尝试 {attempt + 1}/{MAX_RETRIES})") + response = client.chat.completions.create( + model=MODEL, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt} + ], + temperature=0.1, + timeout=REQUEST_TIMEOUT + ) + result = response.choices[0].message.content + # 清理结果中的多余换行符 + result = re.sub(r'\n+', '', result) + split_result = [segment.strip() for segment in result.split("
") if segment.strip()] + + if use_cache: + set_cache(text, MODEL, split_result) + + logger.info(f"[+] LLM断句完成,共{len(split_result)}句") + return split_result + + except Exception as e: + logger.warning(f"[!] 请求LLM失败 (尝试 {attempt + 1}/{MAX_RETRIES}): {e}") + if attempt < MAX_RETRIES - 1: + wait_time = RETRY_DELAY * (2 ** attempt) # 指数退避 + logger.info(f"[+] 等待 {wait_time} 秒后重试...") + time.sleep(wait_time) + + cache_stats["errors"] += 1 + logger.error(f"[!] LLM断句失败,已重试{MAX_RETRIES}次") + return [] + +def batch_split_by_llm(texts: List[str], use_cache: bool = False) -> List[List[str]]: + """ + 批量处理多个文本的断句 + """ + results = [] + for i, text in enumerate(texts): + logger.info(f"[+] 处理第 {i+1}/{len(texts)} 个文本") + result = split_by_llm(text, use_cache) + results.append(result) + + # 打印批量处理统计 + if use_cache: + stats = get_cache_stats() + hit_rate = stats["hits"] / (stats["hits"] + stats["misses"]) * 100 if (stats["hits"] + stats["misses"]) > 0 else 0 + logger.info(f"[+] 批量处理完成,缓存命中率: {hit_rate:.1f}%") + + return results if __name__ == "__main__": sample_text = (