| from contextlib import contextmanager |
| from codetiming import Timer |
| @contextmanager |
| def _timer(name: str, timing_raw): |
| with Timer(name=name, logger=None) as timer: |
| yield |
| timing_raw[name] = timer.last |
| |
| from buffer import SurveyManager |
| from buffer import BufferManager_V2 as BufferManager |
| from vllm import LLM, SamplingParams |
| from transformers import AutoTokenizer |
| import re |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
| from fastapi.middleware.cors import CORSMiddleware |
| import asyncio |
| import argparse |
| from pydantic import BaseModel |
| import json |
| import aiohttp |
|
|
|
|
| app = FastAPI() |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| active_connections = set() |
|
|
| @app.websocket("/ws") |
| async def websocket_endpoint(websocket: WebSocket): |
| await websocket.accept() |
| active_connections.add(websocket) |
| try: |
| while True: |
| await websocket.receive_text() |
| except WebSocketDisconnect: |
| active_connections.remove(websocket) |
|
|
| async def post_to_frontend(payload): |
| print(f"Sending payload to frontend: {payload}") |
| for ws in list(active_connections): |
| try: |
| await ws.send_text(payload) |
| except Exception as e: |
| print(f"Error sending to WebSocket: {e}") |
| active_connections.remove(ws) |
|
|
|
|
| def write_to_json(data, path): |
| with open(path, 'w', encoding='utf8') as f: |
| f.write(json.dumps(data, ensure_ascii=False, indent=4)) |
|
|
| class OriginalvLLMRollout: |
| def __init__(self, model_name_or_path): |
| |
| self.rollout_model = LLM( |
| model=model_name_or_path, |
| tokenizer=model_name_or_path, |
| gpu_memory_utilization=0.95, |
| trust_remote_code=True, |
| ) |
| self.sampling_params = SamplingParams( |
| temperature=0.7, |
| top_p=0.8, |
| repetition_penalty=1.05, |
| top_k=20, |
| max_tokens=2748, |
| ) |
|
|
| def generate(self, input_texts): |
| generated_texts = [] |
| completions = self.rollout_model.generate(input_texts, self.sampling_params, use_tqdm=False) |
| for output in completions: |
| generated_text = output.outputs[0].text |
| generated_texts.append(generated_text) |
| return generated_texts |
| |
| def chat(self, input_messages): |
| generated_texts = [] |
| completions = self.rollout_model.chat(input_messages, self.sampling_params, use_tqdm=False) |
| for output in completions: |
| generated_text = output.outputs[0].text |
| generated_texts.append(generated_text) |
| return generated_texts |
|
|
| async def rollout_with_env(querys, batch_size, max_turns, model_path, url, |
| deploy_port=None): |
| """ |
| Args: |
| querys: [string] |
| """ |
| |
| |
| |
| n = len(querys) // batch_size |
| batch_querys = [] |
| for i in range(n+1): |
| temp_data = querys[i*batch_size: (i+1)*batch_size] |
| if len(temp_data) > 0: |
| batch_querys.append(temp_data) |
| print("QUERY NUMBER with BATCH: ", [len(x) for x in batch_querys]) |
| |
| |
| |
| |
| vllm_manager = OriginalvLLMRollout(model_path) |
|
|
| |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| |
| total_rollout_data = [] |
| for querys in batch_querys: |
| |
| |
| |
| buffer_manager = BufferManager(querys) |
|
|
| while True: |
| |
| if buffer_manager.step >= max_turns: |
| break |
| |
| |
| |
| |
| messagess_todo = buffer_manager.build_prompt_for_generator() |
| |
|
|
| |
| if len(messagess_todo) == 0: |
| break |
| |
| |
| |
| |
| timing_raw = {} |
| with _timer('vllm sampling', timing_raw): |
| |
| response_texts = await asyncio.to_thread(vllm_manager.chat, messagess_todo) |
| |
| |
| |
| |
| |
| extracted_results = [] |
| for response_text in response_texts: |
| result = BufferManager.parse_generator_response(response_text) |
| extracted_results.append(result) |
| |
| |
| |
| |
| payload = { |
| "tool_calls": [x["tool_call"] for x in extracted_results] |
| } |
| if buffer_manager.step <=2: |
| payload["topk"] = 20 |
| with _timer('get env feedback', timing_raw): |
| |
| async with aiohttp.ClientSession() as session: |
| async with session.post(url, json=payload) as resp: |
| env_response_batched = await resp.json() |
|
|
| |
| |
| |
| with _timer('postprocessing', timing_raw): |
| buffer_manager.update_trajectory(extracted_results, env_response_batched) |
| buffer_manager.step += 1 |
| |
| print(timing_raw) |
| |
| if deploy_port is not None: |
| now_text = json_to_markdown(buffer_manager.batch_rollout_data[-1]) |
| now_search_keywords= buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["search_keywords"] |
| now_update = buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["answer_thought"] |
| next_update = buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["tool_call_thought"] |
| now_query = buffer_manager.batch_rollout_data[-1]["query"] |
| trajs = buffer_manager.batch_rollout_data[-1]["trajectory"] |
| updated_success = buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["update_success"] |
| if updated_success: |
| for traj in reversed(trajs): |
| if len(traj["summarys"]) > 0: |
| break |
| summary_num = len(traj["summarys"]) |
| if summary_num == 0: |
| summary_text = "No summaries yet." |
| else: |
| summary_text = "\n".join(traj["summarys"]) |
| frontend_payload = { |
| "markdown": now_text, |
| "searchKeywords": now_search_keywords, |
| "nowUpdate": now_update, |
| "nextUpdate": next_update, |
| "query": now_query, |
| "papers": summary_text |
| } |
| frontend_payload = json.dumps(frontend_payload, ensure_ascii=False) |
| try: |
| await post_to_frontend(frontend_payload) |
| except Exception as e: |
| print(f"Error posting to frontend: {e}") |
|
|
|
|
| |
| for item in buffer_manager.batch_rollout_data: |
| item["survey_text"] = SurveyManager.convert_survey_dict_to_str(item["state"]["current_survey"]) |
|
|
| total_rollout_data.extend(buffer_manager.batch_rollout_data) |
| |
| |
| |
| del buffer_manager |
| |
| return total_rollout_data |
|
|
|
|
| def json_to_markdown(json_data): |
| text = SurveyManager.convert_survey_dict_to_str(json_data["state"]["current_survey"]) |
| all_summarys = {} |
| for traj in json_data["trajectory"]: |
| for item in traj["summarys"]: |
| split_text = item.split("\n") |
| bibkey = split_text[0].split(":")[1].strip() |
| title_begin_index = item.find("Title:") + len("Title:") |
| title_end_index = item.find("Abstract:") |
| title = item[title_begin_index:title_end_index].strip() |
| arxivid = bibkey.split("arxivid")[-1].strip() |
| html = f"arxiv.org/abs/{arxivid}" |
| all_summarys[bibkey] = f"[{title}](https://{html})" |
| |
| reg = r"\\cite\{(.+?)\}" |
| placeholder_reg = re.compile(r"^#\d+$") |
| reg_bibkeys = re.findall(reg, text) |
| bibkeys = [] |
| for bibkey in reg_bibkeys: |
| single_bib = bibkey.split(",") |
| for bib in single_bib: |
| if not placeholder_reg.match(bib): |
| bib = bib.strip() |
| if bib and bib != "*" and bib not in bibkeys: |
| bibkeys.append(bib) |
| |
| bibkeys_index = {bibkey: i+1 for i, bibkey in enumerate(bibkeys)} |
| |
| def replace_bibkey(bibkey): |
| bibkey = bibkey.group(1) |
| single_bib = bibkey.split(",") |
| new_bibs = [] |
| for bib in single_bib: |
| if not placeholder_reg.match(bib): |
| bib = bib.strip() |
| if bib and bib != "*": |
| if bib in bibkeys_index: |
| new_bibs.append(f"{bibkeys_index[bib]}") |
| else: |
| print(f"Warning: {bib} not found in bibkeys") |
| if len(new_bibs) > 0: |
| return "[" + ",".join(new_bibs) + "]" |
| else: |
| return "" |
| text = re.sub(reg, replace_bibkey, text) |
| reference_text = "\n\n".join([f"[{i}] {all_summarys[bibkey]}" for bibkey, i in bibkeys_index.items()]) |
| text += "\n## References\n" + reference_text |
| return text |
| |
| async def test_surveyGen(model_path, out_path,querys, url, deploy_port=None): |
|
|
| total_rollout_data = await rollout_with_env(querys, 1, 1000, model_path, url, deploy_port) |
| all_md_texts = [] |
| for json_data in total_rollout_data: |
| md_text = json_to_markdown(json_data) |
| all_md_texts.append(md_text) |
| |
| all_md_texts = "\n\n".join(all_md_texts) |
| with open(out_path, 'w', encoding='utf8') as f: |
| f.write(all_md_texts) |
| |
| |
| |
| |
|
|
|
|
|
|
| class QueryRequest(BaseModel): |
| query: str |
|
|
| @app.post("/generate_survey") |
| async def generate_survey(request: QueryRequest): |
| global args |
| |
| model_path = args.model_path |
| out_path = args.output_file |
| query = request.query |
| querys = [query] |
| url = args.retriver_url |
| deploy_port = args.port if args.port is not None else None |
| try: |
| await test_surveyGen(model_path, out_path, querys, url, deploy_port) |
| return {"status": "success", "message": "Survey generated successfully."} |
| except Exception as e: |
| print(f"Error generating survey: {e}") |
| return {"status": "error", "message": str(e)} |
| |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Run survey generation with vLLM.") |
| parser.add_argument("--model_path", type=str, required=True, help="Path to the model.") |
| parser.add_argument("--query", type=str, required=True, help="Query to generate survey.") |
| parser.add_argument("--output_file", type=str, required=True, help="Path to the output Markdown file.") |
| parser.add_argument("--retriver_url", type=str, default="http://localhost:8400", help="URL of the retriever service.") |
| parser.add_argument("--port", type=str, default=None, help="Deploy port, default is None, which means not deploy.") |
| args = parser.parse_args() |
|
|
| if args.port is not None: |
| import uvicorn |
| uvicorn.run(app, host="localhost", port=int(args.port)) |
| |
| |
| else: |
| asyncio.run( |
| test_surveyGen( |
| model_path=args.model_path, |
| out_path=args.output_file, |
| querys=[args.query], |
| url=args.retriver_url |
| ) |
| ) |
| |