Spaces:
Sleeping
Sleeping
File size: 4,839 Bytes
507be68 7c4c603 507be68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import json
from langchain.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.prompts import PromptTemplate
from modules.tools import get_study_details
from modules.utils import load_environment
import streamlit as st
import os
# Load env for API key
load_environment()
def get_llm():
"""Retrieves LLM instance with dynamic API key."""
# Check session state first (User provided key)
api_key = None
if hasattr(st, "session_state") and "api_key" in st.session_state:
api_key = st.session_state["api_key"]
# Fallback to environment variable
if not api_key:
api_key = os.environ.get("GOOGLE_API_KEY")
if not api_key:
raise ValueError("Google API Key not found in session state or environment.")
return ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0, google_api_key=api_key)
EXTRACT_PROMPT = PromptTemplate(
template="""
You are a Clinical Informatics Expert.
Your task is to extract structured cohort requirements from the following Clinical Trial Eligibility Criteria.
Output a JSON object with two keys: "inclusion" and "exclusion".
Each key should contain a list of rules.
Each rule should have:
- "concept": The medical concept (e.g., "Type 2 Diabetes", "Metformin").
- "domain": The domain (Condition, Drug, Measurement, Procedure, Observation).
- "temporal": Any temporal logic (e.g., "History of", "Within last 6 months").
- "codes": A list of potential ICD-10 or RxNorm codes (make a best guess).
CRITERIA:
{criteria}
JSON OUTPUT:
""",
input_variables=["criteria"],
)
SQL_PROMPT = PromptTemplate(
template="""
You are a SQL Expert specializing in Healthcare Claims Data Analysis.
Generate a standard SQL query to define a cohort of patients based on the following structured requirements.
### Schema Assumptions
1. **`medical_claims`** (Diagnoses & Procedures):
- `patient_id`, `claim_date`, `diagnosis_code` (ICD-10), `procedure_code` (CPT/HCPCS).
2. **`pharmacy_claims`** (Drugs):
- `patient_id`, `fill_date`, `ndc_code`.
### Logic Rules
1. **Conditions (Diagnoses)**:
- Require **at least 2 distinct claim dates** where the diagnosis code matches.
- These 2 claims must be **at least 30 days apart** (to confirm chronic condition).
2. **Drugs**:
- Require at least 1 claim with a matching NDC code.
3. **Procedures**:
- Require at least 1 claim with a matching CPT/HCPCS code.
4. **Exclusions**:
- Exclude patients who have ANY matching claims for exclusion criteria.
### Requirements (JSON)
{requirements}
### Output
Generate a single SQL query that selects `patient_id` from the claims tables meeting the criteria.
Use Common Table Expressions (CTEs) for clarity.
Do NOT output markdown formatting (```sql), just the raw SQL.
SQL QUERY:
""",
input_variables=["requirements"],
)
def extract_cohort_requirements(criteria_text: str) -> dict:
"""Uses LLM to parse criteria text into structured JSON."""
llm = get_llm()
chain = EXTRACT_PROMPT | llm
response = chain.invoke({"criteria": criteria_text})
try:
# Clean up potential markdown code blocks
text = response.content.replace("```json", "").replace("```", "").strip()
return json.loads(text)
except json.JSONDecodeError:
return {"error": "Failed to parse LLM output", "raw_output": response.content}
def generate_cohort_sql(requirements: dict) -> str:
"""Uses LLM to translate structured requirements into SQL."""
llm = get_llm()
chain = SQL_PROMPT | llm
response = chain.invoke({"requirements": json.dumps(requirements, indent=2)})
return response.content.replace("```sql", "").replace("```", "").strip()
@tool("get_cohort_sql")
def get_cohort_sql(nct_id: str) -> str:
"""
Generates a SQL query to define the patient cohort for a specific study (NCT ID).
Args:
nct_id (str): The ClinicalTrials.gov identifier (e.g., NCT01234567).
Returns:
str: A formatted string containing the Extracted Requirements (JSON) and the Generated SQL.
"""
# 1. Fetch Study Details
# Reuse the existing tool logic to get the text
study_text = get_study_details.invoke(nct_id)
if "No study found" in study_text:
return f"Could not find study {nct_id}."
# 2. Extract Requirements
requirements = extract_cohort_requirements(study_text)
# 3. Generate SQL
sql_query = generate_cohort_sql(requirements)
return f"""
### 📋 Extracted Cohort Requirements
```json
{json.dumps(requirements, indent=2)}
```
### 💾 Generated SQL Query (OMOP CDM)
```sql
{sql_query}
```
"""
|