""" Clinical Trial Inspector Agent Application. This is the main Streamlit application script. It orchestrates: 1. **LLM & Agents**: Initializes Google Gemini and the LangChain agent. 2. **RAG Pipeline**: Loads the LlamaIndex vector store for semantic retrieval. 3. **User Interface**: Renders the Streamlit UI with tabs for Chat, Analytics, and Raw Data. 4. **Visualization**: Handles dynamic chart generation using Altair. """ import streamlit as st import pandas as pd import os import altair as alt import logging from dotenv import load_dotenv # Suppress logging logging.getLogger("langchain_google_genai._function_utils").setLevel(logging.ERROR) # Load environment variables load_dotenv() # Module Imports from modules.utils import ( load_environment, load_index, setup_llama_index, init_embedding_model, get_hybrid_retriever, ) from modules.constants import COUNTRY_COORDINATES, STATE_COORDINATES # ... (imports) from modules.tools import ( search_trials, find_similar_studies, get_study_analytics, compare_studies, get_study_details, fetch_study_analytics_data, ) from modules.cohort_tools import get_cohort_sql from modules.graph_viz import build_graph from streamlit_agraph import agraph from streamlit_option_menu import option_menu import folium from streamlit_folium import st_folium # LangChain Imports from langchain_google_genai import ChatGoogleGenerativeAI from langchain.agents import AgentExecutor, create_tool_calling_agent from langchain_core.prompts import ChatPromptTemplate from langchain_core.messages import HumanMessage, AIMessage from langchain_core.prompts import MessagesPlaceholder # --- App Configuration --- st.set_page_config( page_title="Clinical Trial Inspector", layout="wide", initial_sidebar_state="expanded", ) # --- Custom CSS for Sidebar Width --- st.markdown( """ """, unsafe_allow_html=True, ) # Initialize global resources (Embeddings) once init_embedding_model() st.title("🧬 Clinical Trial Inspector Agent") # 1. Setup LLM & LlamaIndex Settings # We use Google Gemini-2.5-Flash for fast and accurate responses. api_key = os.environ.get("GOOGLE_API_KEY") # Check session state if env var is missing if not api_key and "api_key" in st.session_state: api_key = st.session_state["api_key"] if not api_key: st.sidebar.warning("⚠️ API Key Missing") user_key = st.sidebar.text_input("Enter Google API Key:", type="password", help="Get one at https://aistudio.google.com/") if user_key: st.session_state["api_key"] = user_key st.rerun() else: st.warning("Please enter your Google API Key in the sidebar to continue.") st.stop() else: # Ensure it's in session state for tools/consistency if "api_key" not in st.session_state: st.session_state["api_key"] = api_key # Ensure LlamaIndex settings (Embeddings, LLM) are applied on every run setup_llama_index(api_key=api_key) llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0, google_api_key=api_key) # 2. Load LlamaIndex (Cached) # The index is loaded once and cached to avoid reloading on every interaction. index = load_index() # 3. Define Agent (Cached) @st.cache_resource def get_agent(api_key: str): """Initializes and caches the LangChain agent. Keyed by API key.""" # Create LLM specific to this key (and cache entry) llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0, google_api_key=api_key) tools = [ search_trials, find_similar_studies, get_study_analytics, compare_studies, get_study_details, get_cohort_sql, ] prompt = ChatPromptTemplate.from_messages( [ ( "system", "You are a Clinical Trial Expert Assistant. " "Your goal is to help researchers and analysts understand clinical trial data. " "You have access to a local database of clinical trials (embedded from ClinicalTrials.gov). " "Use the available tools to search for studies, find similar studies, and generate analytics. " "When asked about 'trends', 'counts', 'how many', or 'most common', ALWAYS use the `get_study_analytics` tool. " "Do NOT use `search_trials` for counting questions like 'How many studies...'. " "When asked to 'find studies', 'search', or 'list', use `search_trials`. " "When asked to 'compare' multiple studies or answer complex multi-part questions, use `compare_studies`. " "If the user asks for a specific study by ID (e.g., NCT12345678), `search_trials` handles that automatically. " "However, if the user asks for specific **details**, **criteria**, **summary**, or **protocol** of a single study, " "you MUST use the `get_study_details` tool to fetch the full content. " "If the user asks to **generate SQL**, **build a cohort**, or **translate criteria to code** for a study, " "use the `get_cohort_sql` tool. " "When reporting 'similar studies', ALWAYS include the similarity score provided by the tool " "and DO NOT include the study that was used as the query (the reference study). " "Provide concise, evidence-based answers citing specific studies when possible.", ), MessagesPlaceholder(variable_name="chat_history"), ("human", "{input}"), ("placeholder", "{agent_scratchpad}"), ] ) agent = create_tool_calling_agent(llm, tools, prompt) return AgentExecutor(agent=agent, tools=tools, verbose=True) agent_executor = get_agent(api_key=api_key) # --- Sidebar --- with st.sidebar: st.image( "https://cdn-icons-png.flaticon.com/512/3004/3004458.png", width=50 ) st.title("Clinical Trial Agent") page = option_menu( "Main Menu", ["Chat Assistant", "Analytics Dashboard", "Knowledge Graph", "Study Map", "Raw Data"], icons=["chat-dots", "graph-up", "diagram-3", "map", "database"], menu_icon="cast", default_index=0, ) # --- Helper Functions --- def generate_dashboard_analytics(): """Callback to generate analytics and update session state.""" # Map UI selection to tool arguments group_map = { "Phase": "phase", "Status": "status", "Sponsor": "sponsor", "Start Year": "start_year", "Intervention": "intervention", "Study Type": "study_type", } # Get values from session state # Use .get() to avoid KeyErrors if the widget hasn't initialized yet g_by = st.session_state.get("dash_group_by", "Sponsor") p_filter = st.session_state.get("dash_phase", "") s_filter = st.session_state.get("dash_sponsor", "") with st.spinner(f"Analyzing studies by {g_by}..."): # Call the tool directly result = get_study_analytics.invoke( { "query": "overall", "group_by": group_map.get(g_by, "sponsor"), "phase": p_filter if p_filter else None, "sponsor": s_filter if s_filter else None, } ) # The tool sets session state 'inline_chart_data' if "inline_chart_data" in st.session_state: st.session_state["dashboard_data"] = st.session_state["inline_chart_data"] else: st.warning(result) # --- PAGE 1: CHAT --- if page == "Chat Assistant": st.header("πŸ’¬ Chat Assistant") if "messages" not in st.session_state: st.session_state.messages = [] # Render Chat History for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # Render chart if present in message history (persisted charts) if "chart_data" in message: chart_data = message["chart_data"] st.caption(chart_data["title"]) chart = ( alt.Chart(pd.DataFrame(chart_data["data"])) .mark_bar() .encode( x=alt.X( chart_data["x"], sort="-y", axis=alt.Axis(labelLimit=200) ), y=alt.Y(chart_data["y"], title="Count"), tooltip=[chart_data["x"], chart_data["y"]], ) .interactive() ) st.altair_chart(chart, theme="streamlit", use_container_width=True) # Chat Input if prompt := st.chat_input("Ask about clinical trials..."): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): with st.spinner("Analyzing clinical trials..."): try: # Clear previous inline chart data to avoid stale charts if "inline_chart_data" in st.session_state: del st.session_state["inline_chart_data"] # Construct chat history for the agent context chat_history = [] for msg in st.session_state.messages[:-1]: if msg["role"] == "user": chat_history.append(HumanMessage(content=msg["content"])) else: chat_history.append(AIMessage(content=msg["content"])) # Invoke Agent response = agent_executor.invoke( {"input": prompt, "chat_history": chat_history} ) output = response["output"] st.markdown(output) # Check for inline chart data (set by tools) chart_data = None if "inline_chart_data" in st.session_state: chart_data = st.session_state["inline_chart_data"] st.caption(chart_data["title"]) if chart_data["type"] == "bar": # Use Altair for better charts chart = ( alt.Chart(pd.DataFrame(chart_data["data"])) .mark_bar() .encode( x=alt.X( chart_data["x"], sort="-y", axis=alt.Axis(labelLimit=200), ), y=alt.Y(chart_data["y"], title="Count"), tooltip=[chart_data["x"], chart_data["y"]], ) .interactive() ) st.altair_chart(chart, theme="streamlit", use_container_width=True) # Clean up session state del st.session_state["inline_chart_data"] # Save message with chart data if present msg_obj = {"role": "assistant", "content": output} if chart_data: msg_obj["chart_data"] = chart_data st.session_state.messages.append(msg_obj) except Exception as e: st.error(f"An error occurred: {e}") # --- PAGE 2: ANALYTICS DASHBOARD --- if page == "Analytics Dashboard": st.header("πŸ“Š Global Analytics") st.write( "Analyze trends across the entire clinical trial dataset." ) col1, col2 = st.columns([1, 3]) with col1: st.subheader("Configuration") group_by = st.selectbox( "Group By", ["Phase", "Status", "Sponsor", "Start Year", "Intervention", "Study Type"], index=2, key="dash_group_by", ) # Optional Filters st.markdown("---") st.markdown("**Filters (Optional)**") filter_phase = st.text_input("Phase (e.g., Phase 2)", key="dash_phase") filter_sponsor = st.text_input("Sponsor (e.g., Pfizer)", key="dash_sponsor") st.button( "Generate Analytics", type="primary", on_click=generate_dashboard_analytics ) with col2: # Always render if data exists in session state if "dashboard_data" in st.session_state: c_data = st.session_state["dashboard_data"] st.subheader(c_data["title"]) # Altair Chart Rendering if ( c_data["x"] == "start_year" or group_by == "Start Year" ): # Check both key and UI selection # Line chart for years chart = ( alt.Chart(pd.DataFrame(c_data["data"])) .mark_line(point=True) .encode( x=alt.X( c_data["x"], axis=alt.Axis(format="d"), title="Year" ), # 'd' for integer year y=alt.Y(c_data["y"], title="Count"), tooltip=[c_data["x"], c_data["y"]], ) .interactive() ) else: # Bar chart for others chart = ( alt.Chart(pd.DataFrame(c_data["data"])) .mark_bar() .encode( x=alt.X( c_data["x"], sort="-y", axis=alt.Axis(labelLimit=200), ), y=alt.Y(c_data["y"], title="Count"), tooltip=[c_data["x"], c_data["y"]], ) .interactive() ) st.altair_chart(chart, theme="streamlit", use_container_width=True) # Show raw table with st.expander("View Source Data"): st.dataframe(pd.DataFrame(c_data["data"])) # --- PAGE 3: KNOWLEDGE GRAPH --- if page == "Knowledge Graph": st.header("πŸ•ΈοΈ Interactive Knowledge Graph") st.write("Visualize connections between Studies, Sponsors, and Conditions.") col_g1, col_g2 = st.columns([1, 3]) with col_g1: st.subheader("Graph Settings") graph_query = st.text_input("Search Topic", value="Cancer") limit = st.slider("Max Nodes", 10, 100, 50) if st.button("Build Graph"): with st.spinner("Fetching data and building graph..."): # Use retriever to get relevant nodes retriever = index.as_retriever(similarity_top_k=limit) nodes = retriever.retrieve(graph_query) data = [n.metadata for n in nodes] # Build Graph g_nodes, g_edges, g_config = build_graph(data) st.session_state["graph_data"] = { "nodes": g_nodes, "edges": g_edges, "config": g_config, } with col_g2: if "graph_data" in st.session_state: g_data = st.session_state["graph_data"] st.success( f"Graph built with {len(g_data['nodes'])} nodes and {len(g_data['edges'])} edges." ) agraph( nodes=g_data["nodes"], edges=g_data["edges"], config=g_data["config"] ) else: st.info("Enter a topic and click 'Build Graph' to visualize connections.") # --- PAGE# --- Study Map Tab --- elif page == "Study Map": st.header("🌍 Global Clinical Trial Map") st.markdown("Visualize the geographic distribution of clinical trials.") # Sidebar Filters for Map st.sidebar.markdown("### πŸ—ΊοΈ Map Filters") map_region = st.sidebar.radio("Region", ["World", "USA"], index=0) map_phase = st.sidebar.multiselect( "Phase", ["PHASE1", "PHASE2", "PHASE3", "PHASE4"], default=["PHASE2", "PHASE3"] ) map_status = st.sidebar.selectbox( "Status", ["RECRUITING", "COMPLETED", "ACTIVE_NOT_RECRUITING"], index=0 ) map_sponsor = st.sidebar.text_input("Sponsor (Optional)", "") map_year = st.sidebar.number_input("Start Year (>=)", min_value=2000, value=2020) map_type = st.sidebar.selectbox( "Study Type", ["Interventional", "Observational", "All"], index=0 ) # Convert filters to arguments phase_str = ",".join(map_phase) if map_phase else None type_arg = map_type if map_type != "All" else None if st.button("Update Map"): with st.spinner("Aggregating geographic data..."): # Determine grouping based on Region group_by_field = "state" if map_region == "USA" else "country" # Call analytics logic directly summary = fetch_study_analytics_data( query="overall", group_by=group_by_field, phase=phase_str, status=map_status, sponsor=map_sponsor, start_year=map_year, study_type=type_arg, ) # Retrieve data from session state chart_data = st.session_state.get("inline_chart_data", {}) data_records = chart_data.get("data", []) if not data_records: st.warning("No data found for these filters.") st.session_state["map_data"] = None st.session_state["map_region"] = map_region # Store region too else: # Store in session state for persistence st.session_state["map_data"] = data_records st.session_state["map_region"] = map_region # Render Map (Outside Button Block) if st.session_state.get("map_data"): data_records = st.session_state["map_data"] region_mode = st.session_state.get("map_region", "World") df_map = pd.DataFrame(data_records) # Configure Map Center/Zoom if region_mode == "USA": m = folium.Map(location=[37.0902, -95.7129], zoom_start=4) coord_map = STATE_COORDINATES else: m = folium.Map(location=[20, 0], zoom_start=2) coord_map = COUNTRY_COORDINATES # Add CircleMarkers for _, row in df_map.iterrows(): loc_name = row["category"] count = row["count"] # Clean name if needed (strip trailing parenthesis) loc_clean = loc_name.rstrip(")") coords = coord_map.get(loc_clean) if coords: folium.CircleMarker( location=coords, radius=min(max(count / 5, 3), 20), # Adjust scale popup=f"{loc_clean}: {count} trials", color="blue" if region_mode == "USA" else "crimson", fill=True, fill_color="blue" if region_mode == "USA" else "crimson", ).add_to(m) st_folium(m, width=800, height=500) # Show data table st.subheader(f"{region_mode} Data") st.dataframe(df_map) # --- PAGE 4: RAW DATA --- if page == "Raw Data": st.header("πŸ“‚ Raw Data Explorer") st.write("View and filter the underlying dataset.") # Load a sample (top 100) to avoid performance issues. col_raw_1, col_raw_2 = st.columns([1, 1]) with col_raw_1: if st.button("Load Sample Data (Top 100)"): with st.spinner("Fetching data..."): retriever = index.as_retriever(similarity_top_k=100) nodes = retriever.retrieve("clinical trial") data = [n.metadata for n in nodes] df_raw = pd.DataFrame(data) # Format Year to remove commas (e.g., 2,023 -> 2023) if "start_year" in df_raw.columns: df_raw["start_year"] = ( pd.to_numeric(df_raw["start_year"], errors="coerce") .astype("Int64") .astype(str) .str.replace(",", "") ) # Store in session state to persist the table st.session_state["sample_data"] = df_raw with col_raw_2: # Download Full Dataset Logic if st.button("Prepare Full Download (CSV)"): with st.spinner("Fetching all records from database..."): try: # Access LanceDB directly for speed import lancedb db = lancedb.connect("./ct_gov_lancedb") tbl = db.open_table("clinical_trials") # Fetch all data df_full = tbl.to_pandas() # Handle metadata flattening if needed if "metadata" in df_full.columns: meta_df = pd.json_normalize(df_full["metadata"]) # Combine or just use metadata df_full = meta_df # Convert to CSV csv = df_full.to_csv(index=False).encode("utf-8") st.session_state["full_csv"] = csv st.success(f"Ready! Fetched {len(df_full)} records.") else: st.warning("No data found in database.") except Exception as e: st.error(f"Error fetching data: {e}") if "full_csv" in st.session_state: st.download_button( label="⬇️ Download Full CSV", data=st.session_state["full_csv"], file_name="clinical_trials_full.csv", mime="text/csv", ) # Display Sample Data Table (Full Width) if "sample_data" in st.session_state: st.markdown("### Sample Data (Top 100)") st.dataframe( st.session_state["sample_data"], column_config={ "nct_id": "NCT ID", "title": "Study Title", "start_year": st.column_config.TextColumn( "Start Year" ), # Force text to avoid commas "url": st.column_config.LinkColumn("URL"), }, use_container_width=True, hide_index=True, )