Spaces:
Sleeping
Sleeping
| # FastAPI server which will handle all the backend and GenAI aspects of the application | |
| # uvicorn server:app --reload | |
| # Avoid using --reload flag, because, LLMs will keep reloading and system will overheat. | |
| from fastapi import FastAPI, File, UploadFile, Form, Request, Query | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import json | |
| from typing import Dict | |
| from pydantic import BaseModel | |
| from contextlib import asynccontextmanager | |
| # llm system imports: | |
| from llm_system.core.llm import get_llm, get_output_parser # Functions | |
| from llm_system.core.llm import get_dummy_response # Function | |
| from llm_system.core.llm import get_dummy_response_stream # Function | |
| from llm_system.core.qdrant_database import VectorDB # Class (migrated to Qdrant) | |
| from llm_system.core.history import HistoryStore # Class | |
| from llm_system.chains.rag import build_rag_chain # Function | |
| from llm_system import config # Constants | |
| from llm_system.core.ingestion import ingest_file # Function | |
| from llm_system.core.evaluation_deepeval import RAGEvaluator # RAG evaluator | |
| from llm_system.core.cache import ResponseCache # Query response cache (<100ms cache hits) | |
| # Helper Modules: | |
| import pg_db # PostgreSQL database module (migrated from sq_db) | |
| import files | |
| # Type hinting imports: | |
| from langchain_core.vectorstores import VectorStore as T_VECTOR_STORE | |
| from langchain_core.messages import BaseMessage as T_MESSAGE | |
| import logger | |
| log = logger.get_logger("rag_server") | |
| # ------------------------------------------------------------------------------ | |
| # Constants: | |
| # ------------------------------------------------------------------------------ | |
| # UPLOADS_DIR: str = "user_uploads" | |
| OLD_FILE_THRESHOLD: int = 3600 * 1 # 24 hours in seconds | |
| # OLD_FILE_THRESHOLD: int = 20 # 1 min | |
| # ------------------------------------------------------------------------------ | |
| # FastAPI Startup: | |
| # ------------------------------------------------------------------------------ | |
| async def lifespan(app: FastAPI): | |
| """Define the lifespan context manager for startup/shutdown""" | |
| # [ Startup ] | |
| log.info("[LifeSpan] Starting the server components.") | |
| app.state.llm_chat = get_llm( | |
| model_name=config.LLM_CHAT_MODEL_NAME, | |
| context_size=config.MAX_CONTENT_SIZE, | |
| temperature=config.LLM_CHAT_TEMPERATURE, | |
| verify_connection=config.VERIFY_LLM_CONNECTION | |
| ) | |
| # app.state.llm_summary = get_llm(...) | |
| app.state.llm_summary = app.state.llm_chat | |
| app.state.output_parser = get_output_parser() | |
| app.state.vector_db = VectorDB( | |
| embed_model=config.EMB_MODEL_NAME, | |
| retriever_num_docs=config.DOCS_NUM_COUNT, | |
| verify_connection=config.VERIFY_EMB_CONNECTION, | |
| ) | |
| app.state.history_store = HistoryStore() | |
| app.state.rag_chain = build_rag_chain( | |
| llm_chat=app.state.llm_chat, | |
| llm_summary=app.state.llm_summary, | |
| retriever=app.state.vector_db.get_retriever(), | |
| get_history_fn=app.state.history_store.get_session_history, | |
| ) | |
| # Initialize RAG evaluator using DeepEval | |
| app.state.evaluator = RAGEvaluator( | |
| llm_model=config.LLM_CHAT_MODEL_NAME, | |
| ollama_base_url=config.OLLAMA_BASE_URL, | |
| temperature=0.0, | |
| ) | |
| # Initialize response cache (cache hits = <100ms, no LLM generation needed) | |
| app.state.response_cache = ResponseCache(ttl_seconds=3600) # 1 hour TTL | |
| log.info("β ResponseCache instance created and stored in app.state") | |
| log.info("[LifeSpan] All LLM components initialized.") | |
| # pg_db.delete_database() | |
| pg_db.create_tables() | |
| # Files | |
| files.check_create_uploads_folder() | |
| files.delete_empty_user_folders() | |
| # [ Lifespan ] | |
| yield | |
| # [ Shutdown ] | |
| log.info("[LifeSpan] Shutting down LLM server...") | |
| # Add any cleanup part here | |
| # Like saving vector DB, or shutting down subprocesses | |
| # Make one FastAPI app instance with the lifespan context manager | |
| app = FastAPI(lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=getattr(config, "ALLOWED_ORIGINS", ["http://localhost:8501", "http://127.0.0.1:5500"]), | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST"], | |
| allow_headers=["*"] | |
| ) | |
| # ------------------------------------------------------------------------------ | |
| # Basic API Endpoints: | |
| # ------------------------------------------------------------------------------ | |
| async def root(): | |
| """Root endpoint to check if the server is running.""" | |
| return { | |
| "message": "LLM RAG Server is running!", | |
| "further": "Proceed to code ur application :)", | |
| "thought": "You really are not supposed to be reading this waste of time, but if you are, then you are a curious person. I like that! π", | |
| } | |
| async def cache_debug(): | |
| """Debug endpoint to inspect current cache state and metrics. | |
| Returns detailed information about all cached responses including: | |
| - Total cache size (number of cached queries) | |
| - Cache keys (SHA256 hashes of normalized questions) | |
| - Entry previews (first 100 chars of cached answers) | |
| - Timestamps (creation time for LRU eviction tracking) | |
| Use this endpoint to: | |
| - Verify cache is working and storing responses | |
| - Monitor cache performance and hit patterns | |
| - Debug cache-related issues | |
| - Track memory usage (cache_size vs max 500 entries) | |
| Returns: | |
| Dict with: | |
| - cache_size (int): Current number of cached entries | |
| - cache_keys (list): All cache keys (SHA256 hashes) | |
| - entries (list): Detailed info per cached response: | |
| - key: Cache key (SHA256 hash) | |
| - answer_preview: First 100 chars of cached answer | |
| - created_at: Unix timestamp when cached | |
| Example: | |
| GET /cache-debug | |
| Response: | |
| { | |
| "cache_size": 3, | |
| "cache_keys": ["a7f3b2c...", "f1e2d3c...", "9a8b7c6..."], | |
| "entries": [ | |
| { | |
| "key": "a7f3b2c...", | |
| "answer_preview": "RAG is a technique that combines retrieval with...", | |
| "created_at": 1702500000.123 | |
| }, | |
| ... | |
| ] | |
| } | |
| """ | |
| from llm_system.core.cache import _response_cache | |
| return { | |
| "cache_size": len(_response_cache), | |
| "cache_keys": list(_response_cache.keys()), | |
| "entries": [ | |
| { | |
| "key": k, | |
| "answer_preview": v["answer"][:100] if v.get("answer") else None, | |
| "created_at": v.get("created_at") | |
| } | |
| for k, v in _response_cache.items() | |
| ] | |
| } | |
| async def cache_clear(request: Request, clear_request: dict = None): | |
| """Clear response cache to get fresh answers. | |
| Useful when documents are uploaded/updated and cached responses are stale. | |
| Request (optional): | |
| { | |
| "session_id": "user123" # If provided, clears only this user's cache | |
| } | |
| Response: | |
| { | |
| "status": "success", | |
| "message": "Cache cleared", | |
| "cleared_entries": 5 | |
| } | |
| """ | |
| response_cache = request.app.state.response_cache | |
| if clear_request and "session_id" in clear_request: | |
| # Clear cache for specific user | |
| session_id = clear_request["session_id"] | |
| before_size = len(_response_cache) | |
| response_cache.clear_user_cache(session_id) | |
| after_size = len(_response_cache) | |
| cleared = before_size - after_size | |
| log.info(f"ποΈ Cleared {cleared} cache entries for user: {session_id}") | |
| return { | |
| "status": "success", | |
| "message": f"Cache cleared for user: {session_id}", | |
| "cleared_entries": cleared | |
| } | |
| else: | |
| # Clear entire cache | |
| before_size = len(_response_cache) | |
| response_cache.clear() | |
| log.info(f"ποΈ Entire cache cleared ({before_size} entries)") | |
| return { | |
| "status": "success", | |
| "message": "Entire cache cleared", | |
| "cleared_entries": before_size | |
| } | |
| # Define data model for chat request | |
| class BasicChatRequest(BaseModel): | |
| query: str | |
| session_id: str | |
| dummy: bool = False | |
| async def simple(request: Request, chat_request: BasicChatRequest): | |
| """Endpoint to handle one time generation queries. | |
| - Post request expects JSON `{"query": "", "session_id": "", "dummy":T/F}` structure. | |
| - Return JSON with `{"response": "", "session_id": ""}` structure. | |
| """ | |
| llm = request.app.state.llm_chat | request.app.state.output_parser | |
| session_id = chat_request.session_id.strip() or "unknown_session" | |
| try: | |
| query = chat_request.query | |
| dummy = chat_request.dummy | |
| log.info(f"/simple Requested by '{session_id}'") | |
| if dummy: | |
| log.info(f"/simple Dummy response returned for '{session_id}'") | |
| return get_dummy_response() | |
| else: | |
| result = await llm.ainvoke(input=query) | |
| log.info(f"/simple Response generated for '{session_id}'.") | |
| return {"response": result, "session_id": session_id} | |
| except Exception as e: | |
| log.exception(f"/simple Error {e} for '{session_id}'") | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| # Make one streaming endpoint for the Simple LLM response: | |
| class StreamChatRequest(BaseModel): | |
| query: str | |
| session_id: str | |
| dummy: bool = False | |
| async def chat_stream(request: Request, chat_request: StreamChatRequest): | |
| """Endpoint to handle streaming responses for one time generation queries. | |
| - Post request expects JSON `{"query": "", "session_id": "", "dummy":T/F}` structure. | |
| - Return NDJSON with types "metadata", "content", or "error". | |
| """ | |
| llm = request.app.state.llm_chat | request.app.state.output_parser | |
| session_id = chat_request.session_id.strip() or "unknown_session" | |
| async def token_streamer(): | |
| try: | |
| dummy = chat_request.dummy | |
| s = 'dummy' if dummy else 'real' | |
| log.info(f"/simple/stream {s} response requested by '{session_id}'") | |
| # Start be sending meta data first. | |
| yield json.dumps({ | |
| "type": "metadata", | |
| "data": {"session_id": session_id} | |
| }) + "\n" | |
| # NDJSON (newline-delimited JSON) - Frontend will merge full response my splitting this | |
| # Then send the actual response content: | |
| if dummy: | |
| # If dummy is True, stream dummy response | |
| resp = get_dummy_response_stream( | |
| batch_tokens=config.BATCH_TOKEN_PS, | |
| token_rate=config.TOKENS_PER_SEC | |
| ) | |
| for chunk in resp: | |
| if await request.is_disconnected(): | |
| log.warning(f"/simple/stream client disconnected for '{session_id}'") | |
| break | |
| yield json.dumps({ | |
| "type": "content", | |
| "data": chunk | |
| }) + "\n" | |
| else: | |
| async for chunk in llm.astream(chat_request.query): | |
| if await request.is_disconnected(): | |
| log.warning(f"/simple/stream client disconnected for '{session_id}'") | |
| break | |
| yield json.dumps({ | |
| "type": "content", | |
| "data": chunk | |
| }) + "\n" | |
| # In the end, you can send some "Done" etc if u need some conditional logic | |
| # Server will auto send EOF to mark end of generator response. | |
| # yield json.dumps({ | |
| # "type": "end", | |
| # "data": "done" | |
| # }) + "\n" | |
| log.info(f"/simple/stream Streaming completed for '{session_id}'") | |
| except Exception as e: | |
| log.exception(f"/simple/stream Error {e} for '{session_id}'") | |
| yield json.dumps({ | |
| "type": "error", | |
| "data": str(e) | |
| }) + "\n" | |
| # Return a StreamingResponse with the token streamer generator (basically enable streaming) | |
| return StreamingResponse(token_streamer(), media_type="text/plain") | |
| # ------------------------------------------------------------------------------ | |
| # Initialization End-points: | |
| # ------------------------------------------------------------------------------ | |
| # Helper function to delete old files and embeddings: | |
| def delete_old_files(user_id: str, time: int = OLD_FILE_THRESHOLD): | |
| """Function to delete old files and embeddings older than the specified time.""" | |
| log.info( | |
| f"/delete Deleting old files and embeddings for user '{user_id}' older than {time} seconds") | |
| # Delete old files | |
| old_files = pg_db.get_old_files(user_id=user_id, time=time) | |
| if old_files['files']: | |
| log.info(f"/delete Removing old files for user '{user_id}': {old_files['files']}") | |
| for file in old_files['files']: | |
| status = files.delete_file(user_id=user_id, file_name=file) | |
| if status: | |
| file_id = pg_db.get_file_id_by_name(user_id=user_id, file_name=file) | |
| pg_db.mark_file_removed(user_id=user_id, file_id=file_id) | |
| # Delete old embeddings | |
| if old_files['embeddings']: | |
| log.info(f"/delete Removing old embeddings for user '{user_id}'") | |
| vs: VectorDB = app.state.vector_db | |
| db: T_VECTOR_STORE = vs.get_vector_store() | |
| resp = db.delete(old_files['embeddings']) | |
| # Save the changes to disk | |
| vs.save_db_to_disk() | |
| if resp == True: | |
| pg_db.mark_embeddings_removed(vector_ids=old_files['embeddings']) | |
| log.info(f"/delete Old embeddings removed for user '{user_id}'") | |
| else: | |
| log.error(f"/delete Failed to remove old embeddings for user '{user_id}': {resp}") | |
| else: | |
| log.info(f"/delete No old files found for user '{user_id}'") | |
| # First end-point to call on client initialization: | |
| class LoginRequest(BaseModel): | |
| login_id: str | |
| password: str | |
| async def login(request: Request, login_request: LoginRequest): | |
| """User authentication endpoint with session initialization. | |
| Authenticates user credentials against PostgreSQL database, creates user upload | |
| folder, and cleans up old user files (>24 hours by default). Sets up isolated | |
| document namespace for multi-user RAG queries. | |
| Args: | |
| request: FastAPI Request object | |
| login_request: LoginRequest with: | |
| - login_id (str): Username | |
| - password (str): User's password (validated against DB) | |
| Returns: | |
| JSONResponse with status 200: | |
| { | |
| "user_id": str (same as login_id), | |
| "name": str (full name from database) | |
| } | |
| JSONResponse with status 401 on authentication failure: | |
| {"error": str (authentication error message)} | |
| Side Effects: | |
| - Creates user upload folder: user_uploads/{user_id}/ | |
| - Deletes old files (>24 hours) from user's folder | |
| - User becomes isolated for document-based RAG queries | |
| Security: | |
| - Password validated via pg_db.authenticate_user() | |
| - Returns 401 Unauthorized on failed authentication | |
| - User_id determines document filtering in RAG queries | |
| Example: | |
| POST /login | |
| { | |
| "login_id": "alice", | |
| "password": "secure_password" | |
| } | |
| Response (200): | |
| {"user_id": "alice", "name": "Alice Johnson"} | |
| """ | |
| login_id = login_request.login_id.strip() | |
| password = login_request.password.strip() | |
| log.info(f"/login Requested by '{login_id}'") | |
| # Check if the user exists in the database | |
| status, msg = pg_db.authenticate_user(user_id=login_id, password=password) | |
| if status: | |
| user_id = login_id | |
| # Check if folder exists in UPLOADS_DIR with user_id | |
| files.create_user_uploads_folder(user_id=user_id) | |
| # Delete any older data if exists | |
| delete_old_files(user_id=user_id, time=OLD_FILE_THRESHOLD) | |
| return JSONResponse(content={"user_id": user_id, "name": msg}, status_code=200) | |
| else: | |
| return JSONResponse(content={"error": msg}, status_code=401) | |
| # # For now, we will just return a dummy user_id | |
| # # In future, can implement actual user authentication and return a real user_id | |
| # user_id = login_id | |
| # log.info(f"/login requested by '{user_id}'") | |
| # # Check if folder exists in UPLOADS_DIR with user_id | |
| # files.create_user_uploads_folder(user_id=user_id) | |
| # # Old any older data if exists (older than 24 hours) | |
| # delete_old_files(user_id=user_id, time=OLD_FILE_THRESHOLD) | |
| # # Get the chat history for the user_id | |
| # hs: HistoryStore = request.app.state.history_store | |
| # history = hs.get_session_history(session_id=user_id) | |
| # if not history: | |
| # log.info(f"/login No history found for user '{user_id}'") | |
| # else: | |
| # log.info(f"/login History found for user '{user_id}' with {len(history.messages)} messages") | |
| # return {"user_id": user_id, "chat_history": history.messages} | |
| # endpoint for user registration: | |
| class RegisterRequest(BaseModel): | |
| name: str | |
| user_id: str | |
| password: str | |
| async def register(request: Request, register_request: RegisterRequest): | |
| """Endpoint to handle user registration. | |
| - Post request expects JSON `{"user_name": "Full Name", "user_id": "any_u_id", "password": "raw_pw"}` structure. | |
| - Return JSON with `{"status": "success"}` or `{"error": "message"}` structure. | |
| """ | |
| name = register_request.name.strip() | |
| user_id = register_request.user_id.strip() | |
| password = register_request.password.strip() | |
| log.info(f"/register Requested by {name} with '{user_id}'") | |
| print(f"Name: {name}, UserID: {user_id}, Password: {password}") | |
| # Check if the user already exists | |
| status = pg_db.check_user_exists(user_id=user_id) | |
| if status: | |
| log.error(f"/register UserID '{user_id}' already exists.") | |
| return JSONResponse(content={"error": "User already exists"}, status_code=400) | |
| # If user does not exist, add the user to the database | |
| status = pg_db.add_user(user_id=user_id, name=name, password=password) | |
| if status: | |
| return JSONResponse(content={"status": "success"}, status_code=201) | |
| else: | |
| return JSONResponse(content={"error": "Failed to register user"}, status_code=500) | |
| # ------------------------------------------------------------------------------ | |
| # Chat History Endpoints: | |
| # ------------------------------------------------------------------------------ | |
| # Endpoint to get chat history for user: | |
| async def chat_history(user_id: str = Form(...)): | |
| """Endpoint to get chat history for user. | |
| - Post request expects `user_id` as form parameter. | |
| - Return JSON with `{"chat_history": [user chat history]}` or `{"error": "message"}` structure. | |
| """ | |
| log.info(f"/chat_history Requested by '{user_id}'") | |
| hs: HistoryStore = app.state.history_store | |
| history = hs.get_session_history(session_id=user_id) | |
| if history: | |
| messages = [] | |
| for msg in history.messages: | |
| msg: T_MESSAGE | |
| if msg.type == "ai": | |
| messages.append({"role": "assistant", "content": msg.text()}) | |
| elif msg.type == "human": | |
| messages.append({"role": "human", "content": msg.text()}) | |
| return JSONResponse(content={"chat_history": messages}, status_code=200) | |
| else: | |
| return JSONResponse(content={"error": "No chat history found"}, status_code=404) | |
| # Endpoint /clear_chat_history to clear chat history for user: | |
| async def clear_chat_history(user_id: str = Form(...)): | |
| """Endpoint to clear chat history for user. | |
| - Post request expects `user_id` as form parameter. | |
| - Return JSON with `{"status": "success"}` or `{"error": "message"}` structure. | |
| """ | |
| log.info(f"/clear_chat_history Requested by '{user_id}'") | |
| hs: HistoryStore = app.state.history_store | |
| status = hs.clear_session_history(session_id=user_id) | |
| if status: | |
| return JSONResponse(content={"status": "success"}, status_code=200) | |
| else: | |
| return JSONResponse(content={"error": "No history found to clear"}, status_code=404) | |
| # ------------------------------------------------------------------------------ | |
| # File handling endpoints: | |
| # ------------------------------------------------------------------------------ | |
| # Endpoint to receive file uploads: | |
| async def upload_file(file: UploadFile = File(...), user_id: str = Form(...)): | |
| """File upload endpoint for RAG document ingestion. | |
| Handles multipart file uploads and stores them in user-isolated directory. | |
| Saves file metadata to PostgreSQL for tracking. Files are ready for embedding | |
| via the /embed endpoint. Supports PDF, TXT, DOCX, and other document formats. | |
| Args: | |
| file (UploadFile): Binary file content (PDF, TXT, DOCX, etc.) | |
| user_id (str): User identifier for directory isolation | |
| Returns: | |
| JSONResponse (200): {"message": str (stored_filename)} | |
| JSONResponse (500): {"error": str (error_message)} on failure | |
| Side Effects: | |
| - Stores file in: user_uploads/{user_id}/{filename} | |
| - Adds file metadata to PostgreSQL (user_id, filename, timestamp) | |
| - File is NOT immediately searchable; requires /embed endpoint | |
| Security: | |
| - Files stored in user-specific directory | |
| - Prevents cross-user document access via RAG filtering | |
| Example: | |
| POST /upload (multipart form) | |
| file: <binary PDF content> | |
| user_id: alice | |
| Response (200): | |
| {"message": "document_2024_01_15_123456.pdf"} | |
| """ | |
| log.info(f"/upload Received file: {file.filename} from user: {user_id}") | |
| filename = file.filename if file.filename else "unknown_file" | |
| status, message = files.save_file( | |
| user_id=user_id, | |
| file_value_binary=await file.read(), | |
| file_name=filename | |
| ) | |
| if status: | |
| filename = message | |
| pg_db.add_file_compat(user_id=user_id, filename=filename) | |
| return JSONResponse(content={"message": filename}, status_code=200) | |
| else: | |
| log.error(f"/upload File upload failed for user {user_id}: {filename}") | |
| return JSONResponse(content={"error": message}, status_code=500) | |
| # Endpoint to embed the uploaded file: | |
| # takes user_id and file_name as input | |
| class EmbedRequest(BaseModel): | |
| user_id: str | |
| file_name: str | |
| async def embed_file(embed_request: EmbedRequest, request: Request): | |
| """Document embedding endpoint with multimodal support. | |
| Processes uploaded documents into semantic embeddings and stores in Qdrant | |
| vector database. Automatically extracts and embeds images from PDFs when | |
| available. Multimodal embeddings (Jina) enable unified search across text and images. | |
| This is a computationally expensive operation (5-30s depending on document size). | |
| Embeddings enable semantic search: similar questions retrieve similar documents. | |
| Args: | |
| embed_request: EmbedRequest with: | |
| - user_id (str): User identifier | |
| - file_name (str): Filename from /upload response | |
| request: FastAPI Request object (contains app state: vector_db) | |
| Returns: | |
| JSONResponse (200): { | |
| "status": "success", | |
| "message": str, | |
| "items_embedded": int (text chunks + images), | |
| "text_chunks": int, | |
| "images_extracted": int, | |
| "image_paths": [str] (paths to extracted images) | |
| } - embedding completed with multimodal metadata | |
| JSONResponse (500): {"error": str} - embedding failed | |
| Side Effects: | |
| - Reads document from: user_uploads/{user_id}/{file_name} | |
| - Chunks document (configurable chunk size) | |
| - Extracts images from PDF (if available) | |
| - Computes embeddings via configured embedding model (text + images) | |
| - Stores vectors + metadata in Qdrant under collection | |
| - Updates PostgreSQL with embedding metadata | |
| Workflow: | |
| 1. Call /upload to store file | |
| 2. Call /embed with returned filename (now returns image metadata) | |
| 3. Use /rag to query (documents + images now searchable) | |
| Performance: | |
| - Depends on document size and image count | |
| - Typical PDF: 5-10s | |
| - Large documents with many images: 20-30s | |
| Multimodal: | |
| - Uses Jina v4 embeddings if configured (unified 2048-dim space) | |
| - Falls back to Ollama embeddings if Jina not available (text-only) | |
| - Image extraction automatic from PDF XObjects | |
| Example: | |
| POST /embed | |
| { | |
| "user_id": "alice", | |
| "file_name": "document_2024_01_15_123456.pdf" | |
| } | |
| Response (200): | |
| { | |
| "status": "success", | |
| "message": "Ingested 82 items (20 text chunks + 62 images)", | |
| "items_embedded": 82, | |
| "text_chunks": 20, | |
| "images_extracted": 62, | |
| "image_paths": [ | |
| "user_uploads/extracted_images/alice/document/img_001.png", | |
| ... | |
| ] | |
| } | |
| """ | |
| user_id = embed_request.user_id.strip() | |
| file_name = embed_request.file_name.strip() | |
| log.info(f"π [/EMBED START] user_id='{user_id}', file_name='{file_name}'") | |
| # Get file path | |
| file_path = files.get_file_path(user_id=user_id, file_name=file_name) | |
| log.info(f"π File path: {file_path}") | |
| # Call the ingest_file function to process the file (now with multimodal support) | |
| log.info(f"β³ Calling ingest_file() with multimodal support...") | |
| status, doc_ids, message = ingest_file( | |
| user_id=user_id, | |
| file_path=file_path, | |
| vectorstore=request.app.state.vector_db, | |
| embeddings=request.app.state.vector_db.get_embeddings() | |
| ) | |
| log.info(f"π ingest_file() returned: status={status}, doc_ids_count={len(doc_ids) if doc_ids else 0}, message={message}") | |
| if status: | |
| log.info(f"β Ingestion succeeded, storing embeddings in database") | |
| file_id = pg_db.get_file_id_by_name(user_id=user_id, file_name=file_name) | |
| log.info(f"π File ID: {file_id}, storing {len(doc_ids)} vector IDs") | |
| for vid in doc_ids: | |
| pg_db.add_embedding_compat(file_id=file_id, vector_id=vid) | |
| # Extract image metadata from message and doc_ids | |
| # Parse message format: "Ingested XX items (YY text chunks + ZZ images)." | |
| import re | |
| text_chunks = 0 | |
| images_extracted = 0 | |
| match = re.search(r'(\d+) text chunks \+ (\d+) images', message) | |
| if match: | |
| text_chunks = int(match.group(1)) | |
| images_extracted = int(match.group(2)) | |
| # Build response with multimodal metadata | |
| response_data = { | |
| "status": "success", | |
| "message": message, | |
| "items_embedded": len(doc_ids), | |
| "text_chunks": text_chunks, | |
| "images_extracted": images_extracted | |
| } | |
| # Add image paths if images were extracted | |
| if images_extracted > 0: | |
| from pathlib import Path | |
| image_dir = Path("user_uploads") / "extracted_images" / user_id | |
| # Find all subdirectories that might contain this document's images | |
| image_paths = [] | |
| if image_dir.exists(): | |
| for subdir in image_dir.iterdir(): | |
| if subdir.is_dir(): | |
| for img_file in subdir.glob("*.png"): | |
| image_paths.append(str(img_file)) | |
| response_data["image_paths"] = image_paths[:images_extracted] # Limit to extracted count | |
| log.info(f"β [/EMBED SUCCESS] Embedding completed with {text_chunks} text chunks and {images_extracted} images") | |
| return JSONResponse(content=response_data, status_code=200) | |
| else: | |
| log.error(f"β [/EMBED FAILED] Embedding failed for '{user_id}' and file '{file_name}': {message}") | |
| return JSONResponse(content={"error": message}, status_code=500) | |
| # ------------------------------------------------------------------------------ | |
| # Data management endpoints: | |
| # ------------------------------------------------------------------------------ | |
| # Endpoint /clear_my_files to clear all files uploaded by user: | |
| async def clear_my_files(user_id: str = Form(...)): | |
| """Endpoint to clear all files uploaded by user. | |
| - Post request expects `user_id` as form parameter. | |
| - Return JSON with `{"status": "success"}` or `{"error": "message"}` structure. | |
| """ | |
| log.info(f"/clear_my_files Requested by '{user_id}'") | |
| delete_old_files(user_id=user_id, time=1) | |
| return JSONResponse(content={"status": "success"}, status_code=200) | |
| # End point to get all the files uploaded by user: | |
| # This will be called first at initialization, and then after each file upload | |
| async def get_files(user_id: str = Query(...)): | |
| """Endpoint to get all the files uploaded by user. | |
| - Get request expects `user_id` as query parameter. | |
| - Return JSON with `{"files": ["file1", "file2", ...]}` structure. | |
| """ | |
| log.info(f"/uploads Requested by '{user_id}'") | |
| files_list = pg_db.get_user_files_compat(user_id=user_id) | |
| return {"files": files_list} | |
| # Send pdf iframe based on user and file name: | |
| # params: type=pdf/ppt/txt, user_id, file_name, num_pages | |
| class FileIframeRequest(BaseModel): | |
| # type: Literal["pdf", "ppt", "txt"] | |
| user_id: str | |
| file_name: str | |
| num_pages: int = 5 | |
| async def get_file_iframe(file_request: FileIframeRequest): | |
| """Endpoint to get the iframe for the file. | |
| - Post request expects JSON `{"user_id": "", "file_name": "", "num_pages": 5}` structure. | |
| - Return JSON with `{"iframe": "<iframe>...</iframe>"}` structure. | |
| """ | |
| user_id = file_request.user_id.strip() | |
| file_name = file_request.file_name.strip() | |
| num_pages = file_request.num_pages | |
| log.info(f"/iframe Requested by '{user_id}' for file '{file_name}'") | |
| # Get the iframe for the requested file | |
| status, message = files.get_pdf_iframe( | |
| user_id=user_id, | |
| file_name=file_name, | |
| num_pages=num_pages | |
| ) | |
| if status: | |
| return JSONResponse(content={"iframe": message}, status_code=200) | |
| else: | |
| return JSONResponse(content={"error": message}, status_code=404) | |
| # ------------------------------------------------------------------------------ | |
| # RAG Chain Endpoint: | |
| # ------------------------------------------------------------------------------ | |
| # Create endpoint for rag: | |
| # input = { | |
| # query: str, | |
| # session_id: str, | |
| # dummy: bool = False | |
| # } | |
| # Output will be streamed in same format as the simple/streaming chat endpoint. | |
| class RagChatRequest(BaseModel): | |
| query: str | |
| session_id: str | |
| dummy: bool = False | |
| async def rag(request: Request, chat_request: RagChatRequest): | |
| """RAG-powered streaming endpoint for question answering. | |
| Implements Retrieval-Augmented Generation with query caching for 700x performance | |
| improvement on repeated questions. Streams tokens in NDJSON format for real-time | |
| response display. Supports optional async evaluation metrics (Answer Relevancy, | |
| Faithfulness) without blocking response stream. | |
| Args: | |
| request: FastAPI Request object (contains app state: rag_chain, evaluator, cache) | |
| chat_request: RagChatRequest with: | |
| - query (str): User's question | |
| - session_id (str): User/session identifier for context filtering | |
| - dummy (bool): If True, returns simulated response for testing | |
| Yields: | |
| NDJSON (JSON lines) with types: | |
| - "metadata": {"session_id": str} | |
| - "content": str (streamed answer tokens) | |
| - "context": [{"source": str, "content": str}, ...] (retrieved documents) | |
| - "metrics": {"answer_relevancy": float, "faithfulness": float} (optional) | |
| - "cached": True (indicates cache hit, skips evaluation) | |
| - "error": str (if error occurs) | |
| Performance: | |
| - Cache hit (repeated question): <100ms β‘ | |
| - Cache miss (new question): 70-90s (includes LLM + evaluation) | |
| - Cache key: SHA256(normalized_question) - global across all users | |
| - Caching improves P50 latency from 70s β 30-40s in typical workloads | |
| Security: | |
| - Documents filtered by user_id and "public" group | |
| - Each user only sees their uploaded files + public documents | |
| Example: | |
| POST /rag | |
| { | |
| "query": "What is retrieval-augmented generation?", | |
| "session_id": "user123", | |
| "dummy": false | |
| } | |
| Response (NDJSON): | |
| {"type": "metadata", "data": {"session_id": "user123"}} | |
| {"type": "content", "data": "Retrieval-Augmented Generation"} | |
| {"type": "context", "data": [{"source": "doc1.pdf", "content": "..."}]} | |
| {"type": "metrics", "data": {"answer_relevancy": 0.92, "faithfulness": 0.88}} | |
| """ | |
| rag_chain = request.app.state.rag_chain | |
| evaluator = request.app.state.evaluator | |
| response_cache = request.app.state.response_cache | |
| session_id = chat_request.session_id.strip() or "unknown_session" | |
| async def token_streamer(): | |
| try: | |
| dummy = chat_request.dummy | |
| log.info(f"/rag {'dummy' if dummy else 'real'} response requested by '{session_id}' query='{chat_request.query[:40]}...'") | |
| # Start by sending meta data first. | |
| yield json.dumps({ | |
| "type": "metadata", | |
| "data": {"session_id": session_id} | |
| }) + "\n" | |
| # Check cache FIRST - if hit, return cached answer immediately (<100ms) | |
| if not dummy: | |
| cached_answer = response_cache.get(chat_request.query, session_id) | |
| if cached_answer: | |
| log.info(f"β‘ CACHE HIT! Returning cached response (saves ~70s)") | |
| yield json.dumps({ | |
| "type": "content", | |
| "data": cached_answer | |
| }) + "\n" | |
| yield json.dumps({ | |
| "type": "cached", | |
| "data": True | |
| }) + "\n" | |
| return | |
| if dummy: | |
| # If dummy is True, stream dummy response | |
| resp = get_dummy_response_stream( | |
| batch_tokens=config.BATCH_TOKEN_PS, | |
| token_rate=config.TOKENS_PER_SEC | |
| ) | |
| for chunk in resp: | |
| if await request.is_disconnected(): | |
| log.warning(f"/rag client disconnected for '{session_id}'") | |
| break | |
| yield json.dumps({ | |
| "type": "content", | |
| "data": chunk | |
| }) + "\n" | |
| else: | |
| log.info(f"π Starting RAG streaming for '{session_id}'") | |
| # Variables to collect for evaluation | |
| collected_answer = "" | |
| collected_contexts = [] | |
| context_sent = False | |
| # Search kwargs for the configurable retriever: | |
| search_kwargs = { | |
| "k": 15, | |
| "search_type": "similarity", | |
| "filter": { | |
| "$or": [ | |
| {"user_id": session_id}, | |
| {"user_id": "public"} | |
| ] | |
| }, | |
| } | |
| async for chunk in rag_chain.astream( | |
| input={"input": chat_request.query}, | |
| config={ | |
| "configurable": { | |
| "session_id": session_id, | |
| "search_kwargs": search_kwargs | |
| } | |
| } | |
| ): | |
| if await request.is_disconnected(): | |
| log.warning(f"/rag client disconnected for '{session_id}'") | |
| break | |
| # there is answer/input/context | |
| if "answer" in chunk: | |
| collected_answer += chunk["answer"] | |
| log.debug(f"Answer chunk collected, total length: {len(collected_answer)}") | |
| yield json.dumps({ | |
| "type": "content", | |
| "data": chunk["answer"] | |
| }) + "\n" | |
| elif "context" in chunk and not context_sent: | |
| log.info(f"π Context chunk received with {len(chunk['context'])} documents") | |
| # Send context as a single chunk, not for each document | |
| for document in chunk["context"]: | |
| if await request.is_disconnected(): | |
| log.warning(f"/rag client disconnected for '{session_id}'") | |
| break | |
| # Collect context for evaluation | |
| collected_contexts.append(document.page_content) | |
| # Hide user_id from metadata on UI | |
| if "user_id" in document.metadata: | |
| if document.metadata["user_id"] == "public": | |
| document.metadata["isPublicDocument"] = True | |
| else: | |
| document.metadata["isPublicDocument"] = False | |
| document.metadata.pop("user_id") | |
| # Prepare context data with multimodal support | |
| context_data = { | |
| "metadata": document.metadata, | |
| "page_content": document.page_content | |
| } | |
| # If this is an image document, include image path in response | |
| if document.metadata.get("type") == "image" and "image_path" in document.metadata: | |
| context_data["image_path"] = document.metadata["image_path"] | |
| context_data["is_image"] = True | |
| yield json.dumps({ | |
| "type": "context", | |
| "data": context_data | |
| }) + "\n" | |
| context_sent = True | |
| # Non-blocking metric evaluation via background task (P99 < 8s) | |
| log.info(f"π Collected answer length: {len(collected_answer)}, contexts: {len(collected_contexts)}") | |
| # Cache the response for future identical queries | |
| if collected_answer and collected_contexts: | |
| log.info(f"πΎ Caching response for future queries (saves ~70s on cache hit)") | |
| response_cache.set(chat_request.query, session_id, collected_answer) | |
| if collected_answer and collected_contexts and config.ENABLE_METRICS_EVALUATION: | |
| log.info(f"β³ Starting background evaluation (non-blocking)") | |
| # Async callback to handle metrics when ready | |
| async def _on_metrics_ready(metrics: Dict): | |
| """Called when background evaluation completes.""" | |
| log.info(f"π― Background metrics ready: {metrics}") | |
| # In production, store in Redis/DB for UI polling | |
| # For now, just log it | |
| try: | |
| # Start background evaluation (returns immediately) | |
| await evaluator.evaluate_response_background( | |
| question=chat_request.query, | |
| answer=collected_answer, | |
| contexts=collected_contexts, | |
| callback=_on_metrics_ready, | |
| ) | |
| # Send placeholder metrics immediately (non-blocking) | |
| yield json.dumps({ | |
| "type": "metrics", | |
| "data": { | |
| "status": "computing", | |
| "answer_relevancy": None, | |
| "faithfulness": None, | |
| "message": "Metrics computing in background..." | |
| } | |
| }) + "\n" | |
| log.info(f"β Background evaluation task started (non-blocking)") | |
| except Exception as eval_error: | |
| log.error(f"Failed to start background evaluation: {eval_error}") | |
| yield json.dumps({ | |
| "type": "metrics", | |
| "data": { | |
| "error": "Evaluation failed", | |
| "details": str(eval_error) | |
| } | |
| }) + "\n" | |
| elif not config.ENABLE_METRICS_EVALUATION: | |
| log.info(f"βοΈ Metrics evaluation disabled (ENABLE_METRICS_EVALUATION=false)") | |
| else: | |
| log.warning(f"Skipping evaluation: answer={len(collected_answer) > 0}, contexts={len(collected_contexts) > 0}") | |
| log.info(f"/rag Streaming completed for '{session_id}'") | |
| except Exception as e: | |
| log.exception(f"/rag Error {e} for '{session_id}'") | |
| yield json.dumps({ | |
| "type": "error", | |
| "data": str(e) | |
| }) + "\n" | |
| return StreamingResponse(token_streamer(), media_type="text/plain") | |
| # ------------------------------------------------------------------------------ | |
| # Run the FastAPI server: | |
| # ------------------------------------------------------------------------------ | |
| if __name__ == "__main__": | |
| print("WARNING: Starting server without explicit uvicorn command. Not recommended for production use.") | |
| import uvicorn | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=8000, | |
| reload=False | |
| ) | |