|
|
| import re |
| import json |
| import copy |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| class SurveyManager: |
| BASE_SURVEY_STRUCTURE = { |
| "title": "", |
| "abstract": "", |
| "introduction": { |
| "content": "" |
| }, |
| "sections": [], |
| "conclusion": "" |
| } |
| |
| def __init__(self): |
| pass |
| |
| @staticmethod |
| def parse_update_pos(update_pos): |
| """ |
| (1) "title", "abstract", "introduction", or "conclusion" |
| (2) "section-i/subsection-j/..." |
| |
| """ |
| if update_pos in ["title", "abstract", "introduction", "conclusion","plan"]: |
| return update_pos |
| else: |
| keys = update_pos.split("/") |
| if len(keys) == 1: |
| i = int(keys[0].lower().split("section-")[-1]) |
| return f"section-{i}" |
| elif len(keys) == 2: |
| i = int(keys[0].lower().split("section-")[-1]) |
| j = int(keys[1].lower().split("subsection-")[-1]) |
| return f"section-{i}/subsection-{j}" |
| elif len(keys) == 3: |
| i = int(keys[0].lower().split("section-")[-1]) |
| j = int(keys[1].lower().split("subsection-")[-1]) |
| k = int(keys[2].lower().split("subsubsection-")[-1]) |
| return f"section-{i}/subsection-{j}/subsubsection-{k}" |
| else: |
| raise ValueError("unsupported update_pos keys") |
| |
| @staticmethod |
| def _to_one_line(string): |
| if isinstance(string, dict): |
| if "content" in string and string["content"]: |
| return SurveyManager._to_one_line(string["content"]) |
| |
| else: |
| return "[PLAN] " + string.get("plan", "").replace("\n", " ").strip() |
| if not string: |
| return "" |
| else: |
| return string |
| |
| @staticmethod |
| def convert_survey_dict_to_str(current_survey): |
| string = "" |
| if current_survey == {}: |
| return "There is no survey." |
| |
| try: |
| content = SurveyManager._to_one_line(current_survey["title"]) |
| string += f"# {content}\n" |
| except: |
| string += f"# Title: None\n" |
| |
| |
| try: |
| content = SurveyManager._to_one_line(current_survey["abstract"]) |
| string += f"## Abstract\n{content}\n" |
| except: |
| string += f"## Abstract\nNone\n" |
| |
| |
| try: |
| content = SurveyManager._to_one_line(current_survey["introduction"]) |
| string += f"## Introduction\n{content}\n" |
| except: |
| string += f"## Introduction\nNone\n" |
| |
| |
| if "sections" in current_survey: |
| for i, section in enumerate(current_survey["sections"]): |
| title_key = "name" if "name" in section else "title" |
| name, content = section[title_key], SurveyManager._to_one_line(section) |
| |
| string += f"## {name}\n{content}\n" |
|
|
| if "subsections" in section: |
| for j, subsection in enumerate(section["subsections"]): |
| name, content = subsection[title_key], SurveyManager._to_one_line(subsection) |
| |
| string += f"### {name}\n{content}\n" |
|
|
| if "subsubsections" in subsection: |
| for k, subsubsection in enumerate(subsection["subsubsections"]): |
| name, content = subsubsection[title_key], SurveyManager._to_one_line(subsubsection) |
| |
| string += f"#### {name}\n{content}\n" |
| |
| |
| |
| try: |
| content = SurveyManager._to_one_line(current_survey["conclusion"]) |
| string += f"## Conclusion\n{content}\n" |
| except: |
| string += f"## Conclusion:\nNone\n" |
| |
| return string |
| |
| @staticmethod |
| def _abbr_one_line(string, abbr=True): |
| if isinstance(string, dict): |
| if "content" in string and string["content"]: |
| return SurveyManager._abbr_one_line(string["content"], abbr=abbr) |
| elif "plan" in string: |
| return "[PLAN] " + string["plan"].replace("\n", " ").strip() |
| else: |
| return "" |
| else: |
| if not string: |
| return "" |
| else: |
| if abbr and len(string) > 50: |
| return "[OK] " + string.replace("\n", " ").strip()[:50] + "..." |
| else: |
| return "[OK] " + string.replace("\n", " ").strip() |
|
|
| @staticmethod |
| def convert_survey_dict_to_abbr_str(current_survey): |
| string = "" |
| if current_survey == {}: |
| return "There is no survey." |
| |
| try: |
| content = SurveyManager._abbr_one_line(current_survey["title"], abbr=False) |
| string += f"# Title: {content}\n" |
| except: |
| string += f"# Title: None\n" |
| |
| try: |
| content = SurveyManager._abbr_one_line(current_survey["abstract"], abbr=False) |
| string += f"# Abstract: {content}\n" |
| except: |
| string += f"# Abstract: None\n" |
| |
| |
| try: |
| content = SurveyManager._abbr_one_line(current_survey["introduction"]) |
| string += f"# Introduction: {content}\n" |
| except: |
| string += f"# Introduction: None\n" |
| |
| |
| if "sections" in current_survey: |
| for i, section in enumerate(current_survey["sections"]): |
| title_key = "name" if "name" in section else "title" |
| name, content = section[title_key], SurveyManager._abbr_one_line(section) |
| string += f"# Section-{i+1} [{name}]: {content}\n" |
|
|
| if "subsections" in section: |
| for j, subsection in enumerate(section["subsections"]): |
| name, content = subsection[title_key], SurveyManager._abbr_one_line(subsection) |
| string += f" ## Subsection-{j+1} [{name}]: {content}\n" |
|
|
| if "subsubsections" in subsection: |
| for k, subsubsection in enumerate(subsection["subsubsections"]): |
| name, content = subsubsection[title_key], SurveyManager._abbr_one_line(subsubsection) |
| string += f" ### Subsubsection-{k+1} [{name}]: {content}\n" |
| |
| |
| try: |
| content = SurveyManager._abbr_one_line(current_survey["conclusion"]) |
| string += f"# Conclusion: {content}\n" |
| except: |
| string += f"# Conclusion: None\n" |
| |
| return string |
|
|
| @staticmethod |
| def update_one_section(sections, i, content): |
| |
| if i >= 0 and i <= (len(sections)-1): |
| sections[i]["content"] = content |
| return True |
| else: |
| |
| return False |
| |
| @staticmethod |
| def update_current_survey(current_survey, answer) -> bool: |
| """ |
| update_pos: "section-i/subsection-j/subsubsection-k" |
| """ |
| |
| |
| try: |
| update_pos, content = answer["update"], answer["content"] |
|
|
| if update_pos == "plan": |
| |
| if current_survey == {}: |
| for k,v in content.items(): |
| current_survey[k] = copy.deepcopy(v) |
| else: |
| return False |
| elif update_pos in ["conclusion", "abstract"]: |
| if update_pos not in current_survey: |
| |
| return False |
| current_survey[update_pos] = content |
|
|
| elif update_pos == "introduction": |
| if update_pos not in current_survey: |
| |
| return False |
| current_survey[update_pos] = {"content": content} |
|
|
| else: |
| keys = update_pos.split("/") |
| if len(keys) == 1: |
| i = int(keys[0].lower().split("section-")[-1])-1 |
| return SurveyManager.update_one_section(current_survey["sections"], i, content) |
|
|
| elif len(keys) == 2: |
| i = int(keys[0].lower().split("section-")[-1])-1 |
| j = int(keys[1].lower().split("subsection-")[-1])-1 |
| try: |
| return SurveyManager.update_one_section(current_survey["sections"][i]["subsections"], j, content) |
| except: |
| |
| return False |
|
|
| elif len(keys) == 3: |
| i = int(keys[0].lower().split("section-")[-1])-1 |
| j = int(keys[1].lower().split("subsection-")[-1])-1 |
| k = int(keys[2].lower().split("subsubsection-")[-1])-1 |
| try: |
| return SurveyManager.update_one_section(current_survey["sections"][i]["subsections"][j]["subsubsections"], k, content) |
| except: |
| |
| return False |
| else: |
| |
| |
| return False |
| |
| except: |
| |
| return False |
| |
| |
| |
| return True |
| |
|
|
| from prompts import * |
| class PromptManger: |
| system_prompt = SYSTEM_PROMPT_0415_BUFFER |
| user_prompt_v0 = USER_PROMPT_v0_0424_BUFFER |
| user_prompt = USER_PROMPT_0415_BUFFER |
|
|
|
|
|
|
|
|
|
|
| class BufferManager: |
| """ |
| Used to manage prompts/responses generated during the Rollout phase, providing data support for subsequent training. |
| batch_rollout_data = [ |
| { |
| query (or env_id): # Uniquely identifies a query or environment, [input parameter]. |
| *running_id: # Uniquely identifies a single rollout. For cases where a query or environment is repeated multiple times, the query can be the same, but running_id will not repeat. |
| state: { # Indicates whether the process is finished. |
| "score": 0.0, |
| "done": True / False |
| "current_survey": dict # Structured data. |
| } |
| trajectory: [ # Organizes all data into a multi-turn interaction format. |
| { |
| step: int, 0~?, # The first step, usually includes some init_info or plan. |
| original_response: str, The raw output from the model, which may have various formatting issues. |
| answer_thought: str, # Encapsulated using the <think>...</think> block. |
| answer: { |
| "original_str": str |
| "update": str, |
| "name": str, |
| "content": str, |
| "inclusions": list, # Extracted independently? |
| } |
| tool_call_thought: str, # Encapsulated using the <think>...</think> block. |
| tool_call: { |
| "original_str": str, # Encapsulated using the <tool_call>...</tool_call> block, used for tool invocation. In the survey setting, it is either "done" to end the task or "search". |
| "tool_name": str # done or search. |
| "keywords": list[str], Extracted search keywords from tool_call, otherwise none. |
| } |
| *papers: list[str], # Top-n papers retrieved via the search engine. Required if using the Agent-Summary-1 for collaborative optimization; otherwise, not needed. |
| cites: list[str], # References cited by the model, which may include multiple citations. |
| summarys: list[str], # Summaries of papers generated using Agent-Summary-1. Must include BIBKEY. |
| *prompt_for_generator: str, # The prompt input to the generator at the current step. Required if using Agent-Summary-2 for generation and collaborative optimization; otherwise, not needed. |
| }, |
| ... |
| |
| ] |
| |
| }, |
| ... |
| ] |
| |
| """ |
| def __init__(self, prompts, repeat_n: int=1): |
| |
| self.step = 0 |
| self.batch_rollout_data = [] |
| self.running_ids = [] |
| batch_size = prompts.batch['input_ids'].size(0) |
| uids = prompts.non_tensor_batch['uid'] |
| querys = prompts.non_tensor_batch['raw_prompt'].copy() |
| ground_truths = prompts.non_tensor_batch['ground_truth'] |
| |
| new_querys = [] |
| for i_batch in range(batch_size): |
| raw_prompt_i_batch = querys[i_batch][-1]["content"] |
| new_querys.append(raw_prompt_i_batch) |
| querys = new_querys |
| |
| assert len(querys) == len(uids) |
| for query, uid, ground_truth in zip(querys, uids, ground_truths): |
| |
| now_survey = {} |
| |
| for _ in range(repeat_n): |
| self.batch_rollout_data.append({ |
| "query": query, |
| "uid": uid, |
| "state": { |
| |
| |
| "done": False, |
| "current_survey": {} |
| }, |
| "trajectory": [], |
| "history_messages": [], |
| }) |
| |
| @staticmethod |
| def _build_system_prompt(): |
| prompt = PromptManger.system_prompt |
| return prompt |
| @staticmethod |
| def _build_user_prompt_v0(query, current_survey): |
| |
| prompt = PromptManger.user_prompt_v0.replace("<user_query>", query) |
| |
| |
| prompt = prompt.replace("<init_survey>", SurveyManager.convert_survey_dict_to_abbr_str(current_survey)) |
| return prompt |
| |
| @staticmethod |
| def _build_user_prompt(query, current_survey, trajs): |
| last_traj = trajs[-1] |
| |
| prompt = PromptManger.user_prompt.replace("<user_query>", query) |
| |
| |
| prompt = prompt.replace("<current_survey>", SurveyManager.convert_survey_dict_to_abbr_str(current_survey)) |
| |
| |
| if last_traj["tool_call_thought"] == "": |
| prompt = prompt.replace("<last_step_thought>", "Your last thought is not available, please give new plan") |
| else: |
| prompt = prompt.replace("<last_step_thought>", last_traj["tool_call_thought"]) |
| prompt = prompt.replace("<last_step_tool_call>", json.dumps(last_traj["tool_call"])) |
| |
| |
| for traj in reversed(trajs): |
| if len(traj["summarys"]) > 0: |
| break |
| summary_num = len(traj["summarys"]) |
| |
| if summary_num == 0: |
| prompt = prompt.replace("<summarys>", "There is no result.") |
| else: |
| prompt = prompt.replace("<summarys>", f"There are {summary_num} results:\n\n" + "\n\n".join(traj["summarys"])) |
| |
| return prompt |
| |
| @staticmethod |
| def _build_user_prompt_force_correct(query, current_survey, trajs): |
| if current_survey == {}: |
| |
| now_section = "plan" |
| |
| else: |
| now_section = "" |
| if isinstance(current_survey["abstract"],dict) and "content" not in current_survey["abstract"]: |
| now_section = "abstract" |
| elif "content" not in current_survey["introduction"]: |
| now_section = "introduction" |
| elif "sections" in current_survey: |
| for section in current_survey["sections"]: |
| if "content" not in section: |
| now_section = "section-{}".format(current_survey["sections"].index(section) + 1) |
| break |
| elif "subsections" in section: |
| for subsection in section["subsections"]: |
| if "content" not in subsection: |
| now_section = "section-{}/subsection-{}".format( |
| current_survey["sections"].index(section) + 1, |
| section["subsections"].index(subsection) + 1 |
| ) |
| break |
| elif "subsubsections" in subsection: |
| for subsubsection in subsection["subsubsections"]: |
| if "content" not in subsubsection: |
| now_section = "section-{}/subsection-{}/subsubsection-{}".format( |
| current_survey["sections"].index(section) + 1, |
| section["subsections"].index(subsection) + 1, |
| subsection["subsubsections"].index(subsubsection) + 1 |
| ) |
| break |
| if now_section: |
| break |
| if now_section: |
| break |
| |
| elif isinstance(current_survey["conclusion"],dict) and "content" not in current_survey["conclusion"]: |
| now_section = "conclusion" |
| else: |
| trajs[-1]["tool_call_thought"] = "Next I will finalize the survey." |
| if now_section != "": |
| trajs[-1]["tool_call_thought"] = f"Next I will provide {now_section}" |
| for traj in reversed(trajs): |
| if len(traj["summarys"]) > 0: |
| break |
| summary_num = len(traj["summarys"]) |
| if now_section == "plan" and summary_num == 0: |
| trajs[-1]["tool_call_thought"] = "I need to get enough information." |
| |
| return BufferManager._build_user_prompt(query, current_survey, trajs) |
| |
| @staticmethod |
| def _check_finalize(query, current_survey, trajs): |
| if current_survey == {}: |
| |
| return False |
| |
| else: |
| now_section = "" |
| if isinstance(current_survey["abstract"],dict) and "content" not in current_survey["abstract"]: |
| now_section = "abstract" |
| elif "content" not in current_survey["introduction"]: |
| now_section = "introduction" |
| elif "sections" in current_survey: |
| for section in current_survey["sections"]: |
| if "content" not in section: |
| now_section = "section-{}".format(current_survey["sections"].index(section) + 1) |
| break |
| elif "subsections" in section: |
| for subsection in section["subsections"]: |
| if "content" not in subsection: |
| now_section = "section-{}/subsection-{}".format( |
| current_survey["sections"].index(section) + 1, |
| section["subsections"].index(subsection) + 1 |
| ) |
| break |
| elif "subsubsections" in subsection: |
| for subsubsection in subsection["subsubsections"]: |
| if "content" not in subsubsection: |
| now_section = "section-{}/subsection-{}/subsubsection-{}".format( |
| current_survey["sections"].index(section) + 1, |
| section["subsections"].index(subsection) + 1, |
| subsection["subsubsections"].index(subsubsection) + 1 |
| ) |
| break |
| if now_section: |
| break |
| if now_section: |
| break |
| |
| elif isinstance(current_survey["conclusion"],dict) and "content" not in current_survey["conclusion"]: |
| now_section = "conclusion" |
| |
| |
| if now_section != "": |
| return False |
| |
| return True |
|
|
| |
| def build_prompt_for_generator(self): |
| total_messages = [] |
| self.running_ids = [] |
| for running_id, data in enumerate(self.batch_rollout_data): |
| if data["state"]["done"]: |
| pass |
| else: |
| if len(data["trajectory"]) == 0: |
| user_prompt = BufferManager._build_user_prompt_v0(data["query"], |
| data["state"]["current_survey"]) |
| else: |
| if data["trajectory"][-1]["update_success"]: |
| user_prompt = BufferManager._build_user_prompt(data["query"], |
| data["state"]["current_survey"], |
| data["trajectory"]) |
| else: |
| |
| user_prompt = BufferManager._build_user_prompt_force_correct(data["query"], |
| data["state"]["current_survey"], |
| data["trajectory"]) |
| messages = [ |
| { |
| "role": "system", |
| "content": BufferManager._build_system_prompt(), |
| }, |
| { |
| "role": "user", |
| "content": user_prompt, |
| } |
| ] |
| data["history_messages"].append(messages) |
| total_messages.append(messages) |
| self.running_ids.append(running_id) |
| return total_messages |
| |
| def update_all_scores(self, scores): |
| assert len(scores) == len(self.batch_rollout_data) |
| for score, log in zip(scores, self.batch_rollout_data): |
| log["state"]["score"] = score |
| |
| def update_all_format_scores(self, scores): |
| assert len(scores) == len(self.batch_rollout_data) |
| for score, log in zip(scores, self.batch_rollout_data): |
| log["state"]["format_score"] = score |
| |
| |
| def update_trajectory(self, model_responses, env_feedbacks): |
| """ |
| model_response: original_response, thought, paragraph, tool_call, format_reward |
| env_feedback: done, search_keywards, abstracts, outcome_reward |
| """ |
| assert len(self.running_ids) == len(model_responses) |
| assert len(self.running_ids) == len(env_feedbacks) |
| |
| for running_id, response, feedback in zip(self.running_ids, model_responses, env_feedbacks): |
| |
| self.batch_rollout_data[running_id]["state"]["done"] = feedback["done"] |
| |
| update_success = False |
| if response["true"]: |
| if self.batch_rollout_data[running_id]["state"]["current_survey"] != {}: |
| if len(response["answer"]) != 0: |
| update_success = SurveyManager.update_current_survey( |
| self.batch_rollout_data[running_id]["state"]["current_survey"], |
| response["answer"]) |
| else: |
| |
| if len(response["answer"]) != 0 and "There is no result" not in self.batch_rollout_data[running_id]["history_messages"][-1][1]["content"]: |
| update_success = SurveyManager.update_current_survey( |
| self.batch_rollout_data[running_id]["state"]["current_survey"], |
| response["answer"]) |
| elif "There is no result" in self.batch_rollout_data[running_id]["history_messages"][-1][1]["content"] and len(response["answer"]) == 0: |
| update_success = True |
|
|
|
|
| self.batch_rollout_data[running_id]["trajectory"].append({ |
| "step": self.step, |
| "original_response": response["original_response"], |
| "answer_thought": response["answer_thought"], |
| "answer": response["answer"], |
| "tool_call_thought": response["tool_call_thought"], |
| "tool_call": response["tool_call"], |
| "search_keywords": feedback["search_keywords"], |
| "summarys": feedback["summarys"], |
| "update_success": update_success and response["true"], |
| }) |
| |
| |
| self.batch_rollout_data[running_id]["history_messages"][-1].append({ |
| "role": "assistant", |
| "content": response["original_response"], |
| }) |
| |
| if self.batch_rollout_data[running_id]["state"]["done"]: |
| real_done = BufferManager._check_finalize(self.batch_rollout_data[running_id]["query"], |
| self.batch_rollout_data[running_id]["state"]["current_survey"], |
| self.batch_rollout_data[running_id]["trajectory"]) |
| if not real_done: |
| self.batch_rollout_data[running_id]["state"]["done"] = False |
|
|
| |
| @staticmethod |
| def match_reference(text:str): |
| reg = r"\\\w*cite(?!style)\w*\{(.+?)\}" |
| placeholder_reg = re.compile(r"^#\d+$") |
| reg_bibkeys = re.findall(reg, text) |
| bibkeys = set() |
| for bibkey in reg_bibkeys: |
| single_bib = bibkey.split(",") |
| for bib in single_bib: |
| if not placeholder_reg.match(bib): |
| bib = bib.strip() |
| if bib and bib != "*": |
| bibkeys.add(bib) |
|
|
| reg = r"\\nocite{(.+?)\}" |
| reg_bibkeys = re.findall(reg, text) |
| for bibkey in reg_bibkeys: |
| single_bib = bibkey.split(",") |
| for bib in single_bib: |
| if not placeholder_reg.match(bib): |
| bib = bib.strip() |
| if bib and bib != "*": |
| bibkeys.remove(bib) |
|
|
| ref_key_list = list(bibkeys) |
| return ref_key_list |
| |
| @staticmethod |
| def parse_generator_response(response): |
| """ |
| 1. 解析失败: step + 1, 重新生成, 给出提示 |
| 2. 解析成功: |
| 2.1 tool_call == search(keywords) 发送post请求 |
| 2.2 tool_call == done 结束任务 |
| |
| **standard format** |
| |
| Current Update: |
| <think> [Your Thoughts]: str </think> |
| <answer> {"update": str, "content": str}: dict </answer> |
| |
| Next Plan: |
| <think> [Your Thoughts]: str </think> |
| <tool_call> {"tool": "search", "arguments": {}}: dict</tool_call> |
| """ |
| extracted_result = { |
| "original_response": response |
| } |
| |
| try: |
| current_update = response.split("Current Update:")[-1].split("Next Plan:")[0] |
| except: |
| current_update = response |
| |
| |
| think_pattern = r"<think>(.*?)</think>" |
| answer_pattern = r"<answer>(.*?)</answer>" |
| tool_pattern = r"<tool_call>(.*?)</tool_call>" |
|
|
| |
|
|
| think_match = re.search(think_pattern, current_update, re.DOTALL) |
| if think_match: |
| think = think_match.group(1) |
| think = think.strip() |
| else: |
| think = "" |
| extracted_result["answer_thought"] = think |
| |
| answer_match = re.search(answer_pattern, current_update, re.DOTALL) |
| has_answer = False |
| if answer_match: |
| answer = answer_match.group(1) |
| answer = answer.strip() |
| try: |
| answer = json.loads(answer) |
| if not answer == {}: |
| assert isinstance(answer["update"], str) |
| answer["update"] = SurveyManager.parse_update_pos(answer["update"]) |
| if answer["update"] == "plan": |
| |
| assert isinstance(answer["content"], dict) |
| plan = answer["content"] |
| assert isinstance(plan, dict) |
| plan.pop("instruction",None) |
| keys = ["abstract", "introduction", "conclusion","sections","title"] |
| for key in keys: |
| assert key in plan |
| for key in plan: |
| assert key in keys |
| if key == "sections": |
| assert isinstance(plan[key], list) |
| for section in plan[key]: |
| assert isinstance(section, dict) |
| assert "plan" in section |
| assert "title" in section |
| assert isinstance(section["plan"], str) |
| assert isinstance(section["title"], str) |
| assert section["title"] != "Methodology" |
| if "subsections" in section: |
| assert isinstance(section["subsections"], list) |
| for subsection in section["subsections"]: |
| assert isinstance(subsection, dict) |
| assert "plan" in subsection |
| assert "title" in subsection |
| assert isinstance(subsection["plan"], str) |
| assert isinstance(subsection["title"], str) |
| if "subsubsections" in section: |
| assert isinstance(subsection["subsubsections"], list) |
| for subsubsection in subsection["subsubsections"]: |
| assert isinstance(subsubsection, dict) |
| assert "plan" in subsubsection |
| assert "title" in subsubsection |
| assert isinstance(subsubsection["plan"], str) |
| assert isinstance(subsubsection["title"], str) |
| elif key == "title": |
| assert isinstance(plan[key], str) |
| else: |
| assert isinstance(plan[key], dict) |
| assert "plan" in plan[key] |
| if key not in ["abstract", "conclusion", "introduction"]: |
| assert "title" in plan[key] |
| else: |
| assert isinstance(answer["content"], str) |
| has_answer = True |
| except: |
| answer = {} |
| else: |
| answer = {} |
| extracted_result["answer"] = answer |
| |
| |
| |
| try: |
| next_plan = response.split("Next Plan:")[1] |
| except: |
| try: |
| next_plan = response.split("</answer>")[1] |
| except: |
| next_plan = response |
| |
| think_match = re.search(think_pattern, next_plan, re.DOTALL) |
| if think_match: |
| think = think_match.group(1) |
| think = think.strip() |
| else: |
| think = "" |
| extracted_result["tool_call_thought"] = think |
| |
| tool_match = re.search(tool_pattern, next_plan, re.DOTALL) |
| has_tool_call = False |
| if tool_match: |
| tool_text = tool_match.group(1) |
| tool_text = tool_text.strip() |
| try: |
| tool_call = json.loads(tool_text) |
| assert tool_call["name"] in ["search_engine", "finalize"] |
| if tool_call["name"] == "search_engine": |
| assert isinstance(tool_call["arguments"]["query"], list) |
| has_tool_call = True |
| except: |
| tool_call = {} |
| else: |
| |
| tool_call = {} |
| |
| extracted_result["tool_call"] = tool_call |
|
|
| extracted_result["true"] = has_answer and has_tool_call |
| reg = r"[\u4e00-\u9fa5]" |
| has_chinese = re.search(reg, response) is not None |
| extracted_result["true"] = extracted_result["true"] and not has_chinese |
| |
| return extracted_result |
| |
|
|
| class BufferManager_V2(BufferManager): |
|
|
| def __init__(self, querys, repeat_n=1): |
| |
| self.step = 0 |
| self.batch_rollout_data = [] |
| self.running_ids = [] |
| |
| for uid, query in enumerate(querys): |
| print("CURRENT QUERY: ", query) |
| for _ in range(repeat_n): |
| self.batch_rollout_data.append({ |
| "query": query, |
| "uid": f"query_{uid}", |
| "state": { |
| |
| |
| "done": False, |
| "current_survey": {} |
| }, |
| "trajectory": [], |
| "history_messages": [] |
| }) |
|
|
|
|