| """ |
| Field Extraction Engine |
| |
| Extracts structured data from parsed documents using schemas. |
| """ |
|
|
| import logging |
| import re |
| from dataclasses import dataclass, field |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
| from ..chunks.models import ( |
| DocumentChunk, |
| ExtractionResult, |
| FieldExtraction, |
| EvidenceRef, |
| ParseResult, |
| TableChunk, |
| ChartChunk, |
| ChunkType, |
| ConfidenceLevel, |
| ) |
| from ..grounding.evidence import EvidenceBuilder, EvidenceTracker |
| from .schema import ExtractionSchema, FieldSpec, FieldType |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class ExtractionConfig: |
| """Configuration for field extraction.""" |
|
|
| |
| min_field_confidence: float = 0.5 |
| min_overall_confidence: float = 0.5 |
|
|
| |
| abstain_on_low_confidence: bool = True |
| abstain_threshold: float = 0.3 |
|
|
| |
| search_all_chunks: bool = True |
| prefer_structured_sources: bool = True |
|
|
| |
| validate_extracted_values: bool = True |
| normalize_values: bool = True |
|
|
|
|
| class FieldExtractor: |
| """ |
| Extracts fields from parsed documents. |
| |
| Uses schema definitions to identify and extract |
| structured data with evidence grounding. |
| """ |
|
|
| def __init__( |
| self, |
| config: Optional[ExtractionConfig] = None, |
| evidence_builder: Optional[EvidenceBuilder] = None, |
| ): |
| self.config = config or ExtractionConfig() |
| self.evidence_builder = evidence_builder or EvidenceBuilder() |
| self._normalizers: Dict[FieldType, Callable] = self._build_normalizers() |
| self._validators: Dict[FieldType, Callable] = self._build_validators() |
|
|
| def extract( |
| self, |
| parse_result: ParseResult, |
| schema: ExtractionSchema, |
| ) -> ExtractionResult: |
| """ |
| Extract fields from a parsed document. |
| |
| Args: |
| parse_result: Parsed document with chunks |
| schema: Extraction schema defining fields |
| |
| Returns: |
| ExtractionResult with extracted values and evidence |
| """ |
| logger.info(f"Extracting {len(schema.fields)} fields from {parse_result.filename}") |
|
|
| evidence_tracker = EvidenceTracker() |
| field_extractions: List[FieldExtraction] = [] |
| extracted_data: Dict[str, Any] = {} |
| abstained_fields: List[str] = [] |
|
|
| for field_spec in schema.fields: |
| extraction = self._extract_field( |
| field_spec=field_spec, |
| chunks=parse_result.chunks, |
| evidence_tracker=evidence_tracker, |
| ) |
|
|
| if extraction: |
| field_extractions.append(extraction) |
| extracted_data[field_spec.name] = extraction.value |
|
|
| |
| if extraction.confidence < self.config.abstain_threshold: |
| if self.config.abstain_on_low_confidence: |
| abstained_fields.append(field_spec.name) |
| extracted_data[field_spec.name] = None |
| else: |
| |
| if field_spec.required: |
| abstained_fields.append(field_spec.name) |
| extracted_data[field_spec.name] = field_spec.default |
|
|
| |
| if field_extractions: |
| overall_confidence = sum(f.confidence for f in field_extractions) / len(field_extractions) |
| else: |
| overall_confidence = 0.0 |
|
|
| return ExtractionResult( |
| data=extracted_data, |
| fields=field_extractions, |
| evidence=evidence_tracker.get_all(), |
| overall_confidence=overall_confidence, |
| abstained_fields=abstained_fields, |
| ) |
|
|
| def _extract_field( |
| self, |
| field_spec: FieldSpec, |
| chunks: List[DocumentChunk], |
| evidence_tracker: EvidenceTracker, |
| ) -> Optional[FieldExtraction]: |
| """Extract a single field from chunks.""" |
| candidates: List[Tuple[Any, float, DocumentChunk]] = [] |
|
|
| |
| relevant_chunks = self._find_relevant_chunks(field_spec, chunks) |
|
|
| for chunk in relevant_chunks: |
| value, confidence = self._extract_from_chunk(field_spec, chunk) |
|
|
| if value is not None and confidence >= self.config.min_field_confidence: |
| candidates.append((value, confidence, chunk)) |
|
|
| if not candidates: |
| return None |
|
|
| |
| candidates.sort(key=lambda x: x[1], reverse=True) |
| best_value, best_confidence, best_chunk = candidates[0] |
|
|
| |
| if self.config.normalize_values: |
| best_value = self._normalize_value(best_value, field_spec.field_type) |
|
|
| |
| if self.config.validate_extracted_values: |
| is_valid = self._validate_value(best_value, field_spec) |
| if not is_valid: |
| best_confidence *= 0.5 |
|
|
| |
| evidence = self.evidence_builder.create_evidence( |
| chunk=best_chunk, |
| value=best_value, |
| field_name=field_spec.name, |
| ) |
| evidence_tracker.add(evidence, field_spec.name) |
|
|
| |
| confidence_level = self._confidence_to_level(best_confidence) |
|
|
| return FieldExtraction( |
| field_name=field_spec.name, |
| value=best_value, |
| confidence=best_confidence, |
| confidence_level=confidence_level, |
| evidence=evidence, |
| raw_text=best_chunk.text[:200], |
| ) |
|
|
| def _find_relevant_chunks( |
| self, |
| field_spec: FieldSpec, |
| chunks: List[DocumentChunk], |
| ) -> List[DocumentChunk]: |
| """Find chunks that might contain the field value.""" |
| |
| search_terms = [field_spec.name.lower().replace("_", " ")] |
| search_terms.extend(a.lower() for a in field_spec.aliases) |
| search_terms.extend(h.lower() for h in field_spec.context_hints) |
|
|
| relevant = [] |
|
|
| for chunk in chunks: |
| |
| if self.config.prefer_structured_sources: |
| if isinstance(chunk, (TableChunk, )) or chunk.chunk_type == ChunkType.FORM_FIELD: |
| relevant.append(chunk) |
| continue |
|
|
| |
| text_lower = chunk.text.lower() |
| for term in search_terms: |
| if term in text_lower: |
| relevant.append(chunk) |
| break |
|
|
| |
| if not relevant and self.config.search_all_chunks: |
| return chunks |
|
|
| return relevant |
|
|
| def _extract_from_chunk( |
| self, |
| field_spec: FieldSpec, |
| chunk: DocumentChunk, |
| ) -> Tuple[Optional[Any], float]: |
| """Extract field value from a single chunk.""" |
| |
| if isinstance(chunk, TableChunk): |
| return self._extract_from_table(field_spec, chunk) |
|
|
| |
| return self._extract_from_text(field_spec, chunk.text) |
|
|
| def _extract_from_table( |
| self, |
| field_spec: FieldSpec, |
| table: TableChunk, |
| ) -> Tuple[Optional[Any], float]: |
| """Extract field from a table chunk.""" |
| search_terms = [field_spec.name.lower().replace("_", " ")] |
| search_terms.extend(a.lower() for a in field_spec.aliases) |
|
|
| |
| for col_idx in range(table.num_cols): |
| header_cell = table.get_cell(0, col_idx) |
| if header_cell is None: |
| continue |
|
|
| header_text = header_cell.text.lower() |
| for term in search_terms: |
| if term in header_text: |
| |
| value_cell = table.get_cell(1, col_idx) |
| if value_cell and value_cell.text: |
| return value_cell.text, value_cell.confidence |
|
|
| |
| for row_idx in range(table.num_rows): |
| label_cell = table.get_cell(row_idx, 0) |
| if label_cell is None: |
| continue |
|
|
| label_text = label_cell.text.lower() |
| for term in search_terms: |
| if term in label_text: |
| |
| value_cell = table.get_cell(row_idx, 1) |
| if value_cell and value_cell.text: |
| return value_cell.text, value_cell.confidence |
|
|
| return None, 0.0 |
|
|
| def _extract_from_text( |
| self, |
| field_spec: FieldSpec, |
| text: str, |
| ) -> Tuple[Optional[Any], float]: |
| """Extract field from text using patterns.""" |
| |
| patterns = self._get_extraction_patterns(field_spec) |
|
|
| for pattern, confidence_boost in patterns: |
| matches = re.findall(pattern, text, re.IGNORECASE) |
| if matches: |
| |
| value = matches[0] |
| if isinstance(value, tuple): |
| value = value[0] |
| return value.strip(), 0.7 + confidence_boost |
|
|
| |
| search_terms = [field_spec.name.replace("_", " ")] |
| search_terms.extend(field_spec.aliases) |
|
|
| for term in search_terms: |
| |
| pattern = rf"{re.escape(term)}[\s::\-]+([^\n]+)" |
| matches = re.findall(pattern, text, re.IGNORECASE) |
| if matches: |
| return matches[0].strip(), 0.6 |
|
|
| return None, 0.0 |
|
|
| def _get_extraction_patterns( |
| self, |
| field_spec: FieldSpec, |
| ) -> List[Tuple[str, float]]: |
| """Get regex patterns for field type.""" |
| patterns = [] |
|
|
| |
| if field_spec.pattern: |
| patterns.append((field_spec.pattern, 0.2)) |
|
|
| |
| if field_spec.field_type == FieldType.DATE: |
| patterns.extend([ |
| (r'\b(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})\b', 0.1), |
| (r'\b(\d{4}[/-]\d{1,2}[/-]\d{1,2})\b', 0.1), |
| (r'\b([A-Z][a-z]+\s+\d{1,2},?\s+\d{4})\b', 0.1), |
| ]) |
| elif field_spec.field_type == FieldType.CURRENCY: |
| patterns.extend([ |
| (r'[\$\€\£][\s]*([\d,]+\.?\d*)', 0.2), |
| (r'([\d,]+\.?\d*)\s*(?:USD|EUR|GBP)', 0.1), |
| ]) |
| elif field_spec.field_type == FieldType.PERCENTAGE: |
| patterns.append((r'([\d.]+)\s*%', 0.2)) |
| elif field_spec.field_type == FieldType.EMAIL: |
| patterns.append((r'([a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,})', 0.3)) |
| elif field_spec.field_type == FieldType.PHONE: |
| patterns.extend([ |
| (r'\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}', 0.2), |
| (r'\+\d{1,3}[-.\s]?\d{1,4}[-.\s]?\d{1,4}[-.\s]?\d{1,9}', 0.2), |
| ]) |
| elif field_spec.field_type == FieldType.INTEGER: |
| patterns.append((r'\b(\d+)\b', 0.0)) |
| elif field_spec.field_type == FieldType.FLOAT: |
| patterns.append((r'\b(\d+\.?\d*)\b', 0.0)) |
|
|
| return patterns |
|
|
| def _normalize_value(self, value: Any, field_type: FieldType) -> Any: |
| """Normalize extracted value.""" |
| normalizer = self._normalizers.get(field_type) |
| if normalizer: |
| try: |
| return normalizer(value) |
| except Exception: |
| pass |
| return value |
|
|
| def _validate_value(self, value: Any, field_spec: FieldSpec) -> bool: |
| """Validate extracted value against field spec.""" |
| if value is None: |
| return not field_spec.required |
|
|
| |
| validator = self._validators.get(field_spec.field_type) |
| if validator and not validator(value): |
| return False |
|
|
| |
| if field_spec.pattern: |
| if not re.match(field_spec.pattern, str(value)): |
| return False |
|
|
| |
| if field_spec.min_value is not None: |
| try: |
| if float(value) < field_spec.min_value: |
| return False |
| except (ValueError, TypeError): |
| pass |
|
|
| if field_spec.max_value is not None: |
| try: |
| if float(value) > field_spec.max_value: |
| return False |
| except (ValueError, TypeError): |
| pass |
|
|
| |
| if field_spec.min_length is not None: |
| if len(str(value)) < field_spec.min_length: |
| return False |
|
|
| if field_spec.max_length is not None: |
| if len(str(value)) > field_spec.max_length: |
| return False |
|
|
| |
| if field_spec.allowed_values: |
| if value not in field_spec.allowed_values: |
| return False |
|
|
| return True |
|
|
| def _confidence_to_level(self, confidence: float) -> ConfidenceLevel: |
| """Convert numeric confidence to level.""" |
| if confidence >= 0.9: |
| return ConfidenceLevel.VERY_HIGH |
| elif confidence >= 0.7: |
| return ConfidenceLevel.HIGH |
| elif confidence >= 0.5: |
| return ConfidenceLevel.MEDIUM |
| elif confidence >= 0.3: |
| return ConfidenceLevel.LOW |
| else: |
| return ConfidenceLevel.VERY_LOW |
|
|
| def _build_normalizers(self) -> Dict[FieldType, Callable]: |
| """Build value normalizers for each type.""" |
| return { |
| FieldType.STRING: lambda v: str(v).strip(), |
| FieldType.INTEGER: lambda v: int(re.sub(r'[^\d-]', '', str(v))), |
| FieldType.FLOAT: lambda v: float(re.sub(r'[^\d.-]', '', str(v))), |
| FieldType.BOOLEAN: lambda v: str(v).lower() in ('true', 'yes', '1', 'y'), |
| FieldType.CURRENCY: self._normalize_currency, |
| FieldType.PERCENTAGE: lambda v: float(re.sub(r'[^\d.-]', '', str(v))), |
| FieldType.EMAIL: lambda v: str(v).lower().strip(), |
| FieldType.PHONE: self._normalize_phone, |
| } |
|
|
| def _build_validators(self) -> Dict[FieldType, Callable]: |
| """Build validators for each type.""" |
| return { |
| FieldType.EMAIL: lambda v: '@' in str(v) and '.' in str(v), |
| FieldType.PHONE: lambda v: len(re.sub(r'\D', '', str(v))) >= 7, |
| FieldType.DATE: lambda v: bool(re.search(r'\d', str(v))), |
| } |
|
|
| def _normalize_currency(self, value: str) -> str: |
| """Normalize currency value.""" |
| |
| amount = re.sub(r'[^\d.,]', '', str(value)) |
| |
| if ',' in amount and '.' in amount: |
| if amount.rfind(',') > amount.rfind('.'): |
| |
| amount = amount.replace('.', '').replace(',', '.') |
| elif ',' in amount: |
| |
| parts = amount.split(',') |
| if len(parts[-1]) == 2: |
| |
| amount = amount.replace(',', '.') |
| else: |
| |
| amount = amount.replace(',', '') |
| return amount |
|
|
| def _normalize_phone(self, value: str) -> str: |
| """Normalize phone number.""" |
| digits = re.sub(r'\D', '', str(value)) |
| if len(digits) == 10: |
| return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}" |
| elif len(digits) == 11 and digits[0] == '1': |
| return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}" |
| return value |
|
|