clinical_trial_inspector / ct_agent_app.py
Geoffrey Kip
Fix: Cache embedding model initialization to prevent concurrency crashes
7b7595c
"""
Clinical Trial Inspector Agent Application.
This is the main Streamlit application script. It orchestrates:
1. **LLM & Agents**: Initializes Google Gemini and the LangChain agent.
2. **RAG Pipeline**: Loads the LlamaIndex vector store for semantic retrieval.
3. **User Interface**: Renders the Streamlit UI with tabs for Chat, Analytics, and Raw Data.
4. **Visualization**: Handles dynamic chart generation using Altair.
"""
import streamlit as st
import pandas as pd
import os
import altair as alt
import logging
from dotenv import load_dotenv
# Suppress logging
logging.getLogger("langchain_google_genai._function_utils").setLevel(logging.ERROR)
# Load environment variables
load_dotenv()
# Module Imports
from modules.utils import (
load_environment,
load_index,
setup_llama_index,
init_embedding_model,
get_hybrid_retriever,
)
from modules.constants import COUNTRY_COORDINATES, STATE_COORDINATES
# ... (imports)
from modules.tools import (
search_trials,
find_similar_studies,
get_study_analytics,
compare_studies,
get_study_details,
fetch_study_analytics_data,
)
from modules.cohort_tools import get_cohort_sql
from modules.graph_viz import build_graph
from streamlit_agraph import agraph
from streamlit_option_menu import option_menu
import folium
from streamlit_folium import st_folium
# LangChain Imports
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import MessagesPlaceholder
# --- App Configuration ---
st.set_page_config(
page_title="Clinical Trial Inspector",
layout="wide",
initial_sidebar_state="expanded",
)
# --- Custom CSS for Sidebar Width ---
st.markdown(
"""
<style>
[data-testid="stSidebar"] {
min-width: 200px;
max-width: 250px;
}
</style>
""",
unsafe_allow_html=True,
)
# Initialize global resources (Embeddings) once
init_embedding_model()
st.title("🧬 Clinical Trial Inspector Agent")
# 1. Setup LLM & LlamaIndex Settings
# We use Google Gemini-2.5-Flash for fast and accurate responses.
api_key = os.environ.get("GOOGLE_API_KEY")
# Check session state if env var is missing
if not api_key and "api_key" in st.session_state:
api_key = st.session_state["api_key"]
if not api_key:
st.sidebar.warning("⚠️ API Key Missing")
user_key = st.sidebar.text_input("Enter Google API Key:", type="password", help="Get one at https://aistudio.google.com/")
if user_key:
st.session_state["api_key"] = user_key
st.rerun()
else:
st.warning("Please enter your Google API Key in the sidebar to continue.")
st.stop()
else:
# Ensure it's in session state for tools/consistency
if "api_key" not in st.session_state:
st.session_state["api_key"] = api_key
# Ensure LlamaIndex settings (Embeddings, LLM) are applied on every run
setup_llama_index(api_key=api_key)
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0, google_api_key=api_key)
# 2. Load LlamaIndex (Cached)
# The index is loaded once and cached to avoid reloading on every interaction.
index = load_index()
# 3. Define Agent (Cached)
@st.cache_resource
def get_agent(api_key: str):
"""Initializes and caches the LangChain agent. Keyed by API key."""
# Create LLM specific to this key (and cache entry)
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0, google_api_key=api_key)
tools = [
search_trials,
find_similar_studies,
get_study_analytics,
compare_studies,
get_study_details,
get_cohort_sql,
]
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a Clinical Trial Expert Assistant. "
"Your goal is to help researchers and analysts understand clinical trial data. "
"You have access to a local database of clinical trials (embedded from ClinicalTrials.gov). "
"Use the available tools to search for studies, find similar studies, and generate analytics. "
"When asked about 'trends', 'counts', 'how many', or 'most common', ALWAYS use the `get_study_analytics` tool. "
"Do NOT use `search_trials` for counting questions like 'How many studies...'. "
"When asked to 'find studies', 'search', or 'list', use `search_trials`. "
"When asked to 'compare' multiple studies or answer complex multi-part questions, use `compare_studies`. "
"If the user asks for a specific study by ID (e.g., NCT12345678), `search_trials` handles that automatically. "
"However, if the user asks for specific **details**, **criteria**, **summary**, or **protocol** of a single study, "
"you MUST use the `get_study_details` tool to fetch the full content. "
"If the user asks to **generate SQL**, **build a cohort**, or **translate criteria to code** for a study, "
"use the `get_cohort_sql` tool. "
"When reporting 'similar studies', ALWAYS include the similarity score provided by the tool "
"and DO NOT include the study that was used as the query (the reference study). "
"Provide concise, evidence-based answers citing specific studies when possible.",
),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
("placeholder", "{agent_scratchpad}"),
]
)
agent = create_tool_calling_agent(llm, tools, prompt)
return AgentExecutor(agent=agent, tools=tools, verbose=True)
agent_executor = get_agent(api_key=api_key)
# --- Sidebar ---
with st.sidebar:
st.image(
"https://cdn-icons-png.flaticon.com/512/3004/3004458.png", width=50
)
st.title("Clinical Trial Agent")
page = option_menu(
"Main Menu",
["Chat Assistant", "Analytics Dashboard", "Knowledge Graph", "Study Map", "Raw Data"],
icons=["chat-dots", "graph-up", "diagram-3", "map", "database"],
menu_icon="cast",
default_index=0,
)
# --- Helper Functions ---
def generate_dashboard_analytics():
"""Callback to generate analytics and update session state."""
# Map UI selection to tool arguments
group_map = {
"Phase": "phase",
"Status": "status",
"Sponsor": "sponsor",
"Start Year": "start_year",
"Intervention": "intervention",
"Study Type": "study_type",
}
# Get values from session state
# Use .get() to avoid KeyErrors if the widget hasn't initialized yet
g_by = st.session_state.get("dash_group_by", "Sponsor")
p_filter = st.session_state.get("dash_phase", "")
s_filter = st.session_state.get("dash_sponsor", "")
with st.spinner(f"Analyzing studies by {g_by}..."):
# Call the tool directly
result = get_study_analytics.invoke(
{
"query": "overall",
"group_by": group_map.get(g_by, "sponsor"),
"phase": p_filter if p_filter else None,
"sponsor": s_filter if s_filter else None,
}
)
# The tool sets session state 'inline_chart_data'
if "inline_chart_data" in st.session_state:
st.session_state["dashboard_data"] = st.session_state["inline_chart_data"]
else:
st.warning(result)
# --- PAGE 1: CHAT ---
if page == "Chat Assistant":
st.header("💬 Chat Assistant")
if "messages" not in st.session_state:
st.session_state.messages = []
# Render Chat History
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Render chart if present in message history (persisted charts)
if "chart_data" in message:
chart_data = message["chart_data"]
st.caption(chart_data["title"])
chart = (
alt.Chart(pd.DataFrame(chart_data["data"]))
.mark_bar()
.encode(
x=alt.X(
chart_data["x"], sort="-y", axis=alt.Axis(labelLimit=200)
),
y=alt.Y(chart_data["y"], title="Count"),
tooltip=[chart_data["x"], chart_data["y"]],
)
.interactive()
)
st.altair_chart(chart, theme="streamlit", use_container_width=True)
# Chat Input
if prompt := st.chat_input("Ask about clinical trials..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Analyzing clinical trials..."):
try:
# Clear previous inline chart data to avoid stale charts
if "inline_chart_data" in st.session_state:
del st.session_state["inline_chart_data"]
# Construct chat history for the agent context
chat_history = []
for msg in st.session_state.messages[:-1]:
if msg["role"] == "user":
chat_history.append(HumanMessage(content=msg["content"]))
else:
chat_history.append(AIMessage(content=msg["content"]))
# Invoke Agent
response = agent_executor.invoke(
{"input": prompt, "chat_history": chat_history}
)
output = response["output"]
st.markdown(output)
# Check for inline chart data (set by tools)
chart_data = None
if "inline_chart_data" in st.session_state:
chart_data = st.session_state["inline_chart_data"]
st.caption(chart_data["title"])
if chart_data["type"] == "bar":
# Use Altair for better charts
chart = (
alt.Chart(pd.DataFrame(chart_data["data"]))
.mark_bar()
.encode(
x=alt.X(
chart_data["x"],
sort="-y",
axis=alt.Axis(labelLimit=200),
),
y=alt.Y(chart_data["y"], title="Count"),
tooltip=[chart_data["x"], chart_data["y"]],
)
.interactive()
)
st.altair_chart(chart, theme="streamlit", use_container_width=True)
# Clean up session state
del st.session_state["inline_chart_data"]
# Save message with chart data if present
msg_obj = {"role": "assistant", "content": output}
if chart_data:
msg_obj["chart_data"] = chart_data
st.session_state.messages.append(msg_obj)
except Exception as e:
st.error(f"An error occurred: {e}")
# --- PAGE 2: ANALYTICS DASHBOARD ---
if page == "Analytics Dashboard":
st.header("📊 Global Analytics")
st.write(
"Analyze trends across the entire clinical trial dataset."
)
col1, col2 = st.columns([1, 3])
with col1:
st.subheader("Configuration")
group_by = st.selectbox(
"Group By",
["Phase", "Status", "Sponsor", "Start Year", "Intervention", "Study Type"],
index=2,
key="dash_group_by",
)
# Optional Filters
st.markdown("---")
st.markdown("**Filters (Optional)**")
filter_phase = st.text_input("Phase (e.g., Phase 2)", key="dash_phase")
filter_sponsor = st.text_input("Sponsor (e.g., Pfizer)", key="dash_sponsor")
st.button(
"Generate Analytics", type="primary", on_click=generate_dashboard_analytics
)
with col2:
# Always render if data exists in session state
if "dashboard_data" in st.session_state:
c_data = st.session_state["dashboard_data"]
st.subheader(c_data["title"])
# Altair Chart Rendering
if (
c_data["x"] == "start_year" or group_by == "Start Year"
): # Check both key and UI selection
# Line chart for years
chart = (
alt.Chart(pd.DataFrame(c_data["data"]))
.mark_line(point=True)
.encode(
x=alt.X(
c_data["x"], axis=alt.Axis(format="d"), title="Year"
), # 'd' for integer year
y=alt.Y(c_data["y"], title="Count"),
tooltip=[c_data["x"], c_data["y"]],
)
.interactive()
)
else:
# Bar chart for others
chart = (
alt.Chart(pd.DataFrame(c_data["data"]))
.mark_bar()
.encode(
x=alt.X(
c_data["x"],
sort="-y",
axis=alt.Axis(labelLimit=200),
),
y=alt.Y(c_data["y"], title="Count"),
tooltip=[c_data["x"], c_data["y"]],
)
.interactive()
)
st.altair_chart(chart, theme="streamlit", use_container_width=True)
# Show raw table
with st.expander("View Source Data"):
st.dataframe(pd.DataFrame(c_data["data"]))
# --- PAGE 3: KNOWLEDGE GRAPH ---
if page == "Knowledge Graph":
st.header("🕸️ Interactive Knowledge Graph")
st.write("Visualize connections between Studies, Sponsors, and Conditions.")
col_g1, col_g2 = st.columns([1, 3])
with col_g1:
st.subheader("Graph Settings")
graph_query = st.text_input("Search Topic", value="Cancer")
limit = st.slider("Max Nodes", 10, 100, 50)
if st.button("Build Graph"):
with st.spinner("Fetching data and building graph..."):
# Use retriever to get relevant nodes
retriever = index.as_retriever(similarity_top_k=limit)
nodes = retriever.retrieve(graph_query)
data = [n.metadata for n in nodes]
# Build Graph
g_nodes, g_edges, g_config = build_graph(data)
st.session_state["graph_data"] = {
"nodes": g_nodes,
"edges": g_edges,
"config": g_config,
}
with col_g2:
if "graph_data" in st.session_state:
g_data = st.session_state["graph_data"]
st.success(
f"Graph built with {len(g_data['nodes'])} nodes and {len(g_data['edges'])} edges."
)
agraph(
nodes=g_data["nodes"], edges=g_data["edges"], config=g_data["config"]
)
else:
st.info("Enter a topic and click 'Build Graph' to visualize connections.")
# --- PAGE# --- Study Map Tab ---
elif page == "Study Map":
st.header("🌍 Global Clinical Trial Map")
st.markdown("Visualize the geographic distribution of clinical trials.")
# Sidebar Filters for Map
st.sidebar.markdown("### 🗺️ Map Filters")
map_region = st.sidebar.radio("Region", ["World", "USA"], index=0)
map_phase = st.sidebar.multiselect(
"Phase", ["PHASE1", "PHASE2", "PHASE3", "PHASE4"], default=["PHASE2", "PHASE3"]
)
map_status = st.sidebar.selectbox(
"Status", ["RECRUITING", "COMPLETED", "ACTIVE_NOT_RECRUITING"], index=0
)
map_sponsor = st.sidebar.text_input("Sponsor (Optional)", "")
map_year = st.sidebar.number_input("Start Year (>=)", min_value=2000, value=2020)
map_type = st.sidebar.selectbox(
"Study Type", ["Interventional", "Observational", "All"], index=0
)
# Convert filters to arguments
phase_str = ",".join(map_phase) if map_phase else None
type_arg = map_type if map_type != "All" else None
if st.button("Update Map"):
with st.spinner("Aggregating geographic data..."):
# Determine grouping based on Region
group_by_field = "state" if map_region == "USA" else "country"
# Call analytics logic directly
summary = fetch_study_analytics_data(
query="overall",
group_by=group_by_field,
phase=phase_str,
status=map_status,
sponsor=map_sponsor,
start_year=map_year,
study_type=type_arg,
)
# Retrieve data from session state
chart_data = st.session_state.get("inline_chart_data", {})
data_records = chart_data.get("data", [])
if not data_records:
st.warning("No data found for these filters.")
st.session_state["map_data"] = None
st.session_state["map_region"] = map_region # Store region too
else:
# Store in session state for persistence
st.session_state["map_data"] = data_records
st.session_state["map_region"] = map_region
# Render Map (Outside Button Block)
if st.session_state.get("map_data"):
data_records = st.session_state["map_data"]
region_mode = st.session_state.get("map_region", "World")
df_map = pd.DataFrame(data_records)
# Configure Map Center/Zoom
if region_mode == "USA":
m = folium.Map(location=[37.0902, -95.7129], zoom_start=4)
coord_map = STATE_COORDINATES
else:
m = folium.Map(location=[20, 0], zoom_start=2)
coord_map = COUNTRY_COORDINATES
# Add CircleMarkers
for _, row in df_map.iterrows():
loc_name = row["category"]
count = row["count"]
# Clean name if needed (strip trailing parenthesis)
loc_clean = loc_name.rstrip(")")
coords = coord_map.get(loc_clean)
if coords:
folium.CircleMarker(
location=coords,
radius=min(max(count / 5, 3), 20), # Adjust scale
popup=f"{loc_clean}: {count} trials",
color="blue" if region_mode == "USA" else "crimson",
fill=True,
fill_color="blue" if region_mode == "USA" else "crimson",
).add_to(m)
st_folium(m, width=800, height=500)
# Show data table
st.subheader(f"{region_mode} Data")
st.dataframe(df_map)
# --- PAGE 4: RAW DATA ---
if page == "Raw Data":
st.header("📂 Raw Data Explorer")
st.write("View and filter the underlying dataset.")
# Load a sample (top 100) to avoid performance issues.
col_raw_1, col_raw_2 = st.columns([1, 1])
with col_raw_1:
if st.button("Load Sample Data (Top 100)"):
with st.spinner("Fetching data..."):
retriever = index.as_retriever(similarity_top_k=100)
nodes = retriever.retrieve("clinical trial")
data = [n.metadata for n in nodes]
df_raw = pd.DataFrame(data)
# Format Year to remove commas (e.g., 2,023 -> 2023)
if "start_year" in df_raw.columns:
df_raw["start_year"] = (
pd.to_numeric(df_raw["start_year"], errors="coerce")
.astype("Int64")
.astype(str)
.str.replace(",", "")
)
# Store in session state to persist the table
st.session_state["sample_data"] = df_raw
with col_raw_2:
# Download Full Dataset Logic
if st.button("Prepare Full Download (CSV)"):
with st.spinner("Fetching all records from database..."):
try:
# Access LanceDB directly for speed
import lancedb
db = lancedb.connect("./ct_gov_lancedb")
tbl = db.open_table("clinical_trials")
# Fetch all data
df_full = tbl.to_pandas()
# Handle metadata flattening if needed
if "metadata" in df_full.columns:
meta_df = pd.json_normalize(df_full["metadata"])
# Combine or just use metadata
df_full = meta_df
# Convert to CSV
csv = df_full.to_csv(index=False).encode("utf-8")
st.session_state["full_csv"] = csv
st.success(f"Ready! Fetched {len(df_full)} records.")
else:
st.warning("No data found in database.")
except Exception as e:
st.error(f"Error fetching data: {e}")
if "full_csv" in st.session_state:
st.download_button(
label="⬇️ Download Full CSV",
data=st.session_state["full_csv"],
file_name="clinical_trials_full.csv",
mime="text/csv",
)
# Display Sample Data Table (Full Width)
if "sample_data" in st.session_state:
st.markdown("### Sample Data (Top 100)")
st.dataframe(
st.session_state["sample_data"],
column_config={
"nct_id": "NCT ID",
"title": "Study Title",
"start_year": st.column_config.TextColumn(
"Start Year"
), # Force text to avoid commas
"url": st.column_config.LinkColumn("URL"),
},
use_container_width=True,
hide_index=True,
)