| """ |
| Layout Detector Implementations |
| |
| Rule-based and model-based layout detection. |
| """ |
|
|
| import time |
| import uuid |
| from typing import List, Optional, Dict, Tuple |
| from collections import defaultdict |
| import numpy as np |
| from loguru import logger |
|
|
| from .base import LayoutDetector, LayoutConfig, LayoutResult |
| from ..schemas.core import BoundingBox, LayoutRegion, LayoutType, OCRRegion |
|
|
|
|
| class RuleBasedLayoutDetector(LayoutDetector): |
| """ |
| Rule-based layout detector using OCR region analysis. |
| |
| Uses heuristics based on: |
| - Text positioning and alignment |
| - Font size estimation (based on region height) |
| - Spacing patterns |
| - Structural patterns (tables, lists) |
| """ |
|
|
| def __init__(self, config: Optional[LayoutConfig] = None): |
| """Initialize rule-based detector.""" |
| super().__init__(config) |
|
|
| def initialize(self): |
| """Initialize detector (no model loading needed for rule-based).""" |
| self._initialized = True |
| logger.info("Initialized rule-based layout detector") |
|
|
| def detect( |
| self, |
| image: np.ndarray, |
| page_number: int = 0, |
| ocr_regions: Optional[List[OCRRegion]] = None, |
| ) -> LayoutResult: |
| """ |
| Detect layout regions using rule-based analysis. |
| |
| Args: |
| image: Page image |
| page_number: Page number |
| ocr_regions: OCR regions for text-based analysis |
| |
| Returns: |
| LayoutResult with detected regions |
| """ |
| if not self._initialized: |
| self.initialize() |
|
|
| start_time = time.time() |
| height, width = image.shape[:2] |
|
|
| regions = [] |
| region_counter = 0 |
|
|
| def make_region_id(): |
| nonlocal region_counter |
| region_counter += 1 |
| return f"region_{page_number}_{region_counter}" |
|
|
| if ocr_regions: |
| |
| regions.extend(self._detect_titles_headings(ocr_regions, page_number, make_region_id, height)) |
| regions.extend(self._detect_paragraphs(ocr_regions, page_number, make_region_id)) |
| regions.extend(self._detect_lists(ocr_regions, page_number, make_region_id)) |
| regions.extend(self._detect_tables_from_ocr(ocr_regions, page_number, make_region_id)) |
| regions.extend(self._detect_headers_footers(ocr_regions, page_number, make_region_id, height)) |
|
|
| |
| if self.config.detect_figures: |
| regions.extend(self._detect_figures_from_image(image, page_number, make_region_id, ocr_regions)) |
|
|
| |
| regions = self._merge_overlapping_regions(regions) |
|
|
| |
| regions = self._assign_reading_order(regions) |
|
|
| processing_time = (time.time() - start_time) * 1000 |
|
|
| return LayoutResult( |
| page=page_number, |
| regions=regions, |
| image_width=width, |
| image_height=height, |
| processing_time_ms=processing_time, |
| success=True, |
| ) |
|
|
| def _detect_titles_headings( |
| self, |
| ocr_regions: List[OCRRegion], |
| page_number: int, |
| make_id, |
| page_height: int, |
| ) -> List[LayoutRegion]: |
| """Detect title and heading regions based on font size and position.""" |
| if not ocr_regions or not self.config.detect_titles: |
| return [] |
|
|
| regions = [] |
|
|
| |
| heights = [r.bbox.height for r in ocr_regions if r.bbox.height > 0] |
| if not heights: |
| return [] |
|
|
| avg_height = np.median(heights) |
| title_threshold = avg_height * self.config.heading_font_ratio |
|
|
| |
| lines = self._group_into_lines(ocr_regions) |
|
|
| for line_id, line_regions in lines.items(): |
| if not line_regions: |
| continue |
|
|
| |
| line_height = max(r.bbox.height for r in line_regions) |
| line_text = " ".join(r.text for r in line_regions) |
| line_y = min(r.bbox.y_min for r in line_regions) |
|
|
| |
| is_large_text = line_height > title_threshold |
| is_short = len(line_text) < 100 |
| is_top_of_page = line_y < page_height * 0.15 |
|
|
| if is_large_text and is_short: |
| |
| x_min = min(r.bbox.x_min for r in line_regions) |
| y_min = min(r.bbox.y_min for r in line_regions) |
| x_max = max(r.bbox.x_max for r in line_regions) |
| y_max = max(r.bbox.y_max for r in line_regions) |
|
|
| |
| if is_top_of_page and line_height > title_threshold * 1.2: |
| layout_type = LayoutType.TITLE |
| else: |
| layout_type = LayoutType.HEADING |
|
|
| regions.append(LayoutRegion( |
| id=make_id(), |
| type=layout_type, |
| confidence=0.8, |
| bbox=BoundingBox( |
| x_min=x_min, y_min=y_min, |
| x_max=x_max, y_max=y_max, |
| normalized=False, |
| ), |
| page=page_number, |
| ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in line_regions], |
| )) |
|
|
| return regions |
|
|
| def _detect_paragraphs( |
| self, |
| ocr_regions: List[OCRRegion], |
| page_number: int, |
| make_id, |
| ) -> List[LayoutRegion]: |
| """Detect paragraph regions by grouping nearby text.""" |
| if not ocr_regions: |
| return [] |
|
|
| regions = [] |
|
|
| |
| lines = self._group_into_lines(ocr_regions) |
| paragraphs = self._group_lines_into_paragraphs(lines, ocr_regions) |
|
|
| for para_lines in paragraphs: |
| if not para_lines: |
| continue |
|
|
| |
| para_regions = [] |
| for line_id in para_lines: |
| para_regions.extend(lines.get(line_id, [])) |
|
|
| if not para_regions: |
| continue |
|
|
| |
| x_min = min(r.bbox.x_min for r in para_regions) |
| y_min = min(r.bbox.y_min for r in para_regions) |
| x_max = max(r.bbox.x_max for r in para_regions) |
| y_max = max(r.bbox.y_max for r in para_regions) |
|
|
| regions.append(LayoutRegion( |
| id=make_id(), |
| type=LayoutType.PARAGRAPH, |
| confidence=0.7, |
| bbox=BoundingBox( |
| x_min=x_min, y_min=y_min, |
| x_max=x_max, y_max=y_max, |
| normalized=False, |
| ), |
| page=page_number, |
| ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in para_regions], |
| )) |
|
|
| return regions |
|
|
| def _detect_lists( |
| self, |
| ocr_regions: List[OCRRegion], |
| page_number: int, |
| make_id, |
| ) -> List[LayoutRegion]: |
| """Detect list structures based on bullet/number patterns.""" |
| if not ocr_regions or not self.config.detect_lists: |
| return [] |
|
|
| regions = [] |
|
|
| |
| bullet_patterns = {'•', '-', '–', '—', '*', '○', '●', '■', '□', '▪', '▸', '▹'} |
| number_patterns = ('1.', '2.', '3.', '4.', '5.', '6.', '7.', '8.', '9.', |
| '1)', '2)', '3)', '4)', '5)', 'a.', 'b.', 'c.', 'a)', 'b)', 'c)') |
|
|
| |
| lines = self._group_into_lines(ocr_regions) |
|
|
| |
| list_lines = [] |
| current_list = [] |
|
|
| sorted_line_ids = sorted(lines.keys()) |
| for line_id in sorted_line_ids: |
| line_regions = lines[line_id] |
| if not line_regions: |
| continue |
|
|
| first_text = line_regions[0].text.strip() |
|
|
| |
| is_list_item = ( |
| any(first_text.startswith(p) for p in bullet_patterns) or |
| any(first_text.startswith(p) for p in number_patterns) or |
| (len(first_text) <= 3 and first_text.endswith('.')) |
| ) |
|
|
| if is_list_item: |
| current_list.append(line_id) |
| else: |
| if len(current_list) >= 2: |
| list_lines.append(current_list) |
| current_list = [] |
|
|
| |
| if len(current_list) >= 2: |
| list_lines.append(current_list) |
|
|
| |
| for list_line_ids in list_lines: |
| list_regions = [] |
| for line_id in list_line_ids: |
| list_regions.extend(lines.get(line_id, [])) |
|
|
| if not list_regions: |
| continue |
|
|
| x_min = min(r.bbox.x_min for r in list_regions) |
| y_min = min(r.bbox.y_min for r in list_regions) |
| x_max = max(r.bbox.x_max for r in list_regions) |
| y_max = max(r.bbox.y_max for r in list_regions) |
|
|
| regions.append(LayoutRegion( |
| id=make_id(), |
| type=LayoutType.LIST, |
| confidence=0.75, |
| bbox=BoundingBox( |
| x_min=x_min, y_min=y_min, |
| x_max=x_max, y_max=y_max, |
| normalized=False, |
| ), |
| page=page_number, |
| ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in list_regions], |
| extra={"item_count": len(list_line_ids)}, |
| )) |
|
|
| return regions |
|
|
| def _detect_tables_from_ocr( |
| self, |
| ocr_regions: List[OCRRegion], |
| page_number: int, |
| make_id, |
| ) -> List[LayoutRegion]: |
| """Detect table regions based on aligned text patterns.""" |
| if not ocr_regions or not self.config.detect_tables: |
| return [] |
|
|
| regions = [] |
|
|
| |
| x_groups = defaultdict(list) |
| x_tolerance = 20 |
|
|
| for region in ocr_regions: |
| x_center = region.bbox.center[0] |
| |
| matched = False |
| for x_key in list(x_groups.keys()): |
| if abs(x_center - x_key) < x_tolerance: |
| x_groups[x_key].append(region) |
| matched = True |
| break |
| if not matched: |
| x_groups[x_center].append(region) |
|
|
| |
| if len(x_groups) >= self.config.table_min_cols: |
| |
| columns = sorted(x_groups.keys()) |
|
|
| |
| |
| all_regions = [r for regions in x_groups.values() for r in regions] |
| if len(all_regions) >= self.config.table_min_rows * self.config.table_min_cols: |
| x_min = min(r.bbox.x_min for r in all_regions) |
| y_min = min(r.bbox.y_min for r in all_regions) |
| x_max = max(r.bbox.x_max for r in all_regions) |
| y_max = max(r.bbox.y_max for r in all_regions) |
|
|
| |
| width_ratio = (x_max - x_min) / max(r.bbox.page_width or 1000 for r in all_regions) |
| if width_ratio > 0.3: |
| regions.append(LayoutRegion( |
| id=make_id(), |
| type=LayoutType.TABLE, |
| confidence=0.6, |
| bbox=BoundingBox( |
| x_min=x_min, y_min=y_min, |
| x_max=x_max, y_max=y_max, |
| normalized=False, |
| ), |
| page=page_number, |
| extra={"estimated_cols": len(columns)}, |
| )) |
|
|
| return regions |
|
|
| def _detect_headers_footers( |
| self, |
| ocr_regions: List[OCRRegion], |
| page_number: int, |
| make_id, |
| page_height: int, |
| ) -> List[LayoutRegion]: |
| """Detect header and footer regions.""" |
| if not ocr_regions or not self.config.detect_headers: |
| return [] |
|
|
| regions = [] |
| header_threshold = page_height * 0.08 |
| footer_threshold = page_height * 0.92 |
|
|
| header_regions = [r for r in ocr_regions if r.bbox.y_max < header_threshold] |
| footer_regions = [r for r in ocr_regions if r.bbox.y_min > footer_threshold] |
|
|
| if header_regions: |
| x_min = min(r.bbox.x_min for r in header_regions) |
| y_min = min(r.bbox.y_min for r in header_regions) |
| x_max = max(r.bbox.x_max for r in header_regions) |
| y_max = max(r.bbox.y_max for r in header_regions) |
|
|
| regions.append(LayoutRegion( |
| id=make_id(), |
| type=LayoutType.HEADER, |
| confidence=0.7, |
| bbox=BoundingBox(x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=False), |
| page=page_number, |
| )) |
|
|
| if footer_regions: |
| x_min = min(r.bbox.x_min for r in footer_regions) |
| y_min = min(r.bbox.y_min for r in footer_regions) |
| x_max = max(r.bbox.x_max for r in footer_regions) |
| y_max = max(r.bbox.y_max for r in footer_regions) |
|
|
| regions.append(LayoutRegion( |
| id=make_id(), |
| type=LayoutType.FOOTER, |
| confidence=0.7, |
| bbox=BoundingBox(x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=False), |
| page=page_number, |
| )) |
|
|
| return regions |
|
|
| def _detect_figures_from_image( |
| self, |
| image: np.ndarray, |
| page_number: int, |
| make_id, |
| ocr_regions: Optional[List[OCRRegion]], |
| ) -> List[LayoutRegion]: |
| """Detect figure regions using image analysis.""" |
| |
| regions = [] |
|
|
| |
| if ocr_regions: |
| height, width = image.shape[:2] |
|
|
| |
| text_mask = np.zeros((height, width), dtype=np.uint8) |
| for r in ocr_regions: |
| bbox = r.bbox |
| x1, y1, x2, y2 = int(bbox.x_min), int(bbox.y_min), int(bbox.x_max), int(bbox.y_max) |
| text_mask[y1:y2, x1:x2] = 255 |
|
|
| |
| |
| |
|
|
| return regions |
|
|
| def _group_into_lines( |
| self, |
| ocr_regions: List[OCRRegion], |
| ) -> Dict[int, List[OCRRegion]]: |
| """Group OCR regions into lines based on y-position.""" |
| if not ocr_regions: |
| return {} |
|
|
| lines = defaultdict(list) |
| y_tolerance = 10 |
|
|
| |
| sorted_regions = sorted(ocr_regions, key=lambda r: r.bbox.y_min) |
|
|
| current_line_id = 0 |
| current_y = sorted_regions[0].bbox.y_min if sorted_regions else 0 |
|
|
| for region in sorted_regions: |
| if abs(region.bbox.y_min - current_y) > y_tolerance: |
| current_line_id += 1 |
| current_y = region.bbox.y_min |
| lines[current_line_id].append(region) |
|
|
| |
| for line_id in lines: |
| lines[line_id] = sorted(lines[line_id], key=lambda r: r.bbox.x_min) |
|
|
| return dict(lines) |
|
|
| def _group_lines_into_paragraphs( |
| self, |
| lines: Dict[int, List[OCRRegion]], |
| all_regions: List[OCRRegion], |
| ) -> List[List[int]]: |
| """Group lines into paragraphs based on spacing.""" |
| if not lines: |
| return [] |
|
|
| paragraphs = [] |
| current_para = [] |
|
|
| sorted_line_ids = sorted(lines.keys()) |
|
|
| for i, line_id in enumerate(sorted_line_ids): |
| if not current_para: |
| current_para.append(line_id) |
| continue |
|
|
| prev_line = lines[sorted_line_ids[i - 1]] |
| curr_line = lines[line_id] |
|
|
| if not prev_line or not curr_line: |
| continue |
|
|
| |
| prev_y_max = max(r.bbox.y_max for r in prev_line) |
| curr_y_min = min(r.bbox.y_min for r in curr_line) |
| gap = curr_y_min - prev_y_max |
|
|
| |
| avg_height = np.mean([r.bbox.height for r in prev_line + curr_line]) |
|
|
| |
| if gap > avg_height * 1.5: |
| paragraphs.append(current_para) |
| current_para = [line_id] |
| else: |
| current_para.append(line_id) |
|
|
| if current_para: |
| paragraphs.append(current_para) |
|
|
| return paragraphs |
|
|
| def _merge_overlapping_regions( |
| self, |
| regions: List[LayoutRegion], |
| ) -> List[LayoutRegion]: |
| """Merge overlapping regions of the same type.""" |
| if not regions: |
| return [] |
|
|
| |
| by_type = defaultdict(list) |
| for r in regions: |
| by_type[r.type].append(r) |
|
|
| merged = [] |
| for layout_type, type_regions in by_type.items(): |
| |
| |
| merged.extend(type_regions) |
|
|
| return merged |
|
|
| def _assign_reading_order( |
| self, |
| regions: List[LayoutRegion], |
| ) -> List[LayoutRegion]: |
| """Assign reading order to regions (top-to-bottom, left-to-right).""" |
| if not regions: |
| return [] |
|
|
| |
| sorted_regions = sorted( |
| regions, |
| key=lambda r: (r.bbox.y_min, r.bbox.x_min) |
| ) |
|
|
| for i, region in enumerate(sorted_regions): |
| region.reading_order = i |
|
|
| return sorted_regions |
|
|
|
|
| |
| _layout_detector: Optional[LayoutDetector] = None |
|
|
|
|
| def create_layout_detector( |
| config: Optional[LayoutConfig] = None, |
| initialize: bool = True, |
| ) -> LayoutDetector: |
| """Create a layout detector instance.""" |
| if config is None: |
| config = LayoutConfig() |
|
|
| if config.method == "rule_based": |
| detector = RuleBasedLayoutDetector(config) |
| else: |
| |
| logger.warning(f"Unknown method {config.method}, using rule_based") |
| detector = RuleBasedLayoutDetector(config) |
|
|
| if initialize: |
| detector.initialize() |
|
|
| return detector |
|
|
|
|
| def get_layout_detector( |
| config: Optional[LayoutConfig] = None, |
| ) -> LayoutDetector: |
| """Get or create singleton layout detector.""" |
| global _layout_detector |
| if _layout_detector is None: |
| _layout_detector = create_layout_detector(config) |
| return _layout_detector |
|
|