| """ |
| Grounded Answer Generator |
| |
| Generates answers from retrieved context with citations. |
| Uses local LLMs (Ollama) or cloud APIs. |
| """ |
|
|
| from typing import List, Optional, Dict, Any, Tuple |
| from pydantic import BaseModel, Field |
| from loguru import logger |
| import json |
| import re |
|
|
| from .retriever import RetrievedChunk, DocumentRetriever, get_document_retriever |
|
|
| try: |
| import httpx |
| HTTPX_AVAILABLE = True |
| except ImportError: |
| HTTPX_AVAILABLE = False |
|
|
|
|
| class GeneratorConfig(BaseModel): |
| """Configuration for grounded generator.""" |
| |
| llm_provider: str = Field( |
| default="ollama", |
| description="LLM provider: ollama, openai" |
| ) |
| ollama_base_url: str = Field( |
| default="http://localhost:11434", |
| description="Ollama API base URL" |
| ) |
| ollama_model: str = Field( |
| default="llama3.2:3b", |
| description="Ollama model for generation" |
| ) |
|
|
| |
| openai_model: str = Field( |
| default="gpt-4o-mini", |
| description="OpenAI model for generation" |
| ) |
| openai_api_key: Optional[str] = Field( |
| default=None, |
| description="OpenAI API key" |
| ) |
|
|
| |
| temperature: float = Field(default=0.1, ge=0.0, le=2.0) |
| max_tokens: int = Field(default=1024, ge=1) |
| timeout: float = Field(default=120.0, ge=1.0) |
|
|
| |
| require_citations: bool = Field( |
| default=True, |
| description="Require citations in answers" |
| ) |
| citation_format: str = Field( |
| default="[{index}]", |
| description="Citation format template" |
| ) |
| abstain_on_low_confidence: bool = Field( |
| default=True, |
| description="Abstain when confidence is low" |
| ) |
| confidence_threshold: float = Field( |
| default=0.6, |
| ge=0.0, |
| le=1.0, |
| description="Minimum confidence threshold" |
| ) |
|
|
|
|
| class Citation(BaseModel): |
| """A citation reference.""" |
| index: int |
| chunk_id: str |
| page: Optional[int] = None |
| text_snippet: str |
| confidence: float |
|
|
|
|
| class GeneratedAnswer(BaseModel): |
| """Generated answer with citations.""" |
| answer: str |
| citations: List[Citation] |
| confidence: float |
| abstained: bool = False |
| abstain_reason: Optional[str] = None |
|
|
| |
| num_chunks_used: int |
| query: str |
|
|
|
|
| class GroundedGenerator: |
| """ |
| Generates grounded answers with citations. |
| |
| Features: |
| - Uses retrieved chunks as context |
| - Generates answers with inline citations |
| - Confidence-based abstention |
| - Support for Ollama and OpenAI |
| """ |
|
|
| SYSTEM_PROMPT = """You are a precise document question-answering assistant. |
| Your task is to answer questions based ONLY on the provided context from documents. |
| |
| Rules: |
| 1. Only use information from the provided context |
| 2. Cite your sources using [N] notation where N is the chunk number |
| 3. If the context doesn't contain enough information, say "I cannot answer this based on the available context" |
| 4. Be precise and concise |
| 5. If information is uncertain or partial, indicate this clearly |
| |
| Context format: Each chunk is numbered [1], [2], etc. with page numbers and content. |
| """ |
|
|
| def __init__( |
| self, |
| config: Optional[GeneratorConfig] = None, |
| retriever: Optional[DocumentRetriever] = None, |
| ): |
| """ |
| Initialize generator. |
| |
| Args: |
| config: Generator configuration |
| retriever: Document retriever instance |
| """ |
| self.config = config or GeneratorConfig() |
| self._retriever = retriever |
|
|
| @property |
| def retriever(self) -> DocumentRetriever: |
| """Get retriever (lazy initialization).""" |
| if self._retriever is None: |
| self._retriever = get_document_retriever() |
| return self._retriever |
|
|
| def generate( |
| self, |
| query: str, |
| chunks: List[RetrievedChunk], |
| additional_context: Optional[str] = None, |
| ) -> GeneratedAnswer: |
| """ |
| Generate an answer from retrieved chunks. |
| |
| Args: |
| query: User question |
| chunks: Retrieved context chunks |
| additional_context: Optional additional context |
| |
| Returns: |
| GeneratedAnswer with citations |
| """ |
| |
| if self.config.abstain_on_low_confidence and chunks: |
| avg_confidence = sum(c.similarity for c in chunks) / len(chunks) |
| if avg_confidence < self.config.confidence_threshold: |
| return GeneratedAnswer( |
| answer="I cannot provide a confident answer based on the available context.", |
| citations=[], |
| confidence=avg_confidence, |
| abstained=True, |
| abstain_reason=f"Average confidence ({avg_confidence:.2f}) below threshold ({self.config.confidence_threshold})", |
| num_chunks_used=len(chunks), |
| query=query, |
| ) |
|
|
| |
| context = self._build_context(chunks, additional_context) |
|
|
| |
| prompt = self._build_prompt(query, context) |
|
|
| |
| if self.config.llm_provider == "ollama": |
| raw_answer = self._generate_ollama(prompt) |
| elif self.config.llm_provider == "openai": |
| raw_answer = self._generate_openai(prompt) |
| else: |
| raise ValueError(f"Unknown LLM provider: {self.config.llm_provider}") |
|
|
| |
| citations = self._extract_citations(raw_answer, chunks) |
|
|
| |
| if citations: |
| confidence = sum(c.confidence for c in citations) / len(citations) |
| elif chunks: |
| confidence = sum(c.similarity for c in chunks) / len(chunks) |
| else: |
| confidence = 0.0 |
|
|
| return GeneratedAnswer( |
| answer=raw_answer, |
| citations=citations, |
| confidence=confidence, |
| abstained=False, |
| num_chunks_used=len(chunks), |
| query=query, |
| ) |
|
|
| def answer_question( |
| self, |
| query: str, |
| top_k: int = 5, |
| filters: Optional[Dict[str, Any]] = None, |
| ) -> GeneratedAnswer: |
| """ |
| Retrieve context and generate answer. |
| |
| Args: |
| query: User question |
| top_k: Number of chunks to retrieve |
| filters: Optional retrieval filters |
| |
| Returns: |
| GeneratedAnswer with citations |
| """ |
| |
| chunks = self.retriever.retrieve(query, top_k=top_k, filters=filters) |
|
|
| if not chunks: |
| return GeneratedAnswer( |
| answer="I could not find any relevant information in the documents to answer this question.", |
| citations=[], |
| confidence=0.0, |
| abstained=True, |
| abstain_reason="No relevant chunks found", |
| num_chunks_used=0, |
| query=query, |
| ) |
|
|
| return self.generate(query, chunks) |
|
|
| def _build_context( |
| self, |
| chunks: List[RetrievedChunk], |
| additional_context: Optional[str] = None, |
| ) -> str: |
| """Build context string from chunks.""" |
| parts = [] |
|
|
| if additional_context: |
| parts.append(f"Additional context:\n{additional_context}\n") |
|
|
| parts.append("Document excerpts:") |
|
|
| for i, chunk in enumerate(chunks, 1): |
| header = f"\n[{i}]" |
| if chunk.page is not None: |
| header += f" (Page {chunk.page + 1}" |
| if chunk.chunk_type: |
| header += f", {chunk.chunk_type}" |
| header += ")" |
|
|
| parts.append(f"{header}:") |
| parts.append(chunk.text) |
|
|
| return "\n".join(parts) |
|
|
| def _build_prompt(self, query: str, context: str) -> str: |
| """Build the full prompt.""" |
| return f"""Based on the following context, answer the question. |
| |
| {context} |
| |
| Question: {query} |
| |
| Answer (cite sources using [N] notation):""" |
|
|
| def _generate_ollama(self, prompt: str) -> str: |
| """Generate using Ollama.""" |
| if not HTTPX_AVAILABLE: |
| raise ImportError("httpx required for Ollama") |
|
|
| with httpx.Client(timeout=self.config.timeout) as client: |
| response = client.post( |
| f"{self.config.ollama_base_url}/api/generate", |
| json={ |
| "model": self.config.ollama_model, |
| "prompt": prompt, |
| "system": self.SYSTEM_PROMPT, |
| "stream": False, |
| "options": { |
| "temperature": self.config.temperature, |
| "num_predict": self.config.max_tokens, |
| }, |
| }, |
| ) |
| response.raise_for_status() |
| result = response.json() |
|
|
| return result.get("response", "").strip() |
|
|
| def _generate_openai(self, prompt: str) -> str: |
| """Generate using OpenAI.""" |
| try: |
| import openai |
| except ImportError: |
| raise ImportError("openai package required") |
|
|
| client = openai.OpenAI(api_key=self.config.openai_api_key) |
|
|
| response = client.chat.completions.create( |
| model=self.config.openai_model, |
| messages=[ |
| {"role": "system", "content": self.SYSTEM_PROMPT}, |
| {"role": "user", "content": prompt}, |
| ], |
| temperature=self.config.temperature, |
| max_tokens=self.config.max_tokens, |
| ) |
|
|
| return response.choices[0].message.content.strip() |
|
|
| def _extract_citations( |
| self, |
| answer: str, |
| chunks: List[RetrievedChunk], |
| ) -> List[Citation]: |
| """Extract citations from answer text.""" |
| citations = [] |
| seen_indices = set() |
|
|
| |
| pattern = r'\[(\d+)\]' |
| matches = re.findall(pattern, answer) |
|
|
| for match in matches: |
| index = int(match) |
| if index in seen_indices: |
| continue |
| if index < 1 or index > len(chunks): |
| continue |
|
|
| seen_indices.add(index) |
| chunk = chunks[index - 1] |
|
|
| citation = Citation( |
| index=index, |
| chunk_id=chunk.chunk_id, |
| page=chunk.page, |
| text_snippet=chunk.text[:150] + ("..." if len(chunk.text) > 150 else ""), |
| confidence=chunk.similarity, |
| ) |
| citations.append(citation) |
|
|
| return sorted(citations, key=lambda c: c.index) |
|
|
|
|
| |
| _grounded_generator: Optional[GroundedGenerator] = None |
|
|
|
|
| def get_grounded_generator( |
| config: Optional[GeneratorConfig] = None, |
| retriever: Optional[DocumentRetriever] = None, |
| ) -> GroundedGenerator: |
| """ |
| Get or create singleton grounded generator. |
| |
| Args: |
| config: Generator configuration |
| retriever: Optional retriever instance |
| |
| Returns: |
| GroundedGenerator instance |
| """ |
| global _grounded_generator |
|
|
| if _grounded_generator is None: |
| _grounded_generator = GroundedGenerator( |
| config=config, |
| retriever=retriever, |
| ) |
|
|
| return _grounded_generator |
|
|
|
|
| def reset_grounded_generator(): |
| """Reset the global generator instance.""" |
| global _grounded_generator |
| _grounded_generator = None |
|
|