""" LangChain Tools for the Clinical Trial Agent. This module defines the tools that the agent can use to interact with the clinical trial data. Tools include: 1. **search_trials**: Semantic search with optional strict filtering. 2. **find_similar_studies**: Finding studies semantically similar to a given text. 3. **get_study_analytics**: Aggregating data for trends and insights (with inline charts). """ import pandas as pd import streamlit as st from typing import Optional from langchain.tools import tool as langchain_tool from llama_index.core.vector_stores import ( MetadataFilter, MetadataFilters, FilterOperator, ) from llama_index.core import Settings from llama_index.core.postprocessor import MetadataReplacementPostProcessor from llama_index.core.postprocessor import SentenceTransformerRerank from llama_index.core.query_engine import SubQuestionQueryEngine from llama_index.core.tools import QueryEngineTool, ToolMetadata from modules.utils import ( load_index, normalize_sponsor, get_sponsor_variations, get_hybrid_retriever, get_reranker, # Import cached reranker ) import re import traceback # --- Tools --- def expand_query(query: str) -> str: """Expands a search query with synonyms using the LLM.""" if not query or len(query.split()) > 10: # Skip expansion for long queries return query # Skip expansion if it looks like an NCT ID if re.search(r"NCT\d+", query, re.IGNORECASE): return query prompt = ( f"You are a helpful medical assistant. " f"Expand the following search query with relevant medical synonyms and acronyms. " f"Return ONLY the expanded query string combined with OR operators. " f"Do not add any explanation.\n\n" f"Query: {query}\n" f"Expanded Query:" ) try: # Use the global Settings.llm if not Settings.llm: # Fallback if not initialized (though load_index does it) from modules.utils import setup_llama_index setup_llama_index() response = Settings.llm.complete(prompt) expanded = response.text.strip() # Clean up if LLM is chatty if "Expanded Query:" in expanded: expanded = expanded.split("Expanded Query:")[-1].strip() if not expanded: print(f"⚠️ Expansion returned empty. Using original query.") return query print(f"✨ Expanded Query: '{query}' -> '{expanded}'") return expanded except Exception as e: print(f"⚠️ Query expansion failed: {e}") return query @langchain_tool("search_trials") def search_trials( query: str = None, status: str = None, phase: str = None, sponsor: str = None, intervention: str = None, year: int = None, ): """ Searches for clinical trials using semantic search with robust filtering. Args: query (str, optional): The natural language search query. status (str, optional): Filter by recruitment status. phase (str, optional): Filter by trial phase. sponsor (str, optional): Filter by sponsor name. intervention (str, optional): Filter by intervention/drug name. year (int, optional): Filter for studies starting on or after this year. Returns: str: A structured list of relevant studies. """ index = load_index() # Constants TOP_K_STRICT = 200 # Reduced from 500 for performance # --- Query Construction --- if not query: parts = [p for p in [sponsor, intervention, phase, status] if p] query = " ".join(parts) if parts else "clinical trial" else: # Inject context for vector search if sponsor and normalize_sponsor(sponsor).lower() not in query.lower(): query = f"{normalize_sponsor(sponsor)} {query}" if intervention and intervention.lower() not in query.lower(): query = f"{intervention} {query}" query = expand_query(query) print(f"🔍 Tool Called: search_trials(query='{query}', sponsor='{sponsor}')") # --- Strategy 1: Strict Pre-Retrieval Filtering (High Precision) --- # Filter by Sponsor/Status/Year at the database level first. pre_filters = [] # NCT ID Match nct_match = re.search(r"NCT\d+", query, re.IGNORECASE) if nct_match: nct_id = nct_match.group(0).upper() pre_filters.append(MetadataFilter(key="nct_id", value=nct_id, operator=FilterOperator.EQ)) if status: pre_filters.append(MetadataFilter(key="status", value=status.upper(), operator=FilterOperator.EQ)) if year: pre_filters.append(MetadataFilter(key="start_year", value=year, operator=FilterOperator.GTE)) # Sponsor Pre-Filter if sponsor: from modules.utils import get_sponsor_variations variations = get_sponsor_variations(sponsor) if variations: print(f"🎯 Applying strict pre-filter for sponsor '{sponsor}' ({len(variations)} variants)") # Use 'sponsor' field which is the Lead Sponsor pre_filters.append(MetadataFilter(key="sponsor", value=variations, operator=FilterOperator.IN)) else: print(f"⚠️ No strict mapping for sponsor '{sponsor}'. Will rely on fuzzy post-filtering.") metadata_filters = MetadataFilters(filters=pre_filters) if pre_filters else None # Post-processors (Reranking) # Use cached reranker reranker = get_reranker(top_n=50) # --- HYBRID SEARCH IMPLEMENTATION --- # Combine Vector + BM25 using get_hybrid_retriever try: retriever = get_hybrid_retriever(index, similarity_top_k=TOP_K_STRICT, filters=metadata_filters) nodes = retriever.retrieve(query) # (QueryFusionRetriever returns nodes, but we want to rerank them) if nodes: from llama_index.core.schema import QueryBundle nodes = reranker.postprocess_nodes(nodes, query_bundle=QueryBundle(query_str=query)) except Exception as e: print(f"⚠️ Hybrid search failed: {e}. Falling back to standard vector search.") traceback.print_exc() query_engine = index.as_query_engine( similarity_top_k=TOP_K_STRICT, filters=metadata_filters, node_postprocessors=[reranker] ) response = query_engine.query(query) nodes = response.source_nodes # --- Strict Metadata Filtering (Post-Fusion) --- # BM25 results might not respect the vector filters, so filter them out. final_nodes = [] for node in nodes: meta = node.metadata keep = True # Re-apply filters to ensure BM25 results are valid if status and meta.get("status", "").upper() != status.upper(): keep = False if year: try: if int(meta.get("start_year", 0)) < year: keep = False except: pass if sponsor: # Strict logic for sponsor in pre-filters is ignored by BM25. # Check if the sponsor matches one of the variations OR fuzzy match # If strict variations exist, enforce them. variations = get_sponsor_variations(sponsor) node_sponsor = meta.get("sponsor", "") # Fallback to org if sponsor is missing (legacy data) if not node_sponsor: node_sponsor = meta.get("org", "") if variations: if node_sponsor not in variations: keep = False else: # Fuzzy fallback if normalize_sponsor(sponsor).lower() not in normalize_sponsor(node_sponsor).lower(): keep = False if keep: final_nodes.append(node) nodes = final_nodes # --- Strict Keyword Filtering --- # BM25 handles keyword relevance naturally, so rely on the Hybrid Search + Reranker # rather than applying an aggressive substring check here. # Update response object structure to match expected format if we used retriever class MockResponse: def __init__(self, nodes): self.source_nodes = nodes response = MockResponse(nodes) # --- Strategy 2: Hybrid Search (Fallback) --- # Hybrid Search is enabled by default. # Strict filters are handled in post-processing above. # --- Formatting Output --- if not response.source_nodes: return "No matching studies found. Try broadening your search terms or filters." # Filter by Relevance Score for display MIN_SCORE = 1.5 relevant_nodes = [node for node in response.source_nodes if node.score > MIN_SCORE] # If strict filtering removes too much, show at least top 3 to be helpful if len(relevant_nodes) < 3 and len(response.source_nodes) > 0: relevant_nodes = response.source_nodes[:3] display_limit = 20 display_nodes = relevant_nodes[:display_limit] results = [] for node in display_nodes: meta = node.metadata entry = ( f"**{meta.get('title', 'Untitled')}**\n" f" - ID: {meta.get('nct_id')}\n" f" - Phase: {meta.get('phase', 'N/A')}\n" f" - Status: {meta.get('status', 'N/A')}\n" f" - Sponsor: {meta.get('sponsor', meta.get('org', 'Unknown'))}\n" f" - Relevance: {node.score:.2f}" ) results.append(entry) return f"Found {len(results)} relevant studies:\n\n" + "\n\n".join(results) @langchain_tool("find_similar_studies") def find_similar_studies(query: str): """ Finds studies semantically similar to a given query or study description. This tool is useful for "more like this" functionality. It relies purely on vector similarity without strict metadata filtering. Args: query (str): The text to match against (e.g., a study title or description). Returns: str: A string containing the top 5 similar studies with their titles and summaries. """ index = load_index() # 1. Check if query is an NCT ID nct_match = re.search(r"NCT\d+", query, re.IGNORECASE) target_nct = None search_text = query if nct_match: target_nct = nct_match.group(0).upper() print(f"🎯 Detected NCT ID for similarity: {target_nct}") # Fetch the study content to use as the semantic query # Use the vector store directly to get the text retriever = index.as_retriever( filters=MetadataFilters( filters=[MetadataFilter(key="nct_id", value=target_nct, operator=FilterOperator.EQ)] ), similarity_top_k=1 ) nodes = retriever.retrieve(target_nct) if nodes: # Use the study's text (Title + Summary) as the query search_text = nodes[0].text print(f"✅ Found study content. Using {len(search_text)} chars for semantic search.") else: print(f"⚠️ Study {target_nct} not found. Falling back to text search.") # 2. Perform Semantic Search # Fetch more candidates (10) to allow for filtering retriever = index.as_retriever(similarity_top_k=10) nodes = retriever.retrieve(search_text) results = [] count = 0 for node in nodes: # 3. Self-Exclusion if target_nct and node.metadata.get("nct_id") == target_nct: continue # Deduplication (if multiple chunks of same study appear) if any(r["nct_id"] == node.metadata.get("nct_id") for r in results): continue results.append({ "nct_id": node.metadata.get("nct_id"), "text": f"Study: {node.metadata['title']} (NCT: {node.metadata.get('nct_id')})\nScore: {node.score:.4f}\nSummary: {node.text[:200]}..." }) count += 1 if count >= 5: # Limit to top 5 unique results break if not results: return "No similar studies found." return "\n\n".join([r["text"] for r in results]) def fetch_study_analytics_data( query: str, group_by: str, phase: Optional[str] = None, status: Optional[str] = None, sponsor: Optional[str] = None, intervention: Optional[str] = None, start_year: Optional[int] = None, study_type: Optional[str] = None, ) -> str: """ Underlying logic for fetching and aggregating clinical trial data. See get_study_analytics for full docstring. """ index = load_index() # 1. Retrieve Data if query.lower() == "overall": try: # Connect to LanceDB directly for speed import lancedb db = lancedb.connect("./ct_gov_lancedb") tbl = db.open_table("clinical_trials") # Fetch all data as pandas DataFrame df = tbl.to_pandas() # LlamaIndex stores metadata in a 'metadata' column (usually as a dict/struct) # We need to flatten it to get columns like 'status', 'phase', etc. if "metadata" in df.columns: # Check if it's already a dict or needs parsing # LanceDB to_pandas() converts struct to dict meta_df = pd.json_normalize(df["metadata"]) df = meta_df # If columns are already flat (depending on schema evolution), we are good. # But usually it's nested. except Exception as e: return f"Error fetching full dataset: {e}" else: filters = [] if status: filters.append( MetadataFilter( key="status", value=status.upper(), operator=FilterOperator.EQ ) ) if phase and "," not in phase: pass if sponsor: # Use the helper to get all variations (e.g. "Pfizer" -> ["Pfizer", "Pfizer Inc."]) sponsor_variations = get_sponsor_variations(sponsor) if sponsor_variations: print(f"🎯 Using strict pre-filter for sponsor '{sponsor}': {len(sponsor_variations)} variations found.") filters.append( MetadataFilter( key="sponsor", value=sponsor_variations, operator=FilterOperator.IN ) ) metadata_filters = MetadataFilters(filters=filters) if filters else None search_query = query if sponsor and sponsor.lower() not in query.lower(): search_query = f"{sponsor} {query}" # Use hybrid search for better recall retriever = index.as_retriever( similarity_top_k=5000, filters=metadata_filters, vector_store_query_mode="hybrid" ) nodes = retriever.retrieve(search_query) # --- Strict Keyword Filtering --- # Strictly check if the query appears in Title or Conditions to ensure accurate counting. # EXCEPTION: If the query matches the requested sponsor, we also check the 'org' field. if query.lower() != "overall": q_term = query.lower() # Check if the query is essentially the sponsor name is_sponsor_query = False # Check if the query itself normalizes to a known sponsor query_normalized = normalize_sponsor(query) if query_normalized and query_normalized != query: # If normalization changed it (or found a mapping), it's likely a sponsor is_sponsor_query = True if sponsor: # Normalize both to see if they refer to the same entity norm_query = normalize_sponsor(query) norm_sponsor = normalize_sponsor(sponsor) if norm_query and norm_sponsor and norm_query.lower() == norm_sponsor.lower(): is_sponsor_query = True elif sponsor.lower() in query.lower() or query.lower() in sponsor.lower(): is_sponsor_query = True filtered_nodes = [] for node in nodes: meta = node.metadata title = meta.get("title", "").lower() conditions = meta.get("condition", "").lower() # Note: key is 'condition' in DB org = meta.get("org", "").lower() sponsor_val = meta.get("sponsor", "").lower() # If it's a sponsor query, we allow matches on the Organization OR Sponsor field # AND we check if the normalized values match (handling aliases like J&J -> Janssen) match = False if q_term in title or q_term in conditions: match = True elif is_sponsor_query: # Check raw match if q_term in org or q_term in sponsor_val: match = True else: # Check normalized match norm_org = normalize_sponsor(org) norm_val = normalize_sponsor(sponsor_val) # Compare against the normalized query (which is the sponsor in this case) target_norm = norm_sponsor if sponsor else query_normalized if norm_org and target_norm and norm_org.lower() == target_norm.lower(): match = True elif norm_val and target_norm and norm_val.lower() == target_norm.lower(): match = True if match: filtered_nodes.append(node) print(f"📉 Strict Filter: {len(nodes)} -> {len(filtered_nodes)} nodes for '{query}'") nodes = filtered_nodes data = [node.metadata for node in nodes] df = pd.DataFrame(data) if "nct_id" in df.columns: df = df.drop_duplicates(subset="nct_id") if df.empty: return "No studies found for analytics." # --- APPLY FILTERS (Pandas) --- if phase: target_phases = [p.strip().upper().replace(" ", "") for p in phase.split(",")] df["phase_upper"] = df["phase"].astype(str).str.upper().str.replace(" ", "") mask = df["phase_upper"].apply(lambda x: any(tp in x for tp in target_phases)) df = df[mask] if status: df = df[df["status"].str.upper() == status.upper()] if sponsor: target_sponsor = normalize_sponsor(sponsor).lower() # Use 'sponsor' column if it exists, otherwise fallback to 'org' if "sponsor" in df.columns: df["sponsor_check"] = df["sponsor"].fillna(df["org"]).astype(str).apply(normalize_sponsor).str.lower() else: df["sponsor_check"] = df["org"].astype(str).apply(normalize_sponsor).str.lower() df = df[df["sponsor_check"].str.contains(target_sponsor, regex=False)] if intervention: target_intervention = intervention.lower() df["intervention_lower"] = df["intervention"].astype(str).str.lower() df = df[df["intervention_lower"].str.contains(target_intervention, regex=False)] if start_year: df["start_year"] = pd.to_numeric(df["start_year"], errors="coerce").fillna(0) df = df[df["start_year"] >= start_year] if study_type: df = df[df["study_type"].str.upper() == study_type.upper()] if df.empty: return "No studies found after applying filters." key_map = { "phase": "phase", "status": "status", "sponsor": "sponsor" if "sponsor" in df.columns else "org", "start_year": "start_year", "condition": "condition", "intervention": "intervention", "study_type": "study_type", "country": "country", "state": "state", } if group_by not in key_map: return f"Invalid group_by field: {group_by}. Valid options: phase, status, sponsor, start_year, condition, intervention, study_type, country, state" col = key_map[group_by] if col == "start_year": df[col] = pd.to_numeric(df[col], errors="coerce") counts = df[col].value_counts().sort_index() elif col == "condition": counts = df[col].astype(str).str.split(", ").explode().value_counts().head(10) elif col == "intervention": all_interventions = [] for interventions in df[col].dropna(): parts = [i.strip() for i in interventions.split(";") if i.strip()] all_interventions.extend(parts) counts = pd.Series(all_interventions).value_counts().head(10) else: counts = df[col].value_counts().head(10) summary = counts.to_string() chart_df = counts.reset_index() chart_df.columns = ["category", "count"] chart_data = { "type": "bar", "title": f"Studies by {group_by.capitalize()}", "data": chart_df.to_dict("records"), "x": "category", "y": "count", } if "inline_chart_data" not in st.session_state: st.session_state["inline_chart_data"] = chart_data else: st.session_state["inline_chart_data"] = chart_data return f"Found {len(df)} studies. Top counts:\n{summary}\n\n(Chart generated in UI)" @langchain_tool("get_study_analytics") def get_study_analytics( query: str, group_by: str, phase: Optional[str] = None, status: Optional[str] = None, sponsor: Optional[str] = None, intervention: Optional[str] = None, start_year: Optional[int] = None, study_type: Optional[str] = None, ): """ Aggregates clinical trial data based on a search query and groups by a specific field. This tool performs the following steps: 1. Retrieves a large number of relevant studies (up to 500). 2. Applies strict filters (Phase, Status, Sponsor) in memory (Pandas). 3. Groups the data by the requested field (e.g., Sponsor). 4. Generates a summary string for the LLM. 5. **Side Effect**: Injects chart data into `st.session_state` to trigger an inline chart in the UI. Args: query (str): The search query to filter studies (e.g., "cancer"). group_by (str): The field to group by. Options: "phase", "status", "sponsor", "start_year", "condition", "intervention", "country", "state". phase (Optional[str]): Optional filter for phase (e.g., "PHASE2"). status (Optional[str]): Optional filter for status (e.g., "RECRUITING"). sponsor (Optional[str]): Optional filter for sponsor (e.g., "Pfizer"). intervention (Optional[str]): Optional filter for intervention (e.g., "Keytruda"). Returns: str: A summary string of the top counts and a note that a chart has been generated. """ return fetch_study_analytics_data( query=query, group_by=group_by, phase=phase, status=status, sponsor=sponsor, intervention=intervention, start_year=start_year, study_type=study_type, ) @langchain_tool("compare_studies") def compare_studies(query: str): """ Compares multiple studies or answers complex multi-part questions using query decomposition. Use this tool when the user asks to "compare", "contrast", or analyze differences/similarities between specific studies, sponsors, or phases. It breaks down the question into sub-questions. Args: query (str): The complex comparison query (e.g., "Compare the primary outcomes of Keytruda vs Opdivo"). Returns: str: A detailed response synthesizing the answers to sub-questions. """ index = load_index() # Create a base query engine for the sub-questions # Increase top_k and add re-ranking to improve recall for comparison queries reranker = SentenceTransformerRerank(model="cross-encoder/ms-marco-MiniLM-L-12-v2", top_n=10) base_engine = index.as_query_engine( similarity_top_k=50, node_postprocessors=[reranker] ) # Wrap it in a QueryEngineTool query_tool = QueryEngineTool( query_engine=base_engine, metadata=ToolMetadata( name="clinical_trials_db", description="Vector database of clinical trial protocols, results, and metadata.", ), ) # Create the SubQuestionQueryEngine # Explicitly define the question generator to use the configured LLM (Gemini) # This avoids the default behavior which might try to import OpenAI modules from llama_index.core.question_gen import LLMQuestionGenerator from llama_index.core import Settings question_gen = LLMQuestionGenerator.from_defaults(llm=Settings.llm) query_engine = SubQuestionQueryEngine.from_defaults( query_engine_tools=[query_tool], question_gen=question_gen, use_async=True, ) try: response = query_engine.query(query) return str(response) + "\n\n(Note: This analysis is based on the most relevant studies retrieved from the database, not necessarily an exhaustive list.)" except Exception as e: return f"Error during comparison: {e}" @langchain_tool("get_study_details") def get_study_details(nct_id: str): """ Retrieves the full details of a specific clinical trial by its NCT ID. Use this tool when the user asks for specific information about a single study, such as "What are the inclusion criteria for NCT12345678?" or "Give me a summary of study NCT...". It returns the full text content of the study document, including criteria, outcomes, and contacts. Args: nct_id (str): The NCT ID of the study (e.g., "NCT01234567"). Returns: str: The full text content of the study, or a message if not found. """ index = load_index() # Clean the ID clean_id = nct_id.strip().upper() # Use a retriever with a strict metadata filter for the ID # Set top_k=20 to capture all chunks if the document was split filters = MetadataFilters( filters=[ MetadataFilter(key="nct_id", value=clean_id, operator=FilterOperator.EQ) ] ) retriever = index.as_retriever(similarity_top_k=20, filters=filters) nodes = retriever.retrieve(clean_id) if not nodes: return f"Study {clean_id} not found in the database." # Sort nodes by their position in the document to reconstruct full text # LlamaIndex nodes usually have 'start_char_idx' in metadata or relationships # Try to sort by node ID or just concatenate them # Simple concatenation (assuming retrieval order is roughly correct or sufficient) full_text = "\n\n".join([node.text for node in nodes]) return f"Details for {clean_id} (Combined {len(nodes)} parts):\n\n{full_text}"