Spaces:
Sleeping
Sleeping
| """ | |
| ColPali vision-based document retrieval embeddings. | |
| Uses ColPali to generate multi-vector embeddings from document images, | |
| enabling visual understanding of tables, figures, and complex layouts. | |
| """ | |
| import logging | |
| from typing import Optional | |
| from pathlib import Path | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| # Try to import ColPali - graceful fallback if not available | |
| try: | |
| from colpali_engine.models import ColPali | |
| from colpali_engine.utils import process_images | |
| COLPALI_AVAILABLE = True | |
| except ImportError: | |
| COLPALI_AVAILABLE = False | |
| log = logging.getLogger(__name__) | |
| # Global model instance for efficiency | |
| _colpali_model: Optional[ColPali] = None | |
| _device: Optional[torch.device] = None | |
| def _get_device() -> torch.device: | |
| """Get optimal device (GPU if available, else CPU).""" | |
| global _device | |
| if _device is None: | |
| _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| log.info(f"ColPali using device: {_device}") | |
| return _device | |
| def _get_colpali_model() -> Optional[ColPali]: | |
| """Load ColPali model (cached for efficiency).""" | |
| global _colpali_model | |
| if not COLPALI_AVAILABLE: | |
| log.warning("ColPali not available. Install colpali-engine for vision embeddings.") | |
| return None | |
| if _colpali_model is None: | |
| try: | |
| log.info("Loading ColPali model...") | |
| device = _get_device() | |
| model = ColPali.from_pretrained( | |
| "vidore/colpali", | |
| torch_dtype=torch.bfloat16, | |
| device_map=device | |
| ) | |
| model.eval() | |
| _colpali_model = model | |
| log.info("✅ ColPali model loaded successfully") | |
| except Exception as e: | |
| log.error(f"Failed to load ColPali model: {e}") | |
| return None | |
| return _colpali_model | |
| def embed_image(image_path: str) -> Optional[dict]: | |
| """ | |
| Generate ColPali multi-vector embeddings from an image. | |
| Args: | |
| image_path: Path to the image file | |
| Returns: | |
| Dict with: | |
| - 'embeddings': List of patch embeddings (multi-vector) | |
| - 'num_patches': Number of patches | |
| - 'image_id': ID of the image | |
| Or None if embedding fails | |
| """ | |
| if not COLPALI_AVAILABLE: | |
| return None | |
| try: | |
| model = _get_colpali_model() | |
| if model is None: | |
| return None | |
| # Load and prepare image | |
| image_path = Path(image_path) | |
| if not image_path.exists(): | |
| log.warning(f"Image not found: {image_path}") | |
| return None | |
| image = Image.open(image_path).convert("RGB") | |
| # Process with ColPali | |
| with torch.no_grad(): | |
| # Process images returns list of processed images | |
| processed_images = process_images([image]) | |
| # Get embeddings | |
| embeddings = model.encode_images( | |
| processed_images, | |
| batch_size=1 | |
| ) | |
| # Convert to CPU and numpy for storage | |
| if isinstance(embeddings, torch.Tensor): | |
| embeddings = embeddings.cpu().numpy() | |
| # embeddings shape: (num_patches, embedding_dim) | |
| num_patches = len(embeddings) if isinstance(embeddings, (list, np.ndarray)) else 1 | |
| log.debug(f"Generated {num_patches} patch embeddings for {image_path.name}") | |
| return { | |
| "embeddings": embeddings.tolist() if isinstance(embeddings, np.ndarray) else embeddings, | |
| "num_patches": num_patches, | |
| "image_id": image_path.stem, | |
| "model": "colpali" | |
| } | |
| except Exception as e: | |
| log.error(f"ColPali embedding failed for {image_path}: {e}") | |
| return None | |
| def batch_embed_images(image_paths: list) -> dict: | |
| """ | |
| Batch embed multiple images efficiently. | |
| Args: | |
| image_paths: List of image file paths | |
| Returns: | |
| Dict mapping image_id -> embedding result | |
| """ | |
| if not COLPALI_AVAILABLE: | |
| return {} | |
| try: | |
| model = _get_colpali_model() | |
| if model is None: | |
| return {} | |
| # Load all valid images | |
| images = [] | |
| valid_paths = [] | |
| for path in image_paths: | |
| path = Path(path) | |
| if path.exists(): | |
| try: | |
| img = Image.open(path).convert("RGB") | |
| images.append(img) | |
| valid_paths.append(path) | |
| except Exception as e: | |
| log.warning(f"Could not load image {path}: {e}") | |
| if not images: | |
| return {} | |
| log.info(f"Batch embedding {len(images)} images with ColPali...") | |
| # Process all images | |
| with torch.no_grad(): | |
| processed_images = process_images(images) | |
| # Embed in batches (adjust batch_size based on GPU memory) | |
| embeddings = model.encode_images( | |
| processed_images, | |
| batch_size=4 | |
| ) | |
| # Build results dict | |
| results = {} | |
| for path, emb in zip(valid_paths, embeddings): | |
| emb_np = emb.cpu().numpy() if isinstance(emb, torch.Tensor) else emb | |
| results[path.stem] = { | |
| "embeddings": emb_np.tolist(), | |
| "num_patches": len(emb_np), | |
| "image_id": path.stem, | |
| "model": "colpali" | |
| } | |
| log.info(f"✅ Batch embedded {len(results)} images") | |
| return results | |
| except Exception as e: | |
| log.error(f"Batch embedding failed: {e}") | |
| return {} | |
| def is_colpali_available() -> bool: | |
| """Check if ColPali is available.""" | |
| return COLPALI_AVAILABLE and _get_colpali_model() is not None | |