| | import os |
| | import textwrap |
| | from pathlib import Path |
| | import logging |
| | import numpy as np |
| | from scipy.io.wavfile import write |
| | import config |
| | import csv |
| | import av |
| | import re |
| | from functools import wraps |
| | import time |
| | import threading |
| | |
| | p_pattern = re.compile(r"(\s*\[.*?\])") |
| | p_start_pattern = re.compile(r"(\s*\[.*)") |
| | p_end_pattern = re.compile(r"(\s*.*\])") |
| |
|
| |
|
| | def filter_words(res_word): |
| | """ |
| | Filter words according to specific bracket patterns. |
| | |
| | Args: |
| | res_word: Iterable of word objects with a 'text' attribute |
| | |
| | Returns: |
| | List of filtered word objects |
| | """ |
| | asr_results = [] |
| | skip_word = False |
| |
|
| | for word in res_word: |
| | |
| | if p_pattern.match(word.text): |
| | continue |
| |
|
| | |
| | if p_start_pattern.match(word.text): |
| | skip_word = True |
| | continue |
| |
|
| | |
| | if p_end_pattern.match(word.text) and skip_word: |
| | skip_word = False |
| | continue |
| |
|
| | |
| | if skip_word: |
| | continue |
| |
|
| | word.text = replace_hotwords(word.text) |
| |
|
| | |
| | asr_results.append(word) |
| |
|
| | return asr_results |
| |
|
| |
|
| |
|
| | def replace_hotwords(text: str) -> str: |
| | """ |
| | Reads hotwords from a JSON file and replaces occurrences in the input text. |
| | |
| | Args: |
| | text: The input string to process. |
| | |
| | Returns: |
| | The string with hotwords replaced. |
| | """ |
| |
|
| | processed_text = text |
| | |
| | for key, value in config.hotwords_json.items(): |
| | |
| | processed_text = processed_text.replace(key, value) |
| | logging.debug(f"Replace string: {text} => {processed_text}") |
| | return processed_text |
| |
|
| |
|
| | def log_block(key: str, value, unit=''): |
| | if config.DEBUG: |
| | return |
| | """格式化输出日志内容""" |
| | key_fmt = f"[ {key.ljust(25)}]" |
| | val_fmt = f"{value} {unit}".strip() |
| | logging.info(f"{key_fmt}: {val_fmt}") |
| |
|
| |
|
| | def clear_screen(): |
| | """Clears the console screen.""" |
| | os.system("cls" if os.name == "nt" else "clear") |
| |
|
| |
|
| | def print_transcript(text): |
| | """Prints formatted transcript text.""" |
| | wrapper = textwrap.TextWrapper(width=60) |
| | for line in wrapper.wrap(text="".join(text)): |
| | print(line) |
| |
|
| |
|
| | def format_time(s): |
| | """Convert seconds (float) to SRT time format.""" |
| | hours = int(s // 3600) |
| | minutes = int((s % 3600) // 60) |
| | seconds = int(s % 60) |
| | milliseconds = int((s - int(s)) * 1000) |
| | return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}" |
| |
|
| |
|
| | def create_srt_file(segments, resampled_file): |
| | with open(resampled_file, 'w', encoding='utf-8') as srt_file: |
| | segment_number = 1 |
| | for segment in segments: |
| | start_time = format_time(float(segment['start'])) |
| | end_time = format_time(float(segment['end'])) |
| | text = segment['text'] |
| |
|
| | srt_file.write(f"{segment_number}\n") |
| | srt_file.write(f"{start_time} --> {end_time}\n") |
| | srt_file.write(f"{text}\n\n") |
| |
|
| | segment_number += 1 |
| |
|
| |
|
| | def resample(file: str, sr: int = 16000): |
| | """ |
| | Resample the audio file to 16kHz. |
| | |
| | Args: |
| | file (str): The audio file to open |
| | sr (int): The sample rate to resample the audio if necessary |
| | |
| | Returns: |
| | resampled_file (str): The resampled audio file |
| | """ |
| | container = av.open(file) |
| | stream = next(s for s in container.streams if s.type == 'audio') |
| |
|
| | resampler = av.AudioResampler( |
| | format='s16', |
| | layout='mono', |
| | rate=sr, |
| | ) |
| |
|
| | resampled_file = Path(file).stem + "_resampled.wav" |
| | output_container = av.open(resampled_file, mode='w') |
| | output_stream = output_container.add_stream('pcm_s16le', rate=sr) |
| | output_stream.layout = 'mono' |
| |
|
| | for frame in container.decode(audio=0): |
| | frame.pts = None |
| | resampled_frames = resampler.resample(frame) |
| | if resampled_frames is not None: |
| | for resampled_frame in resampled_frames: |
| | for packet in output_stream.encode(resampled_frame): |
| | output_container.mux(packet) |
| |
|
| | for packet in output_stream.encode(None): |
| | output_container.mux(packet) |
| |
|
| | output_container.close() |
| | return resampled_file |
| |
|
| |
|
| | def save_to_wave(filename, data:np.ndarray, sample_rate=16000): |
| | data = (data * 32767).astype(np.int16) |
| | write(filename, sample_rate, data) |
| |
|
| |
|
| | def pcm_bytes_to_np_array(pcm_bytes: bytes, dtype=np.float32, channels=1): |
| | |
| | audio_np = np.frombuffer(pcm_bytes, dtype=np.int16) |
| | audio_np = audio_np.astype(dtype=dtype) |
| | if dtype == np.float32: |
| | audio_np /= 32768.0 |
| | |
| | if channels > 1: |
| | audio_np = audio_np.reshape(-1, channels) |
| | return audio_np |
| |
|
| | def timer(name: str): |
| | def decorator(func): |
| | @wraps(func) |
| | def wrapper(*args, **kwargs): |
| | start_time = time.perf_counter() |
| | result = func(*args, **kwargs) |
| | end_time = time.perf_counter() |
| | duration = end_time - start_time |
| | log_block(f"{name} cost:", f"{duration:.2f} s") |
| | return result |
| | return wrapper |
| | return decorator |
| |
|
| | def get_text_separator(language: str) -> str: |
| | """根据语言返回适当的文本分隔符""" |
| | return "" if language == "zh" else " " |
| |
|
| |
|
| | def start_thread(target_function) -> threading.Thread: |
| | """启动守护线程执行指定函数""" |
| | thread = threading.Thread(target=target_function) |
| | thread.daemon = True |
| | thread.start() |
| | return thread |
| |
|
| |
|
| | class TestDataWriter: |
| | def __init__(self, file_path='test_data.csv'): |
| | self.file_path = file_path |
| | self.fieldnames = [ |
| | 'seg_id', 'transcribe_time', 'translate_time', |
| | 'transcribeContent', 'from', 'to', 'translateContent', 'partial' |
| | ] |
| | self._ensure_file_has_header() |
| |
|
| | def _ensure_file_has_header(self): |
| | if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0: |
| | with open(self.file_path, mode='w', newline='') as file: |
| | writer = csv.DictWriter(file, fieldnames=self.fieldnames) |
| | writer.writeheader() |
| |
|
| | def write(self, result: 'DebugResult'): |
| | with open(self.file_path, mode='a', newline='') as file: |
| | writer = csv.DictWriter(file, fieldnames=self.fieldnames) |
| | writer.writerow(result.model_dump(by_alias=True)) |
| |
|