""" ColPali vision-based document retrieval embeddings. Uses ColPali to generate multi-vector embeddings from document images, enabling visual understanding of tables, figures, and complex layouts. """ import logging from typing import Optional from pathlib import Path import numpy as np from PIL import Image import torch # Try to import ColPali - graceful fallback if not available try: from colpali_engine.models import ColPali from colpali_engine.utils import process_images COLPALI_AVAILABLE = True except ImportError: COLPALI_AVAILABLE = False log = logging.getLogger(__name__) # Global model instance for efficiency _colpali_model: Optional[ColPali] = None _device: Optional[torch.device] = None def _get_device() -> torch.device: """Get optimal device (GPU if available, else CPU).""" global _device if _device is None: _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") log.info(f"ColPali using device: {_device}") return _device def _get_colpali_model() -> Optional[ColPali]: """Load ColPali model (cached for efficiency).""" global _colpali_model if not COLPALI_AVAILABLE: log.warning("ColPali not available. Install colpali-engine for vision embeddings.") return None if _colpali_model is None: try: log.info("Loading ColPali model...") device = _get_device() model = ColPali.from_pretrained( "vidore/colpali", torch_dtype=torch.bfloat16, device_map=device ) model.eval() _colpali_model = model log.info("✅ ColPali model loaded successfully") except Exception as e: log.error(f"Failed to load ColPali model: {e}") return None return _colpali_model def embed_image(image_path: str) -> Optional[dict]: """ Generate ColPali multi-vector embeddings from an image. Args: image_path: Path to the image file Returns: Dict with: - 'embeddings': List of patch embeddings (multi-vector) - 'num_patches': Number of patches - 'image_id': ID of the image Or None if embedding fails """ if not COLPALI_AVAILABLE: return None try: model = _get_colpali_model() if model is None: return None # Load and prepare image image_path = Path(image_path) if not image_path.exists(): log.warning(f"Image not found: {image_path}") return None image = Image.open(image_path).convert("RGB") # Process with ColPali with torch.no_grad(): # Process images returns list of processed images processed_images = process_images([image]) # Get embeddings embeddings = model.encode_images( processed_images, batch_size=1 ) # Convert to CPU and numpy for storage if isinstance(embeddings, torch.Tensor): embeddings = embeddings.cpu().numpy() # embeddings shape: (num_patches, embedding_dim) num_patches = len(embeddings) if isinstance(embeddings, (list, np.ndarray)) else 1 log.debug(f"Generated {num_patches} patch embeddings for {image_path.name}") return { "embeddings": embeddings.tolist() if isinstance(embeddings, np.ndarray) else embeddings, "num_patches": num_patches, "image_id": image_path.stem, "model": "colpali" } except Exception as e: log.error(f"ColPali embedding failed for {image_path}: {e}") return None def batch_embed_images(image_paths: list) -> dict: """ Batch embed multiple images efficiently. Args: image_paths: List of image file paths Returns: Dict mapping image_id -> embedding result """ if not COLPALI_AVAILABLE: return {} try: model = _get_colpali_model() if model is None: return {} # Load all valid images images = [] valid_paths = [] for path in image_paths: path = Path(path) if path.exists(): try: img = Image.open(path).convert("RGB") images.append(img) valid_paths.append(path) except Exception as e: log.warning(f"Could not load image {path}: {e}") if not images: return {} log.info(f"Batch embedding {len(images)} images with ColPali...") # Process all images with torch.no_grad(): processed_images = process_images(images) # Embed in batches (adjust batch_size based on GPU memory) embeddings = model.encode_images( processed_images, batch_size=4 ) # Build results dict results = {} for path, emb in zip(valid_paths, embeddings): emb_np = emb.cpu().numpy() if isinstance(emb, torch.Tensor) else emb results[path.stem] = { "embeddings": emb_np.tolist(), "num_patches": len(emb_np), "image_id": path.stem, "model": "colpali" } log.info(f"✅ Batch embedded {len(results)} images") return results except Exception as e: log.error(f"Batch embedding failed: {e}") return {} def is_colpali_available() -> bool: """Check if ColPali is available.""" return COLPALI_AVAILABLE and _get_colpali_model() is not None