""" Data Ingestion Script for Clinical Trial Agent. This script fetches clinical trial data from the ClinicalTrials.gov API (v2), processes it into a rich text format, and ingests it into a local ChromaDB vector index using LlamaIndex and PubMedBERT embeddings. Features: - **Pagination**: Fetches data in batches using the API's pagination tokens. - **Robustness**: Implements retry logic for network errors. - **Efficiency**: Uses batch insertion and reuses the existing index. - **Progress Tracking**: Displays a progress bar using `tqdm`. """ import requests import re from datetime import datetime, timedelta from dotenv import load_dotenv import argparse import time from tqdm import tqdm import os import concurrent.futures # LlamaIndex Imports from llama_index.core import Document, VectorStoreIndex, StorageContext, Settings from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.vector_stores.lancedb import LanceDBVectorStore import lancedb # List of US States for extraction US_STATES = [ "Alabama", "Alaska", "Arizona", "Arkansas", "California", "Colorado", "Connecticut", "Delaware", "Florida", "Georgia", "Hawaii", "Idaho", "Illinois", "Indiana", "Iowa", "Kansas", "Kentucky", "Louisiana", "Maine", "Maryland", "Massachusetts", "Michigan", "Minnesota", "Mississippi", "Missouri", "Montana", "Nebraska", "Nevada", "New Hampshire", "New Jersey", "New Mexico", "New York", "North Carolina", "North Dakota", "Ohio", "Oklahoma", "Oregon", "Pennsylvania", "Rhode Island", "South Carolina", "South Dakota", "Tennessee", "Texas", "Utah", "Vermont", "Virginia", "Washington", "West Virginia", "Wisconsin", "Wyoming", "District of Columbia" ] load_dotenv() # Disable LLM for ingestion (we only need embeddings, not generation) Settings.llm = None def clean_text(text: str) -> str: """ Cleans raw text by removing HTML tags and normalizing whitespace. Args: text (str): The raw text string. Returns: str: The cleaned text. """ if not text: return "" # Remove HTML tags text = re.sub(r"<[^>]+>", "", text) # Remove multiple spaces/newlines and trim text = re.sub(r"\s+", " ", text).strip() return text def fetch_trials_generator( years: int = 5, max_studies: int = 1000, status: list = None, phases: list = None ): """ Yields batches of clinical trials from the ClinicalTrials.gov API. Handles pagination automatically and implements retry logic for API requests. Args: years (int): Number of years to look back for study start dates. max_studies (int): Maximum total number of studies to fetch (-1 for all). status (list): List of status strings to filter by (e.g., ["RECRUITING"]). phases (list): List of phase strings to filter by (e.g., ["PHASE2"]). Yields: list: A batch of study dictionaries (JSON objects). """ base_url = "https://clinicaltrials.gov/api/v2/studies" # Calculate start date for filtering start_date = (datetime.now() - timedelta(days=365 * years)).strftime("%Y-%m-%d") print("📡 Connecting to CT.gov API...") print(f"🔎 Fetching trials starting after: {start_date}") if status: print(f" Filters - Status: {status}") if phases: print(f" Filters - Phases: {phases}") fetched_count = 0 next_page_token = None # If max_studies is -1, fetch ALL studies (infinite limit) fetch_limit = float("inf") if max_studies == -1 else max_studies while fetched_count < fetch_limit: # Determine batch size (max 1000 per API limit) current_limit = 1000 if max_studies != -1: current_limit = min(1000, max_studies - fetched_count) # --- Query Construction --- # Build the query term using the API's syntax query_parts = [f"AREA[StartDate]RANGE[{start_date},MAX]"] if status: status_str = " OR ".join(status) query_parts.append(f"AREA[OverallStatus]({status_str})") if phases: phase_str = " OR ".join(phases) query_parts.append(f"AREA[Phase]({phase_str})") full_query = " AND ".join(query_parts) params = { "query.term": full_query, "pageSize": current_limit, # Request specific fields to minimize payload size "fields": ",".join( [ "protocolSection.identificationModule.nctId", "protocolSection.identificationModule.briefTitle", "protocolSection.identificationModule.officialTitle", "protocolSection.identificationModule.organization", "protocolSection.statusModule.overallStatus", "protocolSection.statusModule.startDateStruct", "protocolSection.statusModule.completionDateStruct", "protocolSection.designModule.phases", "protocolSection.designModule.studyType", "protocolSection.eligibilityModule.eligibilityCriteria", "protocolSection.eligibilityModule.sex", "protocolSection.eligibilityModule.stdAges", "protocolSection.descriptionModule.briefSummary", "protocolSection.conditionsModule.conditions", "protocolSection.outcomesModule.primaryOutcomes", "protocolSection.contactsLocationsModule.locations", "protocolSection.outcomesModule.primaryOutcomes", "protocolSection.contactsLocationsModule.locations", "protocolSection.armsInterventionsModule", "protocolSection.sponsorCollaboratorsModule.leadSponsor", ] ), } if next_page_token: params["pageToken"] = next_page_token # --- Retry Logic --- retries = 3 for attempt in range(retries): try: response = requests.get(base_url, params=params, timeout=30) if response.status_code == 200: data = response.json() studies = data.get("studies", []) if not studies: return # Stop generator if no studies returned yield studies fetched_count += len(studies) next_page_token = data.get("nextPageToken") if not next_page_token: return # Stop generator if no more pages break # Success, exit retry loop else: print(f"❌ API Error: {response.status_code} - {response.text}") if attempt < retries - 1: time.sleep(2) else: return # Stop generator on persistent error except Exception as e: print(f"❌ Request Error (Attempt {attempt+1}/{retries}): {e}") if attempt < retries - 1: time.sleep(2) else: return # Stop generator def process_study(study): """ Processes a single study dictionary into a LlamaIndex Document. This function is designed to be run in parallel. """ try: # Extract Modules protocol = study.get("protocolSection", {}) identification = protocol.get("identificationModule", {}) status_module = protocol.get("statusModule", {}) design = protocol.get("designModule", {}) eligibility = protocol.get("eligibilityModule", {}) description = protocol.get("descriptionModule", {}) conditions_module = protocol.get("conditionsModule", {}) outcomes_module = protocol.get("outcomesModule", {}) arms_interventions_module = protocol.get("armsInterventionsModule", {}) outcomes_module = protocol.get("outcomesModule", {}) arms_interventions_module = protocol.get("armsInterventionsModule", {}) locations_module = protocol.get("contactsLocationsModule", {}) sponsor_module = protocol.get("sponsorCollaboratorsModule", {}) # Extract Fields nct_id = identification.get("nctId", "N/A") title = identification.get("briefTitle", "N/A") official_title = identification.get("officialTitle", "N/A") official_title = identification.get("officialTitle", "N/A") org = identification.get("organization", {}).get("fullName", "N/A") sponsor_name = sponsor_module.get("leadSponsor", {}).get("name", "N/A") summary = clean_text(description.get("briefSummary", "N/A")) overall_status = status_module.get("overallStatus", "N/A") start_date = status_module.get("startDateStruct", {}).get("date", "N/A") completion_date = status_module.get("completionDateStruct", {}).get( "date", "N/A" ) phases = ", ".join(design.get("phases", [])) study_type = design.get("studyType", "N/A") criteria = clean_text(eligibility.get("eligibilityCriteria", "N/A")) gender = eligibility.get("sex", "N/A") ages = ", ".join(eligibility.get("stdAges", [])) conditions = ", ".join(conditions_module.get("conditions", [])) interventions = [] for interv in arms_interventions_module.get("interventions", []): name = interv.get("name", "") type_ = interv.get("type", "") interventions.append(f"{type_}: {name}") interventions_str = "; ".join(interventions) primary_outcomes = [] for outcome in outcomes_module.get("primaryOutcomes", []): measure = outcome.get("measure", "") desc = outcome.get("description", "") primary_outcomes.append(f"- {measure}: {desc}") outcomes_str = clean_text("\n".join(primary_outcomes)) locations = [] for loc in locations_module.get("locations", []): facility = loc.get("facility", "N/A") city = loc.get("city", "") country = loc.get("country", "") locations.append(f"{facility} ({city}, {country})") locations_str = "; ".join(locations[:5]) # Limit to 5 locations to save space # Extract State (First match) state = "Unknown" # Check locations for US States for loc_str in locations: if "United States" in loc_str: for s in US_STATES: if s in loc_str: state = s break if state != "Unknown": break # Construct Rich Page Content with Markdown Headers # This text is what gets embedded and searched page_content = ( f"# {title}\n" f"**NCT ID:** {nct_id}\n" f"**Official Title:** {official_title}\n" f"**Sponsor:** {sponsor_name}\n" f"**Organization:** {org}\n" f"**Status:** {overall_status}\n" f"**Phase:** {phases}\n" f"**Study Type:** {study_type}\n" f"**Start Date:** {start_date}\n" f"**Completion Date:** {completion_date}\n\n" f"## Summary\n{summary}\n\n" f"## Conditions\n{conditions}\n\n" f"## Interventions\n{interventions_str}\n\n" f"## Eligibility Criteria\n" f"**Gender:** {gender}\n" f"**Ages:** {ages}\n" f"**Criteria:**\n{criteria}\n\n" f"## Primary Outcomes\n{outcomes_str}\n\n" f"## Locations\n{locations_str}" ) # Metadata for filtering (Structured Data) metadata = { "nct_id": nct_id, "title": title, "org": org, "sponsor": sponsor_name, "status": overall_status, "phase": phases, "study_type": study_type, "start_year": (int(start_date.split("-")[0]) if start_date != "N/A" else 0), "condition": conditions, "intervention": interventions_str, "country": ( locations[0].split(",")[-1].strip() if locations else "Unknown" ), "state": state, } return Document(text=page_content, metadata=metadata, id_=nct_id) except Exception as e: print( f"âš ī¸ Error processing study {study.get('protocolSection', {}).get('identificationModule', {}).get('nctId', 'Unknown')}: {e}" ) return None def run_ingestion(): """ Main execution function for the ingestion script. Parses arguments, initializes the index, and runs the ingestion loop. """ parser = argparse.ArgumentParser(description="Ingest Clinical Trials data.") parser.add_argument( "--limit", type=int, default=-1, help="Number of studies to ingest. Set to -1 for ALL.", ) parser.add_argument( "--years", type=int, default=10, help="Number of years to look back." ) parser.add_argument( "--status", type=str, default="COMPLETED", help="Comma-separated list of statuses (e.g., COMPLETED,RECRUITING).", ) parser.add_argument( "--phases", type=str, default="PHASE1,PHASE2,PHASE3,PHASE4", help="Comma-separated list of phases (e.g., PHASE2,PHASE3).", ) args = parser.parse_args() status_list = args.status.split(",") if args.status else [] phase_list = args.phases.split(",") if args.phases else [] print(f"âš™ī¸ Configuration: Limit={args.limit}, Years={args.years}") print(f" Status Filter: {status_list}") print(f" Phase Filter: {phase_list}") # --- INITIALIZE LLAMAINDEX COMPONENTS --- print("🧠 Initializing LlamaIndex Embeddings (PubMedBERT)...") embed_model = HuggingFaceEmbedding(model_name="pritamdeka/S-PubMedBert-MS-MARCO") # Initialize LanceDB (Persistent) print("🚀 Initializing LanceDB...") # Determine the project root directory (one level up from this script) script_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.dirname(script_dir) db_path = os.path.join(project_root, "ct_gov_lancedb") # Connect to LanceDB db = lancedb.connect(db_path) table_name = "clinical_trials" if table_name in db.table_names(): mode = "append" print(f"â„šī¸ Table '{table_name}' exists. Appending data.") else: mode = "create" print(f"â„šī¸ Table '{table_name}' does not exist. Creating new table.") # Initialize Vector Store vector_store = LanceDBVectorStore( uri=db_path, table_name=table_name, mode=mode, query_mode="hybrid" # Enable hybrid search support ) storage_context = StorageContext.from_defaults(vector_store=vector_store) # Initialize Index ONCE # We pass the storage context to link it to the vector store index = VectorStoreIndex.from_vector_store( vector_store, storage_context=storage_context, embed_model=embed_model ) total_ingested = 0 # Progress Bar pbar = tqdm( total=args.limit if args.limit > 0 else float("inf"), desc="Ingesting Studies", unit="study", ) # --- INGESTION LOOP --- # Use ProcessPoolExecutor for parallel processing of study data with concurrent.futures.ProcessPoolExecutor() as executor: for batch_studies in fetch_trials_generator( years=args.years, max_studies=args.limit, status=status_list, phases=phase_list, ): # Parallelize the processing of the batch # map returns an iterator, so we convert to list to trigger execution documents_iter = executor.map(process_study, batch_studies) # Filter out None results (errors) documents = [doc for doc in documents_iter if doc is not None] if documents: # Overwrite Logic: # To avoid duplicates, we delete existing records with the same NCT IDs. doc_ids = [doc.id_ for doc in documents] try: # LanceDB supports deletion via SQL-like filter # We construct a filter string: "nct_id IN ('NCT123', 'NCT456')" ids_str = ", ".join([f"'{id}'" for id in doc_ids]) if ids_str: tbl = db.open_table("clinical_trials") tbl.delete(f"nct_id IN ({ids_str})") except Exception as e: # Ignore if table doesn't exist yet pass # Efficient Batch Insertion # We convert documents to nodes and insert them into the index. # This handles embedding generation automatically. parser = Settings.node_parser nodes = parser.get_nodes_from_documents(documents) index.insert_nodes(nodes) total_ingested += len(documents) pbar.update(len(documents)) pbar.close() print(f"🎉 Ingestion Complete! Total studies in DB: {total_ingested}") if __name__ == "__main__": run_ingestion()