| """ |
| Table Extraction Model Interface |
| |
| Abstract interface for table structure recognition and cell extraction. |
| Handles complex tables with merged cells, headers, and nested structures. |
| """ |
|
|
| from abc import abstractmethod |
| from dataclasses import dataclass, field |
| from enum import Enum |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| from ..chunks.models import BoundingBox, TableCell, TableChunk |
| from .base import ( |
| BaseModel, |
| BatchableModel, |
| ImageInput, |
| ModelCapability, |
| ModelConfig, |
| ) |
| from .layout import LayoutRegion |
|
|
|
|
| class TableCellType(str, Enum): |
| """Types of table cells.""" |
|
|
| HEADER = "header" |
| DATA = "data" |
| INDEX = "index" |
| MERGED = "merged" |
| EMPTY = "empty" |
|
|
|
|
| @dataclass |
| class TableConfig(ModelConfig): |
| """Configuration for table extraction models.""" |
|
|
| min_confidence: float = 0.5 |
| detect_headers: bool = True |
| detect_merged_cells: bool = True |
| max_rows: int = 500 |
| max_cols: int = 50 |
| extract_cell_text: bool = True |
|
|
| def __post_init__(self): |
| super().__post_init__() |
| if not self.name: |
| self.name = "table_extractor" |
|
|
|
|
| @dataclass |
| class TableStructure: |
| """ |
| Detected table structure with cell grid. |
| |
| Represents the logical structure of a table including |
| merged cells, headers, and cell relationships. |
| """ |
|
|
| bbox: BoundingBox |
| cells: List[TableCell] = field(default_factory=list) |
| num_rows: int = 0 |
| num_cols: int = 0 |
|
|
| |
| header_rows: List[int] = field(default_factory=list) |
| header_cols: List[int] = field(default_factory=list) |
|
|
| |
| structure_confidence: float = 0.0 |
| cell_confidence_avg: float = 0.0 |
|
|
| |
| has_merged_cells: bool = False |
| is_bordered: bool = True |
| table_id: str = "" |
|
|
| def __post_init__(self): |
| if not self.table_id: |
| import hashlib |
| content = f"table_{self.bbox.xyxy}_{self.num_rows}x{self.num_cols}" |
| self.table_id = hashlib.md5(content.encode()).hexdigest()[:12] |
|
|
| def get_cell(self, row: int, col: int) -> Optional[TableCell]: |
| """Get cell at specific position.""" |
| for cell in self.cells: |
| if cell.row == row and cell.col == col: |
| return cell |
| |
| if (cell.row <= row < cell.row + cell.rowspan and |
| cell.col <= col < cell.col + cell.colspan): |
| return cell |
| return None |
|
|
| def get_row(self, row_index: int) -> List[TableCell]: |
| """Get all cells in a row.""" |
| return sorted( |
| [c for c in self.cells if c.row == row_index], |
| key=lambda c: c.col |
| ) |
|
|
| def get_col(self, col_index: int) -> List[TableCell]: |
| """Get all cells in a column.""" |
| return sorted( |
| [c for c in self.cells if c.col == col_index], |
| key=lambda c: c.row |
| ) |
|
|
| def get_headers(self) -> List[TableCell]: |
| """Get all header cells.""" |
| return [c for c in self.cells if c.is_header] |
|
|
| def to_csv(self, delimiter: str = ",") -> str: |
| """Convert table to CSV string.""" |
| rows = [] |
| for r in range(self.num_rows): |
| row_cells = [] |
| for c in range(self.num_cols): |
| cell = self.get_cell(r, c) |
| text = cell.text if cell else "" |
| |
| if delimiter in text or '"' in text or '\n' in text: |
| text = '"' + text.replace('"', '""') + '"' |
| row_cells.append(text) |
| rows.append(delimiter.join(row_cells)) |
| return "\n".join(rows) |
|
|
| def to_markdown(self) -> str: |
| """Convert table to Markdown format.""" |
| if self.num_rows == 0 or self.num_cols == 0: |
| return "" |
|
|
| lines = [] |
|
|
| |
| for r in range(self.num_rows): |
| row_texts = [] |
| for c in range(self.num_cols): |
| cell = self.get_cell(r, c) |
| text = cell.text.replace("|", "\\|") if cell else "" |
| row_texts.append(text) |
| lines.append("| " + " | ".join(row_texts) + " |") |
|
|
| |
| if r == 0: |
| separators = ["---"] * self.num_cols |
| lines.append("| " + " | ".join(separators) + " |") |
|
|
| return "\n".join(lines) |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """Convert to structured dictionary.""" |
| return { |
| "num_rows": self.num_rows, |
| "num_cols": self.num_cols, |
| "header_rows": self.header_rows, |
| "header_cols": self.header_cols, |
| "cells": [ |
| { |
| "row": c.row, |
| "col": c.col, |
| "text": c.text, |
| "rowspan": c.rowspan, |
| "colspan": c.colspan, |
| "is_header": c.is_header, |
| "confidence": c.confidence |
| } |
| for c in self.cells |
| ] |
| } |
|
|
| def to_table_chunk( |
| self, |
| doc_id: str, |
| page: int, |
| sequence_index: int |
| ) -> TableChunk: |
| """Convert to TableChunk for the chunks module.""" |
| return TableChunk( |
| chunk_id=TableChunk.generate_chunk_id( |
| doc_id=doc_id, |
| page=page, |
| bbox=self.bbox, |
| chunk_type_str="table" |
| ), |
| doc_id=doc_id, |
| text=self.to_markdown(), |
| page=page, |
| bbox=self.bbox, |
| confidence=self.structure_confidence, |
| sequence_index=sequence_index, |
| cells=self.cells, |
| num_rows=self.num_rows, |
| num_cols=self.num_cols, |
| header_rows=self.header_rows, |
| header_cols=self.header_cols, |
| has_merged_cells=self.has_merged_cells |
| ) |
|
|
|
|
| @dataclass |
| class TableExtractionResult: |
| """Result of table extraction from a page.""" |
|
|
| tables: List[TableStructure] = field(default_factory=list) |
| processing_time_ms: float = 0.0 |
| model_metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
| @property |
| def table_count(self) -> int: |
| return len(self.tables) |
|
|
| def get_table_at_region( |
| self, |
| region: LayoutRegion, |
| iou_threshold: float = 0.5 |
| ) -> Optional[TableStructure]: |
| """Find table that matches a layout region.""" |
| best_match = None |
| best_iou = 0.0 |
|
|
| for table in self.tables: |
| iou = table.bbox.iou(region.bbox) |
| if iou > iou_threshold and iou > best_iou: |
| best_match = table |
| best_iou = iou |
|
|
| return best_match |
|
|
|
|
| class TableModel(BatchableModel): |
| """ |
| Abstract base class for table extraction models. |
| |
| Implementations should handle: |
| - Table structure detection (rows, columns) |
| - Cell boundary detection |
| - Merged cell handling |
| - Header detection |
| - Cell content extraction |
| """ |
|
|
| def __init__(self, config: Optional[TableConfig] = None): |
| super().__init__(config or TableConfig(name="table")) |
| self.config: TableConfig = self.config |
|
|
| def get_capabilities(self) -> List[ModelCapability]: |
| return [ModelCapability.TABLE_EXTRACTION] |
|
|
| @abstractmethod |
| def extract_structure( |
| self, |
| image: ImageInput, |
| table_region: Optional[BoundingBox] = None, |
| **kwargs |
| ) -> TableStructure: |
| """ |
| Extract table structure from an image. |
| |
| Args: |
| image: Input image containing a table |
| table_region: Optional bounding box of the table region |
| **kwargs: Additional parameters |
| |
| Returns: |
| TableStructure with cells and metadata |
| """ |
| pass |
|
|
| def extract_all_tables( |
| self, |
| image: ImageInput, |
| table_regions: Optional[List[BoundingBox]] = None, |
| **kwargs |
| ) -> TableExtractionResult: |
| """ |
| Extract all tables from an image. |
| |
| Args: |
| image: Input document image |
| table_regions: Optional list of table bounding boxes |
| **kwargs: Additional parameters |
| |
| Returns: |
| TableExtractionResult with all detected tables |
| """ |
| import time |
| start_time = time.time() |
|
|
| tables = [] |
|
|
| if table_regions: |
| |
| for region in table_regions: |
| try: |
| table = self.extract_structure(image, region, **kwargs) |
| tables.append(table) |
| except Exception: |
| continue |
| else: |
| |
| table = self.extract_structure(image, **kwargs) |
| if table.num_rows > 0: |
| tables.append(table) |
|
|
| processing_time = (time.time() - start_time) * 1000 |
|
|
| return TableExtractionResult( |
| tables=tables, |
| processing_time_ms=processing_time |
| ) |
|
|
| def process_batch( |
| self, |
| inputs: List[ImageInput], |
| **kwargs |
| ) -> List[TableExtractionResult]: |
| """Process multiple images.""" |
| return [self.extract_all_tables(img, **kwargs) for img in inputs] |
|
|
| @abstractmethod |
| def extract_cell_text( |
| self, |
| image: ImageInput, |
| cell_bbox: BoundingBox, |
| **kwargs |
| ) -> str: |
| """ |
| Extract text from a specific cell region. |
| |
| Args: |
| image: Image containing the cell |
| cell_bbox: Bounding box of the cell |
| **kwargs: Additional parameters |
| |
| Returns: |
| Extracted text content |
| """ |
| pass |
|
|