| """ |
| DocumentAgent for SPARKNET |
| |
| A ReAct-style agent for document intelligence tasks: |
| - Document parsing and extraction |
| - Field extraction with grounding |
| - Table and chart analysis |
| - Document classification |
| - Question answering over documents |
| """ |
|
|
| from typing import List, Dict, Any, Optional, Tuple |
| from dataclasses import dataclass |
| from enum import Enum |
| import json |
| import time |
| from loguru import logger |
|
|
| from .base_agent import BaseAgent, Task, Message |
| from ..llm.langchain_ollama_client import LangChainOllamaClient |
| from ..document.schemas.core import ( |
| ProcessedDocument, |
| DocumentChunk, |
| EvidenceRef, |
| ExtractionResult, |
| ) |
| from ..document.schemas.extraction import ExtractionSchema, ExtractedField |
| from ..document.schemas.classification import DocumentClassification, DocumentType |
|
|
|
|
| class AgentAction(str, Enum): |
| """Actions the DocumentAgent can take.""" |
| THINK = "think" |
| USE_TOOL = "use_tool" |
| ANSWER = "answer" |
| ABSTAIN = "abstain" |
|
|
|
|
| @dataclass |
| class ThoughtAction: |
| """A thought-action pair in the ReAct loop.""" |
| thought: str |
| action: AgentAction |
| tool_name: Optional[str] = None |
| tool_args: Optional[Dict[str, Any]] = None |
| observation: Optional[str] = None |
| evidence: Optional[List[EvidenceRef]] = None |
|
|
|
|
| @dataclass |
| class AgentTrace: |
| """Full trace of agent execution for inspection.""" |
| task: str |
| steps: List[ThoughtAction] |
| final_answer: Optional[Any] = None |
| confidence: float = 0.0 |
| total_time_ms: float = 0.0 |
| success: bool = True |
| error: Optional[str] = None |
|
|
|
|
| class DocumentAgent: |
| """ |
| ReAct-style agent for document intelligence tasks. |
| |
| Implements the Think -> Tool -> Observe -> Refine loop |
| with inspectable traces and grounded outputs. |
| """ |
|
|
| |
| SYSTEM_PROMPT = """You are a document intelligence agent that analyzes documents |
| and extracts information with evidence. |
| |
| You operate in a Think-Act-Observe loop: |
| 1. THINK: Analyze what you need to do and what information you have |
| 2. ACT: Choose a tool to use or provide an answer |
| 3. OBSERVE: Review the tool output and update your understanding |
| |
| Available tools: |
| {tool_descriptions} |
| |
| CRITICAL RULES: |
| - Every extraction MUST include evidence (page, bbox, text snippet) |
| - If you cannot find evidence for a value, ABSTAIN rather than guess |
| - Always cite the source of information with page numbers |
| - For tables, analyze structure before extracting data |
| - For charts, describe what you see before extracting values |
| |
| Output format for each step: |
| THOUGHT: <your reasoning> |
| ACTION: <tool_name or ANSWER or ABSTAIN> |
| ACTION_INPUT: <JSON arguments for tool, or final answer> |
| """ |
|
|
| |
| TOOLS = { |
| "extract_text": { |
| "description": "Extract text from specific pages or regions", |
| "args": ["page_numbers", "region_bbox"], |
| }, |
| "analyze_table": { |
| "description": "Analyze and extract structured data from a table region", |
| "args": ["page", "bbox", "expected_columns"], |
| }, |
| "analyze_chart": { |
| "description": "Analyze a chart/graph and extract insights", |
| "args": ["page", "bbox"], |
| }, |
| "extract_fields": { |
| "description": "Extract specific fields using a schema", |
| "args": ["schema", "context_chunks"], |
| }, |
| "classify_document": { |
| "description": "Classify the document type", |
| "args": ["first_page_chunks"], |
| }, |
| "search_text": { |
| "description": "Search for text patterns in the document", |
| "args": ["query", "page_range"], |
| }, |
| } |
|
|
| def __init__( |
| self, |
| llm_client: LangChainOllamaClient, |
| memory_agent: Optional[Any] = None, |
| max_iterations: int = 10, |
| temperature: float = 0.3, |
| ): |
| """ |
| Initialize DocumentAgent. |
| |
| Args: |
| llm_client: LangChain Ollama client |
| memory_agent: Optional memory agent for context retrieval |
| max_iterations: Maximum ReAct iterations |
| temperature: LLM temperature for reasoning |
| """ |
| self.llm_client = llm_client |
| self.memory_agent = memory_agent |
| self.max_iterations = max_iterations |
| self.temperature = temperature |
|
|
| |
| self._current_document: Optional[ProcessedDocument] = None |
| self._page_images: Dict[int, Any] = {} |
|
|
| logger.info(f"Initialized DocumentAgent (max_iterations={max_iterations})") |
|
|
| def set_document( |
| self, |
| document: ProcessedDocument, |
| page_images: Optional[Dict[int, Any]] = None, |
| ): |
| """ |
| Set the current document context. |
| |
| Args: |
| document: Processed document |
| page_images: Optional dict of page number -> image array |
| """ |
| self._current_document = document |
| self._page_images = page_images or {} |
| logger.info(f"Set document context: {document.metadata.document_id}") |
|
|
| async def run( |
| self, |
| task_description: str, |
| extraction_schema: Optional[ExtractionSchema] = None, |
| ) -> Tuple[Any, AgentTrace]: |
| """ |
| Run the agent on a task. |
| |
| Args: |
| task_description: Natural language task description |
| extraction_schema: Optional schema for structured extraction |
| |
| Returns: |
| Tuple of (result, trace) |
| """ |
| start_time = time.time() |
|
|
| if not self._current_document: |
| raise ValueError("No document set. Call set_document() first.") |
|
|
| trace = AgentTrace(task=task_description, steps=[]) |
|
|
| try: |
| |
| context = self._build_context(extraction_schema) |
|
|
| |
| result = None |
| for iteration in range(self.max_iterations): |
| logger.debug(f"ReAct iteration {iteration + 1}") |
|
|
| |
| step = await self._generate_step(task_description, context, trace.steps) |
| trace.steps.append(step) |
|
|
| |
| if step.action == AgentAction.ANSWER: |
| result = self._parse_answer(step.tool_args) |
| trace.final_answer = result |
| trace.confidence = self._calculate_confidence(trace.steps) |
| break |
|
|
| elif step.action == AgentAction.ABSTAIN: |
| trace.final_answer = { |
| "abstained": True, |
| "reason": step.thought, |
| } |
| trace.confidence = 0.0 |
| break |
|
|
| elif step.action == AgentAction.USE_TOOL: |
| |
| observation, evidence = await self._execute_tool( |
| step.tool_name, step.tool_args |
| ) |
| step.observation = observation |
| step.evidence = evidence |
|
|
| |
| context += f"\n\nObservation from {step.tool_name}:\n{observation}" |
|
|
| trace.success = True |
|
|
| except Exception as e: |
| logger.error(f"Agent execution failed: {e}") |
| trace.success = False |
| trace.error = str(e) |
|
|
| trace.total_time_ms = (time.time() - start_time) * 1000 |
| return trace.final_answer, trace |
|
|
| async def extract_fields( |
| self, |
| schema: ExtractionSchema, |
| ) -> ExtractionResult: |
| """ |
| Extract fields from the document using a schema. |
| |
| Args: |
| schema: Extraction schema defining fields |
| |
| Returns: |
| ExtractionResult with extracted data and evidence |
| """ |
| task = f"Extract the following fields from this document: {', '.join(f.name for f in schema.fields)}" |
| result, trace = await self.run(task, schema) |
|
|
| |
| data = {} |
| evidence = [] |
| warnings = [] |
| abstained = [] |
|
|
| if isinstance(result, dict): |
| data = result.get("data", result) |
|
|
| |
| for step in trace.steps: |
| if step.evidence: |
| evidence.extend(step.evidence) |
|
|
| |
| for field in schema.fields: |
| if field.name not in data and field.required: |
| abstained.append(field.name) |
| warnings.append( |
| f"Required field '{field.name}' not found with sufficient confidence" |
| ) |
|
|
| return ExtractionResult( |
| data=data, |
| evidence=evidence, |
| warnings=warnings, |
| confidence=trace.confidence, |
| abstained_fields=abstained, |
| ) |
|
|
| async def classify(self) -> DocumentClassification: |
| """ |
| Classify the document type. |
| |
| Returns: |
| DocumentClassification with type and confidence |
| """ |
| task = "Classify this document into one of the standard document types (contract, invoice, patent, research_paper, report, letter, form, etc.)" |
| result, trace = await self.run(task) |
|
|
| |
| doc_type = DocumentType.UNKNOWN |
| confidence = 0.0 |
|
|
| if isinstance(result, dict): |
| type_str = result.get("document_type", "unknown") |
| try: |
| doc_type = DocumentType(type_str.lower()) |
| except ValueError: |
| doc_type = DocumentType.OTHER |
|
|
| confidence = result.get("confidence", trace.confidence) |
|
|
| return DocumentClassification( |
| document_id=self._current_document.metadata.document_id, |
| primary_type=doc_type, |
| primary_confidence=confidence, |
| evidence=[e for step in trace.steps if step.evidence for e in step.evidence], |
| method="llm", |
| is_confident=confidence >= 0.7, |
| ) |
|
|
| async def answer_question(self, question: str) -> Tuple[str, List[EvidenceRef]]: |
| """ |
| Answer a question about the document. |
| |
| Args: |
| question: Natural language question |
| |
| Returns: |
| Tuple of (answer, evidence) |
| """ |
| task = f"Answer this question about the document: {question}" |
| result, trace = await self.run(task) |
|
|
| answer = "" |
| evidence = [] |
|
|
| if isinstance(result, dict): |
| answer = result.get("answer", str(result)) |
| elif isinstance(result, str): |
| answer = result |
|
|
| |
| for step in trace.steps: |
| if step.evidence: |
| evidence.extend(step.evidence) |
|
|
| return answer, evidence |
|
|
| def _build_context(self, schema: Optional[ExtractionSchema] = None) -> str: |
| """Build initial context from document.""" |
| doc = self._current_document |
| context_parts = [ |
| f"Document: {doc.metadata.filename}", |
| f"Type: {doc.metadata.file_type}", |
| f"Pages: {doc.metadata.num_pages}", |
| f"Chunks: {len(doc.chunks)}", |
| "", |
| "Document content summary:", |
| ] |
|
|
| |
| for chunk in doc.chunks[:10]: |
| context_parts.append( |
| f"[Page {chunk.page + 1}, {chunk.chunk_type.value}]: {chunk.text[:200]}..." |
| ) |
|
|
| if schema: |
| context_parts.append("") |
| context_parts.append("Extraction schema:") |
| for field in schema.fields: |
| req = "required" if field.required else "optional" |
| context_parts.append(f"- {field.name} ({field.type.value}, {req}): {field.description}") |
|
|
| return "\n".join(context_parts) |
|
|
| async def _generate_step( |
| self, |
| task: str, |
| context: str, |
| previous_steps: List[ThoughtAction], |
| ) -> ThoughtAction: |
| """Generate the next thought-action step.""" |
| |
| tool_descriptions = "\n".join( |
| f"- {name}: {info['description']}" |
| for name, info in self.TOOLS.items() |
| ) |
|
|
| system_prompt = self.SYSTEM_PROMPT.format(tool_descriptions=tool_descriptions) |
|
|
| messages = [{"role": "system", "content": system_prompt}] |
|
|
| |
| user_content = f"TASK: {task}\n\nCONTEXT:\n{context}" |
|
|
| |
| if previous_steps: |
| user_content += "\n\nPREVIOUS STEPS:" |
| for i, step in enumerate(previous_steps, 1): |
| user_content += f"\n\nStep {i}:" |
| user_content += f"\nTHOUGHT: {step.thought}" |
| user_content += f"\nACTION: {step.action.value}" |
| if step.tool_name: |
| user_content += f"\nTOOL: {step.tool_name}" |
| if step.observation: |
| user_content += f"\nOBSERVATION: {step.observation[:500]}..." |
|
|
| user_content += "\n\nNow generate your next step:" |
| messages.append({"role": "user", "content": user_content}) |
|
|
| |
| llm = self.llm_client.get_llm(complexity="complex", temperature=self.temperature) |
|
|
| from langchain_core.messages import HumanMessage, SystemMessage |
| lc_messages = [ |
| SystemMessage(content=system_prompt), |
| HumanMessage(content=user_content), |
| ] |
|
|
| response = await llm.ainvoke(lc_messages) |
| response_text = response.content |
|
|
| |
| return self._parse_step(response_text) |
|
|
| def _parse_step(self, response: str) -> ThoughtAction: |
| """Parse LLM response into ThoughtAction.""" |
| thought = "" |
| action = AgentAction.THINK |
| tool_name = None |
| tool_args = None |
|
|
| lines = response.strip().split("\n") |
| current_section = None |
|
|
| for line in lines: |
| line = line.strip() |
|
|
| if line.startswith("THOUGHT:"): |
| current_section = "thought" |
| thought = line[8:].strip() |
| elif line.startswith("ACTION:"): |
| current_section = "action" |
| action_str = line[7:].strip().lower() |
| if action_str == "answer": |
| action = AgentAction.ANSWER |
| elif action_str == "abstain": |
| action = AgentAction.ABSTAIN |
| elif action_str in self.TOOLS: |
| action = AgentAction.USE_TOOL |
| tool_name = action_str |
| else: |
| action = AgentAction.USE_TOOL |
| tool_name = action_str |
| elif line.startswith("ACTION_INPUT:"): |
| current_section = "input" |
| input_str = line[13:].strip() |
| try: |
| tool_args = json.loads(input_str) |
| except json.JSONDecodeError: |
| tool_args = {"raw": input_str} |
| elif current_section == "thought": |
| thought += " " + line |
| elif current_section == "input": |
| try: |
| tool_args = json.loads(line) |
| except: |
| pass |
|
|
| return ThoughtAction( |
| thought=thought, |
| action=action, |
| tool_name=tool_name, |
| tool_args=tool_args, |
| ) |
|
|
| async def _execute_tool( |
| self, |
| tool_name: str, |
| tool_args: Optional[Dict[str, Any]], |
| ) -> Tuple[str, List[EvidenceRef]]: |
| """Execute a tool and return observation.""" |
| if not tool_args: |
| tool_args = {} |
|
|
| doc = self._current_document |
| evidence = [] |
|
|
| try: |
| if tool_name == "extract_text": |
| return self._tool_extract_text(tool_args) |
|
|
| elif tool_name == "analyze_table": |
| return await self._tool_analyze_table(tool_args) |
|
|
| elif tool_name == "analyze_chart": |
| return await self._tool_analyze_chart(tool_args) |
|
|
| elif tool_name == "extract_fields": |
| return await self._tool_extract_fields(tool_args) |
|
|
| elif tool_name == "classify_document": |
| return self._tool_classify_document(tool_args) |
|
|
| elif tool_name == "search_text": |
| return self._tool_search_text(tool_args) |
|
|
| else: |
| return f"Unknown tool: {tool_name}", [] |
|
|
| except Exception as e: |
| logger.error(f"Tool {tool_name} failed: {e}") |
| return f"Error executing {tool_name}: {e}", [] |
|
|
| def _tool_extract_text(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]: |
| """Extract text from pages or regions.""" |
| doc = self._current_document |
| page_numbers = args.get("page_numbers", list(range(doc.metadata.num_pages))) |
|
|
| if isinstance(page_numbers, int): |
| page_numbers = [page_numbers] |
|
|
| texts = [] |
| evidence = [] |
|
|
| for page in page_numbers: |
| page_chunks = doc.get_page_chunks(page) |
| for chunk in page_chunks: |
| texts.append(f"[Page {page + 1}]: {chunk.text}") |
| evidence.append(EvidenceRef( |
| chunk_id=chunk.chunk_id, |
| page=chunk.page, |
| bbox=chunk.bbox, |
| source_type="text", |
| snippet=chunk.text[:100], |
| confidence=chunk.confidence, |
| )) |
|
|
| return "\n".join(texts[:20]), evidence[:10] |
|
|
| async def _tool_analyze_table(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]: |
| """Analyze a table region.""" |
| page = args.get("page", 0) |
| doc = self._current_document |
|
|
| |
| table_chunks = [c for c in doc.chunks if c.chunk_type.value == "table" and c.page == page] |
|
|
| if not table_chunks: |
| return "No table found on this page", [] |
|
|
| |
| table_text = table_chunks[0].text |
| llm = self.llm_client.get_llm(complexity="standard") |
|
|
| from langchain_core.messages import HumanMessage |
| prompt = f"Analyze this table and extract structured data as JSON:\n\n{table_text}" |
| response = await llm.ainvoke([HumanMessage(content=prompt)]) |
|
|
| evidence = [EvidenceRef( |
| chunk_id=table_chunks[0].chunk_id, |
| page=page, |
| bbox=table_chunks[0].bbox, |
| source_type="table", |
| snippet=table_text[:200], |
| confidence=table_chunks[0].confidence, |
| )] |
|
|
| return response.content, evidence |
|
|
| async def _tool_analyze_chart(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]: |
| """Analyze a chart region.""" |
| page = args.get("page", 0) |
| doc = self._current_document |
|
|
| |
| chart_chunks = [ |
| c for c in doc.chunks |
| if c.chunk_type.value in ("chart", "figure") and c.page == page |
| ] |
|
|
| if not chart_chunks: |
| return "No chart/figure found on this page", [] |
|
|
| |
| if page in self._page_images: |
| |
| pass |
|
|
| return f"Chart found on page {page + 1}: {chart_chunks[0].caption or 'No caption'}", [] |
|
|
| async def _tool_extract_fields(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]: |
| """Extract specific fields.""" |
| schema_dict = args.get("schema", {}) |
| doc = self._current_document |
|
|
| |
| context = "\n".join(c.text for c in doc.chunks[:20]) |
|
|
| |
| llm = self.llm_client.get_llm(complexity="complex") |
|
|
| from langchain_core.messages import HumanMessage, SystemMessage |
| system = "Extract the requested fields from the document. Output JSON with field names as keys." |
| user = f"Fields to extract: {json.dumps(schema_dict)}\n\nDocument content:\n{context}" |
|
|
| response = await llm.ainvoke([ |
| SystemMessage(content=system), |
| HumanMessage(content=user), |
| ]) |
|
|
| return response.content, [] |
|
|
| def _tool_classify_document(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]: |
| """Classify document type based on first page.""" |
| doc = self._current_document |
| first_page_chunks = doc.get_page_chunks(0) |
| text = " ".join(c.text for c in first_page_chunks[:5]) |
|
|
| return f"First page content for classification:\n{text[:500]}", [] |
|
|
| def _tool_search_text(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]: |
| """Search for text in document.""" |
| query = args.get("query", "").lower() |
| doc = self._current_document |
|
|
| matches = [] |
| evidence = [] |
|
|
| for chunk in doc.chunks: |
| if query in chunk.text.lower(): |
| matches.append(f"[Page {chunk.page + 1}]: ...{chunk.text}...") |
| evidence.append(EvidenceRef( |
| chunk_id=chunk.chunk_id, |
| page=chunk.page, |
| bbox=chunk.bbox, |
| source_type="text", |
| snippet=chunk.text[:100], |
| confidence=chunk.confidence, |
| )) |
|
|
| if not matches: |
| return f"No matches found for '{query}'", [] |
|
|
| return f"Found {len(matches)} matches:\n" + "\n".join(matches[:10]), evidence[:10] |
|
|
| def _parse_answer(self, answer_input: Optional[Dict[str, Any]]) -> Any: |
| """Parse the final answer from tool args.""" |
| if not answer_input: |
| return None |
|
|
| if isinstance(answer_input, dict): |
| return answer_input |
|
|
| return {"answer": answer_input} |
|
|
| def _calculate_confidence(self, steps: List[ThoughtAction]) -> float: |
| """Calculate overall confidence from trace.""" |
| if not steps: |
| return 0.0 |
|
|
| |
| all_evidence = [e for s in steps if s.evidence for e in s.evidence] |
| if all_evidence: |
| return sum(e.confidence for e in all_evidence) / len(all_evidence) |
|
|
| return 0.5 |
|
|