| """ |
| Schema Definitions for Field Extraction |
| |
| Pydantic-compatible schemas for defining extraction targets. |
| """ |
|
|
| from dataclasses import dataclass, field as dataclass_field |
| from enum import Enum |
| from typing import Any, Callable, Dict, List, Optional, Type, Union |
|
|
| from pydantic import BaseModel, Field, create_model |
|
|
|
|
| class FieldType(str, Enum): |
| """Types of extractable fields.""" |
|
|
| STRING = "string" |
| INTEGER = "integer" |
| FLOAT = "float" |
| BOOLEAN = "boolean" |
| DATE = "date" |
| DATETIME = "datetime" |
| CURRENCY = "currency" |
| PERCENTAGE = "percentage" |
| EMAIL = "email" |
| PHONE = "phone" |
| ADDRESS = "address" |
| LIST = "list" |
| OBJECT = "object" |
|
|
|
|
| @dataclass |
| class FieldSpec: |
| """Specification for a single extraction field.""" |
|
|
| name: str |
| field_type: FieldType = FieldType.STRING |
| description: str = "" |
| required: bool = True |
| default: Any = None |
|
|
| |
| pattern: Optional[str] = None |
| min_value: Optional[float] = None |
| max_value: Optional[float] = None |
| min_length: Optional[int] = None |
| max_length: Optional[int] = None |
| allowed_values: Optional[List[Any]] = None |
|
|
| |
| nested_schema: Optional["ExtractionSchema"] = None |
| list_item_type: Optional[FieldType] = None |
|
|
| |
| aliases: List[str] = dataclass_field(default_factory=list) |
| examples: List[str] = dataclass_field(default_factory=list) |
| context_hints: List[str] = dataclass_field(default_factory=list) |
|
|
| |
| min_confidence: float = 0.5 |
|
|
| def to_json_schema(self) -> Dict[str, Any]: |
| """Convert to JSON Schema format.""" |
| type_mapping = { |
| FieldType.STRING: "string", |
| FieldType.INTEGER: "integer", |
| FieldType.FLOAT: "number", |
| FieldType.BOOLEAN: "boolean", |
| FieldType.DATE: "string", |
| FieldType.DATETIME: "string", |
| FieldType.CURRENCY: "string", |
| FieldType.PERCENTAGE: "string", |
| FieldType.EMAIL: "string", |
| FieldType.PHONE: "string", |
| FieldType.ADDRESS: "string", |
| FieldType.LIST: "array", |
| FieldType.OBJECT: "object", |
| } |
|
|
| schema: Dict[str, Any] = { |
| "type": type_mapping.get(self.field_type, "string"), |
| } |
|
|
| if self.description: |
| schema["description"] = self.description |
|
|
| if self.pattern: |
| schema["pattern"] = self.pattern |
|
|
| if self.field_type == FieldType.DATE: |
| schema["format"] = "date" |
| elif self.field_type == FieldType.DATETIME: |
| schema["format"] = "date-time" |
| elif self.field_type == FieldType.EMAIL: |
| schema["format"] = "email" |
|
|
| if self.min_value is not None: |
| schema["minimum"] = self.min_value |
| if self.max_value is not None: |
| schema["maximum"] = self.max_value |
| if self.min_length is not None: |
| schema["minLength"] = self.min_length |
| if self.max_length is not None: |
| schema["maxLength"] = self.max_length |
| if self.allowed_values: |
| schema["enum"] = self.allowed_values |
|
|
| if self.field_type == FieldType.LIST and self.nested_schema: |
| schema["items"] = self.nested_schema.to_json_schema() |
| elif self.field_type == FieldType.OBJECT and self.nested_schema: |
| schema.update(self.nested_schema.to_json_schema()) |
|
|
| return schema |
|
|
|
|
| @dataclass |
| class ExtractionSchema: |
| """ |
| Schema defining fields to extract from a document. |
| |
| Can be nested for complex document structures. |
| """ |
|
|
| name: str |
| description: str = "" |
| fields: List[FieldSpec] = dataclass_field(default_factory=list) |
|
|
| |
| allow_partial: bool = True |
| abstain_on_low_confidence: bool = True |
| min_overall_confidence: float = 0.5 |
|
|
| def add_field(self, field: FieldSpec) -> "ExtractionSchema": |
| """Add a field to the schema.""" |
| self.fields.append(field) |
| return self |
|
|
| def add_string_field( |
| self, |
| name: str, |
| description: str = "", |
| required: bool = True, |
| **kwargs |
| ) -> "ExtractionSchema": |
| """Add a string field.""" |
| field = FieldSpec( |
| name=name, |
| field_type=FieldType.STRING, |
| description=description, |
| required=required, |
| **kwargs |
| ) |
| return self.add_field(field) |
|
|
| def add_number_field( |
| self, |
| name: str, |
| description: str = "", |
| required: bool = True, |
| is_integer: bool = False, |
| **kwargs |
| ) -> "ExtractionSchema": |
| """Add a number field.""" |
| field = FieldSpec( |
| name=name, |
| field_type=FieldType.INTEGER if is_integer else FieldType.FLOAT, |
| description=description, |
| required=required, |
| **kwargs |
| ) |
| return self.add_field(field) |
|
|
| def add_date_field( |
| self, |
| name: str, |
| description: str = "", |
| required: bool = True, |
| **kwargs |
| ) -> "ExtractionSchema": |
| """Add a date field.""" |
| field = FieldSpec( |
| name=name, |
| field_type=FieldType.DATE, |
| description=description, |
| required=required, |
| **kwargs |
| ) |
| return self.add_field(field) |
|
|
| def add_currency_field( |
| self, |
| name: str, |
| description: str = "", |
| required: bool = True, |
| **kwargs |
| ) -> "ExtractionSchema": |
| """Add a currency field.""" |
| field = FieldSpec( |
| name=name, |
| field_type=FieldType.CURRENCY, |
| description=description, |
| required=required, |
| **kwargs |
| ) |
| return self.add_field(field) |
|
|
| def get_field(self, name: str) -> Optional[FieldSpec]: |
| """Get a field by name.""" |
| for field in self.fields: |
| if field.name == name: |
| return field |
| return None |
|
|
| def get_required_fields(self) -> List[FieldSpec]: |
| """Get all required fields.""" |
| return [f for f in self.fields if f.required] |
|
|
| def get_optional_fields(self) -> List[FieldSpec]: |
| """Get all optional fields.""" |
| return [f for f in self.fields if not f.required] |
|
|
| def to_json_schema(self) -> Dict[str, Any]: |
| """Convert to JSON Schema format.""" |
| properties = {} |
| required = [] |
|
|
| for field in self.fields: |
| properties[field.name] = field.to_json_schema() |
| if field.required: |
| required.append(field.name) |
|
|
| schema = { |
| "type": "object", |
| "properties": properties, |
| } |
|
|
| if required: |
| schema["required"] = required |
|
|
| if self.description: |
| schema["description"] = self.description |
|
|
| return schema |
|
|
| def to_pydantic_model(self) -> Type[BaseModel]: |
| """Generate a Pydantic model from this schema.""" |
| field_definitions = {} |
|
|
| for field in self.fields: |
| python_type = self._get_python_type(field.field_type) |
| default = ... if field.required else field.default |
|
|
| field_definitions[field.name] = ( |
| python_type, |
| Field(default=default, description=field.description) |
| ) |
|
|
| return create_model( |
| self.name, |
| **field_definitions |
| ) |
|
|
| def _get_python_type(self, field_type: FieldType) -> type: |
| """Get Python type for field type.""" |
| type_mapping = { |
| FieldType.STRING: str, |
| FieldType.INTEGER: int, |
| FieldType.FLOAT: float, |
| FieldType.BOOLEAN: bool, |
| FieldType.DATE: str, |
| FieldType.DATETIME: str, |
| FieldType.CURRENCY: str, |
| FieldType.PERCENTAGE: str, |
| FieldType.EMAIL: str, |
| FieldType.PHONE: str, |
| FieldType.ADDRESS: str, |
| FieldType.LIST: list, |
| FieldType.OBJECT: dict, |
| } |
| return type_mapping.get(field_type, str) |
|
|
| @classmethod |
| def from_json_schema(cls, schema: Dict[str, Any], name: str = "Schema") -> "ExtractionSchema": |
| """Create from JSON Schema.""" |
| extraction_schema = cls( |
| name=name, |
| description=schema.get("description", ""), |
| ) |
|
|
| properties = schema.get("properties", {}) |
| required = set(schema.get("required", [])) |
|
|
| for field_name, field_schema in properties.items(): |
| field_type = cls._json_type_to_field_type(field_schema) |
|
|
| field = FieldSpec( |
| name=field_name, |
| field_type=field_type, |
| description=field_schema.get("description", ""), |
| required=field_name in required, |
| pattern=field_schema.get("pattern"), |
| min_value=field_schema.get("minimum"), |
| max_value=field_schema.get("maximum"), |
| min_length=field_schema.get("minLength"), |
| max_length=field_schema.get("maxLength"), |
| allowed_values=field_schema.get("enum"), |
| ) |
|
|
| extraction_schema.add_field(field) |
|
|
| return extraction_schema |
|
|
| @staticmethod |
| def _json_type_to_field_type(field_schema: Dict[str, Any]) -> FieldType: |
| """Convert JSON Schema type to FieldType.""" |
| json_type = field_schema.get("type", "string") |
| format_ = field_schema.get("format", "") |
|
|
| if json_type == "integer": |
| return FieldType.INTEGER |
| elif json_type == "number": |
| return FieldType.FLOAT |
| elif json_type == "boolean": |
| return FieldType.BOOLEAN |
| elif json_type == "array": |
| return FieldType.LIST |
| elif json_type == "object": |
| return FieldType.OBJECT |
| elif format_ == "date": |
| return FieldType.DATE |
| elif format_ == "date-time": |
| return FieldType.DATETIME |
| elif format_ == "email": |
| return FieldType.EMAIL |
| else: |
| return FieldType.STRING |
|
|
|
|
| |
|
|
| def create_invoice_schema() -> ExtractionSchema: |
| """Create schema for invoice extraction.""" |
| schema = ExtractionSchema( |
| name="Invoice", |
| description="Invoice document extraction schema" |
| ) |
|
|
| schema.add_string_field("invoice_number", "Invoice number or ID", required=True) |
| schema.add_date_field("invoice_date", "Date of invoice") |
| schema.add_date_field("due_date", "Payment due date", required=False) |
| schema.add_string_field("vendor_name", "Name of vendor/seller") |
| schema.add_string_field("vendor_address", "Address of vendor", required=False) |
| schema.add_string_field("customer_name", "Name of customer/buyer", required=False) |
| schema.add_string_field("customer_address", "Address of customer", required=False) |
| schema.add_currency_field("subtotal", "Subtotal before tax", required=False) |
| schema.add_currency_field("tax_amount", "Tax amount", required=False) |
| schema.add_currency_field("total_amount", "Total amount due", required=True) |
| schema.add_string_field("currency", "Currency code (USD, EUR, etc.)", required=False) |
| schema.add_string_field("payment_terms", "Payment terms", required=False) |
|
|
| return schema |
|
|
|
|
| def create_receipt_schema() -> ExtractionSchema: |
| """Create schema for receipt extraction.""" |
| schema = ExtractionSchema( |
| name="Receipt", |
| description="Receipt document extraction schema" |
| ) |
|
|
| schema.add_string_field("merchant_name", "Name of merchant/store") |
| schema.add_string_field("merchant_address", "Address of merchant", required=False) |
| schema.add_date_field("transaction_date", "Date of transaction") |
| schema.add_string_field("transaction_time", "Time of transaction", required=False) |
| schema.add_currency_field("subtotal", "Subtotal before tax", required=False) |
| schema.add_currency_field("tax_amount", "Tax amount", required=False) |
| schema.add_currency_field("total_amount", "Total amount paid") |
| schema.add_string_field("payment_method", "Method of payment", required=False) |
| schema.add_string_field("last_four_digits", "Last 4 digits of card", required=False) |
|
|
| return schema |
|
|
|
|
| def create_contract_schema() -> ExtractionSchema: |
| """Create schema for contract extraction.""" |
| schema = ExtractionSchema( |
| name="Contract", |
| description="Contract document extraction schema" |
| ) |
|
|
| schema.add_string_field("contract_title", "Title of the contract", required=False) |
| schema.add_date_field("effective_date", "Date contract becomes effective") |
| schema.add_date_field("expiration_date", "Date contract expires", required=False) |
| schema.add_string_field("party_a_name", "Name of first party") |
| schema.add_string_field("party_b_name", "Name of second party") |
| schema.add_currency_field("contract_value", "Total contract value", required=False) |
| schema.add_string_field("governing_law", "Governing law/jurisdiction", required=False) |
| schema.add_string_field("termination_clause", "Summary of termination terms", required=False) |
|
|
| return schema |
|
|