| """ |
| Extraction Critic for Validation |
| |
| Validates extracted information against source evidence. |
| Provides confidence scoring and abstention recommendations. |
| """ |
|
|
| from typing import List, Optional, Dict, Any, Tuple |
| from enum import Enum |
| from pydantic import BaseModel, Field |
| from loguru import logger |
|
|
| try: |
| import httpx |
| HTTPX_AVAILABLE = True |
| except ImportError: |
| HTTPX_AVAILABLE = False |
|
|
|
|
| class ValidationStatus(str, Enum): |
| """Validation status codes.""" |
| VALID = "valid" |
| INVALID = "invalid" |
| UNCERTAIN = "uncertain" |
| ABSTAIN = "abstain" |
| NO_EVIDENCE = "no_evidence" |
|
|
|
|
| class CriticConfig(BaseModel): |
| """Configuration for extraction critic.""" |
| |
| llm_provider: str = Field(default="ollama", description="LLM provider") |
| ollama_base_url: str = Field(default="http://localhost:11434") |
| ollama_model: str = Field(default="llama3.2:3b") |
|
|
| |
| confidence_threshold: float = Field( |
| default=0.7, |
| ge=0.0, |
| le=1.0, |
| description="Minimum confidence for valid extraction" |
| ) |
| evidence_required: bool = Field( |
| default=True, |
| description="Require evidence for validation" |
| ) |
| strict_mode: bool = Field( |
| default=False, |
| description="Strict validation mode" |
| ) |
|
|
| |
| max_fields_per_request: int = Field(default=10, ge=1) |
| timeout: float = Field(default=60.0, ge=1.0) |
|
|
|
|
| class FieldValidation(BaseModel): |
| """Validation result for a single field.""" |
| field_name: str |
| extracted_value: Any |
| status: ValidationStatus |
| confidence: float |
| reasoning: str |
|
|
| |
| evidence_found: bool = False |
| evidence_snippet: Optional[str] = None |
| evidence_page: Optional[int] = None |
|
|
| |
| suggested_value: Optional[Any] = None |
| correction_reason: Optional[str] = None |
|
|
|
|
| class ValidationResult(BaseModel): |
| """Complete validation result.""" |
| overall_status: ValidationStatus |
| overall_confidence: float |
| field_validations: List[FieldValidation] |
|
|
| |
| valid_count: int = 0 |
| invalid_count: int = 0 |
| uncertain_count: int = 0 |
| abstain_count: int = 0 |
|
|
| |
| should_accept: bool |
| abstain_reason: Optional[str] = None |
|
|
|
|
| class ExtractionCritic: |
| """ |
| Critic for validating extracted information. |
| |
| Features: |
| - Validates extracted fields against source evidence |
| - Provides confidence scores |
| - Recommends abstention when uncertain |
| - Suggests corrections when possible |
| """ |
|
|
| VALIDATION_PROMPT = """You are a critical validator for document extraction. |
| Your task is to validate extracted information against the source evidence. |
| |
| For each field, determine: |
| 1. Is the extracted value supported by the evidence? (yes/no/partially) |
| 2. Confidence score (0.0 to 1.0) |
| 3. Brief reasoning |
| 4. If incorrect, suggest the correct value |
| |
| Be strict and skeptical. Only mark as valid if clearly supported. |
| |
| Evidence: |
| {evidence} |
| |
| Extracted Fields to Validate: |
| {fields} |
| |
| Respond in JSON format: |
| {{ |
| "validations": [ |
| {{ |
| "field": "field_name", |
| "status": "valid|invalid|uncertain|no_evidence", |
| "confidence": 0.0-1.0, |
| "reasoning": "explanation", |
| "suggested_value": null or corrected value |
| }} |
| ] |
| }}""" |
|
|
| def __init__(self, config: Optional[CriticConfig] = None): |
| """Initialize extraction critic.""" |
| self.config = config or CriticConfig() |
|
|
| def validate_extraction( |
| self, |
| extracted_fields: Dict[str, Any], |
| evidence: List[Dict[str, Any]], |
| ) -> ValidationResult: |
| """ |
| Validate extracted fields against evidence. |
| |
| Args: |
| extracted_fields: Dictionary of field_name -> value |
| evidence: List of evidence chunks with text, page, etc. |
| |
| Returns: |
| ValidationResult |
| """ |
| if not extracted_fields: |
| return ValidationResult( |
| overall_status=ValidationStatus.ABSTAIN, |
| overall_confidence=0.0, |
| field_validations=[], |
| should_accept=False, |
| abstain_reason="No fields to validate", |
| ) |
|
|
| |
| if not evidence and self.config.evidence_required: |
| return self._create_no_evidence_result(extracted_fields) |
|
|
| |
| field_validations = self._validate_with_llm(extracted_fields, evidence) |
|
|
| |
| valid_count = sum(1 for v in field_validations if v.status == ValidationStatus.VALID) |
| invalid_count = sum(1 for v in field_validations if v.status == ValidationStatus.INVALID) |
| uncertain_count = sum(1 for v in field_validations if v.status == ValidationStatus.UNCERTAIN) |
| abstain_count = sum(1 for v in field_validations if v.status == ValidationStatus.ABSTAIN) |
|
|
| |
| if field_validations: |
| overall_confidence = sum(v.confidence for v in field_validations) / len(field_validations) |
| else: |
| overall_confidence = 0.0 |
|
|
| |
| if invalid_count > 0: |
| overall_status = ValidationStatus.INVALID |
| elif abstain_count > valid_count: |
| overall_status = ValidationStatus.ABSTAIN |
| elif uncertain_count > valid_count: |
| overall_status = ValidationStatus.UNCERTAIN |
| else: |
| overall_status = ValidationStatus.VALID |
|
|
| |
| should_accept = ( |
| overall_confidence >= self.config.confidence_threshold |
| and invalid_count == 0 |
| and overall_status in [ValidationStatus.VALID, ValidationStatus.UNCERTAIN] |
| ) |
|
|
| |
| abstain_reason = None |
| if not should_accept: |
| if overall_confidence < self.config.confidence_threshold: |
| abstain_reason = f"Confidence ({overall_confidence:.2f}) below threshold ({self.config.confidence_threshold})" |
| elif invalid_count > 0: |
| abstain_reason = f"{invalid_count} field(s) validated as invalid" |
| elif overall_status == ValidationStatus.ABSTAIN: |
| abstain_reason = "Insufficient evidence to validate" |
|
|
| return ValidationResult( |
| overall_status=overall_status, |
| overall_confidence=overall_confidence, |
| field_validations=field_validations, |
| valid_count=valid_count, |
| invalid_count=invalid_count, |
| uncertain_count=uncertain_count, |
| abstain_count=abstain_count, |
| should_accept=should_accept, |
| abstain_reason=abstain_reason, |
| ) |
|
|
| def _validate_with_llm( |
| self, |
| fields: Dict[str, Any], |
| evidence: List[Dict[str, Any]], |
| ) -> List[FieldValidation]: |
| """Validate fields using LLM.""" |
| |
| evidence_text = self._format_evidence(evidence) |
|
|
| |
| fields_text = "\n".join( |
| f"- {name}: {value}" |
| for name, value in fields.items() |
| ) |
|
|
| |
| prompt = self.VALIDATION_PROMPT.format( |
| evidence=evidence_text, |
| fields=fields_text, |
| ) |
|
|
| |
| try: |
| response = self._call_llm(prompt) |
| validations = self._parse_validation_response(response, fields, evidence) |
| except Exception as e: |
| logger.error(f"LLM validation failed: {e}") |
| |
| validations = self._heuristic_validation(fields, evidence) |
|
|
| return validations |
|
|
| def _format_evidence(self, evidence: List[Dict[str, Any]]) -> str: |
| """Format evidence for prompt.""" |
| parts = [] |
| for i, ev in enumerate(evidence[:10], 1): |
| page = ev.get("page", "?") |
| text = ev.get("text", ev.get("snippet", ""))[:500] |
| parts.append(f"[{i}] Page {page}: {text}") |
| return "\n\n".join(parts) |
|
|
| def _call_llm(self, prompt: str) -> str: |
| """Call LLM for validation.""" |
| if not HTTPX_AVAILABLE: |
| raise ImportError("httpx required for LLM calls") |
|
|
| with httpx.Client(timeout=self.config.timeout) as client: |
| response = client.post( |
| f"{self.config.ollama_base_url}/api/generate", |
| json={ |
| "model": self.config.ollama_model, |
| "prompt": prompt, |
| "stream": False, |
| "options": {"temperature": 0.1}, |
| }, |
| ) |
| response.raise_for_status() |
| return response.json().get("response", "") |
|
|
| def _parse_validation_response( |
| self, |
| response: str, |
| fields: Dict[str, Any], |
| evidence: List[Dict[str, Any]], |
| ) -> List[FieldValidation]: |
| """Parse LLM validation response.""" |
| import json |
| import re |
|
|
| validations = [] |
|
|
| |
| json_match = re.search(r'\{[\s\S]*\}', response) |
| if json_match: |
| try: |
| data = json.loads(json_match.group()) |
| llm_validations = data.get("validations", []) |
|
|
| for v in llm_validations: |
| field_name = v.get("field", "") |
| if field_name not in fields: |
| continue |
|
|
| status_str = v.get("status", "uncertain").lower() |
| try: |
| status = ValidationStatus(status_str) |
| except ValueError: |
| status = ValidationStatus.UNCERTAIN |
|
|
| validation = FieldValidation( |
| field_name=field_name, |
| extracted_value=fields[field_name], |
| status=status, |
| confidence=float(v.get("confidence", 0.5)), |
| reasoning=v.get("reasoning", ""), |
| evidence_found=status != ValidationStatus.NO_EVIDENCE, |
| suggested_value=v.get("suggested_value"), |
| ) |
| validations.append(validation) |
|
|
| except json.JSONDecodeError: |
| pass |
|
|
| |
| validated_fields = {v.field_name for v in validations} |
| for field_name, value in fields.items(): |
| if field_name not in validated_fields: |
| validations.append(FieldValidation( |
| field_name=field_name, |
| extracted_value=value, |
| status=ValidationStatus.UNCERTAIN, |
| confidence=0.5, |
| reasoning="Could not validate", |
| evidence_found=False, |
| )) |
|
|
| return validations |
|
|
| def _heuristic_validation( |
| self, |
| fields: Dict[str, Any], |
| evidence: List[Dict[str, Any]], |
| ) -> List[FieldValidation]: |
| """Heuristic validation when LLM fails.""" |
| validations = [] |
| evidence_text = " ".join( |
| ev.get("text", ev.get("snippet", "")).lower() |
| for ev in evidence |
| ) |
|
|
| for field_name, value in fields.items(): |
| |
| value_str = str(value).lower() |
| found = value_str in evidence_text if value_str else False |
|
|
| if found: |
| status = ValidationStatus.VALID |
| confidence = 0.7 |
| reasoning = "Value found in evidence" |
| elif evidence: |
| status = ValidationStatus.UNCERTAIN |
| confidence = 0.4 |
| reasoning = "Value not directly found in evidence" |
| else: |
| status = ValidationStatus.NO_EVIDENCE |
| confidence = 0.2 |
| reasoning = "No evidence available" |
|
|
| validations.append(FieldValidation( |
| field_name=field_name, |
| extracted_value=value, |
| status=status, |
| confidence=confidence, |
| reasoning=reasoning, |
| evidence_found=found, |
| )) |
|
|
| return validations |
|
|
| def _create_no_evidence_result( |
| self, |
| fields: Dict[str, Any], |
| ) -> ValidationResult: |
| """Create result when no evidence is available.""" |
| validations = [ |
| FieldValidation( |
| field_name=name, |
| extracted_value=value, |
| status=ValidationStatus.NO_EVIDENCE, |
| confidence=0.0, |
| reasoning="No evidence provided for validation", |
| evidence_found=False, |
| ) |
| for name, value in fields.items() |
| ] |
|
|
| return ValidationResult( |
| overall_status=ValidationStatus.ABSTAIN, |
| overall_confidence=0.0, |
| field_validations=validations, |
| abstain_count=len(validations), |
| should_accept=False, |
| abstain_reason="No evidence available for validation", |
| ) |
|
|
|
|
| |
| _extraction_critic: Optional[ExtractionCritic] = None |
|
|
|
|
| def get_extraction_critic( |
| config: Optional[CriticConfig] = None, |
| ) -> ExtractionCritic: |
| """Get or create singleton extraction critic.""" |
| global _extraction_critic |
| if _extraction_critic is None: |
| _extraction_critic = ExtractionCritic(config) |
| return _extraction_critic |
|
|
|
|
| def reset_extraction_critic(): |
| """Reset the global critic instance.""" |
| global _extraction_critic |
| _extraction_critic = None |
|
|