from fastapi import FastAPI from fastapi.responses import StreamingResponse from pydantic import BaseModel from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer ) import torch import threading app = FastAPI() MODEL_NAME = "Qwen/Qwen2.5-Coder-7B" # ---- Quantization config (CPU safe) ---- bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float32, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) tokenizer = AutoTokenizer.from_pretrained( MODEL_NAME, trust_remote_code=True ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="cpu", quantization_config=bnb_config, trust_remote_code=True ) class Prompt(BaseModel): message: str # ------------------------------------------------- # ✅ NORMAL CHAT (UNCHANGED) # ------------------------------------------------- @app.post("/chat") def chat(prompt: Prompt): inputs = tokenizer(prompt.message, return_tensors="pt") outputs = model.generate( **inputs, max_new_tokens=200, temperature=0.7, do_sample=True ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"response": response} # ------------------------------------------------- # 🚀 STREAMING CHAT (CHATGPT-LIKE) # ------------------------------------------------- @app.post("/chat-stream") def chat_stream(prompt: Prompt): inputs = tokenizer(prompt.message, return_tensors="pt") streamer = TextIteratorStreamer( tokenizer, skip_special_tokens=True, skip_prompt=True ) generation_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=200, temperature=0.7, do_sample=True ) # Run generation in background thread thread = threading.Thread( target=model.generate, kwargs=generation_kwargs ) thread.start() def token_generator(): for token in streamer: yield token return StreamingResponse( token_generator(), media_type="text/plain" )