Spaces:
Sleeping
Sleeping
File size: 9,717 Bytes
507be68 122cdca 507be68 122cdca 507be68 7b7595c 507be68 7b7595c 507be68 7b7595c 507be68 7b7595c 507be68 7b7595c 507be68 13bb297 507be68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 |
"""
Utility functions for the Clinical Trial Agent.
Handles configuration, LanceDB index loading, data normalization, and custom filtering logic.
"""
import os
import streamlit as st
from typing import List, Optional
from llama_index.core import VectorStoreIndex, StorageContext, Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.lancedb import LanceDBVectorStore
from llama_index.llms.gemini import Gemini
from llama_index.core.postprocessor import SentenceTransformerRerank
import lancedb
from dotenv import load_dotenv
@st.cache_resource
def get_reranker(top_n: int = 50):
"""
Loads and caches the Reranker model.
"""
print("🔄 Loading Reranker Model (Cached)...")
return SentenceTransformerRerank(
model="cross-encoder/ms-marco-MiniLM-L-12-v2",
top_n=top_n
)
# --- MONKEYPATCH START ---
# Patch LanceDBVectorStore to handle 'nprobes' AttributeError and fix SQL quoting for IN filters.
original_query = LanceDBVectorStore.query
def patched_query(self, query, **kwargs):
try:
return original_query(self, query, **kwargs)
except Exception as e:
print(f"⚠️ LanceDB Query Error: {e}")
if hasattr(query, "filters"):
print(f" Filters: {query.filters}")
if "nprobes" in str(e):
from llama_index.core.vector_stores.types import VectorStoreQueryResult
return VectorStoreQueryResult(nodes=[], similarities=[], ids=[])
raise e
LanceDBVectorStore.query = patched_query
# Patch _to_lance_filter to fix SQL quoting for IN operator with strings.
from llama_index.vector_stores.lancedb import base as lancedb_base
from llama_index.core.vector_stores.types import FilterOperator
original_to_lance_filter = lancedb_base._to_lance_filter
def patched_to_lance_filter(standard_filters, metadata_keys):
if not standard_filters:
return None
# Reimplement filter logic to ensure correct SQL generation for LanceDB
filters = []
for filter in standard_filters.filters:
key = filter.key
if metadata_keys and key not in metadata_keys:
continue
# Prefix key with 'metadata.' for LanceDB struct column
lance_key = f"metadata.{key}"
# Handle IN operator with proper string quoting
if filter.operator == FilterOperator.IN:
if isinstance(filter.value, list):
# Quote strings properly
values = []
for v in filter.value:
if isinstance(v, str):
values.append(f"'{v}'") # Single quotes for SQL
else:
values.append(str(v))
val_str = ", ".join(values)
filters.append(f"{lance_key} IN ({val_str})")
continue
# Standard operators
op = filter.operator
val = filter.value
if op == FilterOperator.EQ:
if isinstance(val, str):
filters.append(f"{lance_key} = '{val}'")
else:
filters.append(f"{lance_key} = {val}")
elif op == FilterOperator.GT:
filters.append(f"{lance_key} > {val}")
elif op == FilterOperator.LT:
filters.append(f"{lance_key} < {val}")
elif op == FilterOperator.GTE:
filters.append(f"{lance_key} >= {val}")
elif op == FilterOperator.LTE:
filters.append(f"{lance_key} <= {val}")
elif op == FilterOperator.NE:
if isinstance(val, str):
filters.append(f"{lance_key} != '{val}'")
else:
filters.append(f"{lance_key} != {val}")
# Add other operators as needed
if not filters:
return None
return " AND ".join(filters)
lancedb_base._to_lance_filter = patched_to_lance_filter
# --- MONKEYPATCH END ---
def load_environment():
"""Loads environment variables from .env file."""
load_dotenv()
# --- Configuration ---
@st.cache_resource
def init_embedding_model():
"""Initializes and caches the embedding model globally."""
Settings.embed_model = HuggingFaceEmbedding(
model_name="pritamdeka/S-PubMedBert-MS-MARCO",
device="cpu"
)
def setup_llama_index(api_key: Optional[str] = None):
"""
Configures global LlamaIndex settings (LLM).
Embedding model is handled by init_embedding_model().
"""
# Ensure embedding model is loaded
init_embedding_model()
# Use passed key, or fallback to env var
final_key = api_key or os.environ.get("GOOGLE_API_KEY")
if not final_key:
return
try:
# Pass the key explicitly if available
Settings.llm = Gemini(model="models/gemini-2.5-flash", temperature=0, api_key=final_key)
except Exception as e:
print(f"⚠️ LLM initialization failed: {e}")
from llama_index.core.llms import MockLLM
Settings.llm = MockLLM()
@st.cache_resource
def load_index() -> VectorStoreIndex:
"""
Loads and caches the persistent LanceDB index.
"""
# setup_llama_index() <-- REMOVED: App handles setup. Calling here resets LLM if env var is missing.
# Initialize LanceDB
db_path = "./ct_gov_lancedb"
db = lancedb.connect(db_path)
# Define metadata keys explicitly to ensure filters work
metadata_keys = [
"nct_id", "title", "org", "sponsor", "status", "phase",
"study_type", "start_year", "condition", "intervention",
"country", "state"
]
# Create the vector store wrapper
vector_store = LanceDBVectorStore(
uri=db_path,
table_name="clinical_trials",
query_mode="hybrid",
)
# Manually set metadata keys as constructor doesn't accept them
vector_store._metadata_keys = metadata_keys
# Create storage context
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# Load the index from the vector store
index = VectorStoreIndex.from_vector_store(
vector_store, storage_context=storage_context
)
return index
def get_hybrid_retriever(index: VectorStoreIndex, similarity_top_k: int = 50, filters=None):
"""
Creates a Hybrid Retriever using LanceDB's native hybrid search.
Args:
index (VectorStoreIndex): The loaded vector index.
similarity_top_k (int): Number of top results to retrieve.
filters (MetadataFilters, optional): Filters to apply.
Returns:
VectorIndexRetriever: The configured retriever.
"""
# LanceDB supports native hybrid search via query_mode="hybrid"
# We pass this configuration to the retriever
# Use standard retriever first to avoid LanceDB hybrid search issues on small datasets
return index.as_retriever(
similarity_top_k=similarity_top_k,
filters=filters,
)
# --- Normalization ---
# Centralized Sponsor Mappings
# Key: Canonical Name
# Value: List of variations/aliases (including the canonical name itself if needed for matching)
SPONSOR_MAPPINGS = {
"GlaxoSmithKline": [
"gsk", "glaxo", "glaxosmithkline", "glaxosmithkline",
"GlaxoSmithKline"
],
"Janssen": [
"j&j", "johnson & johnson", "johnson and johnson", "janssen", "Janssen",
"Janssen Research & Development, LLC",
"Janssen Vaccines & Prevention B.V.",
"Janssen Pharmaceutical K.K.",
"Janssen-Cilag International NV",
"Janssen Sciences Ireland UC",
"Janssen Pharmaceutica N.V., Belgium",
"Janssen Scientific Affairs, LLC",
"Janssen-Cilag Ltd.",
"Xian-Janssen Pharmaceutical Ltd.",
"Janssen Korea, Ltd., Korea",
"Janssen-Cilag G.m.b.H",
"Janssen-Cilag, S.A.",
"Janssen BioPharma, Inc.",
],
"Bristol-Myers Squibb": [
"bms", "bristol", "bristol myers squibb", "bristol-myers squibb",
"Bristol-Myers Squibb"
],
"Merck Sharp & Dohme": [
"merck", "msd", "merck sharp & dohme",
"Merck Sharp & Dohme LLC"
],
"Pfizer": ["pfizer", "Pfizer", "Pfizer Inc."],
"AstraZeneca": ["astrazeneca", "AstraZeneca"],
"Eli Lilly and Company": ["lilly", "eli lilly", "Eli Lilly and Company"],
"Sanofi": ["sanofi", "Sanofi"],
"Novartis": ["novartis", "Novartis"],
}
def normalize_sponsor(sponsor: str) -> Optional[str]:
"""
Normalizes sponsor names to canonical forms using centralized mappings.
"""
if not sponsor:
return None
s = sponsor.lower().strip()
for canonical, variations in SPONSOR_MAPPINGS.items():
# Check if input matches canonical name (case-insensitive)
if s == canonical.lower():
return canonical
# Check variations and aliases
for v in variations:
v_lower = v.lower()
if v_lower == s:
return canonical
# If the variation is a known alias (like 'gsk'), check if it's in the string
if len(v) < 5 and v_lower in s:
return canonical
if canonical.lower() in s:
return canonical
return sponsor
def get_sponsor_variations(sponsor: str) -> Optional[List[str]]:
"""
Returns list of exact database 'org' values for a given sponsor alias.
"""
if not sponsor:
return None
# First, normalize the input to get the canonical name
canonical = normalize_sponsor(sponsor)
if canonical in SPONSOR_MAPPINGS:
return SPONSOR_MAPPINGS[canonical]
return None
|