| """ |
| OCR Model Interface |
| |
| Abstract interface for Optical Character Recognition models. |
| Supports both local engines and cloud services. |
| """ |
|
|
| from abc import abstractmethod |
| from dataclasses import dataclass, field |
| from enum import Enum |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| from ..chunks.models import BoundingBox |
| from .base import ( |
| BaseModel, |
| BatchableModel, |
| ImageInput, |
| ModelCapability, |
| ModelConfig, |
| ) |
|
|
|
|
| class OCREngine(str, Enum): |
| """Supported OCR engines.""" |
|
|
| PADDLEOCR = "paddleocr" |
| TESSERACT = "tesseract" |
| EASYOCR = "easyocr" |
| CUSTOM = "custom" |
|
|
|
|
| @dataclass |
| class OCRConfig(ModelConfig): |
| """Configuration for OCR models.""" |
|
|
| engine: OCREngine = OCREngine.PADDLEOCR |
| languages: List[str] = field(default_factory=lambda: ["en"]) |
| detect_orientation: bool = True |
| detect_tables: bool = True |
| min_confidence: float = 0.5 |
| |
| use_angle_cls: bool = True |
| use_gpu: bool = True |
| |
| tesseract_config: str = "" |
| psm_mode: int = 3 |
|
|
| def __post_init__(self): |
| super().__post_init__() |
| if not self.name: |
| self.name = f"ocr_{self.engine.value}" |
|
|
|
|
| @dataclass |
| class OCRWord: |
| """A single recognized word with its bounding box.""" |
|
|
| text: str |
| bbox: BoundingBox |
| confidence: float |
| language: Optional[str] = None |
| is_handwritten: bool = False |
| font_size: Optional[float] = None |
| is_bold: bool = False |
| is_italic: bool = False |
|
|
|
|
| @dataclass |
| class OCRLine: |
| """A line of text composed of words.""" |
|
|
| text: str |
| bbox: BoundingBox |
| confidence: float |
| words: List[OCRWord] = field(default_factory=list) |
| line_index: int = 0 |
|
|
| @property |
| def word_count(self) -> int: |
| return len(self.words) |
|
|
| @classmethod |
| def from_words(cls, words: List[OCRWord], line_index: int = 0) -> "OCRLine": |
| """Create a line from a list of words.""" |
| if not words: |
| raise ValueError("Cannot create line from empty word list") |
|
|
| text = " ".join(w.text for w in words) |
| confidence = sum(w.confidence for w in words) / len(words) |
|
|
| |
| x_min = min(w.bbox.x_min for w in words) |
| y_min = min(w.bbox.y_min for w in words) |
| x_max = max(w.bbox.x_max for w in words) |
| y_max = max(w.bbox.y_max for w in words) |
|
|
| bbox = BoundingBox( |
| x_min=x_min, y_min=y_min, |
| x_max=x_max, y_max=y_max, |
| normalized=words[0].bbox.normalized |
| ) |
|
|
| return cls( |
| text=text, |
| bbox=bbox, |
| confidence=confidence, |
| words=words, |
| line_index=line_index |
| ) |
|
|
|
|
| @dataclass |
| class OCRBlock: |
| """A block of text composed of lines (e.g., a paragraph).""" |
|
|
| text: str |
| bbox: BoundingBox |
| confidence: float |
| lines: List[OCRLine] = field(default_factory=list) |
| block_type: str = "text" |
|
|
| @property |
| def line_count(self) -> int: |
| return len(self.lines) |
|
|
| @classmethod |
| def from_lines(cls, lines: List[OCRLine], block_type: str = "text") -> "OCRBlock": |
| """Create a block from a list of lines.""" |
| if not lines: |
| raise ValueError("Cannot create block from empty line list") |
|
|
| text = "\n".join(line.text for line in lines) |
| confidence = sum(line.confidence for line in lines) / len(lines) |
|
|
| x_min = min(line.bbox.x_min for line in lines) |
| y_min = min(line.bbox.y_min for line in lines) |
| x_max = max(line.bbox.x_max for line in lines) |
| y_max = max(line.bbox.y_max for line in lines) |
|
|
| bbox = BoundingBox( |
| x_min=x_min, y_min=y_min, |
| x_max=x_max, y_max=y_max, |
| normalized=lines[0].bbox.normalized |
| ) |
|
|
| return cls( |
| text=text, |
| bbox=bbox, |
| confidence=confidence, |
| lines=lines, |
| block_type=block_type |
| ) |
|
|
|
|
| @dataclass |
| class OCRResult: |
| """Complete OCR result for a single page/image.""" |
|
|
| text: str |
| blocks: List[OCRBlock] = field(default_factory=list) |
| lines: List[OCRLine] = field(default_factory=list) |
| words: List[OCRWord] = field(default_factory=list) |
| confidence: float = 0.0 |
| language_detected: Optional[str] = None |
| orientation: float = 0.0 |
| deskew_angle: float = 0.0 |
| image_width: int = 0 |
| image_height: int = 0 |
| processing_time_ms: float = 0.0 |
| engine_metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
| @property |
| def word_count(self) -> int: |
| return len(self.words) |
|
|
| @property |
| def line_count(self) -> int: |
| return len(self.lines) |
|
|
| @property |
| def block_count(self) -> int: |
| return len(self.blocks) |
|
|
| def get_text_in_region(self, bbox: BoundingBox, threshold: float = 0.5) -> str: |
| """ |
| Get text within a specific bounding box region. |
| |
| Args: |
| bbox: Region to extract text from |
| threshold: Minimum IoU overlap required |
| |
| Returns: |
| Concatenated text of words in region |
| """ |
| words_in_region = [] |
| for word in self.words: |
| iou = word.bbox.iou(bbox) |
| if iou >= threshold or bbox.contains(word.bbox.center): |
| words_in_region.append(word) |
|
|
| |
| words_in_region.sort(key=lambda w: (w.bbox.y_min, w.bbox.x_min)) |
| return " ".join(w.text for w in words_in_region) |
|
|
|
|
| class OCRModel(BatchableModel): |
| """ |
| Abstract base class for OCR models. |
| |
| Implementations should handle: |
| - Text detection (finding text regions) |
| - Text recognition (converting regions to text) |
| - Word/line/block segmentation |
| - Confidence scoring |
| """ |
|
|
| def __init__(self, config: Optional[OCRConfig] = None): |
| super().__init__(config or OCRConfig(name="ocr")) |
| self.config: OCRConfig = self.config |
|
|
| def get_capabilities(self) -> List[ModelCapability]: |
| return [ModelCapability.OCR] |
|
|
| @abstractmethod |
| def recognize( |
| self, |
| image: ImageInput, |
| **kwargs |
| ) -> OCRResult: |
| """ |
| Perform OCR on a single image. |
| |
| Args: |
| image: Input image (numpy array, PIL Image, or path) |
| **kwargs: Additional engine-specific parameters |
| |
| Returns: |
| OCRResult with detected text and locations |
| """ |
| pass |
|
|
| def process_batch( |
| self, |
| inputs: List[ImageInput], |
| **kwargs |
| ) -> List[OCRResult]: |
| """ |
| Process multiple images. |
| |
| Default implementation processes sequentially. |
| Override for optimized batch processing. |
| """ |
| return [self.recognize(img, **kwargs) for img in inputs] |
|
|
| def detect_text_regions( |
| self, |
| image: ImageInput, |
| **kwargs |
| ) -> List[BoundingBox]: |
| """ |
| Detect text regions without performing recognition. |
| |
| Useful for layout analysis or selective OCR. |
| |
| Args: |
| image: Input image |
| **kwargs: Additional parameters |
| |
| Returns: |
| List of bounding boxes containing text |
| """ |
| |
| result = self.recognize(image, **kwargs) |
| return [block.bbox for block in result.blocks] |
|
|
| def recognize_region( |
| self, |
| image: ImageInput, |
| region: BoundingBox, |
| **kwargs |
| ) -> OCRResult: |
| """ |
| Perform OCR on a specific region of an image. |
| |
| Args: |
| image: Full image |
| region: Region to OCR |
| **kwargs: Additional parameters |
| |
| Returns: |
| OCR result for the region |
| """ |
| from .base import ensure_pil_image |
|
|
| pil_image = ensure_pil_image(image) |
|
|
| |
| if region.normalized: |
| pixel_bbox = region.to_pixel(pil_image.width, pil_image.height) |
| else: |
| pixel_bbox = region |
|
|
| |
| cropped = pil_image.crop(( |
| int(pixel_bbox.x_min), |
| int(pixel_bbox.y_min), |
| int(pixel_bbox.x_max), |
| int(pixel_bbox.y_max) |
| )) |
|
|
| |
| result = self.recognize(cropped, **kwargs) |
|
|
| |
| offset_x = pixel_bbox.x_min |
| offset_y = pixel_bbox.y_min |
|
|
| for word in result.words: |
| word.bbox = BoundingBox( |
| x_min=word.bbox.x_min + offset_x, |
| y_min=word.bbox.y_min + offset_y, |
| x_max=word.bbox.x_max + offset_x, |
| y_max=word.bbox.y_max + offset_y, |
| normalized=False |
| ) |
|
|
| for line in result.lines: |
| line.bbox = BoundingBox( |
| x_min=line.bbox.x_min + offset_x, |
| y_min=line.bbox.y_min + offset_y, |
| x_max=line.bbox.x_max + offset_x, |
| y_max=line.bbox.y_max + offset_y, |
| normalized=False |
| ) |
|
|
| for block in result.blocks: |
| block.bbox = BoundingBox( |
| x_min=block.bbox.x_min + offset_x, |
| y_min=block.bbox.y_min + offset_y, |
| x_max=block.bbox.x_max + offset_x, |
| y_max=block.bbox.y_max + offset_y, |
| normalized=False |
| ) |
|
|
| return result |
|
|