Spaces:
Sleeping
Sleeping
File size: 4,845 Bytes
507be68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import pytest
import pandas as pd
import sys
import os
from unittest.mock import MagicMock, patch
# Add project root to path to import app modules
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from modules.utils import normalize_sponsor # noqa: E402
from modules.tools import expand_query # noqa: E402
from modules.graph_viz import build_graph # noqa: E402
from llama_index.core.schema import NodeWithScore, TextNode # noqa: E402
# --- Tests for normalize_sponsor ---
def test_normalize_sponsor_aliases():
assert normalize_sponsor("J&J") == "Janssen"
assert normalize_sponsor("Johnson & Johnson") == "Janssen"
assert normalize_sponsor("GSK") == "GlaxoSmithKline"
assert normalize_sponsor("Merck") == "Merck Sharp & Dohme"
assert normalize_sponsor("MSD") == "Merck Sharp & Dohme"
assert normalize_sponsor("BMS") == "Bristol-Myers Squibb"
def test_normalize_sponsor_no_change():
assert normalize_sponsor("Pfizer") == "Pfizer"
assert normalize_sponsor("Moderna") == "Moderna"
assert normalize_sponsor("Unknown Sponsor") == "Unknown Sponsor"
# --- Tests for Analytics Logic (Mocked) ---
def filter_dataframe(df, phase=None, status=None, sponsor=None, intervention=None):
"""
Replicating the logic from get_study_analytics for testing purposes.
"""
if phase:
target_phases = [p.strip().upper().replace(" ", "") for p in phase.split(",")]
df["phase_upper"] = df["phase"].astype(str).str.upper().str.replace(" ", "")
mask = df["phase_upper"].apply(lambda x: any(tp in x for tp in target_phases))
df = df[mask]
if status:
df = df[df["status"].str.upper() == status.upper()]
if sponsor:
target_sponsor = normalize_sponsor(sponsor).lower()
df["org_lower"] = df["org"].astype(str).apply(normalize_sponsor).str.lower()
df = df[df["org_lower"].str.contains(target_sponsor, regex=False)]
if intervention:
target_intervention = intervention.lower()
df["intervention_lower"] = df["intervention"].astype(str).str.lower()
df = df[df["intervention_lower"].str.contains(target_intervention, regex=False)]
return df
@pytest.fixture
def sample_df():
data = {
"nct_id": ["NCT001", "NCT002", "NCT003", "NCT004"],
"phase": ["PHASE1", "PHASE2", "PHASE3", "PHASE2"],
"status": ["RECRUITING", "COMPLETED", "COMPLETED", "RECRUITING"],
"org": ["Pfizer", "Janssen", "Merck Sharp & Dohme", "Pfizer"],
"intervention": ["Drug A", "Drug B", "Keytruda", "Drug A + Drug C"],
"start_year": [2020, 2021, 2022, 2023],
"title": [
"Study of Drug A",
"Study of Drug B",
"Keytruda Trial",
"Combo Study",
],
"condition": ["Cancer", "Diabetes", "Lung Cancer", "Cancer"],
}
return pd.DataFrame(data)
def test_analytics_filter_intervention(sample_df):
# Filter for Keytruda
filtered = filter_dataframe(sample_df, intervention="Keytruda")
assert len(filtered) == 1
assert filtered.iloc[0]["nct_id"] == "NCT003"
def test_analytics_filter_intervention_partial(sample_df):
# Filter for "Drug A" (should match NCT001 and NCT004)
filtered = filter_dataframe(sample_df, intervention="Drug A")
assert len(filtered) == 2
assert set(filtered["nct_id"]) == {"NCT001", "NCT004"}
# --- Tests for Query Expansion ---
@patch("modules.tools.Settings")
def test_expand_query(mock_settings):
# Mock LLM response
mock_response = MagicMock()
mock_response.text = "Expanded Query: cancer OR carcinoma OR tumor"
mock_settings.llm.complete.return_value = mock_response
query = "cancer"
expanded = expand_query(query)
assert "cancer OR carcinoma OR tumor" in expanded
mock_settings.llm.complete.assert_called_once()
def test_expand_query_skip_long():
long_query = "this is a very long query that should definitely be skipped because it has too many words"
assert expand_query(long_query) == long_query
# --- Tests for Graph Visualization ---
def test_build_graph():
data = [
{"nct_id": "NCT1", "title": "Study 1", "org": "Pfizer", "condition": "Cancer"},
{
"nct_id": "NCT2",
"title": "Study 2",
"org": "Merck",
"condition": "Cancer, Diabetes",
},
]
nodes, edges, config = build_graph(data)
# Check Nodes
# 2 Studies + 2 Sponsors + 2 Conditions (Cancer, Diabetes) = 6 Nodes
assert len(nodes) == 6
node_ids = [n.id for n in nodes]
assert "NCT1" in node_ids
assert "Pfizer" in node_ids
assert "Cancer" in node_ids
# Check Edges
# NCT1 -> Pfizer, NCT1 -> Cancer (2 edges)
# NCT2 -> Merck, NCT2 -> Cancer, NCT2 -> Diabetes (3 edges)
assert len(edges) == 5
|