File size: 9,717 Bytes
507be68
 
 
 
 
 
 
 
 
 
 
 
 
122cdca
507be68
 
 
122cdca
 
 
 
 
 
 
 
 
 
 
507be68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b7595c
 
 
 
 
 
 
 
507be68
 
7b7595c
 
507be68
7b7595c
 
 
507be68
 
 
 
7b7595c
507be68
 
 
 
 
7b7595c
507be68
 
 
 
 
 
 
 
 
13bb297
507be68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
"""
Utility functions for the Clinical Trial Agent.

Handles configuration, LanceDB index loading, data normalization, and custom filtering logic.
"""

import os
import streamlit as st
from typing import List, Optional
from llama_index.core import VectorStoreIndex, StorageContext, Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.lancedb import LanceDBVectorStore
from llama_index.llms.gemini import Gemini
from llama_index.core.postprocessor import SentenceTransformerRerank
import lancedb
from dotenv import load_dotenv

@st.cache_resource
def get_reranker(top_n: int = 50):
    """
    Loads and caches the Reranker model.
    """
    print("🔄 Loading Reranker Model (Cached)...")
    return SentenceTransformerRerank(
        model="cross-encoder/ms-marco-MiniLM-L-12-v2", 
        top_n=top_n
    )

# --- MONKEYPATCH START ---
# Patch LanceDBVectorStore to handle 'nprobes' AttributeError and fix SQL quoting for IN filters.
original_query = LanceDBVectorStore.query

def patched_query(self, query, **kwargs):
    try:
        return original_query(self, query, **kwargs)
    except Exception as e:
        print(f"⚠️ LanceDB Query Error: {e}")
        if hasattr(query, "filters"):
            print(f"   Filters: {query.filters}")
        
        if "nprobes" in str(e):
            from llama_index.core.vector_stores.types import VectorStoreQueryResult
            return VectorStoreQueryResult(nodes=[], similarities=[], ids=[])
        raise e

LanceDBVectorStore.query = patched_query

# Patch _to_lance_filter to fix SQL quoting for IN operator with strings.
from llama_index.vector_stores.lancedb import base as lancedb_base
from llama_index.core.vector_stores.types import FilterOperator

original_to_lance_filter = lancedb_base._to_lance_filter

def patched_to_lance_filter(standard_filters, metadata_keys):
    if not standard_filters:
        return None
        
    # Reimplement filter logic to ensure correct SQL generation for LanceDB
    filters = []
    for filter in standard_filters.filters:
        key = filter.key
        if metadata_keys and key not in metadata_keys:
             continue
        
        # Prefix key with 'metadata.' for LanceDB struct column
        lance_key = f"metadata.{key}"
             
        # Handle IN operator with proper string quoting
        if filter.operator == FilterOperator.IN:
            if isinstance(filter.value, list):
                # Quote strings properly
                values = []
                for v in filter.value:
                    if isinstance(v, str):
                        values.append(f"'{v}'") # Single quotes for SQL
                    else:
                        values.append(str(v))
                val_str = ", ".join(values)
                filters.append(f"{lance_key} IN ({val_str})")
                continue
        
        # Standard operators
        op = filter.operator
        val = filter.value
        
        if op == FilterOperator.EQ:
            if isinstance(val, str):
                filters.append(f"{lance_key} = '{val}'")
            else:
                filters.append(f"{lance_key} = {val}")
        elif op == FilterOperator.GT:
            filters.append(f"{lance_key} > {val}")
        elif op == FilterOperator.LT:
            filters.append(f"{lance_key} < {val}")
        elif op == FilterOperator.GTE:
            filters.append(f"{lance_key} >= {val}")
        elif op == FilterOperator.LTE:
            filters.append(f"{lance_key} <= {val}")
        elif op == FilterOperator.NE:
            if isinstance(val, str):
                filters.append(f"{lance_key} != '{val}'")
            else:
                filters.append(f"{lance_key} != {val}")
        # Add other operators as needed
        
    if not filters:
        return None
        
    return " AND ".join(filters)

lancedb_base._to_lance_filter = patched_to_lance_filter
# --- MONKEYPATCH END ---


def load_environment():
    """Loads environment variables from .env file."""
    load_dotenv()


# --- Configuration ---
@st.cache_resource
def init_embedding_model():
    """Initializes and caches the embedding model globally."""
    Settings.embed_model = HuggingFaceEmbedding(
        model_name="pritamdeka/S-PubMedBert-MS-MARCO",
        device="cpu"
    )

def setup_llama_index(api_key: Optional[str] = None):
    """
    Configures global LlamaIndex settings (LLM).
    Embedding model is handled by init_embedding_model().
    """
    # Ensure embedding model is loaded
    init_embedding_model()

    # Use passed key, or fallback to env var
    final_key = api_key or os.environ.get("GOOGLE_API_KEY")

    if not final_key:
        return

    try:
        # Pass the key explicitly if available
        Settings.llm = Gemini(model="models/gemini-2.5-flash", temperature=0, api_key=final_key)
    except Exception as e:
        print(f"⚠️ LLM initialization failed: {e}")
        from llama_index.core.llms import MockLLM
        Settings.llm = MockLLM()


@st.cache_resource
def load_index() -> VectorStoreIndex:
    """
    Loads and caches the persistent LanceDB index.
    """
    # setup_llama_index()  <-- REMOVED: App handles setup. Calling here resets LLM if env var is missing.
    
    # Initialize LanceDB
    db_path = "./ct_gov_lancedb"
    db = lancedb.connect(db_path)

    # Define metadata keys explicitly to ensure filters work
    metadata_keys = [
        "nct_id", "title", "org", "sponsor", "status", "phase", 
        "study_type", "start_year", "condition", "intervention", 
        "country", "state"
    ]

    # Create the vector store wrapper
    vector_store = LanceDBVectorStore(
        uri=db_path, 
        table_name="clinical_trials",
        query_mode="hybrid",
    )
    
    # Manually set metadata keys as constructor doesn't accept them
    vector_store._metadata_keys = metadata_keys

    # Create storage context
    storage_context = StorageContext.from_defaults(vector_store=vector_store)

    # Load the index from the vector store
    index = VectorStoreIndex.from_vector_store(
        vector_store, storage_context=storage_context
    )
    return index


def get_hybrid_retriever(index: VectorStoreIndex, similarity_top_k: int = 50, filters=None):
    """
    Creates a Hybrid Retriever using LanceDB's native hybrid search.
    
    Args:
        index (VectorStoreIndex): The loaded vector index.
        similarity_top_k (int): Number of top results to retrieve.
        filters (MetadataFilters, optional): Filters to apply.
        
    Returns:
        VectorIndexRetriever: The configured retriever.
    """
    # LanceDB supports native hybrid search via query_mode="hybrid"
    # We pass this configuration to the retriever
    # Use standard retriever first to avoid LanceDB hybrid search issues on small datasets
    return index.as_retriever(
        similarity_top_k=similarity_top_k, 
        filters=filters,
    )


# --- Normalization ---

# Centralized Sponsor Mappings
# Key: Canonical Name
# Value: List of variations/aliases (including the canonical name itself if needed for matching)
SPONSOR_MAPPINGS = {
    "GlaxoSmithKline": [
        "gsk", "glaxo", "glaxosmithkline", "glaxosmithkline", 
        "GlaxoSmithKline"
    ],
    "Janssen": [
        "j&j", "johnson & johnson", "johnson and johnson", "janssen", "Janssen",
        "Janssen Research & Development, LLC",
        "Janssen Vaccines & Prevention B.V.",
        "Janssen Pharmaceutical K.K.",
        "Janssen-Cilag International NV",
        "Janssen Sciences Ireland UC",
        "Janssen Pharmaceutica N.V., Belgium",
        "Janssen Scientific Affairs, LLC",
        "Janssen-Cilag Ltd.",
        "Xian-Janssen Pharmaceutical Ltd.",
        "Janssen Korea, Ltd., Korea",
        "Janssen-Cilag G.m.b.H",
        "Janssen-Cilag, S.A.",
        "Janssen BioPharma, Inc.",
    ],
    "Bristol-Myers Squibb": [
        "bms", "bristol", "bristol myers squibb", "bristol-myers squibb",
        "Bristol-Myers Squibb"
    ],
    "Merck Sharp & Dohme": [
        "merck", "msd", "merck sharp & dohme", 
        "Merck Sharp & Dohme LLC"
    ],
    "Pfizer": ["pfizer", "Pfizer", "Pfizer Inc."],
    "AstraZeneca": ["astrazeneca", "AstraZeneca"],
    "Eli Lilly and Company": ["lilly", "eli lilly", "Eli Lilly and Company"],
    "Sanofi": ["sanofi", "Sanofi"],
    "Novartis": ["novartis", "Novartis"],
}

def normalize_sponsor(sponsor: str) -> Optional[str]:
    """
    Normalizes sponsor names to canonical forms using centralized mappings.
    """
    if not sponsor:
        return None

    s = sponsor.lower().strip()
    
    for canonical, variations in SPONSOR_MAPPINGS.items():
        # Check if input matches canonical name (case-insensitive)
        if s == canonical.lower():
            return canonical
            
        # Check variations and aliases
        for v in variations:
            v_lower = v.lower()
            if v_lower == s:
                return canonical
            # If the variation is a known alias (like 'gsk'), check if it's in the string
            if len(v) < 5 and v_lower in s: 
                 return canonical
            
            if canonical.lower() in s:
                return canonical

    return sponsor


def get_sponsor_variations(sponsor: str) -> Optional[List[str]]:
    """
    Returns list of exact database 'org' values for a given sponsor alias.
    """
    if not sponsor:
        return None

    # First, normalize the input to get the canonical name
    canonical = normalize_sponsor(sponsor)
    
    if canonical in SPONSOR_MAPPINGS:
        return SPONSOR_MAPPINGS[canonical]
        
    return None