Spaces:
Paused
Paused
| # coding=utf-8 | |
| # Implements API for ChatGLM3-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) | |
| # Usage: python openai_api.py | |
| # Visit http://localhost:8000/docs for documents. | |
| # 在OpenAI的API中,max_tokens 等价于 HuggingFace 的 max_new_tokens 而不是 max_length,。 | |
| # 例如,对于6b模型,设置max_tokens = 8192,则会报错,因为扣除历史记录和提示词后,模型不能输出那么多的tokens。 | |
| import os | |
| import time | |
| from contextlib import asynccontextmanager | |
| from typing import List, Literal, Optional, Union | |
| import torch | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from loguru import logger | |
| from pydantic import BaseModel, Field | |
| from sse_starlette.sse import EventSourceResponse | |
| from transformers import AutoTokenizer, AutoModel | |
| from utils import process_response, generate_chatglm3, generate_stream_chatglm3 | |
| MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b') | |
| TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH) | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| async def lifespan(app: FastAPI): # collects GPU memory | |
| yield | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| app = FastAPI(lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class ModelCard(BaseModel): | |
| id: str | |
| object: str = "model" | |
| created: int = Field(default_factory=lambda: int(time.time())) | |
| owned_by: str = "owner" | |
| root: Optional[str] = None | |
| parent: Optional[str] = None | |
| permission: Optional[list] = None | |
| class ModelList(BaseModel): | |
| object: str = "list" | |
| data: List[ModelCard] = [] | |
| class FunctionCallResponse(BaseModel): | |
| name: Optional[str] = None | |
| arguments: Optional[str] = None | |
| class ChatMessage(BaseModel): | |
| role: Literal["user", "assistant", "system", "function"] | |
| content: str = None | |
| name: Optional[str] = None | |
| function_call: Optional[FunctionCallResponse] = None | |
| class DeltaMessage(BaseModel): | |
| role: Optional[Literal["user", "assistant", "system"]] = None | |
| content: Optional[str] = None | |
| function_call: Optional[FunctionCallResponse] = None | |
| class ChatCompletionRequest(BaseModel): | |
| model: str | |
| messages: List[ChatMessage] | |
| temperature: Optional[float] = 0.8 | |
| top_p: Optional[float] = 0.8 | |
| max_tokens: Optional[int] = None | |
| stream: Optional[bool] = False | |
| functions: Optional[Union[dict, List[dict]]] = None | |
| # Additional parameters | |
| repetition_penalty: Optional[float] = 1.1 | |
| class ChatCompletionResponseChoice(BaseModel): | |
| index: int | |
| message: ChatMessage | |
| finish_reason: Literal["stop", "length", "function_call"] | |
| class ChatCompletionResponseStreamChoice(BaseModel): | |
| index: int | |
| delta: DeltaMessage | |
| finish_reason: Optional[Literal["stop", "length", "function_call"]] | |
| class UsageInfo(BaseModel): | |
| prompt_tokens: int = 0 | |
| total_tokens: int = 0 | |
| completion_tokens: Optional[int] = 0 | |
| class ChatCompletionResponse(BaseModel): | |
| model: str | |
| object: Literal["chat.completion", "chat.completion.chunk"] | |
| choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] | |
| created: Optional[int] = Field(default_factory=lambda: int(time.time())) | |
| usage: Optional[UsageInfo] = None | |
| async def list_models(): | |
| model_card = ModelCard(id="chatglm3-6b") | |
| return ModelList(data=[model_card]) | |
| async def create_chat_completion(request: ChatCompletionRequest): | |
| global model, tokenizer | |
| if len(request.messages) < 1 or request.messages[-1].role == "assistant": | |
| raise HTTPException(status_code=400, detail="Invalid request") | |
| gen_params = dict( | |
| messages=request.messages, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| max_tokens=request.max_tokens or 1024, | |
| echo=False, | |
| stream=request.stream, | |
| repetition_penalty=request.repetition_penalty, | |
| functions=request.functions, | |
| ) | |
| logger.debug(f"==== request ====\n{gen_params}") | |
| if request.stream: | |
| generate = predict(request.model, gen_params) | |
| return EventSourceResponse(generate, media_type="text/event-stream") | |
| response = generate_chatglm3(model, tokenizer, gen_params) | |
| # Remove the first newline character | |
| if response["text"].startswith("\n"): | |
| response["text"] = response["text"][1:] | |
| response["text"] = response["text"].strip() | |
| usage = UsageInfo() | |
| function_call, finish_reason = None, "stop" | |
| if request.functions: | |
| try: | |
| function_call = process_response(response["text"], use_tool=True) | |
| except: | |
| logger.warning("Failed to parse tool call, maybe the response is not a tool call or have been answered.") | |
| if isinstance(function_call, dict): | |
| finish_reason = "function_call" | |
| function_call = FunctionCallResponse(**function_call) | |
| message = ChatMessage( | |
| role="assistant", | |
| content=response["text"], | |
| function_call=function_call if isinstance(function_call, FunctionCallResponse) else None, | |
| ) | |
| logger.debug(f"==== message ====\n{message}") | |
| choice_data = ChatCompletionResponseChoice( | |
| index=0, | |
| message=message, | |
| finish_reason=finish_reason, | |
| ) | |
| task_usage = UsageInfo.model_validate(response["usage"]) | |
| for usage_key, usage_value in task_usage.model_dump().items(): | |
| setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) | |
| return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", usage=usage) | |
| async def predict(model_id: str, params: dict): | |
| global model, tokenizer | |
| choice_data = ChatCompletionResponseStreamChoice( | |
| index=0, | |
| delta=DeltaMessage(role="assistant"), | |
| finish_reason=None | |
| ) | |
| chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") | |
| yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
| previous_text = "" | |
| for new_response in generate_stream_chatglm3(model, tokenizer, params): | |
| decoded_unicode = new_response["text"] | |
| delta_text = decoded_unicode[len(previous_text):] | |
| previous_text = decoded_unicode | |
| finish_reason = new_response["finish_reason"] | |
| if len(delta_text) == 0 and finish_reason != "function_call": | |
| continue | |
| function_call = None | |
| if finish_reason == "function_call": | |
| try: | |
| function_call = process_response(decoded_unicode, use_tool=True) | |
| except: | |
| logger.warning("Failed to parse tool call, maybe the response is not a tool call or have been answered.") | |
| if isinstance(function_call, dict): | |
| function_call = FunctionCallResponse(**function_call) | |
| delta = DeltaMessage( | |
| content=delta_text, | |
| role="assistant", | |
| function_call=function_call if isinstance(function_call, FunctionCallResponse) else None, | |
| ) | |
| choice_data = ChatCompletionResponseStreamChoice( | |
| index=0, | |
| delta=delta, | |
| finish_reason=finish_reason | |
| ) | |
| chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") | |
| yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
| choice_data = ChatCompletionResponseStreamChoice( | |
| index=0, | |
| delta=DeltaMessage(), | |
| finish_reason="stop" | |
| ) | |
| chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") | |
| yield "{}".format(chunk.model_dump_json(exclude_unset=True)) | |
| yield '[DONE]' | |
| if __name__ == "__main__": | |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) | |
| if 'cuda' in DEVICE: # AMD, NVIDIA GPU can use Half Precision | |
| model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).to(DEVICE).eval() | |
| else: # CPU, Intel GPU and other GPU can use Float16 Precision Only | |
| model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).float().to(DEVICE).eval() | |
| uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) | |