Spaces:
Runtime error
Runtime error
| try: | |
| import openai | |
| except ImportError: | |
| print("The 'openai' module is not installed. Please install it using 'pip install openai'.") | |
| exit(1) | |
| import json | |
| import re | |
| import time | |
| from tqdm import tqdm | |
| from config import LLM_RETRIES, LLM_REQUEST_INTERVAL, LLM_RETRY_DELAY, LLM_MAX_TEXT_LENGTH, LLM_PROMPT | |
| def send_request(client, prompt, text, model): | |
| text = remove_json_escape_characters(text) | |
| messages = [{"role": "user", "content": f"{prompt}\n\n{text}"}] | |
| try: | |
| response = client.chat.completions.create(model=model, messages=messages, max_tokens=4096) | |
| print(response) | |
| return response.choices[0].message.content | |
| except openai.OpenAIError as e: | |
| print(f"OpenAI API error: {e}") | |
| return None | |
| def clean_text(text): | |
| import re | |
| if isinstance(text, str): | |
| # 移除 ASCII 控制字符(0-31 和 127) | |
| text = re.sub(r'[\x00-\x1F\x7F]', '', text) | |
| return text | |
| def extract_json(response_text): | |
| with open("debug.txt", "w", encoding="utf8") as f: | |
| f.write(response_text) | |
| pattern = re.compile(r'((\[[^\}]{3,})?\{s*[^\}\{]{3,}?:.*\}([^\{]+\])?)', re.M | re.S) | |
| match = re.search(pattern, response_text) | |
| if match: | |
| return match.group(0) | |
| return None | |
| def clean_and_load_json(json_string): | |
| try: | |
| cleaned_json_string = json_string.replace("'", '"') | |
| cleaned_json_string = clean_text(cleaned_json_string) | |
| # debug 写入文本 | |
| with open("debug.json", "w", encoding="utf8") as f: | |
| f.write(cleaned_json_string) | |
| json_obj = json.loads(cleaned_json_string) | |
| return json_obj | |
| except json.JSONDecodeError as e: | |
| print(f"JSON decode error: {e}") | |
| return None | |
| def validate_json(json_obj, required_keys): | |
| return isinstance(json_obj, list) | |
| print(json_obj) | |
| return True | |
| if json_obj and all(key in json_obj for key in required_keys): | |
| return True | |
| return False | |
| def process_text(client, prompt, text, model, required_keys): | |
| parts = [text[i:i + LLM_MAX_TEXT_LENGTH] for i in range(0, len(text), LLM_MAX_TEXT_LENGTH)] | |
| results = [] | |
| for part in tqdm(parts, desc="Processing text"): | |
| for attempt in range(LLM_RETRIES + 1): | |
| response = send_request(client, prompt, part, model) | |
| if response: | |
| json_string = extract_json(response) | |
| if json_string: | |
| json_obj = clean_and_load_json(json_string) | |
| if validate_json(json_obj, required_keys): | |
| results.extend(json_obj) | |
| break | |
| else: | |
| print(f"Invalid JSON structure. Retrying ({attempt + 1}/{LLM_RETRIES})...") | |
| else: | |
| print(f"No JSON found in response. Retrying ({attempt + 1}/{LLM_RETRIES})...") | |
| else: | |
| print(f"API request failed. Retrying ({attempt + 1}/{LLM_RETRIES})...") | |
| time.sleep(LLM_RETRY_DELAY) | |
| time.sleep(LLM_REQUEST_INTERVAL) | |
| return results | |
| def llm_operation(api_base, api_key, model, prompt, text, required_keys): | |
| client = openai.OpenAI(api_key=api_key, base_url=api_base) | |
| return process_text(client, prompt, text, model, required_keys) | |
| def remove_json_escape_characters(s): | |
| """ | |
| 移除用户提交文本中容易被llm输出导致json校验出错的字符 | |
| :param s: | |
| :return: | |
| """ | |
| # 定义需要移除的字符 | |
| escape_chars = { | |
| '"': '', | |
| '\\': '', | |
| '/': '', | |
| '\b': '', | |
| '\f': '', | |
| '\n': '', | |
| '\r': '', | |
| '\t': '', | |
| } | |
| escape_re = re.compile('|'.join(re.escape(key) for key in escape_chars.keys())) | |
| def replace(match): | |
| return escape_chars[match.group(0)] | |
| return escape_re.sub(replace, s) | |