File size: 4,839 Bytes
507be68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c4c603
507be68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import json
from langchain.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.prompts import PromptTemplate
from modules.tools import get_study_details
from modules.utils import load_environment

import streamlit as st
import os

# Load env for API key
load_environment()

def get_llm():
    """Retrieves LLM instance with dynamic API key."""
    # Check session state first (User provided key)
    api_key = None
    if hasattr(st, "session_state") and "api_key" in st.session_state:
        api_key = st.session_state["api_key"]
    
    # Fallback to environment variable
    if not api_key:
        api_key = os.environ.get("GOOGLE_API_KEY")
        
    if not api_key:
        raise ValueError("Google API Key not found in session state or environment.")

    return ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0, google_api_key=api_key)


EXTRACT_PROMPT = PromptTemplate(
    template="""
    You are a Clinical Informatics Expert.
    Your task is to extract structured cohort requirements from the following Clinical Trial Eligibility Criteria.

    Output a JSON object with two keys: "inclusion" and "exclusion".
    Each key should contain a list of rules.
    Each rule should have:
    - "concept": The medical concept (e.g., "Type 2 Diabetes", "Metformin").
    - "domain": The domain (Condition, Drug, Measurement, Procedure, Observation).
    - "temporal": Any temporal logic (e.g., "History of", "Within last 6 months").
    - "codes": A list of potential ICD-10 or RxNorm codes (make a best guess).

    CRITERIA:
    {criteria}

    JSON OUTPUT:
    """,
    input_variables=["criteria"],
)

SQL_PROMPT = PromptTemplate(
    template="""
    You are a SQL Expert specializing in Healthcare Claims Data Analysis.
    Generate a standard SQL query to define a cohort of patients based on the following structured requirements.

    ### Schema Assumptions
    1.  **`medical_claims`** (Diagnoses & Procedures):
        - `patient_id`, `claim_date`, `diagnosis_code` (ICD-10), `procedure_code` (CPT/HCPCS).
    2.  **`pharmacy_claims`** (Drugs):
        - `patient_id`, `fill_date`, `ndc_code`.

    ### Logic Rules
    1.  **Conditions (Diagnoses)**:
        - Require **at least 2 distinct claim dates** where the diagnosis code matches.
        - These 2 claims must be **at least 30 days apart** (to confirm chronic condition).
    2.  **Drugs**:
        - Require at least 1 claim with a matching NDC code.
    3.  **Procedures**:
        - Require at least 1 claim with a matching CPT/HCPCS code.
    4.  **Exclusions**:
        - Exclude patients who have ANY matching claims for exclusion criteria.

    ### Requirements (JSON)
    {requirements}

    ### Output
    Generate a single SQL query that selects `patient_id` from the claims tables meeting the criteria.
    Use Common Table Expressions (CTEs) for clarity.
    Do NOT output markdown formatting (```sql), just the raw SQL.

    SQL QUERY:
    """,
    input_variables=["requirements"],
)


def extract_cohort_requirements(criteria_text: str) -> dict:
    """Uses LLM to parse criteria text into structured JSON."""
    llm = get_llm()
    chain = EXTRACT_PROMPT | llm
    response = chain.invoke({"criteria": criteria_text})
    try:
        # Clean up potential markdown code blocks
        text = response.content.replace("```json", "").replace("```", "").strip()
        return json.loads(text)
    except json.JSONDecodeError:
        return {"error": "Failed to parse LLM output", "raw_output": response.content}


def generate_cohort_sql(requirements: dict) -> str:
    """Uses LLM to translate structured requirements into SQL."""
    llm = get_llm()
    chain = SQL_PROMPT | llm
    response = chain.invoke({"requirements": json.dumps(requirements, indent=2)})
    return response.content.replace("```sql", "").replace("```", "").strip()


@tool("get_cohort_sql")
def get_cohort_sql(nct_id: str) -> str:
    """
    Generates a SQL query to define the patient cohort for a specific study (NCT ID).
    
    Args:
        nct_id (str): The ClinicalTrials.gov identifier (e.g., NCT01234567).
        
    Returns:
        str: A formatted string containing the Extracted Requirements (JSON) and the Generated SQL.
    """
    # 1. Fetch Study Details
    # Reuse the existing tool logic to get the text
    study_text = get_study_details.invoke(nct_id)
    
    if "No study found" in study_text:
        return f"Could not find study {nct_id}."

    # 2. Extract Requirements
    requirements = extract_cohort_requirements(study_text)
    
    # 3. Generate SQL
    sql_query = generate_cohort_sql(requirements)
    
    return f"""
### 📋 Extracted Cohort Requirements
```json
{json.dumps(requirements, indent=2)}
```

### 💾 Generated SQL Query (OMOP CDM)
```sql
{sql_query}
```
"""