| """ |
| Base OCR Interface |
| |
| Defines the abstract OCR engine interface and common data structures. |
| """ |
|
|
| from abc import ABC, abstractmethod |
| from typing import List, Optional, Dict, Any, Tuple |
| from dataclasses import dataclass, field |
| from enum import Enum |
| import numpy as np |
| from pydantic import BaseModel, Field |
|
|
| from ..schemas.core import BoundingBox, OCRRegion |
|
|
|
|
| class OCRLanguage(str, Enum): |
| """Supported OCR languages.""" |
| ENGLISH = "en" |
| CHINESE_SIMPLIFIED = "ch" |
| CHINESE_TRADITIONAL = "chinese_cht" |
| FRENCH = "fr" |
| GERMAN = "german" |
| SPANISH = "es" |
| ITALIAN = "it" |
| PORTUGUESE = "pt" |
| RUSSIAN = "ru" |
| JAPANESE = "japan" |
| KOREAN = "korean" |
| ARABIC = "ar" |
| HINDI = "hi" |
| LATIN = "latin" |
|
|
|
|
| class OCRConfig(BaseModel): |
| """Configuration for OCR processing.""" |
| |
| engine: str = Field(default="paddle", description="OCR engine: paddle or tesseract") |
|
|
| |
| languages: List[str] = Field( |
| default=["en"], |
| description="Languages to detect (ISO codes)" |
| ) |
|
|
| |
| det_db_thresh: float = Field( |
| default=0.3, |
| ge=0.0, |
| le=1.0, |
| description="Detection threshold for text regions" |
| ) |
| det_db_box_thresh: float = Field( |
| default=0.5, |
| ge=0.0, |
| le=1.0, |
| description="Box detection threshold" |
| ) |
|
|
| |
| rec_batch_num: int = Field( |
| default=6, |
| ge=1, |
| description="Recognition batch size" |
| ) |
| min_confidence: float = Field( |
| default=0.5, |
| ge=0.0, |
| le=1.0, |
| description="Minimum confidence threshold" |
| ) |
|
|
| |
| use_gpu: bool = Field(default=True, description="Use GPU acceleration") |
| gpu_id: int = Field(default=0, ge=0, description="GPU device ID") |
| use_angle_cls: bool = Field( |
| default=True, |
| description="Use angle classification for rotated text" |
| ) |
| use_dilation: bool = Field( |
| default=False, |
| description="Use dilation for detection" |
| ) |
|
|
| |
| drop_score: float = Field( |
| default=0.5, |
| ge=0.0, |
| le=1.0, |
| description="Drop results below this score" |
| ) |
| return_word_boxes: bool = Field( |
| default=False, |
| description="Return word-level boxes (vs line-level)" |
| ) |
|
|
| |
| preprocess_resize: Optional[int] = Field( |
| default=None, |
| description="Resize image max dimension before OCR" |
| ) |
| preprocess_denoise: bool = Field( |
| default=False, |
| description="Apply denoising before OCR" |
| ) |
|
|
|
|
| @dataclass |
| class OCRResult: |
| """ |
| Result of OCR processing for a single image/page. |
| """ |
| regions: List[OCRRegion] = field(default_factory=list) |
| full_text: str = "" |
| confidence_avg: float = 0.0 |
| processing_time_ms: float = 0.0 |
| engine: str = "unknown" |
| language_detected: Optional[str] = None |
|
|
| |
| success: bool = True |
| error: Optional[str] = None |
|
|
| def get_text_in_bbox(self, bbox: BoundingBox) -> str: |
| """Get text within a bounding box.""" |
| texts = [] |
| for region in self.regions: |
| if bbox.contains(region.bbox) or bbox.iou(region.bbox) > 0.5: |
| texts.append(region.text) |
| return " ".join(texts) |
|
|
| def filter_by_confidence(self, min_confidence: float) -> "OCRResult": |
| """Return new result with regions above confidence threshold.""" |
| filtered_regions = [r for r in self.regions if r.confidence >= min_confidence] |
| return OCRResult( |
| regions=filtered_regions, |
| full_text=" ".join(r.text for r in filtered_regions), |
| confidence_avg=sum(r.confidence for r in filtered_regions) / len(filtered_regions) if filtered_regions else 0, |
| processing_time_ms=self.processing_time_ms, |
| engine=self.engine, |
| language_detected=self.language_detected, |
| success=self.success, |
| error=self.error, |
| ) |
|
|
|
|
| class OCREngine(ABC): |
| """ |
| Abstract base class for OCR engines. |
| Defines the interface that all OCR implementations must follow. |
| """ |
|
|
| def __init__(self, config: Optional[OCRConfig] = None): |
| """ |
| Initialize OCR engine. |
| |
| Args: |
| config: OCR configuration |
| """ |
| self.config = config or OCRConfig() |
| self._initialized = False |
|
|
| @abstractmethod |
| def initialize(self): |
| """Initialize the OCR engine (load models, etc.).""" |
| pass |
|
|
| @abstractmethod |
| def recognize( |
| self, |
| image: np.ndarray, |
| page_number: int = 0, |
| ) -> OCRResult: |
| """ |
| Perform OCR on an image. |
| |
| Args: |
| image: Image as numpy array (RGB, HWC format) |
| page_number: Page number for multi-page documents |
| |
| Returns: |
| OCRResult with recognized text and regions |
| """ |
| pass |
|
|
| def recognize_batch( |
| self, |
| images: List[np.ndarray], |
| page_numbers: Optional[List[int]] = None, |
| ) -> List[OCRResult]: |
| """ |
| Perform OCR on multiple images. |
| |
| Args: |
| images: List of images |
| page_numbers: Optional page numbers |
| |
| Returns: |
| List of OCRResult |
| """ |
| if page_numbers is None: |
| page_numbers = list(range(len(images))) |
|
|
| results = [] |
| for img, page_num in zip(images, page_numbers): |
| results.append(self.recognize(img, page_num)) |
| return results |
|
|
| @abstractmethod |
| def get_supported_languages(self) -> List[str]: |
| """Return list of supported language codes.""" |
| pass |
|
|
| @property |
| def name(self) -> str: |
| """Return engine name.""" |
| return self.__class__.__name__ |
|
|
| @property |
| def is_initialized(self) -> bool: |
| """Check if engine is initialized.""" |
| return self._initialized |
|
|
| def __enter__(self): |
| """Context manager entry.""" |
| if not self._initialized: |
| self.initialize() |
| return self |
|
|
| def __exit__(self, exc_type, exc_val, exc_tb): |
| """Context manager exit.""" |
| pass |
|
|