File size: 17,439 Bytes
507be68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
"""
Data Ingestion Script for Clinical Trial Agent.

This script fetches clinical trial data from the ClinicalTrials.gov API (v2),
processes it into a rich text format, and ingests it into a local ChromaDB vector index
using LlamaIndex and PubMedBERT embeddings.

Features:
- **Pagination**: Fetches data in batches using the API's pagination tokens.
- **Robustness**: Implements retry logic for network errors.
- **Efficiency**: Uses batch insertion and reuses the existing index.
- **Progress Tracking**: Displays a progress bar using `tqdm`.
"""

import requests
import re
from datetime import datetime, timedelta
from dotenv import load_dotenv
import argparse
import time
from tqdm import tqdm
import os
import concurrent.futures

# LlamaIndex Imports
from llama_index.core import Document, VectorStoreIndex, StorageContext, Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.lancedb import LanceDBVectorStore
import lancedb

# List of US States for extraction
US_STATES = [
    "Alabama", "Alaska", "Arizona", "Arkansas", "California", "Colorado", "Connecticut", 
    "Delaware", "Florida", "Georgia", "Hawaii", "Idaho", "Illinois", "Indiana", "Iowa", 
    "Kansas", "Kentucky", "Louisiana", "Maine", "Maryland", "Massachusetts", "Michigan", 
    "Minnesota", "Mississippi", "Missouri", "Montana", "Nebraska", "Nevada", "New Hampshire", 
    "New Jersey", "New Mexico", "New York", "North Carolina", "North Dakota", "Ohio", 
    "Oklahoma", "Oregon", "Pennsylvania", "Rhode Island", "South Carolina", "South Dakota", 
    "Tennessee", "Texas", "Utah", "Vermont", "Virginia", "Washington", "West Virginia", 
    "Wisconsin", "Wyoming", "District of Columbia"
]

load_dotenv()

# Disable LLM for ingestion (we only need embeddings, not generation)
Settings.llm = None


def clean_text(text: str) -> str:
    """
    Cleans raw text by removing HTML tags and normalizing whitespace.

    Args:
        text (str): The raw text string.

    Returns:
        str: The cleaned text.
    """
    if not text:
        return ""
    # Remove HTML tags
    text = re.sub(r"<[^>]+>", "", text)
    # Remove multiple spaces/newlines and trim
    text = re.sub(r"\s+", " ", text).strip()
    return text


def fetch_trials_generator(
    years: int = 5, max_studies: int = 1000, status: list = None, phases: list = None
):
    """
    Yields batches of clinical trials from the ClinicalTrials.gov API.

    Handles pagination automatically and implements retry logic for API requests.

    Args:
        years (int): Number of years to look back for study start dates.
        max_studies (int): Maximum total number of studies to fetch (-1 for all).
        status (list): List of status strings to filter by (e.g., ["RECRUITING"]).
        phases (list): List of phase strings to filter by (e.g., ["PHASE2"]).

    Yields:
        list: A batch of study dictionaries (JSON objects).
    """
    base_url = "https://clinicaltrials.gov/api/v2/studies"

    # Calculate start date for filtering
    start_date = (datetime.now() - timedelta(days=365 * years)).strftime("%Y-%m-%d")
    print("📡 Connecting to CT.gov API...")
    print(f"🔎 Fetching trials starting after: {start_date}")
    if status:
        print(f"   Filters - Status: {status}")
    if phases:
        print(f"   Filters - Phases: {phases}")

    fetched_count = 0
    next_page_token = None

    # If max_studies is -1, fetch ALL studies (infinite limit)
    fetch_limit = float("inf") if max_studies == -1 else max_studies

    while fetched_count < fetch_limit:
        # Determine batch size (max 1000 per API limit)
        current_limit = 1000
        if max_studies != -1:
            current_limit = min(1000, max_studies - fetched_count)

        # --- Query Construction ---
        # Build the query term using the API's syntax
        query_parts = [f"AREA[StartDate]RANGE[{start_date},MAX]"]

        if status:
            status_str = " OR ".join(status)
            query_parts.append(f"AREA[OverallStatus]({status_str})")

        if phases:
            phase_str = " OR ".join(phases)
            query_parts.append(f"AREA[Phase]({phase_str})")

        full_query = " AND ".join(query_parts)

        params = {
            "query.term": full_query,
            "pageSize": current_limit,
            # Request specific fields to minimize payload size
            "fields": ",".join(
                [
                    "protocolSection.identificationModule.nctId",
                    "protocolSection.identificationModule.briefTitle",
                    "protocolSection.identificationModule.officialTitle",
                    "protocolSection.identificationModule.organization",
                    "protocolSection.statusModule.overallStatus",
                    "protocolSection.statusModule.startDateStruct",
                    "protocolSection.statusModule.completionDateStruct",
                    "protocolSection.designModule.phases",
                    "protocolSection.designModule.studyType",
                    "protocolSection.eligibilityModule.eligibilityCriteria",
                    "protocolSection.eligibilityModule.sex",
                    "protocolSection.eligibilityModule.stdAges",
                    "protocolSection.descriptionModule.briefSummary",
                    "protocolSection.conditionsModule.conditions",
                    "protocolSection.outcomesModule.primaryOutcomes",
                    "protocolSection.contactsLocationsModule.locations",
                    "protocolSection.outcomesModule.primaryOutcomes",
                    "protocolSection.contactsLocationsModule.locations",
                    "protocolSection.armsInterventionsModule",
                    "protocolSection.sponsorCollaboratorsModule.leadSponsor",
                ]
            ),
        }

        if next_page_token:
            params["pageToken"] = next_page_token

        # --- Retry Logic ---
        retries = 3
        for attempt in range(retries):
            try:
                response = requests.get(base_url, params=params, timeout=30)
                if response.status_code == 200:
                    data = response.json()
                    studies = data.get("studies", [])

                    if not studies:
                        return  # Stop generator if no studies returned

                    yield studies

                    fetched_count += len(studies)
                    next_page_token = data.get("nextPageToken")

                    if not next_page_token:
                        return  # Stop generator if no more pages

                    break  # Success, exit retry loop
                else:
                    print(f"❌ API Error: {response.status_code} - {response.text}")
                    if attempt < retries - 1:
                        time.sleep(2)
                    else:
                        return  # Stop generator on persistent error
            except Exception as e:
                print(f"❌ Request Error (Attempt {attempt+1}/{retries}): {e}")
                if attempt < retries - 1:
                    time.sleep(2)
                else:
                    return  # Stop generator


def process_study(study):
    """
    Processes a single study dictionary into a LlamaIndex Document.
    This function is designed to be run in parallel.
    """
    try:
        # Extract Modules
        protocol = study.get("protocolSection", {})
        identification = protocol.get("identificationModule", {})
        status_module = protocol.get("statusModule", {})
        design = protocol.get("designModule", {})
        eligibility = protocol.get("eligibilityModule", {})
        description = protocol.get("descriptionModule", {})
        conditions_module = protocol.get("conditionsModule", {})
        outcomes_module = protocol.get("outcomesModule", {})
        arms_interventions_module = protocol.get("armsInterventionsModule", {})
        outcomes_module = protocol.get("outcomesModule", {})
        arms_interventions_module = protocol.get("armsInterventionsModule", {})
        locations_module = protocol.get("contactsLocationsModule", {})
        sponsor_module = protocol.get("sponsorCollaboratorsModule", {})

        # Extract Fields
        nct_id = identification.get("nctId", "N/A")
        title = identification.get("briefTitle", "N/A")
        official_title = identification.get("officialTitle", "N/A")
        official_title = identification.get("officialTitle", "N/A")
        org = identification.get("organization", {}).get("fullName", "N/A")
        sponsor_name = sponsor_module.get("leadSponsor", {}).get("name", "N/A")
        summary = clean_text(description.get("briefSummary", "N/A"))

        overall_status = status_module.get("overallStatus", "N/A")
        start_date = status_module.get("startDateStruct", {}).get("date", "N/A")
        completion_date = status_module.get("completionDateStruct", {}).get(
            "date", "N/A"
        )

        phases = ", ".join(design.get("phases", []))
        study_type = design.get("studyType", "N/A")

        criteria = clean_text(eligibility.get("eligibilityCriteria", "N/A"))
        gender = eligibility.get("sex", "N/A")
        ages = ", ".join(eligibility.get("stdAges", []))

        conditions = ", ".join(conditions_module.get("conditions", []))

        interventions = []
        for interv in arms_interventions_module.get("interventions", []):
            name = interv.get("name", "")
            type_ = interv.get("type", "")
            interventions.append(f"{type_}: {name}")
        interventions_str = "; ".join(interventions)

        primary_outcomes = []
        for outcome in outcomes_module.get("primaryOutcomes", []):
            measure = outcome.get("measure", "")
            desc = outcome.get("description", "")
            primary_outcomes.append(f"- {measure}: {desc}")
        outcomes_str = clean_text("\n".join(primary_outcomes))

        locations = []
        for loc in locations_module.get("locations", []):
            facility = loc.get("facility", "N/A")
            city = loc.get("city", "")
            country = loc.get("country", "")
            locations.append(f"{facility} ({city}, {country})")
        locations_str = "; ".join(locations[:5])  # Limit to 5 locations to save space

        # Extract State (First match)
        state = "Unknown"
        # Check locations for US States
        for loc_str in locations:
            if "United States" in loc_str:
                for s in US_STATES:
                    if s in loc_str:
                        state = s
                        break
                if state != "Unknown":
                    break

        # Construct Rich Page Content with Markdown Headers
        # This text is what gets embedded and searched
        page_content = (
            f"# {title}\n"
            f"**NCT ID:** {nct_id}\n"
            f"**Official Title:** {official_title}\n"
            f"**Sponsor:** {sponsor_name}\n"
            f"**Organization:** {org}\n"
            f"**Status:** {overall_status}\n"
            f"**Phase:** {phases}\n"
            f"**Study Type:** {study_type}\n"
            f"**Start Date:** {start_date}\n"
            f"**Completion Date:** {completion_date}\n\n"
            f"## Summary\n{summary}\n\n"
            f"## Conditions\n{conditions}\n\n"
            f"## Interventions\n{interventions_str}\n\n"
            f"## Eligibility Criteria\n"
            f"**Gender:** {gender}\n"
            f"**Ages:** {ages}\n"
            f"**Criteria:**\n{criteria}\n\n"
            f"## Primary Outcomes\n{outcomes_str}\n\n"
            f"## Locations\n{locations_str}"
        )

        # Metadata for filtering (Structured Data)
        metadata = {
            "nct_id": nct_id,
            "title": title,
            "org": org,
            "sponsor": sponsor_name,
            "status": overall_status,
            "phase": phases,
            "study_type": study_type,
            "start_year": (int(start_date.split("-")[0]) if start_date != "N/A" else 0),
            "condition": conditions,
            "intervention": interventions_str,
            "country": (
                locations[0].split(",")[-1].strip() if locations else "Unknown"
            ),
            "state": state,
        }

        return Document(text=page_content, metadata=metadata, id_=nct_id)
    except Exception as e:
        print(
            f"⚠️ Error processing study {study.get('protocolSection', {}).get('identificationModule', {}).get('nctId', 'Unknown')}: {e}"
        )
        return None


def run_ingestion():
    """
    Main execution function for the ingestion script.
    Parses arguments, initializes the index, and runs the ingestion loop.
    """
    parser = argparse.ArgumentParser(description="Ingest Clinical Trials data.")
    parser.add_argument(
        "--limit",
        type=int,
        default=-1,
        help="Number of studies to ingest. Set to -1 for ALL.",
    )
    parser.add_argument(
        "--years", type=int, default=10, help="Number of years to look back."
    )
    parser.add_argument(
        "--status",
        type=str,
        default="COMPLETED",
        help="Comma-separated list of statuses (e.g., COMPLETED,RECRUITING).",
    )
    parser.add_argument(
        "--phases",
        type=str,
        default="PHASE1,PHASE2,PHASE3,PHASE4",
        help="Comma-separated list of phases (e.g., PHASE2,PHASE3).",
    )
    args = parser.parse_args()

    status_list = args.status.split(",") if args.status else []
    phase_list = args.phases.split(",") if args.phases else []

    print(f"⚙️ Configuration: Limit={args.limit}, Years={args.years}")
    print(f"   Status Filter: {status_list}")
    print(f"   Phase Filter: {phase_list}")

    # --- INITIALIZE LLAMAINDEX COMPONENTS ---
    print("🧠 Initializing LlamaIndex Embeddings (PubMedBERT)...")
    embed_model = HuggingFaceEmbedding(model_name="pritamdeka/S-PubMedBert-MS-MARCO")

    # Initialize LanceDB (Persistent)
    print("🚀 Initializing LanceDB...")

    # Determine the project root directory (one level up from this script)
    script_dir = os.path.dirname(os.path.abspath(__file__))
    project_root = os.path.dirname(script_dir)
    db_path = os.path.join(project_root, "ct_gov_lancedb")

    # Connect to LanceDB
    db = lancedb.connect(db_path)
    
    table_name = "clinical_trials"
    if table_name in db.table_names():
        mode = "append"
        print(f"ℹ️ Table '{table_name}' exists. Appending data.")
    else:
        mode = "create"
        print(f"ℹ️ Table '{table_name}' does not exist. Creating new table.")

    # Initialize Vector Store
    vector_store = LanceDBVectorStore(
        uri=db_path, 
        table_name=table_name, 
        mode=mode,
        query_mode="hybrid" # Enable hybrid search support
    )
    storage_context = StorageContext.from_defaults(vector_store=vector_store)

    # Initialize Index ONCE
    # We pass the storage context to link it to the vector store
    index = VectorStoreIndex.from_vector_store(
        vector_store, storage_context=storage_context, embed_model=embed_model
    )

    total_ingested = 0

    # Progress Bar
    pbar = tqdm(
        total=args.limit if args.limit > 0 else float("inf"),
        desc="Ingesting Studies",
        unit="study",
    )

    # --- INGESTION LOOP ---
    # Use ProcessPoolExecutor for parallel processing of study data
    with concurrent.futures.ProcessPoolExecutor() as executor:
        for batch_studies in fetch_trials_generator(
            years=args.years,
            max_studies=args.limit,
            status=status_list,
            phases=phase_list,
        ):
            # Parallelize the processing of the batch
            # map returns an iterator, so we convert to list to trigger execution
            documents_iter = executor.map(process_study, batch_studies)

            # Filter out None results (errors)
            documents = [doc for doc in documents_iter if doc is not None]

            if documents:
                # Overwrite Logic:
                # To avoid duplicates, we delete existing records with the same NCT IDs.
                doc_ids = [doc.id_ for doc in documents]
                try:
                    # LanceDB supports deletion via SQL-like filter
                    # We construct a filter string: "nct_id IN ('NCT123', 'NCT456')"
                    ids_str = ", ".join([f"'{id}'" for id in doc_ids])
                    if ids_str:
                        tbl = db.open_table("clinical_trials")
                        tbl.delete(f"nct_id IN ({ids_str})")
                except Exception as e:
                    # Ignore if table doesn't exist yet
                    pass

                # Efficient Batch Insertion
                # We convert documents to nodes and insert them into the index.
                # This handles embedding generation automatically.
                parser = Settings.node_parser
                nodes = parser.get_nodes_from_documents(documents)

                index.insert_nodes(nodes)

                total_ingested += len(documents)
                pbar.update(len(documents))

    pbar.close()
    print(f"🎉 Ingestion Complete! Total studies in DB: {total_ingested}")


if __name__ == "__main__":
    run_ingestion()