| """ |
| Extraction Validation |
| |
| Validates extracted data and provides confidence scoring. |
| """ |
|
|
| import logging |
| from dataclasses import dataclass, field |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| from ..chunks.models import ( |
| ExtractionResult, |
| FieldExtraction, |
| ConfidenceLevel, |
| ) |
| from .schema import ExtractionSchema, FieldSpec, FieldType |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class ValidationIssue: |
| """A validation issue found during extraction validation.""" |
|
|
| field_name: str |
| issue_type: str |
| message: str |
| severity: str = "warning" |
| suggested_action: Optional[str] = None |
|
|
|
|
| @dataclass |
| class ValidationResult: |
| """Result of extraction validation.""" |
|
|
| is_valid: bool |
| issues: List[ValidationIssue] = field(default_factory=list) |
| confidence_score: float = 0.0 |
| field_scores: Dict[str, float] = field(default_factory=dict) |
| recommendations: List[str] = field(default_factory=list) |
|
|
| @property |
| def error_count(self) -> int: |
| return sum(1 for i in self.issues if i.severity == "error") |
|
|
| @property |
| def warning_count(self) -> int: |
| return sum(1 for i in self.issues if i.severity == "warning") |
|
|
| def get_issues_for_field(self, field_name: str) -> List[ValidationIssue]: |
| """Get all issues for a specific field.""" |
| return [i for i in self.issues if i.field_name == field_name] |
|
|
|
|
| class ExtractionValidator: |
| """ |
| Validates extraction results against schemas. |
| |
| Checks for: |
| - Required field presence |
| - Type correctness |
| - Value constraints |
| - Confidence thresholds |
| """ |
|
|
| def __init__( |
| self, |
| min_confidence: float = 0.5, |
| strict_mode: bool = False, |
| ): |
| self.min_confidence = min_confidence |
| self.strict_mode = strict_mode |
|
|
| def validate( |
| self, |
| extraction: ExtractionResult, |
| schema: ExtractionSchema, |
| ) -> ValidationResult: |
| """ |
| Validate extraction result against schema. |
| |
| Args: |
| extraction: Extraction result to validate |
| schema: Schema defining expected fields |
| |
| Returns: |
| ValidationResult with issues and scores |
| """ |
| issues: List[ValidationIssue] = [] |
| field_scores: Dict[str, float] = {} |
|
|
| |
| for field_spec in schema.fields: |
| field_issues, score = self._validate_field( |
| field_spec=field_spec, |
| extraction=extraction, |
| ) |
| issues.extend(field_issues) |
| field_scores[field_spec.name] = score |
|
|
| |
| expected_fields = {f.name for f in schema.fields} |
| for field_name in extraction.data.keys(): |
| if field_name not in expected_fields: |
| issues.append(ValidationIssue( |
| field_name=field_name, |
| issue_type="unexpected", |
| message=f"Unexpected field: {field_name}", |
| severity="info", |
| )) |
|
|
| |
| if field_scores: |
| confidence_score = sum(field_scores.values()) / len(field_scores) |
| else: |
| confidence_score = 0.0 |
|
|
| |
| is_valid = ( |
| all(i.severity != "error" for i in issues) and |
| confidence_score >= schema.min_overall_confidence |
| ) |
|
|
| |
| recommendations = self._generate_recommendations(issues, extraction) |
|
|
| return ValidationResult( |
| is_valid=is_valid, |
| issues=issues, |
| confidence_score=confidence_score, |
| field_scores=field_scores, |
| recommendations=recommendations, |
| ) |
|
|
| def _validate_field( |
| self, |
| field_spec: FieldSpec, |
| extraction: ExtractionResult, |
| ) -> Tuple[List[ValidationIssue], float]: |
| """Validate a single field.""" |
| issues: List[ValidationIssue] = [] |
| score = 1.0 |
|
|
| value = extraction.data.get(field_spec.name) |
| field_extraction = self._get_field_extraction(field_spec.name, extraction) |
|
|
| |
| if value is None: |
| if field_spec.required: |
| issues.append(ValidationIssue( |
| field_name=field_spec.name, |
| issue_type="missing", |
| message=f"Required field '{field_spec.name}' is missing", |
| severity="error", |
| suggested_action="Manual review required", |
| )) |
| return issues, 0.0 |
| else: |
| return issues, 1.0 |
|
|
| |
| if field_spec.name in extraction.abstained_fields: |
| issues.append(ValidationIssue( |
| field_name=field_spec.name, |
| issue_type="abstained", |
| message=f"Field '{field_spec.name}' was abstained due to low confidence", |
| severity="warning", |
| suggested_action="Manual verification recommended", |
| )) |
| score *= 0.5 |
|
|
| |
| if field_extraction: |
| if field_extraction.confidence < self.min_confidence: |
| issues.append(ValidationIssue( |
| field_name=field_spec.name, |
| issue_type="low_confidence", |
| message=f"Field '{field_spec.name}' has low confidence: {field_extraction.confidence:.2f}", |
| severity="warning", |
| suggested_action="Manual verification recommended", |
| )) |
| score *= field_extraction.confidence |
| else: |
| score *= field_extraction.confidence |
|
|
| |
| type_issues = self._validate_type(field_spec, value) |
| issues.extend(type_issues) |
| if type_issues: |
| score *= 0.7 |
|
|
| |
| constraint_issues = self._validate_constraints(field_spec, value) |
| issues.extend(constraint_issues) |
| if constraint_issues: |
| score *= 0.8 |
|
|
| return issues, max(0.0, min(1.0, score)) |
|
|
| def _validate_type( |
| self, |
| field_spec: FieldSpec, |
| value: Any, |
| ) -> List[ValidationIssue]: |
| """Validate field type.""" |
| issues = [] |
|
|
| expected_type = self._get_expected_python_type(field_spec.field_type) |
|
|
| if expected_type and not isinstance(value, expected_type): |
| |
| try: |
| expected_type(value) |
| except (ValueError, TypeError): |
| issues.append(ValidationIssue( |
| field_name=field_spec.name, |
| issue_type="type_mismatch", |
| message=f"Field '{field_spec.name}' expected {field_spec.field_type.value}, got {type(value).__name__}", |
| severity="warning" if not self.strict_mode else "error", |
| )) |
|
|
| return issues |
|
|
| def _validate_constraints( |
| self, |
| field_spec: FieldSpec, |
| value: Any, |
| ) -> List[ValidationIssue]: |
| """Validate field constraints.""" |
| issues = [] |
|
|
| |
| if field_spec.pattern: |
| import re |
| if not re.match(field_spec.pattern, str(value)): |
| issues.append(ValidationIssue( |
| field_name=field_spec.name, |
| issue_type="pattern_mismatch", |
| message=f"Field '{field_spec.name}' does not match pattern: {field_spec.pattern}", |
| severity="warning", |
| )) |
|
|
| |
| try: |
| num_value = float(value) |
| if field_spec.min_value is not None and num_value < field_spec.min_value: |
| issues.append(ValidationIssue( |
| field_name=field_spec.name, |
| issue_type="below_minimum", |
| message=f"Field '{field_spec.name}' value {num_value} is below minimum {field_spec.min_value}", |
| severity="warning", |
| )) |
| if field_spec.max_value is not None and num_value > field_spec.max_value: |
| issues.append(ValidationIssue( |
| field_name=field_spec.name, |
| issue_type="above_maximum", |
| message=f"Field '{field_spec.name}' value {num_value} is above maximum {field_spec.max_value}", |
| severity="warning", |
| )) |
| except (ValueError, TypeError): |
| pass |
|
|
| |
| str_value = str(value) |
| if field_spec.min_length is not None and len(str_value) < field_spec.min_length: |
| issues.append(ValidationIssue( |
| field_name=field_spec.name, |
| issue_type="too_short", |
| message=f"Field '{field_spec.name}' is too short: {len(str_value)} < {field_spec.min_length}", |
| severity="warning", |
| )) |
| if field_spec.max_length is not None and len(str_value) > field_spec.max_length: |
| issues.append(ValidationIssue( |
| field_name=field_spec.name, |
| issue_type="too_long", |
| message=f"Field '{field_spec.name}' is too long: {len(str_value)} > {field_spec.max_length}", |
| severity="warning", |
| )) |
|
|
| |
| if field_spec.allowed_values and value not in field_spec.allowed_values: |
| issues.append(ValidationIssue( |
| field_name=field_spec.name, |
| issue_type="not_in_allowed", |
| message=f"Field '{field_spec.name}' value '{value}' not in allowed values", |
| severity="warning", |
| )) |
|
|
| return issues |
|
|
| def _get_field_extraction( |
| self, |
| field_name: str, |
| extraction: ExtractionResult, |
| ) -> Optional[FieldExtraction]: |
| """Get field extraction by name.""" |
| for fe in extraction.fields: |
| if fe.field_name == field_name: |
| return fe |
| return None |
|
|
| def _get_expected_python_type(self, field_type: FieldType) -> Optional[type]: |
| """Get expected Python type for field type.""" |
| type_map = { |
| FieldType.INTEGER: int, |
| FieldType.FLOAT: float, |
| FieldType.BOOLEAN: bool, |
| FieldType.LIST: list, |
| FieldType.OBJECT: dict, |
| } |
| return type_map.get(field_type) |
|
|
| def _generate_recommendations( |
| self, |
| issues: List[ValidationIssue], |
| extraction: ExtractionResult, |
| ) -> List[str]: |
| """Generate recommendations based on issues.""" |
| recommendations = [] |
|
|
| |
| missing_count = sum(1 for i in issues if i.issue_type == "missing") |
| low_conf_count = sum(1 for i in issues if i.issue_type == "low_confidence") |
| type_count = sum(1 for i in issues if i.issue_type == "type_mismatch") |
|
|
| if missing_count > 0: |
| recommendations.append( |
| f"Review document for {missing_count} missing required field(s)" |
| ) |
|
|
| if low_conf_count > 0: |
| recommendations.append( |
| f"Manual verification recommended for {low_conf_count} low-confidence field(s)" |
| ) |
|
|
| if type_count > 0: |
| recommendations.append( |
| f"Check data types for {type_count} field(s) with type mismatches" |
| ) |
|
|
| if extraction.overall_confidence < 0.5: |
| recommendations.append( |
| "Overall extraction confidence is low - consider manual review" |
| ) |
|
|
| if len(extraction.abstained_fields) > 0: |
| recommendations.append( |
| f"System abstained on {len(extraction.abstained_fields)} field(s) due to uncertainty" |
| ) |
|
|
| return recommendations |
|
|
|
|
| class CrossFieldValidator: |
| """ |
| Validates relationships between fields. |
| |
| Checks for: |
| - Consistency (e.g., subtotal + tax = total) |
| - Logical relationships |
| - Date ordering |
| """ |
|
|
| def validate_consistency( |
| self, |
| extraction: ExtractionResult, |
| rules: List[Dict[str, Any]], |
| ) -> List[ValidationIssue]: |
| """ |
| Validate cross-field consistency rules. |
| |
| Rules format: |
| { |
| "type": "sum", |
| "fields": ["subtotal", "tax"], |
| "equals": "total", |
| "tolerance": 0.01 |
| } |
| """ |
| issues = [] |
|
|
| for rule in rules: |
| rule_type = rule.get("type") |
|
|
| if rule_type == "sum": |
| issue = self._validate_sum_rule(extraction, rule) |
| if issue: |
| issues.append(issue) |
|
|
| elif rule_type == "date_order": |
| issue = self._validate_date_order(extraction, rule) |
| if issue: |
| issues.append(issue) |
|
|
| elif rule_type == "required_if": |
| issue = self._validate_required_if(extraction, rule) |
| if issue: |
| issues.append(issue) |
|
|
| return issues |
|
|
| def _validate_sum_rule( |
| self, |
| extraction: ExtractionResult, |
| rule: Dict[str, Any], |
| ) -> Optional[ValidationIssue]: |
| """Validate that sum of fields equals another field.""" |
| fields = rule.get("fields", []) |
| equals_field = rule.get("equals") |
| tolerance = rule.get("tolerance", 0.01) |
|
|
| try: |
| sum_value = sum( |
| float(extraction.data.get(f, 0) or 0) |
| for f in fields |
| ) |
| expected = float(extraction.data.get(equals_field, 0) or 0) |
|
|
| if abs(sum_value - expected) > tolerance: |
| return ValidationIssue( |
| field_name=equals_field, |
| issue_type="sum_mismatch", |
| message=f"Sum of {fields} ({sum_value}) does not equal {equals_field} ({expected})", |
| severity="warning", |
| ) |
| except (ValueError, TypeError): |
| pass |
|
|
| return None |
|
|
| def _validate_date_order( |
| self, |
| extraction: ExtractionResult, |
| rule: Dict[str, Any], |
| ) -> Optional[ValidationIssue]: |
| """Validate that dates are in correct order.""" |
| from datetime import datetime |
|
|
| before_field = rule.get("before") |
| after_field = rule.get("after") |
|
|
| before_val = extraction.data.get(before_field) |
| after_val = extraction.data.get(after_field) |
|
|
| if not before_val or not after_val: |
| return None |
|
|
| try: |
| |
| formats = ["%Y-%m-%d", "%m/%d/%Y", "%d/%m/%Y", "%B %d, %Y"] |
|
|
| before_date = None |
| after_date = None |
|
|
| for fmt in formats: |
| try: |
| before_date = datetime.strptime(str(before_val), fmt) |
| break |
| except ValueError: |
| continue |
|
|
| for fmt in formats: |
| try: |
| after_date = datetime.strptime(str(after_val), fmt) |
| break |
| except ValueError: |
| continue |
|
|
| if before_date and after_date and before_date > after_date: |
| return ValidationIssue( |
| field_name=after_field, |
| issue_type="date_order", |
| message=f"Date {before_field} ({before_val}) should be before {after_field} ({after_val})", |
| severity="warning", |
| ) |
| except Exception: |
| pass |
|
|
| return None |
|
|
| def _validate_required_if( |
| self, |
| extraction: ExtractionResult, |
| rule: Dict[str, Any], |
| ) -> Optional[ValidationIssue]: |
| """Validate conditional required fields.""" |
| field = rule.get("field") |
| required_if = rule.get("required_if") |
| condition_value = rule.get("value") |
|
|
| condition_field_value = extraction.data.get(required_if) |
|
|
| |
| condition_met = False |
| if condition_value is not None: |
| condition_met = condition_field_value == condition_value |
| else: |
| condition_met = condition_field_value is not None |
|
|
| if condition_met: |
| field_value = extraction.data.get(field) |
| if field_value is None: |
| return ValidationIssue( |
| field_name=field, |
| issue_type="conditional_required", |
| message=f"Field '{field}' is required when '{required_if}' is present", |
| severity="warning", |
| ) |
|
|
| return None |
|
|