| """ |
| Base Model Interfaces for Document Intelligence |
| |
| Abstract base classes defining the contract for all model components. |
| All models are pluggable and can be swapped without changing the pipeline. |
| """ |
|
|
| from abc import ABC, abstractmethod |
| from dataclasses import dataclass, field |
| from enum import Enum |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Union |
|
|
| import numpy as np |
| from PIL import Image |
|
|
|
|
| class ModelCapability(str, Enum): |
| """Capabilities that a model may support.""" |
|
|
| OCR = "ocr" |
| LAYOUT_DETECTION = "layout_detection" |
| TABLE_EXTRACTION = "table_extraction" |
| CHART_EXTRACTION = "chart_extraction" |
| READING_ORDER = "reading_order" |
| VISION_LANGUAGE = "vision_language" |
| EMBEDDING = "embedding" |
| CLASSIFICATION = "classification" |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| """Base configuration for all models.""" |
|
|
| name: str |
| version: str = "1.0.0" |
| device: str = "auto" |
| batch_size: int = 1 |
| max_workers: int = 4 |
| cache_enabled: bool = True |
| cache_dir: Optional[Path] = None |
| timeout_seconds: float = 300.0 |
| extra_params: Dict[str, Any] = field(default_factory=dict) |
|
|
| def __post_init__(self): |
| if self.cache_dir is not None: |
| self.cache_dir = Path(self.cache_dir) |
|
|
|
|
| @dataclass |
| class ModelMetadata: |
| """Metadata about a loaded model.""" |
|
|
| name: str |
| version: str |
| capabilities: List[ModelCapability] |
| device: str |
| memory_usage_mb: float = 0.0 |
| is_loaded: bool = False |
| supports_batching: bool = False |
| max_batch_size: int = 1 |
| input_requirements: Dict[str, Any] = field(default_factory=dict) |
| output_format: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
| class BaseModel(ABC): |
| """ |
| Abstract base class for all document intelligence models. |
| |
| All model implementations must inherit from this class and implement |
| the required abstract methods. |
| """ |
|
|
| def __init__(self, config: Optional[ModelConfig] = None): |
| self.config = config or ModelConfig(name=self.__class__.__name__) |
| self._is_loaded = False |
| self._metadata: Optional[ModelMetadata] = None |
|
|
| @property |
| def is_loaded(self) -> bool: |
| """Check if the model is loaded and ready for inference.""" |
| return self._is_loaded |
|
|
| @property |
| def metadata(self) -> Optional[ModelMetadata]: |
| """Get model metadata.""" |
| return self._metadata |
|
|
| @abstractmethod |
| def load(self) -> None: |
| """ |
| Load the model into memory. |
| |
| Should set self._is_loaded = True upon successful loading. |
| Should populate self._metadata with model information. |
| """ |
| pass |
|
|
| @abstractmethod |
| def unload(self) -> None: |
| """ |
| Unload the model from memory. |
| |
| Should set self._is_loaded = False. |
| Should free GPU/CPU memory. |
| """ |
| pass |
|
|
| @abstractmethod |
| def get_capabilities(self) -> List[ModelCapability]: |
| """Return list of capabilities this model provides.""" |
| pass |
|
|
| def validate_input(self, input_data: Any) -> bool: |
| """ |
| Validate input data before processing. |
| |
| Override in subclasses for specific validation. |
| """ |
| return True |
|
|
| def preprocess(self, input_data: Any) -> Any: |
| """ |
| Preprocess input data before model inference. |
| |
| Override in subclasses for specific preprocessing. |
| """ |
| return input_data |
|
|
| def postprocess(self, output_data: Any) -> Any: |
| """ |
| Postprocess model output. |
| |
| Override in subclasses for specific postprocessing. |
| """ |
| return output_data |
|
|
| def __enter__(self): |
| """Context manager entry.""" |
| if not self.is_loaded: |
| self.load() |
| return self |
|
|
| def __exit__(self, exc_type, exc_val, exc_tb): |
| """Context manager exit.""" |
| self.unload() |
| return False |
|
|
|
|
| class BatchableModel(BaseModel): |
| """ |
| Base class for models that support batch processing. |
| |
| Provides infrastructure for processing multiple inputs efficiently. |
| """ |
|
|
| @abstractmethod |
| def process_batch( |
| self, |
| inputs: List[Any], |
| **kwargs |
| ) -> List[Any]: |
| """ |
| Process a batch of inputs. |
| |
| Args: |
| inputs: List of input items to process |
| **kwargs: Additional processing parameters |
| |
| Returns: |
| List of outputs, one per input |
| """ |
| pass |
|
|
| def process_single(self, input_data: Any, **kwargs) -> Any: |
| """Process a single input by wrapping in a batch.""" |
| results = self.process_batch([input_data], **kwargs) |
| return results[0] if results else None |
|
|
|
|
| ImageInput = Union[np.ndarray, Image.Image, Path, str] |
|
|
|
|
| def normalize_image_input(image: ImageInput) -> np.ndarray: |
| """ |
| Normalize various image input formats to numpy array. |
| |
| Args: |
| image: Image as numpy array, PIL Image, or path |
| |
| Returns: |
| Image as numpy array (RGB, HWC format) |
| """ |
| if isinstance(image, np.ndarray): |
| return image |
|
|
| if isinstance(image, Image.Image): |
| return np.array(image.convert("RGB")) |
|
|
| if isinstance(image, (str, Path)): |
| img = Image.open(image).convert("RGB") |
| return np.array(img) |
|
|
| raise ValueError(f"Unsupported image input type: {type(image)}") |
|
|
|
|
| def ensure_pil_image(image: ImageInput) -> Image.Image: |
| """ |
| Ensure input is a PIL Image. |
| |
| Args: |
| image: Image as numpy array, PIL Image, or path |
| |
| Returns: |
| PIL Image in RGB mode |
| """ |
| if isinstance(image, Image.Image): |
| return image.convert("RGB") |
|
|
| if isinstance(image, np.ndarray): |
| return Image.fromarray(image).convert("RGB") |
|
|
| if isinstance(image, (str, Path)): |
| return Image.open(image).convert("RGB") |
|
|
| raise ValueError(f"Unsupported image input type: {type(image)}") |
|
|