Geoffrey Kip
Fix: Add country/state to analytics docstring
bdac52c
"""
LangChain Tools for the Clinical Trial Agent.
This module defines the tools that the agent can use to interact with the clinical trial data.
Tools include:
1. **search_trials**: Semantic search with optional strict filtering.
2. **find_similar_studies**: Finding studies semantically similar to a given text.
3. **get_study_analytics**: Aggregating data for trends and insights (with inline charts).
"""
import pandas as pd
import streamlit as st
from typing import Optional
from langchain.tools import tool as langchain_tool
from llama_index.core.vector_stores import (
MetadataFilter,
MetadataFilters,
FilterOperator,
)
from llama_index.core import Settings
from llama_index.core.postprocessor import MetadataReplacementPostProcessor
from llama_index.core.postprocessor import SentenceTransformerRerank
from llama_index.core.query_engine import SubQuestionQueryEngine
from llama_index.core.tools import QueryEngineTool, ToolMetadata
from modules.utils import (
load_index,
normalize_sponsor,
get_sponsor_variations,
get_hybrid_retriever,
get_reranker, # Import cached reranker
)
import re
import traceback
# --- Tools ---
def expand_query(query: str) -> str:
"""Expands a search query with synonyms using the LLM."""
if not query or len(query.split()) > 10: # Skip expansion for long queries
return query
# Skip expansion if it looks like an NCT ID
if re.search(r"NCT\d+", query, re.IGNORECASE):
return query
prompt = (
f"You are a helpful medical assistant. "
f"Expand the following search query with relevant medical synonyms and acronyms. "
f"Return ONLY the expanded query string combined with OR operators. "
f"Do not add any explanation.\n\n"
f"Query: {query}\n"
f"Expanded Query:"
)
try:
# Use the global Settings.llm
if not Settings.llm:
# Fallback if not initialized (though load_index does it)
from modules.utils import setup_llama_index
setup_llama_index()
response = Settings.llm.complete(prompt)
expanded = response.text.strip()
# Clean up if LLM is chatty
if "Expanded Query:" in expanded:
expanded = expanded.split("Expanded Query:")[-1].strip()
if not expanded:
print(f"⚠️ Expansion returned empty. Using original query.")
return query
print(f"✨ Expanded Query: '{query}' -> '{expanded}'")
return expanded
except Exception as e:
print(f"⚠️ Query expansion failed: {e}")
return query
@langchain_tool("search_trials")
def search_trials(
query: str = None,
status: str = None,
phase: str = None,
sponsor: str = None,
intervention: str = None,
year: int = None,
):
"""
Searches for clinical trials using semantic search with robust filtering.
Args:
query (str, optional): The natural language search query.
status (str, optional): Filter by recruitment status.
phase (str, optional): Filter by trial phase.
sponsor (str, optional): Filter by sponsor name.
intervention (str, optional): Filter by intervention/drug name.
year (int, optional): Filter for studies starting on or after this year.
Returns:
str: A structured list of relevant studies.
"""
index = load_index()
# Constants
TOP_K_STRICT = 200 # Reduced from 500 for performance
# --- Query Construction ---
if not query:
parts = [p for p in [sponsor, intervention, phase, status] if p]
query = " ".join(parts) if parts else "clinical trial"
else:
# Inject context for vector search
if sponsor and normalize_sponsor(sponsor).lower() not in query.lower():
query = f"{normalize_sponsor(sponsor)} {query}"
if intervention and intervention.lower() not in query.lower():
query = f"{intervention} {query}"
query = expand_query(query)
print(f"🔍 Tool Called: search_trials(query='{query}', sponsor='{sponsor}')")
# --- Strategy 1: Strict Pre-Retrieval Filtering (High Precision) ---
# Filter by Sponsor/Status/Year at the database level first.
pre_filters = []
# NCT ID Match
nct_match = re.search(r"NCT\d+", query, re.IGNORECASE)
if nct_match:
nct_id = nct_match.group(0).upper()
pre_filters.append(MetadataFilter(key="nct_id", value=nct_id, operator=FilterOperator.EQ))
if status:
pre_filters.append(MetadataFilter(key="status", value=status.upper(), operator=FilterOperator.EQ))
if year:
pre_filters.append(MetadataFilter(key="start_year", value=year, operator=FilterOperator.GTE))
# Sponsor Pre-Filter
if sponsor:
from modules.utils import get_sponsor_variations
variations = get_sponsor_variations(sponsor)
if variations:
print(f"🎯 Applying strict pre-filter for sponsor '{sponsor}' ({len(variations)} variants)")
# Use 'sponsor' field which is the Lead Sponsor
pre_filters.append(MetadataFilter(key="sponsor", value=variations, operator=FilterOperator.IN))
else:
print(f"⚠️ No strict mapping for sponsor '{sponsor}'. Will rely on fuzzy post-filtering.")
metadata_filters = MetadataFilters(filters=pre_filters) if pre_filters else None
# Post-processors (Reranking)
# Use cached reranker
reranker = get_reranker(top_n=50)
# --- HYBRID SEARCH IMPLEMENTATION ---
# Combine Vector + BM25 using get_hybrid_retriever
try:
retriever = get_hybrid_retriever(index, similarity_top_k=TOP_K_STRICT, filters=metadata_filters)
nodes = retriever.retrieve(query)
# (QueryFusionRetriever returns nodes, but we want to rerank them)
if nodes:
from llama_index.core.schema import QueryBundle
nodes = reranker.postprocess_nodes(nodes, query_bundle=QueryBundle(query_str=query))
except Exception as e:
print(f"⚠️ Hybrid search failed: {e}. Falling back to standard vector search.")
traceback.print_exc()
query_engine = index.as_query_engine(
similarity_top_k=TOP_K_STRICT,
filters=metadata_filters,
node_postprocessors=[reranker]
)
response = query_engine.query(query)
nodes = response.source_nodes
# --- Strict Metadata Filtering (Post-Fusion) ---
# BM25 results might not respect the vector filters, so filter them out.
final_nodes = []
for node in nodes:
meta = node.metadata
keep = True
# Re-apply filters to ensure BM25 results are valid
if status and meta.get("status", "").upper() != status.upper():
keep = False
if year:
try:
if int(meta.get("start_year", 0)) < year:
keep = False
except:
pass
if sponsor:
# Strict logic for sponsor in pre-filters is ignored by BM25.
# Check if the sponsor matches one of the variations OR fuzzy match
# If strict variations exist, enforce them.
variations = get_sponsor_variations(sponsor)
node_sponsor = meta.get("sponsor", "")
# Fallback to org if sponsor is missing (legacy data)
if not node_sponsor:
node_sponsor = meta.get("org", "")
if variations:
if node_sponsor not in variations:
keep = False
else:
# Fuzzy fallback
if normalize_sponsor(sponsor).lower() not in normalize_sponsor(node_sponsor).lower():
keep = False
if keep:
final_nodes.append(node)
nodes = final_nodes
# --- Strict Keyword Filtering ---
# BM25 handles keyword relevance naturally, so rely on the Hybrid Search + Reranker
# rather than applying an aggressive substring check here.
# Update response object structure to match expected format if we used retriever
class MockResponse:
def __init__(self, nodes):
self.source_nodes = nodes
response = MockResponse(nodes)
# --- Strategy 2: Hybrid Search (Fallback) ---
# Hybrid Search is enabled by default.
# Strict filters are handled in post-processing above.
# --- Formatting Output ---
if not response.source_nodes:
return "No matching studies found. Try broadening your search terms or filters."
# Filter by Relevance Score for display
MIN_SCORE = 1.5
relevant_nodes = [node for node in response.source_nodes if node.score > MIN_SCORE]
# If strict filtering removes too much, show at least top 3 to be helpful
if len(relevant_nodes) < 3 and len(response.source_nodes) > 0:
relevant_nodes = response.source_nodes[:3]
display_limit = 20
display_nodes = relevant_nodes[:display_limit]
results = []
for node in display_nodes:
meta = node.metadata
entry = (
f"**{meta.get('title', 'Untitled')}**\n"
f" - ID: {meta.get('nct_id')}\n"
f" - Phase: {meta.get('phase', 'N/A')}\n"
f" - Status: {meta.get('status', 'N/A')}\n"
f" - Sponsor: {meta.get('sponsor', meta.get('org', 'Unknown'))}\n"
f" - Relevance: {node.score:.2f}"
)
results.append(entry)
return f"Found {len(results)} relevant studies:\n\n" + "\n\n".join(results)
@langchain_tool("find_similar_studies")
def find_similar_studies(query: str):
"""
Finds studies semantically similar to a given query or study description.
This tool is useful for "more like this" functionality. It relies purely
on vector similarity without strict metadata filtering.
Args:
query (str): The text to match against (e.g., a study title or description).
Returns:
str: A string containing the top 5 similar studies with their titles and summaries.
"""
index = load_index()
# 1. Check if query is an NCT ID
nct_match = re.search(r"NCT\d+", query, re.IGNORECASE)
target_nct = None
search_text = query
if nct_match:
target_nct = nct_match.group(0).upper()
print(f"🎯 Detected NCT ID for similarity: {target_nct}")
# Fetch the study content to use as the semantic query
# Use the vector store directly to get the text
retriever = index.as_retriever(
filters=MetadataFilters(
filters=[MetadataFilter(key="nct_id", value=target_nct, operator=FilterOperator.EQ)]
),
similarity_top_k=1
)
nodes = retriever.retrieve(target_nct)
if nodes:
# Use the study's text (Title + Summary) as the query
search_text = nodes[0].text
print(f"✅ Found study content. Using {len(search_text)} chars for semantic search.")
else:
print(f"⚠️ Study {target_nct} not found. Falling back to text search.")
# 2. Perform Semantic Search
# Fetch more candidates (10) to allow for filtering
retriever = index.as_retriever(similarity_top_k=10)
nodes = retriever.retrieve(search_text)
results = []
count = 0
for node in nodes:
# 3. Self-Exclusion
if target_nct and node.metadata.get("nct_id") == target_nct:
continue
# Deduplication (if multiple chunks of same study appear)
if any(r["nct_id"] == node.metadata.get("nct_id") for r in results):
continue
results.append({
"nct_id": node.metadata.get("nct_id"),
"text": f"Study: {node.metadata['title']} (NCT: {node.metadata.get('nct_id')})\nScore: {node.score:.4f}\nSummary: {node.text[:200]}..."
})
count += 1
if count >= 5: # Limit to top 5 unique results
break
if not results:
return "No similar studies found."
return "\n\n".join([r["text"] for r in results])
def fetch_study_analytics_data(
query: str,
group_by: str,
phase: Optional[str] = None,
status: Optional[str] = None,
sponsor: Optional[str] = None,
intervention: Optional[str] = None,
start_year: Optional[int] = None,
study_type: Optional[str] = None,
) -> str:
"""
Underlying logic for fetching and aggregating clinical trial data.
See get_study_analytics for full docstring.
"""
index = load_index()
# 1. Retrieve Data
if query.lower() == "overall":
try:
# Connect to LanceDB directly for speed
import lancedb
db = lancedb.connect("./ct_gov_lancedb")
tbl = db.open_table("clinical_trials")
# Fetch all data as pandas DataFrame
df = tbl.to_pandas()
# LlamaIndex stores metadata in a 'metadata' column (usually as a dict/struct)
# We need to flatten it to get columns like 'status', 'phase', etc.
if "metadata" in df.columns:
# Check if it's already a dict or needs parsing
# LanceDB to_pandas() converts struct to dict
meta_df = pd.json_normalize(df["metadata"])
df = meta_df
# If columns are already flat (depending on schema evolution), we are good.
# But usually it's nested.
except Exception as e:
return f"Error fetching full dataset: {e}"
else:
filters = []
if status:
filters.append(
MetadataFilter(
key="status", value=status.upper(), operator=FilterOperator.EQ
)
)
if phase and "," not in phase:
pass
if sponsor:
# Use the helper to get all variations (e.g. "Pfizer" -> ["Pfizer", "Pfizer Inc."])
sponsor_variations = get_sponsor_variations(sponsor)
if sponsor_variations:
print(f"🎯 Using strict pre-filter for sponsor '{sponsor}': {len(sponsor_variations)} variations found.")
filters.append(
MetadataFilter(
key="sponsor", value=sponsor_variations, operator=FilterOperator.IN
)
)
metadata_filters = MetadataFilters(filters=filters) if filters else None
search_query = query
if sponsor and sponsor.lower() not in query.lower():
search_query = f"{sponsor} {query}"
# Use hybrid search for better recall
retriever = index.as_retriever(
similarity_top_k=5000,
filters=metadata_filters,
vector_store_query_mode="hybrid"
)
nodes = retriever.retrieve(search_query)
# --- Strict Keyword Filtering ---
# Strictly check if the query appears in Title or Conditions to ensure accurate counting.
# EXCEPTION: If the query matches the requested sponsor, we also check the 'org' field.
if query.lower() != "overall":
q_term = query.lower()
# Check if the query is essentially the sponsor name
is_sponsor_query = False
# Check if the query itself normalizes to a known sponsor
query_normalized = normalize_sponsor(query)
if query_normalized and query_normalized != query:
# If normalization changed it (or found a mapping), it's likely a sponsor
is_sponsor_query = True
if sponsor:
# Normalize both to see if they refer to the same entity
norm_query = normalize_sponsor(query)
norm_sponsor = normalize_sponsor(sponsor)
if norm_query and norm_sponsor and norm_query.lower() == norm_sponsor.lower():
is_sponsor_query = True
elif sponsor.lower() in query.lower() or query.lower() in sponsor.lower():
is_sponsor_query = True
filtered_nodes = []
for node in nodes:
meta = node.metadata
title = meta.get("title", "").lower()
conditions = meta.get("condition", "").lower() # Note: key is 'condition' in DB
org = meta.get("org", "").lower()
sponsor_val = meta.get("sponsor", "").lower()
# If it's a sponsor query, we allow matches on the Organization OR Sponsor field
# AND we check if the normalized values match (handling aliases like J&J -> Janssen)
match = False
if q_term in title or q_term in conditions:
match = True
elif is_sponsor_query:
# Check raw match
if q_term in org or q_term in sponsor_val:
match = True
else:
# Check normalized match
norm_org = normalize_sponsor(org)
norm_val = normalize_sponsor(sponsor_val)
# Compare against the normalized query (which is the sponsor in this case)
target_norm = norm_sponsor if sponsor else query_normalized
if norm_org and target_norm and norm_org.lower() == target_norm.lower():
match = True
elif norm_val and target_norm and norm_val.lower() == target_norm.lower():
match = True
if match:
filtered_nodes.append(node)
print(f"📉 Strict Filter: {len(nodes)} -> {len(filtered_nodes)} nodes for '{query}'")
nodes = filtered_nodes
data = [node.metadata for node in nodes]
df = pd.DataFrame(data)
if "nct_id" in df.columns:
df = df.drop_duplicates(subset="nct_id")
if df.empty:
return "No studies found for analytics."
# --- APPLY FILTERS (Pandas) ---
if phase:
target_phases = [p.strip().upper().replace(" ", "") for p in phase.split(",")]
df["phase_upper"] = df["phase"].astype(str).str.upper().str.replace(" ", "")
mask = df["phase_upper"].apply(lambda x: any(tp in x for tp in target_phases))
df = df[mask]
if status:
df = df[df["status"].str.upper() == status.upper()]
if sponsor:
target_sponsor = normalize_sponsor(sponsor).lower()
# Use 'sponsor' column if it exists, otherwise fallback to 'org'
if "sponsor" in df.columns:
df["sponsor_check"] = df["sponsor"].fillna(df["org"]).astype(str).apply(normalize_sponsor).str.lower()
else:
df["sponsor_check"] = df["org"].astype(str).apply(normalize_sponsor).str.lower()
df = df[df["sponsor_check"].str.contains(target_sponsor, regex=False)]
if intervention:
target_intervention = intervention.lower()
df["intervention_lower"] = df["intervention"].astype(str).str.lower()
df = df[df["intervention_lower"].str.contains(target_intervention, regex=False)]
if start_year:
df["start_year"] = pd.to_numeric(df["start_year"], errors="coerce").fillna(0)
df = df[df["start_year"] >= start_year]
if study_type:
df = df[df["study_type"].str.upper() == study_type.upper()]
if df.empty:
return "No studies found after applying filters."
key_map = {
"phase": "phase",
"status": "status",
"sponsor": "sponsor" if "sponsor" in df.columns else "org",
"start_year": "start_year",
"condition": "condition",
"intervention": "intervention",
"study_type": "study_type",
"country": "country",
"state": "state",
}
if group_by not in key_map:
return f"Invalid group_by field: {group_by}. Valid options: phase, status, sponsor, start_year, condition, intervention, study_type, country, state"
col = key_map[group_by]
if col == "start_year":
df[col] = pd.to_numeric(df[col], errors="coerce")
counts = df[col].value_counts().sort_index()
elif col == "condition":
counts = df[col].astype(str).str.split(", ").explode().value_counts().head(10)
elif col == "intervention":
all_interventions = []
for interventions in df[col].dropna():
parts = [i.strip() for i in interventions.split(";") if i.strip()]
all_interventions.extend(parts)
counts = pd.Series(all_interventions).value_counts().head(10)
else:
counts = df[col].value_counts().head(10)
summary = counts.to_string()
chart_df = counts.reset_index()
chart_df.columns = ["category", "count"]
chart_data = {
"type": "bar",
"title": f"Studies by {group_by.capitalize()}",
"data": chart_df.to_dict("records"),
"x": "category",
"y": "count",
}
if "inline_chart_data" not in st.session_state:
st.session_state["inline_chart_data"] = chart_data
else:
st.session_state["inline_chart_data"] = chart_data
return f"Found {len(df)} studies. Top counts:\n{summary}\n\n(Chart generated in UI)"
@langchain_tool("get_study_analytics")
def get_study_analytics(
query: str,
group_by: str,
phase: Optional[str] = None,
status: Optional[str] = None,
sponsor: Optional[str] = None,
intervention: Optional[str] = None,
start_year: Optional[int] = None,
study_type: Optional[str] = None,
):
"""
Aggregates clinical trial data based on a search query and groups by a specific field.
This tool performs the following steps:
1. Retrieves a large number of relevant studies (up to 500).
2. Applies strict filters (Phase, Status, Sponsor) in memory (Pandas).
3. Groups the data by the requested field (e.g., Sponsor).
4. Generates a summary string for the LLM.
5. **Side Effect**: Injects chart data into `st.session_state` to trigger an inline chart in the UI.
Args:
query (str): The search query to filter studies (e.g., "cancer").
group_by (str): The field to group by. Options: "phase", "status", "sponsor", "start_year", "condition", "intervention", "country", "state".
phase (Optional[str]): Optional filter for phase (e.g., "PHASE2").
status (Optional[str]): Optional filter for status (e.g., "RECRUITING").
sponsor (Optional[str]): Optional filter for sponsor (e.g., "Pfizer").
intervention (Optional[str]): Optional filter for intervention (e.g., "Keytruda").
Returns:
str: A summary string of the top counts and a note that a chart has been generated.
"""
return fetch_study_analytics_data(
query=query,
group_by=group_by,
phase=phase,
status=status,
sponsor=sponsor,
intervention=intervention,
start_year=start_year,
study_type=study_type,
)
@langchain_tool("compare_studies")
def compare_studies(query: str):
"""
Compares multiple studies or answers complex multi-part questions using query decomposition.
Use this tool when the user asks to "compare", "contrast", or analyze differences/similarities
between specific studies, sponsors, or phases. It breaks down the question into sub-questions.
Args:
query (str): The complex comparison query (e.g., "Compare the primary outcomes of Keytruda vs Opdivo").
Returns:
str: A detailed response synthesizing the answers to sub-questions.
"""
index = load_index()
# Create a base query engine for the sub-questions
# Increase top_k and add re-ranking to improve recall for comparison queries
reranker = SentenceTransformerRerank(model="cross-encoder/ms-marco-MiniLM-L-12-v2", top_n=10)
base_engine = index.as_query_engine(
similarity_top_k=50,
node_postprocessors=[reranker]
)
# Wrap it in a QueryEngineTool
query_tool = QueryEngineTool(
query_engine=base_engine,
metadata=ToolMetadata(
name="clinical_trials_db",
description="Vector database of clinical trial protocols, results, and metadata.",
),
)
# Create the SubQuestionQueryEngine
# Explicitly define the question generator to use the configured LLM (Gemini)
# This avoids the default behavior which might try to import OpenAI modules
from llama_index.core.question_gen import LLMQuestionGenerator
from llama_index.core import Settings
question_gen = LLMQuestionGenerator.from_defaults(llm=Settings.llm)
query_engine = SubQuestionQueryEngine.from_defaults(
query_engine_tools=[query_tool],
question_gen=question_gen,
use_async=True,
)
try:
response = query_engine.query(query)
return str(response) + "\n\n(Note: This analysis is based on the most relevant studies retrieved from the database, not necessarily an exhaustive list.)"
except Exception as e:
return f"Error during comparison: {e}"
@langchain_tool("get_study_details")
def get_study_details(nct_id: str):
"""
Retrieves the full details of a specific clinical trial by its NCT ID.
Use this tool when the user asks for specific information about a single study,
such as "What are the inclusion criteria for NCT12345678?" or "Give me a summary of study NCT...".
It returns the full text content of the study document, including criteria, outcomes, and contacts.
Args:
nct_id (str): The NCT ID of the study (e.g., "NCT01234567").
Returns:
str: The full text content of the study, or a message if not found.
"""
index = load_index()
# Clean the ID
clean_id = nct_id.strip().upper()
# Use a retriever with a strict metadata filter for the ID
# Set top_k=20 to capture all chunks if the document was split
filters = MetadataFilters(
filters=[
MetadataFilter(key="nct_id", value=clean_id, operator=FilterOperator.EQ)
]
)
retriever = index.as_retriever(similarity_top_k=20, filters=filters)
nodes = retriever.retrieve(clean_id)
if not nodes:
return f"Study {clean_id} not found in the database."
# Sort nodes by their position in the document to reconstruct full text
# LlamaIndex nodes usually have 'start_char_idx' in metadata or relationships
# Try to sort by node ID or just concatenate them
# Simple concatenation (assuming retrieval order is roughly correct or sufficient)
full_text = "\n\n".join([node.text for node in nodes])
return f"Details for {clean_id} (Combined {len(nodes)} parts):\n\n{full_text}"