Spaces:
Sleeping
Sleeping
| """ | |
| Module for managing PostgreSQL database operations. | |
| - Replaces SQLite (sq_db.py) with PostgreSQL using SQLAlchemy | |
| - Provides functions to track user uploaded/generated files/data | |
| - Includes creating tables, adding files and embeddings, and managing users | |
| """ | |
| import bcrypt | |
| import os | |
| from typing import List, Optional, Dict, Any | |
| from datetime import datetime, timedelta | |
| from contextlib import contextmanager | |
| import pytz | |
| from sqlalchemy import create_engine, Column, String, Integer, DateTime, Boolean, Text, ForeignKey | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from sqlalchemy.orm import sessionmaker, Session, relationship | |
| from logger import get_logger | |
| log = get_logger(name="pg_db") | |
| # Get DATABASE_URL from environment or use default | |
| DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://raguser:ragpass@localhost:5432/ragdb") | |
| # Create engine | |
| engine = create_engine(DATABASE_URL, pool_pre_ping=True, echo=False) | |
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
| Base = declarative_base() | |
| IST = pytz.timezone('Asia/Kolkata') | |
| # ------------------------------------------------------------------------------ | |
| # SQLAlchemy Models | |
| # ------------------------------------------------------------------------------ | |
| class User(Base): | |
| """User model""" | |
| __tablename__ = "users" | |
| user_id = Column(String, primary_key=True, index=True) | |
| password_hash = Column(String, nullable=False) | |
| created_at = Column(DateTime, default=lambda: datetime.now(IST)) | |
| last_login = Column(DateTime, nullable=True) | |
| available = Column(Boolean, default=True) | |
| # Relationships | |
| files = relationship("UserFile", back_populates="user") | |
| embeddings = relationship("UserEmbedding", back_populates="user") | |
| class UserFile(Base): | |
| """User uploaded files model""" | |
| __tablename__ = "user_files" | |
| id = Column(Integer, primary_key=True, autoincrement=True) | |
| user_id = Column(String, ForeignKey("users.user_id"), nullable=False, index=True) | |
| file_name = Column(String, nullable=False) | |
| file_path = Column(String, nullable=False) | |
| file_type = Column(String, nullable=True) | |
| uploaded_at = Column(DateTime, default=lambda: datetime.now(IST)) | |
| available = Column(Boolean, default=True) | |
| # Relationship | |
| user = relationship("User", back_populates="files") | |
| class UserEmbedding(Base): | |
| """User embeddings/documents model""" | |
| __tablename__ = "user_embeddings" | |
| id = Column(Integer, primary_key=True, autoincrement=True) | |
| user_id = Column(String, ForeignKey("users.user_id"), nullable=False, index=True) | |
| qdrant_doc_id = Column(String, nullable=True) # ID in Qdrant | |
| source = Column(String, nullable=True) | |
| created_at = Column(DateTime, default=lambda: datetime.now(IST)) | |
| available = Column(Boolean, default=True) | |
| # Relationship | |
| user = relationship("User", back_populates="embeddings") | |
| # ------------------------------------------------------------------------------ | |
| # Database Management Functions | |
| # ------------------------------------------------------------------------------ | |
| def init_database(): | |
| """Initialize the database by creating all tables.""" | |
| try: | |
| Base.metadata.create_all(bind=engine) | |
| log.info("Database tables created successfully.") | |
| return True | |
| except Exception as e: | |
| log.error(f"Error creating database tables: {e}") | |
| return False | |
| def get_db(): | |
| """Context manager for database sessions.""" | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| db.commit() | |
| except Exception: | |
| db.rollback() | |
| raise | |
| finally: | |
| db.close() | |
| def get_connection(): | |
| """Returns a new database session. (For compatibility with existing code)""" | |
| return SessionLocal() | |
| def delete_database() -> bool: | |
| """Drop all tables (equivalent to deleting SQLite file).""" | |
| try: | |
| Base.metadata.drop_all(bind=engine) | |
| log.info("All database tables dropped successfully.") | |
| return True | |
| except Exception as e: | |
| log.error(f"Error dropping database tables: {e}") | |
| return False | |
| # ------------------------------------------------------------------------------ | |
| # User Management Functions | |
| # ------------------------------------------------------------------------------ | |
| def create_user(user_id: str, password: str) -> bool: | |
| """Create a new user with hashed password.""" | |
| try: | |
| with get_db() as db: | |
| # Check if user already exists | |
| existing_user = db.query(User).filter(User.user_id == user_id).first() | |
| if existing_user: | |
| log.warning(f"User '{user_id}' already exists.") | |
| return False | |
| # Hash password | |
| password_hash = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') | |
| # Create user | |
| user = User(user_id=user_id, password_hash=password_hash) | |
| db.add(user) | |
| log.info(f"User '{user_id}' created successfully.") | |
| return True | |
| except Exception as e: | |
| log.error(f"Error creating user '{user_id}': {e}") | |
| return False | |
| def verify_user(user_id: str, password: str) -> bool: | |
| """Verify user credentials.""" | |
| try: | |
| with get_db() as db: | |
| user = db.query(User).filter(User.user_id == user_id, User.available == True).first() | |
| if not user: | |
| log.warning(f"User '{user_id}' not found.") | |
| return False | |
| # Verify password | |
| if bcrypt.checkpw(password.encode('utf-8'), user.password_hash.encode('utf-8')): | |
| # Update last login | |
| user.last_login = datetime.now(IST) | |
| log.info(f"User '{user_id}' authenticated successfully.") | |
| return True | |
| else: | |
| log.warning(f"Invalid password for user '{user_id}'.") | |
| return False | |
| except Exception as e: | |
| log.error(f"Error verifying user '{user_id}': {e}") | |
| return False | |
| def get_all_users() -> List[Dict[str, Any]]: | |
| """Get all users.""" | |
| try: | |
| with get_db() as db: | |
| users = db.query(User).filter(User.available == True).all() | |
| return [ | |
| { | |
| "user_id": u.user_id, | |
| "created_at": u.created_at.isoformat() if u.created_at else None, | |
| "last_login": u.last_login.isoformat() if u.last_login else None | |
| } | |
| for u in users | |
| ] | |
| except Exception as e: | |
| log.error(f"Error getting all users: {e}") | |
| return [] | |
| def delete_user(user_id: str) -> bool: | |
| """Mark user as unavailable (soft delete).""" | |
| try: | |
| with get_db() as db: | |
| user = db.query(User).filter(User.user_id == user_id).first() | |
| if user: | |
| user.available = False | |
| log.info(f"User '{user_id}' marked as deleted.") | |
| return True | |
| else: | |
| log.warning(f"User '{user_id}' not found.") | |
| return False | |
| except Exception as e: | |
| log.error(f"Error deleting user '{user_id}': {e}") | |
| return False | |
| # ------------------------------------------------------------------------------ | |
| # File Management Functions | |
| # ------------------------------------------------------------------------------ | |
| def add_file(user_id: str, file_name: str, file_path: str, file_type: Optional[str] = None) -> bool: | |
| """Add a file record for a user.""" | |
| try: | |
| with get_db() as db: | |
| file_record = UserFile( | |
| user_id=user_id, | |
| file_name=file_name, | |
| file_path=file_path, | |
| file_type=file_type | |
| ) | |
| db.add(file_record) | |
| log.info(f"File '{file_name}' added for user '{user_id}'.") | |
| return True | |
| except Exception as e: | |
| log.error(f"Error adding file '{file_name}' for user '{user_id}': {e}") | |
| return False | |
| def get_user_files(user_id: str) -> List[Dict[str, Any]]: | |
| """Get all files for a user.""" | |
| try: | |
| with get_db() as db: | |
| files = db.query(UserFile).filter( | |
| UserFile.user_id == user_id, | |
| UserFile.available == True | |
| ).all() | |
| return [ | |
| { | |
| "id": f.id, | |
| "file_name": f.file_name, | |
| "file_path": f.file_path, | |
| "file_type": f.file_type, | |
| "uploaded_at": f.uploaded_at.isoformat() if f.uploaded_at else None | |
| } | |
| for f in files | |
| ] | |
| except Exception as e: | |
| log.error(f"Error getting files for user '{user_id}': {e}") | |
| return [] | |
| def delete_file(file_id: int) -> bool: | |
| """Mark file as unavailable (soft delete).""" | |
| try: | |
| with get_db() as db: | |
| file_record = db.query(UserFile).filter(UserFile.id == file_id).first() | |
| if file_record: | |
| file_record.available = False | |
| log.info(f"File ID {file_id} marked as deleted.") | |
| return True | |
| else: | |
| log.warning(f"File ID {file_id} not found.") | |
| return False | |
| except Exception as e: | |
| log.error(f"Error deleting file ID {file_id}: {e}") | |
| return False | |
| # ------------------------------------------------------------------------------ | |
| # Embedding Management Functions | |
| # ------------------------------------------------------------------------------ | |
| def add_embedding(user_id: str, qdrant_doc_id: Optional[str] = None, source: Optional[str] = None) -> bool: | |
| """Add an embedding record for a user.""" | |
| try: | |
| with get_db() as db: | |
| embedding = UserEmbedding( | |
| user_id=user_id, | |
| qdrant_doc_id=qdrant_doc_id, | |
| source=source | |
| ) | |
| db.add(embedding) | |
| log.info(f"Embedding added for user '{user_id}'.") | |
| return True | |
| except Exception as e: | |
| log.error(f"Error adding embedding for user '{user_id}': {e}") | |
| return False | |
| def get_user_embeddings(user_id: str) -> List[Dict[str, Any]]: | |
| """Get all embeddings for a user.""" | |
| try: | |
| with get_db() as db: | |
| embeddings = db.query(UserEmbedding).filter( | |
| UserEmbedding.user_id == user_id, | |
| UserEmbedding.available == True | |
| ).all() | |
| return [ | |
| { | |
| "id": e.id, | |
| "qdrant_doc_id": e.qdrant_doc_id, | |
| "source": e.source, | |
| "created_at": e.created_at.isoformat() if e.created_at else None | |
| } | |
| for e in embeddings | |
| ] | |
| except Exception as e: | |
| log.error(f"Error getting embeddings for user '{user_id}': {e}") | |
| return [] | |
| def delete_embedding(embedding_id: int) -> bool: | |
| """Mark embedding as unavailable (soft delete).""" | |
| try: | |
| with get_db() as db: | |
| embedding = db.query(UserEmbedding).filter(UserEmbedding.id == embedding_id).first() | |
| if embedding: | |
| embedding.available = False | |
| log.info(f"Embedding ID {embedding_id} marked as deleted.") | |
| return True | |
| else: | |
| log.warning(f"Embedding ID {embedding_id} not found.") | |
| return False | |
| except Exception as e: | |
| log.error(f"Error deleting embedding ID {embedding_id}: {e}") | |
| return False | |
| # ============================================================================== | |
| # Compatibility Functions for sq_db.py API | |
| # ============================================================================== | |
| def create_tables(): | |
| """Alias for init_database() to match sq_db.py API.""" | |
| init_database() | |
| def add_user(user_id: str, name: str, password: str) -> bool: | |
| """ | |
| Add user with name parameter (sq_db.py compatibility). | |
| Note: name parameter is ignored in PostgreSQL version as schema doesn't include it. | |
| """ | |
| return create_user(user_id=user_id, password=password) | |
| def check_user_exists(user_id: str) -> bool: | |
| """Check if user exists in database.""" | |
| try: | |
| with get_db() as db: | |
| user = db.query(User).filter(User.user_id == user_id, User.available == True).first() | |
| return user is not None | |
| except Exception as e: | |
| log.error(f"Error checking if user '{user_id}' exists: {e}") | |
| return False | |
| def authenticate_user(user_id: str, password: str) -> tuple[bool, str]: | |
| """ | |
| Authenticate user and return status with message. | |
| Returns: (success: bool, message: str) | |
| """ | |
| try: | |
| if verify_user(user_id=user_id, password=password): | |
| # Update last_login | |
| with get_db() as db: | |
| user = db.query(User).filter(User.user_id == user_id).first() | |
| if user: | |
| user.last_login = datetime.now() | |
| log.info(f"User '{user_id}' authenticated successfully.") | |
| return (True, "Authentication successful") | |
| else: | |
| return (False, "Invalid credentials") | |
| except Exception as e: | |
| log.error(f"Error authenticating user '{user_id}': {e}") | |
| return (False, f"Error: {str(e)}") | |
| def get_file_id_by_name(user_id: str, file_name: str) -> int: | |
| """Get file ID by user ID and filename.""" | |
| try: | |
| with get_db() as db: | |
| file = db.query(UserFile).filter( | |
| UserFile.user_id == user_id, | |
| UserFile.file_name == file_name, | |
| UserFile.available == True | |
| ).first() | |
| return file.id if file else -1 | |
| except Exception as e: | |
| log.error(f"Error getting file ID for '{file_name}': {e}") | |
| return -1 | |
| def mark_file_removed(user_id: str, file_id: int) -> bool: | |
| """Mark file as unavailable (soft delete).""" | |
| try: | |
| with get_db() as db: | |
| file = db.query(UserFile).filter( | |
| UserFile.id == file_id, | |
| UserFile.user_id == user_id | |
| ).first() | |
| if file: | |
| file.available = False | |
| log.info(f"File ID {file_id} marked as removed.") | |
| return True | |
| else: | |
| log.warning(f"File ID {file_id} not found for user '{user_id}'.") | |
| return False | |
| except Exception as e: | |
| log.error(f"Error marking file ID {file_id} as removed: {e}") | |
| return False | |
| def mark_embeddings_removed(vector_ids: List[str]) -> bool: | |
| """Mark embeddings as unavailable by qdrant_doc_id list.""" | |
| try: | |
| with get_db() as db: | |
| embeddings = db.query(UserEmbedding).filter( | |
| UserEmbedding.qdrant_doc_id.in_(vector_ids) | |
| ).all() | |
| for embedding in embeddings: | |
| embedding.available = False | |
| log.info(f"Marked {len(embeddings)} embeddings as removed.") | |
| return True | |
| except Exception as e: | |
| log.error(f"Error marking embeddings as removed: {e}") | |
| return False | |
| def get_old_files(user_id: str, time: int = 12*3600) -> dict: | |
| """ | |
| Get files older than specified time (in seconds). | |
| Returns: dict with 'files' (list of filenames) and 'embeddings' (list of qdrant_doc_ids) | |
| """ | |
| try: | |
| with get_db() as db: | |
| cutoff_time = datetime.now() - timedelta(seconds=time) | |
| old_files = db.query(UserFile).filter( | |
| UserFile.user_id == user_id, | |
| UserFile.uploaded_at < cutoff_time, | |
| UserFile.available == True | |
| ).all() | |
| filenames = [f.file_name for f in old_files] | |
| file_ids = [f.id for f in old_files] | |
| # Get embeddings for these files | |
| if file_ids: | |
| embeddings = db.query(UserEmbedding).join(UserFile).filter( | |
| UserFile.id.in_(file_ids), | |
| UserEmbedding.available == True | |
| ).all() | |
| embedding_ids = [e.qdrant_doc_id for e in embeddings if e.qdrant_doc_id] | |
| else: | |
| embedding_ids = [] | |
| return { | |
| 'files': filenames, | |
| 'embeddings': embedding_ids | |
| } | |
| except Exception as e: | |
| log.error(f"Error getting old files for user '{user_id}': {e}") | |
| return {'files': [], 'embeddings': []} | |
| # Modified add_file to match sq_db signature (filename only) | |
| def add_file_compat(user_id: str, filename: str) -> int: | |
| """ | |
| Add file with only filename parameter (sq_db.py compatibility). | |
| Returns file ID. | |
| """ | |
| try: | |
| file_path = f"/fastAPI/user_uploads/{user_id}/{filename}" # Default path | |
| success = add_file(user_id=user_id, file_name=filename, file_path=file_path) | |
| if success: | |
| file_id = get_file_id_by_name(user_id=user_id, file_name=filename) | |
| return file_id | |
| return -1 | |
| except Exception as e: | |
| log.error(f"Error adding file '{filename}': {e}") | |
| return -1 | |
| # Modified add_embedding to match sq_db signature | |
| def add_embedding_compat(file_id: int, vector_id: str) -> bool: | |
| """ | |
| Add embedding with file_id and vector_id (sq_db.py compatibility). | |
| """ | |
| try: | |
| with get_db() as db: | |
| file = db.query(UserFile).filter(UserFile.id == file_id).first() | |
| if not file: | |
| log.warning(f"File ID {file_id} not found.") | |
| return False | |
| user_id = file.user_id | |
| return add_embedding(user_id=user_id, qdrant_doc_id=vector_id, source=file.file_name) | |
| except Exception as e: | |
| log.error(f"Error adding embedding for file ID {file_id}: {e}") | |
| return False | |
| # Modified get_user_files to return list of strings (filenames only) | |
| def get_user_files_compat(user_id: str) -> List[str]: | |
| """ | |
| Get user files as list of filenames (sq_db.py compatibility). | |
| """ | |
| files_data = get_user_files(user_id=user_id) | |
| return [f['file_name'] for f in files_data] | |
| # Initialize database on module import | |
| init_database() | |