Spaces:
Sleeping
Sleeping
File size: 17,439 Bytes
507be68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 |
"""
Data Ingestion Script for Clinical Trial Agent.
This script fetches clinical trial data from the ClinicalTrials.gov API (v2),
processes it into a rich text format, and ingests it into a local ChromaDB vector index
using LlamaIndex and PubMedBERT embeddings.
Features:
- **Pagination**: Fetches data in batches using the API's pagination tokens.
- **Robustness**: Implements retry logic for network errors.
- **Efficiency**: Uses batch insertion and reuses the existing index.
- **Progress Tracking**: Displays a progress bar using `tqdm`.
"""
import requests
import re
from datetime import datetime, timedelta
from dotenv import load_dotenv
import argparse
import time
from tqdm import tqdm
import os
import concurrent.futures
# LlamaIndex Imports
from llama_index.core import Document, VectorStoreIndex, StorageContext, Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.lancedb import LanceDBVectorStore
import lancedb
# List of US States for extraction
US_STATES = [
"Alabama", "Alaska", "Arizona", "Arkansas", "California", "Colorado", "Connecticut",
"Delaware", "Florida", "Georgia", "Hawaii", "Idaho", "Illinois", "Indiana", "Iowa",
"Kansas", "Kentucky", "Louisiana", "Maine", "Maryland", "Massachusetts", "Michigan",
"Minnesota", "Mississippi", "Missouri", "Montana", "Nebraska", "Nevada", "New Hampshire",
"New Jersey", "New Mexico", "New York", "North Carolina", "North Dakota", "Ohio",
"Oklahoma", "Oregon", "Pennsylvania", "Rhode Island", "South Carolina", "South Dakota",
"Tennessee", "Texas", "Utah", "Vermont", "Virginia", "Washington", "West Virginia",
"Wisconsin", "Wyoming", "District of Columbia"
]
load_dotenv()
# Disable LLM for ingestion (we only need embeddings, not generation)
Settings.llm = None
def clean_text(text: str) -> str:
"""
Cleans raw text by removing HTML tags and normalizing whitespace.
Args:
text (str): The raw text string.
Returns:
str: The cleaned text.
"""
if not text:
return ""
# Remove HTML tags
text = re.sub(r"<[^>]+>", "", text)
# Remove multiple spaces/newlines and trim
text = re.sub(r"\s+", " ", text).strip()
return text
def fetch_trials_generator(
years: int = 5, max_studies: int = 1000, status: list = None, phases: list = None
):
"""
Yields batches of clinical trials from the ClinicalTrials.gov API.
Handles pagination automatically and implements retry logic for API requests.
Args:
years (int): Number of years to look back for study start dates.
max_studies (int): Maximum total number of studies to fetch (-1 for all).
status (list): List of status strings to filter by (e.g., ["RECRUITING"]).
phases (list): List of phase strings to filter by (e.g., ["PHASE2"]).
Yields:
list: A batch of study dictionaries (JSON objects).
"""
base_url = "https://clinicaltrials.gov/api/v2/studies"
# Calculate start date for filtering
start_date = (datetime.now() - timedelta(days=365 * years)).strftime("%Y-%m-%d")
print("📡 Connecting to CT.gov API...")
print(f"🔎 Fetching trials starting after: {start_date}")
if status:
print(f" Filters - Status: {status}")
if phases:
print(f" Filters - Phases: {phases}")
fetched_count = 0
next_page_token = None
# If max_studies is -1, fetch ALL studies (infinite limit)
fetch_limit = float("inf") if max_studies == -1 else max_studies
while fetched_count < fetch_limit:
# Determine batch size (max 1000 per API limit)
current_limit = 1000
if max_studies != -1:
current_limit = min(1000, max_studies - fetched_count)
# --- Query Construction ---
# Build the query term using the API's syntax
query_parts = [f"AREA[StartDate]RANGE[{start_date},MAX]"]
if status:
status_str = " OR ".join(status)
query_parts.append(f"AREA[OverallStatus]({status_str})")
if phases:
phase_str = " OR ".join(phases)
query_parts.append(f"AREA[Phase]({phase_str})")
full_query = " AND ".join(query_parts)
params = {
"query.term": full_query,
"pageSize": current_limit,
# Request specific fields to minimize payload size
"fields": ",".join(
[
"protocolSection.identificationModule.nctId",
"protocolSection.identificationModule.briefTitle",
"protocolSection.identificationModule.officialTitle",
"protocolSection.identificationModule.organization",
"protocolSection.statusModule.overallStatus",
"protocolSection.statusModule.startDateStruct",
"protocolSection.statusModule.completionDateStruct",
"protocolSection.designModule.phases",
"protocolSection.designModule.studyType",
"protocolSection.eligibilityModule.eligibilityCriteria",
"protocolSection.eligibilityModule.sex",
"protocolSection.eligibilityModule.stdAges",
"protocolSection.descriptionModule.briefSummary",
"protocolSection.conditionsModule.conditions",
"protocolSection.outcomesModule.primaryOutcomes",
"protocolSection.contactsLocationsModule.locations",
"protocolSection.outcomesModule.primaryOutcomes",
"protocolSection.contactsLocationsModule.locations",
"protocolSection.armsInterventionsModule",
"protocolSection.sponsorCollaboratorsModule.leadSponsor",
]
),
}
if next_page_token:
params["pageToken"] = next_page_token
# --- Retry Logic ---
retries = 3
for attempt in range(retries):
try:
response = requests.get(base_url, params=params, timeout=30)
if response.status_code == 200:
data = response.json()
studies = data.get("studies", [])
if not studies:
return # Stop generator if no studies returned
yield studies
fetched_count += len(studies)
next_page_token = data.get("nextPageToken")
if not next_page_token:
return # Stop generator if no more pages
break # Success, exit retry loop
else:
print(f"❌ API Error: {response.status_code} - {response.text}")
if attempt < retries - 1:
time.sleep(2)
else:
return # Stop generator on persistent error
except Exception as e:
print(f"❌ Request Error (Attempt {attempt+1}/{retries}): {e}")
if attempt < retries - 1:
time.sleep(2)
else:
return # Stop generator
def process_study(study):
"""
Processes a single study dictionary into a LlamaIndex Document.
This function is designed to be run in parallel.
"""
try:
# Extract Modules
protocol = study.get("protocolSection", {})
identification = protocol.get("identificationModule", {})
status_module = protocol.get("statusModule", {})
design = protocol.get("designModule", {})
eligibility = protocol.get("eligibilityModule", {})
description = protocol.get("descriptionModule", {})
conditions_module = protocol.get("conditionsModule", {})
outcomes_module = protocol.get("outcomesModule", {})
arms_interventions_module = protocol.get("armsInterventionsModule", {})
outcomes_module = protocol.get("outcomesModule", {})
arms_interventions_module = protocol.get("armsInterventionsModule", {})
locations_module = protocol.get("contactsLocationsModule", {})
sponsor_module = protocol.get("sponsorCollaboratorsModule", {})
# Extract Fields
nct_id = identification.get("nctId", "N/A")
title = identification.get("briefTitle", "N/A")
official_title = identification.get("officialTitle", "N/A")
official_title = identification.get("officialTitle", "N/A")
org = identification.get("organization", {}).get("fullName", "N/A")
sponsor_name = sponsor_module.get("leadSponsor", {}).get("name", "N/A")
summary = clean_text(description.get("briefSummary", "N/A"))
overall_status = status_module.get("overallStatus", "N/A")
start_date = status_module.get("startDateStruct", {}).get("date", "N/A")
completion_date = status_module.get("completionDateStruct", {}).get(
"date", "N/A"
)
phases = ", ".join(design.get("phases", []))
study_type = design.get("studyType", "N/A")
criteria = clean_text(eligibility.get("eligibilityCriteria", "N/A"))
gender = eligibility.get("sex", "N/A")
ages = ", ".join(eligibility.get("stdAges", []))
conditions = ", ".join(conditions_module.get("conditions", []))
interventions = []
for interv in arms_interventions_module.get("interventions", []):
name = interv.get("name", "")
type_ = interv.get("type", "")
interventions.append(f"{type_}: {name}")
interventions_str = "; ".join(interventions)
primary_outcomes = []
for outcome in outcomes_module.get("primaryOutcomes", []):
measure = outcome.get("measure", "")
desc = outcome.get("description", "")
primary_outcomes.append(f"- {measure}: {desc}")
outcomes_str = clean_text("\n".join(primary_outcomes))
locations = []
for loc in locations_module.get("locations", []):
facility = loc.get("facility", "N/A")
city = loc.get("city", "")
country = loc.get("country", "")
locations.append(f"{facility} ({city}, {country})")
locations_str = "; ".join(locations[:5]) # Limit to 5 locations to save space
# Extract State (First match)
state = "Unknown"
# Check locations for US States
for loc_str in locations:
if "United States" in loc_str:
for s in US_STATES:
if s in loc_str:
state = s
break
if state != "Unknown":
break
# Construct Rich Page Content with Markdown Headers
# This text is what gets embedded and searched
page_content = (
f"# {title}\n"
f"**NCT ID:** {nct_id}\n"
f"**Official Title:** {official_title}\n"
f"**Sponsor:** {sponsor_name}\n"
f"**Organization:** {org}\n"
f"**Status:** {overall_status}\n"
f"**Phase:** {phases}\n"
f"**Study Type:** {study_type}\n"
f"**Start Date:** {start_date}\n"
f"**Completion Date:** {completion_date}\n\n"
f"## Summary\n{summary}\n\n"
f"## Conditions\n{conditions}\n\n"
f"## Interventions\n{interventions_str}\n\n"
f"## Eligibility Criteria\n"
f"**Gender:** {gender}\n"
f"**Ages:** {ages}\n"
f"**Criteria:**\n{criteria}\n\n"
f"## Primary Outcomes\n{outcomes_str}\n\n"
f"## Locations\n{locations_str}"
)
# Metadata for filtering (Structured Data)
metadata = {
"nct_id": nct_id,
"title": title,
"org": org,
"sponsor": sponsor_name,
"status": overall_status,
"phase": phases,
"study_type": study_type,
"start_year": (int(start_date.split("-")[0]) if start_date != "N/A" else 0),
"condition": conditions,
"intervention": interventions_str,
"country": (
locations[0].split(",")[-1].strip() if locations else "Unknown"
),
"state": state,
}
return Document(text=page_content, metadata=metadata, id_=nct_id)
except Exception as e:
print(
f"⚠️ Error processing study {study.get('protocolSection', {}).get('identificationModule', {}).get('nctId', 'Unknown')}: {e}"
)
return None
def run_ingestion():
"""
Main execution function for the ingestion script.
Parses arguments, initializes the index, and runs the ingestion loop.
"""
parser = argparse.ArgumentParser(description="Ingest Clinical Trials data.")
parser.add_argument(
"--limit",
type=int,
default=-1,
help="Number of studies to ingest. Set to -1 for ALL.",
)
parser.add_argument(
"--years", type=int, default=10, help="Number of years to look back."
)
parser.add_argument(
"--status",
type=str,
default="COMPLETED",
help="Comma-separated list of statuses (e.g., COMPLETED,RECRUITING).",
)
parser.add_argument(
"--phases",
type=str,
default="PHASE1,PHASE2,PHASE3,PHASE4",
help="Comma-separated list of phases (e.g., PHASE2,PHASE3).",
)
args = parser.parse_args()
status_list = args.status.split(",") if args.status else []
phase_list = args.phases.split(",") if args.phases else []
print(f"⚙️ Configuration: Limit={args.limit}, Years={args.years}")
print(f" Status Filter: {status_list}")
print(f" Phase Filter: {phase_list}")
# --- INITIALIZE LLAMAINDEX COMPONENTS ---
print("🧠 Initializing LlamaIndex Embeddings (PubMedBERT)...")
embed_model = HuggingFaceEmbedding(model_name="pritamdeka/S-PubMedBert-MS-MARCO")
# Initialize LanceDB (Persistent)
print("🚀 Initializing LanceDB...")
# Determine the project root directory (one level up from this script)
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
db_path = os.path.join(project_root, "ct_gov_lancedb")
# Connect to LanceDB
db = lancedb.connect(db_path)
table_name = "clinical_trials"
if table_name in db.table_names():
mode = "append"
print(f"ℹ️ Table '{table_name}' exists. Appending data.")
else:
mode = "create"
print(f"ℹ️ Table '{table_name}' does not exist. Creating new table.")
# Initialize Vector Store
vector_store = LanceDBVectorStore(
uri=db_path,
table_name=table_name,
mode=mode,
query_mode="hybrid" # Enable hybrid search support
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# Initialize Index ONCE
# We pass the storage context to link it to the vector store
index = VectorStoreIndex.from_vector_store(
vector_store, storage_context=storage_context, embed_model=embed_model
)
total_ingested = 0
# Progress Bar
pbar = tqdm(
total=args.limit if args.limit > 0 else float("inf"),
desc="Ingesting Studies",
unit="study",
)
# --- INGESTION LOOP ---
# Use ProcessPoolExecutor for parallel processing of study data
with concurrent.futures.ProcessPoolExecutor() as executor:
for batch_studies in fetch_trials_generator(
years=args.years,
max_studies=args.limit,
status=status_list,
phases=phase_list,
):
# Parallelize the processing of the batch
# map returns an iterator, so we convert to list to trigger execution
documents_iter = executor.map(process_study, batch_studies)
# Filter out None results (errors)
documents = [doc for doc in documents_iter if doc is not None]
if documents:
# Overwrite Logic:
# To avoid duplicates, we delete existing records with the same NCT IDs.
doc_ids = [doc.id_ for doc in documents]
try:
# LanceDB supports deletion via SQL-like filter
# We construct a filter string: "nct_id IN ('NCT123', 'NCT456')"
ids_str = ", ".join([f"'{id}'" for id in doc_ids])
if ids_str:
tbl = db.open_table("clinical_trials")
tbl.delete(f"nct_id IN ({ids_str})")
except Exception as e:
# Ignore if table doesn't exist yet
pass
# Efficient Batch Insertion
# We convert documents to nodes and insert them into the index.
# This handles embedding generation automatically.
parser = Settings.node_parser
nodes = parser.get_nodes_from_documents(documents)
index.insert_nodes(nodes)
total_ingested += len(documents)
pbar.update(len(documents))
pbar.close()
print(f"🎉 Ingestion Complete! Total studies in DB: {total_ingested}")
if __name__ == "__main__":
run_ingestion()
|