| """ |
| Chart Extraction Model Interface |
| |
| Abstract interface for chart/graph understanding models. |
| Extracts data points, axes, legends, and interprets visualizations. |
| """ |
|
|
| from abc import abstractmethod |
| from dataclasses import dataclass, field |
| from enum import Enum |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| from ..chunks.models import BoundingBox, ChartChunk, ChartDataPoint |
| from .base import ( |
| BaseModel, |
| BatchableModel, |
| ImageInput, |
| ModelCapability, |
| ModelConfig, |
| ) |
|
|
|
|
| class ChartType(str, Enum): |
| """Types of charts that can be detected.""" |
|
|
| |
| BAR = "bar" |
| LINE = "line" |
| PIE = "pie" |
| SCATTER = "scatter" |
| AREA = "area" |
|
|
| |
| HISTOGRAM = "histogram" |
| BOX_PLOT = "box_plot" |
| HEATMAP = "heatmap" |
| TREEMAP = "treemap" |
| RADAR = "radar" |
| BUBBLE = "bubble" |
| WATERFALL = "waterfall" |
| FUNNEL = "funnel" |
| GANTT = "gantt" |
|
|
| |
| STACKED_BAR = "stacked_bar" |
| GROUPED_BAR = "grouped_bar" |
| MULTI_LINE = "multi_line" |
| COMBO = "combo" |
|
|
| |
| DIAGRAM = "diagram" |
| UNKNOWN = "unknown" |
|
|
|
|
| @dataclass |
| class ChartConfig(ModelConfig): |
| """Configuration for chart extraction models.""" |
|
|
| min_confidence: float = 0.5 |
| extract_data_points: bool = True |
| extract_trends: bool = True |
| max_data_points: int = 1000 |
| detect_chart_type: bool = True |
|
|
| def __post_init__(self): |
| super().__post_init__() |
| if not self.name: |
| self.name = "chart_extractor" |
|
|
|
|
| @dataclass |
| class AxisInfo: |
| """Information about a chart axis.""" |
|
|
| label: str = "" |
| unit: str = "" |
| min_value: Optional[float] = None |
| max_value: Optional[float] = None |
| scale: str = "linear" |
| tick_labels: List[str] = field(default_factory=list) |
| tick_values: List[float] = field(default_factory=list) |
| is_datetime: bool = False |
| orientation: str = "horizontal" |
|
|
|
|
| @dataclass |
| class LegendItem: |
| """A single legend entry.""" |
|
|
| label: str |
| color: Optional[str] = None |
| series_index: int = 0 |
|
|
|
|
| @dataclass |
| class DataSeries: |
| """A data series in a chart.""" |
|
|
| name: str |
| data_points: List[ChartDataPoint] = field(default_factory=list) |
| color: Optional[str] = None |
| series_type: Optional[ChartType] = None |
|
|
| @property |
| def x_values(self) -> List[Any]: |
| return [p.x for p in self.data_points] |
|
|
| @property |
| def y_values(self) -> List[Any]: |
| return [p.y for p in self.data_points] |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """Convert to dictionary.""" |
| return { |
| "name": self.name, |
| "color": self.color, |
| "series_type": self.series_type.value if self.series_type else None, |
| "data_points": [ |
| {"x": p.x, "y": p.y, "label": p.label, "value": p.value} |
| for p in self.data_points |
| ] |
| } |
|
|
|
|
| @dataclass |
| class TrendInfo: |
| """Detected trend in the data.""" |
|
|
| description: str |
| direction: str = "neutral" |
| start_point: Optional[ChartDataPoint] = None |
| end_point: Optional[ChartDataPoint] = None |
| change_percent: Optional[float] = None |
| confidence: float = 0.0 |
|
|
|
|
| @dataclass |
| class ChartStructure: |
| """ |
| Complete extracted chart structure. |
| |
| Contains all detected elements of a chart including |
| type, axes, data series, legends, and interpretations. |
| """ |
|
|
| bbox: BoundingBox |
| chart_type: ChartType = ChartType.UNKNOWN |
| confidence: float = 0.0 |
|
|
| |
| title: str = "" |
| subtitle: str = "" |
|
|
| |
| x_axis: Optional[AxisInfo] = None |
| y_axis: Optional[AxisInfo] = None |
| secondary_y_axis: Optional[AxisInfo] = None |
|
|
| |
| series: List[DataSeries] = field(default_factory=list) |
| legend_items: List[LegendItem] = field(default_factory=list) |
|
|
| |
| key_values: Dict[str, Any] = field(default_factory=dict) |
| trends: List[TrendInfo] = field(default_factory=list) |
| summary: str = "" |
|
|
| |
| chart_id: str = "" |
| source_text: str = "" |
|
|
| def __post_init__(self): |
| if not self.chart_id: |
| import hashlib |
| content = f"chart_{self.chart_type.value}_{self.bbox.xyxy}" |
| self.chart_id = hashlib.md5(content.encode()).hexdigest()[:12] |
|
|
| @property |
| def total_data_points(self) -> int: |
| return sum(len(s.data_points) for s in self.series) |
|
|
| @property |
| def all_data_points(self) -> List[ChartDataPoint]: |
| """Get all data points from all series.""" |
| points = [] |
| for series in self.series: |
| points.extend(series.data_points) |
| return points |
|
|
| def get_series_by_name(self, name: str) -> Optional[DataSeries]: |
| """Find a series by name.""" |
| for series in self.series: |
| if series.name.lower() == name.lower(): |
| return series |
| return None |
|
|
| def to_text_description(self) -> str: |
| """Generate a text description of the chart.""" |
| parts = [] |
|
|
| if self.title: |
| parts.append(f"Chart: {self.title}") |
| else: |
| parts.append(f"Chart Type: {self.chart_type.value}") |
|
|
| if self.x_axis and self.x_axis.label: |
| parts.append(f"X-Axis: {self.x_axis.label}") |
| if self.y_axis and self.y_axis.label: |
| parts.append(f"Y-Axis: {self.y_axis.label}") |
|
|
| if self.series: |
| parts.append(f"Series: {', '.join(s.name for s in self.series if s.name)}") |
|
|
| if self.key_values: |
| kv_str = ", ".join(f"{k}: {v}" for k, v in self.key_values.items()) |
| parts.append(f"Key Values: {kv_str}") |
|
|
| if self.trends: |
| trend_strs = [t.description for t in self.trends if t.description] |
| if trend_strs: |
| parts.append(f"Trends: {'; '.join(trend_strs)}") |
|
|
| return "\n".join(parts) |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """Convert to structured dictionary.""" |
| return { |
| "chart_type": self.chart_type.value, |
| "title": self.title, |
| "x_axis": { |
| "label": self.x_axis.label if self.x_axis else "", |
| "unit": self.x_axis.unit if self.x_axis else "", |
| }, |
| "y_axis": { |
| "label": self.y_axis.label if self.y_axis else "", |
| "unit": self.y_axis.unit if self.y_axis else "", |
| }, |
| "series": [s.to_dict() for s in self.series], |
| "key_values": self.key_values, |
| "trends": [ |
| {"description": t.description, "direction": t.direction} |
| for t in self.trends |
| ], |
| "summary": self.summary |
| } |
|
|
| def to_chart_chunk( |
| self, |
| doc_id: str, |
| page: int, |
| sequence_index: int |
| ) -> ChartChunk: |
| """Convert to ChartChunk for the chunks module.""" |
| |
| all_points = self.all_data_points |
|
|
| return ChartChunk( |
| chunk_id=ChartChunk.generate_chunk_id( |
| doc_id=doc_id, |
| page=page, |
| bbox=self.bbox, |
| chunk_type_str="chart" |
| ), |
| doc_id=doc_id, |
| text=self.to_text_description(), |
| page=page, |
| bbox=self.bbox, |
| confidence=self.confidence, |
| sequence_index=sequence_index, |
| chart_type=self.chart_type.value, |
| title=self.title, |
| x_axis_label=self.x_axis.label if self.x_axis else None, |
| y_axis_label=self.y_axis.label if self.y_axis else None, |
| data_points=all_points, |
| key_values=self.key_values, |
| trends=[t.description for t in self.trends] |
| ) |
|
|
|
|
| @dataclass |
| class ChartExtractionResult: |
| """Result of chart extraction from a page.""" |
|
|
| charts: List[ChartStructure] = field(default_factory=list) |
| processing_time_ms: float = 0.0 |
| model_metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
| @property |
| def chart_count(self) -> int: |
| return len(self.charts) |
|
|
|
|
| class ChartModel(BatchableModel): |
| """ |
| Abstract base class for chart extraction models. |
| |
| Implementations should handle: |
| - Chart type classification |
| - Axis detection and labeling |
| - Data point extraction |
| - Legend parsing |
| - Trend detection |
| """ |
|
|
| def __init__(self, config: Optional[ChartConfig] = None): |
| super().__init__(config or ChartConfig(name="chart")) |
| self.config: ChartConfig = self.config |
|
|
| def get_capabilities(self) -> List[ModelCapability]: |
| return [ModelCapability.CHART_EXTRACTION] |
|
|
| @abstractmethod |
| def extract_chart( |
| self, |
| image: ImageInput, |
| chart_region: Optional[BoundingBox] = None, |
| **kwargs |
| ) -> ChartStructure: |
| """ |
| Extract chart structure from an image. |
| |
| Args: |
| image: Input image containing a chart |
| chart_region: Optional bounding box of the chart |
| **kwargs: Additional parameters |
| |
| Returns: |
| ChartStructure with extracted data |
| """ |
| pass |
|
|
| def extract_all_charts( |
| self, |
| image: ImageInput, |
| chart_regions: Optional[List[BoundingBox]] = None, |
| **kwargs |
| ) -> ChartExtractionResult: |
| """ |
| Extract all charts from an image. |
| |
| Args: |
| image: Input document image |
| chart_regions: Optional list of chart bounding boxes |
| **kwargs: Additional parameters |
| |
| Returns: |
| ChartExtractionResult with all detected charts |
| """ |
| import time |
| start_time = time.time() |
|
|
| charts = [] |
|
|
| if chart_regions: |
| for region in chart_regions: |
| try: |
| chart = self.extract_chart(image, region, **kwargs) |
| if chart.chart_type != ChartType.UNKNOWN: |
| charts.append(chart) |
| except Exception: |
| continue |
| else: |
| chart = self.extract_chart(image, **kwargs) |
| if chart.chart_type != ChartType.UNKNOWN: |
| charts.append(chart) |
|
|
| processing_time = (time.time() - start_time) * 1000 |
|
|
| return ChartExtractionResult( |
| charts=charts, |
| processing_time_ms=processing_time |
| ) |
|
|
| def process_batch( |
| self, |
| inputs: List[ImageInput], |
| **kwargs |
| ) -> List[ChartExtractionResult]: |
| """Process multiple images.""" |
| return [self.extract_all_charts(img, **kwargs) for img in inputs] |
|
|
| @abstractmethod |
| def classify_chart_type( |
| self, |
| image: ImageInput, |
| chart_region: Optional[BoundingBox] = None, |
| **kwargs |
| ) -> Tuple[ChartType, float]: |
| """ |
| Classify the type of chart in an image. |
| |
| Args: |
| image: Input image |
| chart_region: Optional bounding box |
| **kwargs: Additional parameters |
| |
| Returns: |
| Tuple of (ChartType, confidence) |
| """ |
| pass |
|
|
| def detect_trends( |
| self, |
| chart: ChartStructure, |
| **kwargs |
| ) -> List[TrendInfo]: |
| """ |
| Analyze chart data for trends. |
| |
| Default implementation provides basic trend detection. |
| Override for more sophisticated analysis. |
| """ |
| trends = [] |
|
|
| for series in chart.series: |
| if len(series.data_points) < 2: |
| continue |
|
|
| |
| y_values = [] |
| for dp in series.data_points: |
| if dp.y is not None: |
| try: |
| y_values.append(float(dp.y)) |
| except (ValueError, TypeError): |
| continue |
|
|
| if len(y_values) < 2: |
| continue |
|
|
| |
| first_half_avg = sum(y_values[:len(y_values)//2]) / (len(y_values)//2) |
| second_half_avg = sum(y_values[len(y_values)//2:]) / (len(y_values) - len(y_values)//2) |
|
|
| if second_half_avg > first_half_avg * 1.1: |
| direction = "increasing" |
| elif second_half_avg < first_half_avg * 0.9: |
| direction = "decreasing" |
| else: |
| direction = "stable" |
|
|
| change_pct = ((second_half_avg - first_half_avg) / first_half_avg * 100 |
| if first_half_avg != 0 else 0) |
|
|
| trend = TrendInfo( |
| description=f"{series.name}: {direction} trend ({change_pct:+.1f}%)", |
| direction=direction, |
| start_point=series.data_points[0], |
| end_point=series.data_points[-1], |
| change_percent=change_pct, |
| confidence=0.7 |
| ) |
| trends.append(trend) |
|
|
| return trends |
|
|