Spaces:
Sleeping
Sleeping
| """ | |
| LangChain Tools for the Clinical Trial Agent. | |
| This module defines the tools that the agent can use to interact with the clinical trial data. | |
| Tools include: | |
| 1. **search_trials**: Semantic search with optional strict filtering. | |
| 2. **find_similar_studies**: Finding studies semantically similar to a given text. | |
| 3. **get_study_analytics**: Aggregating data for trends and insights (with inline charts). | |
| """ | |
| import pandas as pd | |
| import streamlit as st | |
| from typing import Optional | |
| from langchain.tools import tool as langchain_tool | |
| from llama_index.core.vector_stores import ( | |
| MetadataFilter, | |
| MetadataFilters, | |
| FilterOperator, | |
| ) | |
| from llama_index.core import Settings | |
| from llama_index.core.postprocessor import MetadataReplacementPostProcessor | |
| from llama_index.core.postprocessor import SentenceTransformerRerank | |
| from llama_index.core.query_engine import SubQuestionQueryEngine | |
| from llama_index.core.tools import QueryEngineTool, ToolMetadata | |
| from modules.utils import ( | |
| load_index, | |
| normalize_sponsor, | |
| get_sponsor_variations, | |
| get_hybrid_retriever, | |
| get_reranker, # Import cached reranker | |
| ) | |
| import re | |
| import traceback | |
| # --- Tools --- | |
| def expand_query(query: str) -> str: | |
| """Expands a search query with synonyms using the LLM.""" | |
| if not query or len(query.split()) > 10: # Skip expansion for long queries | |
| return query | |
| # Skip expansion if it looks like an NCT ID | |
| if re.search(r"NCT\d+", query, re.IGNORECASE): | |
| return query | |
| prompt = ( | |
| f"You are a helpful medical assistant. " | |
| f"Expand the following search query with relevant medical synonyms and acronyms. " | |
| f"Return ONLY the expanded query string combined with OR operators. " | |
| f"Do not add any explanation.\n\n" | |
| f"Query: {query}\n" | |
| f"Expanded Query:" | |
| ) | |
| try: | |
| # Use the global Settings.llm | |
| if not Settings.llm: | |
| # Fallback if not initialized (though load_index does it) | |
| from modules.utils import setup_llama_index | |
| setup_llama_index() | |
| response = Settings.llm.complete(prompt) | |
| expanded = response.text.strip() | |
| # Clean up if LLM is chatty | |
| if "Expanded Query:" in expanded: | |
| expanded = expanded.split("Expanded Query:")[-1].strip() | |
| if not expanded: | |
| print(f"⚠️ Expansion returned empty. Using original query.") | |
| return query | |
| print(f"✨ Expanded Query: '{query}' -> '{expanded}'") | |
| return expanded | |
| except Exception as e: | |
| print(f"⚠️ Query expansion failed: {e}") | |
| return query | |
| def search_trials( | |
| query: str = None, | |
| status: str = None, | |
| phase: str = None, | |
| sponsor: str = None, | |
| intervention: str = None, | |
| year: int = None, | |
| ): | |
| """ | |
| Searches for clinical trials using semantic search with robust filtering. | |
| Args: | |
| query (str, optional): The natural language search query. | |
| status (str, optional): Filter by recruitment status. | |
| phase (str, optional): Filter by trial phase. | |
| sponsor (str, optional): Filter by sponsor name. | |
| intervention (str, optional): Filter by intervention/drug name. | |
| year (int, optional): Filter for studies starting on or after this year. | |
| Returns: | |
| str: A structured list of relevant studies. | |
| """ | |
| index = load_index() | |
| # Constants | |
| TOP_K_STRICT = 200 # Reduced from 500 for performance | |
| # --- Query Construction --- | |
| if not query: | |
| parts = [p for p in [sponsor, intervention, phase, status] if p] | |
| query = " ".join(parts) if parts else "clinical trial" | |
| else: | |
| # Inject context for vector search | |
| if sponsor and normalize_sponsor(sponsor).lower() not in query.lower(): | |
| query = f"{normalize_sponsor(sponsor)} {query}" | |
| if intervention and intervention.lower() not in query.lower(): | |
| query = f"{intervention} {query}" | |
| query = expand_query(query) | |
| print(f"🔍 Tool Called: search_trials(query='{query}', sponsor='{sponsor}')") | |
| # --- Strategy 1: Strict Pre-Retrieval Filtering (High Precision) --- | |
| # Filter by Sponsor/Status/Year at the database level first. | |
| pre_filters = [] | |
| # NCT ID Match | |
| nct_match = re.search(r"NCT\d+", query, re.IGNORECASE) | |
| if nct_match: | |
| nct_id = nct_match.group(0).upper() | |
| pre_filters.append(MetadataFilter(key="nct_id", value=nct_id, operator=FilterOperator.EQ)) | |
| if status: | |
| pre_filters.append(MetadataFilter(key="status", value=status.upper(), operator=FilterOperator.EQ)) | |
| if year: | |
| pre_filters.append(MetadataFilter(key="start_year", value=year, operator=FilterOperator.GTE)) | |
| # Sponsor Pre-Filter | |
| if sponsor: | |
| from modules.utils import get_sponsor_variations | |
| variations = get_sponsor_variations(sponsor) | |
| if variations: | |
| print(f"🎯 Applying strict pre-filter for sponsor '{sponsor}' ({len(variations)} variants)") | |
| # Use 'sponsor' field which is the Lead Sponsor | |
| pre_filters.append(MetadataFilter(key="sponsor", value=variations, operator=FilterOperator.IN)) | |
| else: | |
| print(f"⚠️ No strict mapping for sponsor '{sponsor}'. Will rely on fuzzy post-filtering.") | |
| metadata_filters = MetadataFilters(filters=pre_filters) if pre_filters else None | |
| # Post-processors (Reranking) | |
| # Use cached reranker | |
| reranker = get_reranker(top_n=50) | |
| # --- HYBRID SEARCH IMPLEMENTATION --- | |
| # Combine Vector + BM25 using get_hybrid_retriever | |
| try: | |
| retriever = get_hybrid_retriever(index, similarity_top_k=TOP_K_STRICT, filters=metadata_filters) | |
| nodes = retriever.retrieve(query) | |
| # (QueryFusionRetriever returns nodes, but we want to rerank them) | |
| if nodes: | |
| from llama_index.core.schema import QueryBundle | |
| nodes = reranker.postprocess_nodes(nodes, query_bundle=QueryBundle(query_str=query)) | |
| except Exception as e: | |
| print(f"⚠️ Hybrid search failed: {e}. Falling back to standard vector search.") | |
| traceback.print_exc() | |
| query_engine = index.as_query_engine( | |
| similarity_top_k=TOP_K_STRICT, | |
| filters=metadata_filters, | |
| node_postprocessors=[reranker] | |
| ) | |
| response = query_engine.query(query) | |
| nodes = response.source_nodes | |
| # --- Strict Metadata Filtering (Post-Fusion) --- | |
| # BM25 results might not respect the vector filters, so filter them out. | |
| final_nodes = [] | |
| for node in nodes: | |
| meta = node.metadata | |
| keep = True | |
| # Re-apply filters to ensure BM25 results are valid | |
| if status and meta.get("status", "").upper() != status.upper(): | |
| keep = False | |
| if year: | |
| try: | |
| if int(meta.get("start_year", 0)) < year: | |
| keep = False | |
| except: | |
| pass | |
| if sponsor: | |
| # Strict logic for sponsor in pre-filters is ignored by BM25. | |
| # Check if the sponsor matches one of the variations OR fuzzy match | |
| # If strict variations exist, enforce them. | |
| variations = get_sponsor_variations(sponsor) | |
| node_sponsor = meta.get("sponsor", "") | |
| # Fallback to org if sponsor is missing (legacy data) | |
| if not node_sponsor: | |
| node_sponsor = meta.get("org", "") | |
| if variations: | |
| if node_sponsor not in variations: | |
| keep = False | |
| else: | |
| # Fuzzy fallback | |
| if normalize_sponsor(sponsor).lower() not in normalize_sponsor(node_sponsor).lower(): | |
| keep = False | |
| if keep: | |
| final_nodes.append(node) | |
| nodes = final_nodes | |
| # --- Strict Keyword Filtering --- | |
| # BM25 handles keyword relevance naturally, so rely on the Hybrid Search + Reranker | |
| # rather than applying an aggressive substring check here. | |
| # Update response object structure to match expected format if we used retriever | |
| class MockResponse: | |
| def __init__(self, nodes): | |
| self.source_nodes = nodes | |
| response = MockResponse(nodes) | |
| # --- Strategy 2: Hybrid Search (Fallback) --- | |
| # Hybrid Search is enabled by default. | |
| # Strict filters are handled in post-processing above. | |
| # --- Formatting Output --- | |
| if not response.source_nodes: | |
| return "No matching studies found. Try broadening your search terms or filters." | |
| # Filter by Relevance Score for display | |
| MIN_SCORE = 1.5 | |
| relevant_nodes = [node for node in response.source_nodes if node.score > MIN_SCORE] | |
| # If strict filtering removes too much, show at least top 3 to be helpful | |
| if len(relevant_nodes) < 3 and len(response.source_nodes) > 0: | |
| relevant_nodes = response.source_nodes[:3] | |
| display_limit = 20 | |
| display_nodes = relevant_nodes[:display_limit] | |
| results = [] | |
| for node in display_nodes: | |
| meta = node.metadata | |
| entry = ( | |
| f"**{meta.get('title', 'Untitled')}**\n" | |
| f" - ID: {meta.get('nct_id')}\n" | |
| f" - Phase: {meta.get('phase', 'N/A')}\n" | |
| f" - Status: {meta.get('status', 'N/A')}\n" | |
| f" - Sponsor: {meta.get('sponsor', meta.get('org', 'Unknown'))}\n" | |
| f" - Relevance: {node.score:.2f}" | |
| ) | |
| results.append(entry) | |
| return f"Found {len(results)} relevant studies:\n\n" + "\n\n".join(results) | |
| def find_similar_studies(query: str): | |
| """ | |
| Finds studies semantically similar to a given query or study description. | |
| This tool is useful for "more like this" functionality. It relies purely | |
| on vector similarity without strict metadata filtering. | |
| Args: | |
| query (str): The text to match against (e.g., a study title or description). | |
| Returns: | |
| str: A string containing the top 5 similar studies with their titles and summaries. | |
| """ | |
| index = load_index() | |
| # 1. Check if query is an NCT ID | |
| nct_match = re.search(r"NCT\d+", query, re.IGNORECASE) | |
| target_nct = None | |
| search_text = query | |
| if nct_match: | |
| target_nct = nct_match.group(0).upper() | |
| print(f"🎯 Detected NCT ID for similarity: {target_nct}") | |
| # Fetch the study content to use as the semantic query | |
| # Use the vector store directly to get the text | |
| retriever = index.as_retriever( | |
| filters=MetadataFilters( | |
| filters=[MetadataFilter(key="nct_id", value=target_nct, operator=FilterOperator.EQ)] | |
| ), | |
| similarity_top_k=1 | |
| ) | |
| nodes = retriever.retrieve(target_nct) | |
| if nodes: | |
| # Use the study's text (Title + Summary) as the query | |
| search_text = nodes[0].text | |
| print(f"✅ Found study content. Using {len(search_text)} chars for semantic search.") | |
| else: | |
| print(f"⚠️ Study {target_nct} not found. Falling back to text search.") | |
| # 2. Perform Semantic Search | |
| # Fetch more candidates (10) to allow for filtering | |
| retriever = index.as_retriever(similarity_top_k=10) | |
| nodes = retriever.retrieve(search_text) | |
| results = [] | |
| count = 0 | |
| for node in nodes: | |
| # 3. Self-Exclusion | |
| if target_nct and node.metadata.get("nct_id") == target_nct: | |
| continue | |
| # Deduplication (if multiple chunks of same study appear) | |
| if any(r["nct_id"] == node.metadata.get("nct_id") for r in results): | |
| continue | |
| results.append({ | |
| "nct_id": node.metadata.get("nct_id"), | |
| "text": f"Study: {node.metadata['title']} (NCT: {node.metadata.get('nct_id')})\nScore: {node.score:.4f}\nSummary: {node.text[:200]}..." | |
| }) | |
| count += 1 | |
| if count >= 5: # Limit to top 5 unique results | |
| break | |
| if not results: | |
| return "No similar studies found." | |
| return "\n\n".join([r["text"] for r in results]) | |
| def fetch_study_analytics_data( | |
| query: str, | |
| group_by: str, | |
| phase: Optional[str] = None, | |
| status: Optional[str] = None, | |
| sponsor: Optional[str] = None, | |
| intervention: Optional[str] = None, | |
| start_year: Optional[int] = None, | |
| study_type: Optional[str] = None, | |
| ) -> str: | |
| """ | |
| Underlying logic for fetching and aggregating clinical trial data. | |
| See get_study_analytics for full docstring. | |
| """ | |
| index = load_index() | |
| # 1. Retrieve Data | |
| if query.lower() == "overall": | |
| try: | |
| # Connect to LanceDB directly for speed | |
| import lancedb | |
| db = lancedb.connect("./ct_gov_lancedb") | |
| tbl = db.open_table("clinical_trials") | |
| # Fetch all data as pandas DataFrame | |
| df = tbl.to_pandas() | |
| # LlamaIndex stores metadata in a 'metadata' column (usually as a dict/struct) | |
| # We need to flatten it to get columns like 'status', 'phase', etc. | |
| if "metadata" in df.columns: | |
| # Check if it's already a dict or needs parsing | |
| # LanceDB to_pandas() converts struct to dict | |
| meta_df = pd.json_normalize(df["metadata"]) | |
| df = meta_df | |
| # If columns are already flat (depending on schema evolution), we are good. | |
| # But usually it's nested. | |
| except Exception as e: | |
| return f"Error fetching full dataset: {e}" | |
| else: | |
| filters = [] | |
| if status: | |
| filters.append( | |
| MetadataFilter( | |
| key="status", value=status.upper(), operator=FilterOperator.EQ | |
| ) | |
| ) | |
| if phase and "," not in phase: | |
| pass | |
| if sponsor: | |
| # Use the helper to get all variations (e.g. "Pfizer" -> ["Pfizer", "Pfizer Inc."]) | |
| sponsor_variations = get_sponsor_variations(sponsor) | |
| if sponsor_variations: | |
| print(f"🎯 Using strict pre-filter for sponsor '{sponsor}': {len(sponsor_variations)} variations found.") | |
| filters.append( | |
| MetadataFilter( | |
| key="sponsor", value=sponsor_variations, operator=FilterOperator.IN | |
| ) | |
| ) | |
| metadata_filters = MetadataFilters(filters=filters) if filters else None | |
| search_query = query | |
| if sponsor and sponsor.lower() not in query.lower(): | |
| search_query = f"{sponsor} {query}" | |
| # Use hybrid search for better recall | |
| retriever = index.as_retriever( | |
| similarity_top_k=5000, | |
| filters=metadata_filters, | |
| vector_store_query_mode="hybrid" | |
| ) | |
| nodes = retriever.retrieve(search_query) | |
| # --- Strict Keyword Filtering --- | |
| # Strictly check if the query appears in Title or Conditions to ensure accurate counting. | |
| # EXCEPTION: If the query matches the requested sponsor, we also check the 'org' field. | |
| if query.lower() != "overall": | |
| q_term = query.lower() | |
| # Check if the query is essentially the sponsor name | |
| is_sponsor_query = False | |
| # Check if the query itself normalizes to a known sponsor | |
| query_normalized = normalize_sponsor(query) | |
| if query_normalized and query_normalized != query: | |
| # If normalization changed it (or found a mapping), it's likely a sponsor | |
| is_sponsor_query = True | |
| if sponsor: | |
| # Normalize both to see if they refer to the same entity | |
| norm_query = normalize_sponsor(query) | |
| norm_sponsor = normalize_sponsor(sponsor) | |
| if norm_query and norm_sponsor and norm_query.lower() == norm_sponsor.lower(): | |
| is_sponsor_query = True | |
| elif sponsor.lower() in query.lower() or query.lower() in sponsor.lower(): | |
| is_sponsor_query = True | |
| filtered_nodes = [] | |
| for node in nodes: | |
| meta = node.metadata | |
| title = meta.get("title", "").lower() | |
| conditions = meta.get("condition", "").lower() # Note: key is 'condition' in DB | |
| org = meta.get("org", "").lower() | |
| sponsor_val = meta.get("sponsor", "").lower() | |
| # If it's a sponsor query, we allow matches on the Organization OR Sponsor field | |
| # AND we check if the normalized values match (handling aliases like J&J -> Janssen) | |
| match = False | |
| if q_term in title or q_term in conditions: | |
| match = True | |
| elif is_sponsor_query: | |
| # Check raw match | |
| if q_term in org or q_term in sponsor_val: | |
| match = True | |
| else: | |
| # Check normalized match | |
| norm_org = normalize_sponsor(org) | |
| norm_val = normalize_sponsor(sponsor_val) | |
| # Compare against the normalized query (which is the sponsor in this case) | |
| target_norm = norm_sponsor if sponsor else query_normalized | |
| if norm_org and target_norm and norm_org.lower() == target_norm.lower(): | |
| match = True | |
| elif norm_val and target_norm and norm_val.lower() == target_norm.lower(): | |
| match = True | |
| if match: | |
| filtered_nodes.append(node) | |
| print(f"📉 Strict Filter: {len(nodes)} -> {len(filtered_nodes)} nodes for '{query}'") | |
| nodes = filtered_nodes | |
| data = [node.metadata for node in nodes] | |
| df = pd.DataFrame(data) | |
| if "nct_id" in df.columns: | |
| df = df.drop_duplicates(subset="nct_id") | |
| if df.empty: | |
| return "No studies found for analytics." | |
| # --- APPLY FILTERS (Pandas) --- | |
| if phase: | |
| target_phases = [p.strip().upper().replace(" ", "") for p in phase.split(",")] | |
| df["phase_upper"] = df["phase"].astype(str).str.upper().str.replace(" ", "") | |
| mask = df["phase_upper"].apply(lambda x: any(tp in x for tp in target_phases)) | |
| df = df[mask] | |
| if status: | |
| df = df[df["status"].str.upper() == status.upper()] | |
| if sponsor: | |
| target_sponsor = normalize_sponsor(sponsor).lower() | |
| # Use 'sponsor' column if it exists, otherwise fallback to 'org' | |
| if "sponsor" in df.columns: | |
| df["sponsor_check"] = df["sponsor"].fillna(df["org"]).astype(str).apply(normalize_sponsor).str.lower() | |
| else: | |
| df["sponsor_check"] = df["org"].astype(str).apply(normalize_sponsor).str.lower() | |
| df = df[df["sponsor_check"].str.contains(target_sponsor, regex=False)] | |
| if intervention: | |
| target_intervention = intervention.lower() | |
| df["intervention_lower"] = df["intervention"].astype(str).str.lower() | |
| df = df[df["intervention_lower"].str.contains(target_intervention, regex=False)] | |
| if start_year: | |
| df["start_year"] = pd.to_numeric(df["start_year"], errors="coerce").fillna(0) | |
| df = df[df["start_year"] >= start_year] | |
| if study_type: | |
| df = df[df["study_type"].str.upper() == study_type.upper()] | |
| if df.empty: | |
| return "No studies found after applying filters." | |
| key_map = { | |
| "phase": "phase", | |
| "status": "status", | |
| "sponsor": "sponsor" if "sponsor" in df.columns else "org", | |
| "start_year": "start_year", | |
| "condition": "condition", | |
| "intervention": "intervention", | |
| "study_type": "study_type", | |
| "country": "country", | |
| "state": "state", | |
| } | |
| if group_by not in key_map: | |
| return f"Invalid group_by field: {group_by}. Valid options: phase, status, sponsor, start_year, condition, intervention, study_type, country, state" | |
| col = key_map[group_by] | |
| if col == "start_year": | |
| df[col] = pd.to_numeric(df[col], errors="coerce") | |
| counts = df[col].value_counts().sort_index() | |
| elif col == "condition": | |
| counts = df[col].astype(str).str.split(", ").explode().value_counts().head(10) | |
| elif col == "intervention": | |
| all_interventions = [] | |
| for interventions in df[col].dropna(): | |
| parts = [i.strip() for i in interventions.split(";") if i.strip()] | |
| all_interventions.extend(parts) | |
| counts = pd.Series(all_interventions).value_counts().head(10) | |
| else: | |
| counts = df[col].value_counts().head(10) | |
| summary = counts.to_string() | |
| chart_df = counts.reset_index() | |
| chart_df.columns = ["category", "count"] | |
| chart_data = { | |
| "type": "bar", | |
| "title": f"Studies by {group_by.capitalize()}", | |
| "data": chart_df.to_dict("records"), | |
| "x": "category", | |
| "y": "count", | |
| } | |
| if "inline_chart_data" not in st.session_state: | |
| st.session_state["inline_chart_data"] = chart_data | |
| else: | |
| st.session_state["inline_chart_data"] = chart_data | |
| return f"Found {len(df)} studies. Top counts:\n{summary}\n\n(Chart generated in UI)" | |
| def get_study_analytics( | |
| query: str, | |
| group_by: str, | |
| phase: Optional[str] = None, | |
| status: Optional[str] = None, | |
| sponsor: Optional[str] = None, | |
| intervention: Optional[str] = None, | |
| start_year: Optional[int] = None, | |
| study_type: Optional[str] = None, | |
| ): | |
| """ | |
| Aggregates clinical trial data based on a search query and groups by a specific field. | |
| This tool performs the following steps: | |
| 1. Retrieves a large number of relevant studies (up to 500). | |
| 2. Applies strict filters (Phase, Status, Sponsor) in memory (Pandas). | |
| 3. Groups the data by the requested field (e.g., Sponsor). | |
| 4. Generates a summary string for the LLM. | |
| 5. **Side Effect**: Injects chart data into `st.session_state` to trigger an inline chart in the UI. | |
| Args: | |
| query (str): The search query to filter studies (e.g., "cancer"). | |
| group_by (str): The field to group by. Options: "phase", "status", "sponsor", "start_year", "condition", "intervention", "country", "state". | |
| phase (Optional[str]): Optional filter for phase (e.g., "PHASE2"). | |
| status (Optional[str]): Optional filter for status (e.g., "RECRUITING"). | |
| sponsor (Optional[str]): Optional filter for sponsor (e.g., "Pfizer"). | |
| intervention (Optional[str]): Optional filter for intervention (e.g., "Keytruda"). | |
| Returns: | |
| str: A summary string of the top counts and a note that a chart has been generated. | |
| """ | |
| return fetch_study_analytics_data( | |
| query=query, | |
| group_by=group_by, | |
| phase=phase, | |
| status=status, | |
| sponsor=sponsor, | |
| intervention=intervention, | |
| start_year=start_year, | |
| study_type=study_type, | |
| ) | |
| def compare_studies(query: str): | |
| """ | |
| Compares multiple studies or answers complex multi-part questions using query decomposition. | |
| Use this tool when the user asks to "compare", "contrast", or analyze differences/similarities | |
| between specific studies, sponsors, or phases. It breaks down the question into sub-questions. | |
| Args: | |
| query (str): The complex comparison query (e.g., "Compare the primary outcomes of Keytruda vs Opdivo"). | |
| Returns: | |
| str: A detailed response synthesizing the answers to sub-questions. | |
| """ | |
| index = load_index() | |
| # Create a base query engine for the sub-questions | |
| # Increase top_k and add re-ranking to improve recall for comparison queries | |
| reranker = SentenceTransformerRerank(model="cross-encoder/ms-marco-MiniLM-L-12-v2", top_n=10) | |
| base_engine = index.as_query_engine( | |
| similarity_top_k=50, | |
| node_postprocessors=[reranker] | |
| ) | |
| # Wrap it in a QueryEngineTool | |
| query_tool = QueryEngineTool( | |
| query_engine=base_engine, | |
| metadata=ToolMetadata( | |
| name="clinical_trials_db", | |
| description="Vector database of clinical trial protocols, results, and metadata.", | |
| ), | |
| ) | |
| # Create the SubQuestionQueryEngine | |
| # Explicitly define the question generator to use the configured LLM (Gemini) | |
| # This avoids the default behavior which might try to import OpenAI modules | |
| from llama_index.core.question_gen import LLMQuestionGenerator | |
| from llama_index.core import Settings | |
| question_gen = LLMQuestionGenerator.from_defaults(llm=Settings.llm) | |
| query_engine = SubQuestionQueryEngine.from_defaults( | |
| query_engine_tools=[query_tool], | |
| question_gen=question_gen, | |
| use_async=True, | |
| ) | |
| try: | |
| response = query_engine.query(query) | |
| return str(response) + "\n\n(Note: This analysis is based on the most relevant studies retrieved from the database, not necessarily an exhaustive list.)" | |
| except Exception as e: | |
| return f"Error during comparison: {e}" | |
| def get_study_details(nct_id: str): | |
| """ | |
| Retrieves the full details of a specific clinical trial by its NCT ID. | |
| Use this tool when the user asks for specific information about a single study, | |
| such as "What are the inclusion criteria for NCT12345678?" or "Give me a summary of study NCT...". | |
| It returns the full text content of the study document, including criteria, outcomes, and contacts. | |
| Args: | |
| nct_id (str): The NCT ID of the study (e.g., "NCT01234567"). | |
| Returns: | |
| str: The full text content of the study, or a message if not found. | |
| """ | |
| index = load_index() | |
| # Clean the ID | |
| clean_id = nct_id.strip().upper() | |
| # Use a retriever with a strict metadata filter for the ID | |
| # Set top_k=20 to capture all chunks if the document was split | |
| filters = MetadataFilters( | |
| filters=[ | |
| MetadataFilter(key="nct_id", value=clean_id, operator=FilterOperator.EQ) | |
| ] | |
| ) | |
| retriever = index.as_retriever(similarity_top_k=20, filters=filters) | |
| nodes = retriever.retrieve(clean_id) | |
| if not nodes: | |
| return f"Study {clean_id} not found in the database." | |
| # Sort nodes by their position in the document to reconstruct full text | |
| # LlamaIndex nodes usually have 'start_char_idx' in metadata or relationships | |
| # Try to sort by node ID or just concatenate them | |
| # Simple concatenation (assuming retrieval order is roughly correct or sufficient) | |
| full_text = "\n\n".join([node.text for node in nodes]) | |
| return f"Details for {clean_id} (Combined {len(nodes)} parts):\n\n{full_text}" | |