Create test_cross_validation_stability.py
Browse filesImplement proper model validation tests with CV stability checks
- Add cross-validation stability test with coefficient of variation thresholds
- Add overfitting detection test using performance indicators
- Remove empty test function with meaningful implementation
- Include proper error handling and pytest.skip for missing data
tests/test_cross_validation_stability.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import numpy as np
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from model.train import ModelTrainer
|
| 6 |
+
|
| 7 |
+
def test_model_performance_thresholds():
|
| 8 |
+
"""Test that model meets minimum performance requirements"""
|
| 9 |
+
trainer = ModelTrainer()
|
| 10 |
+
|
| 11 |
+
# Load test results
|
| 12 |
+
metadata = trainer.load_metadata()
|
| 13 |
+
|
| 14 |
+
assert metadata['test_f1'] >= 0.75, "F1 score below threshold"
|
| 15 |
+
assert metadata['test_accuracy'] >= 0.75, "Accuracy below threshold"
|
| 16 |
+
|
| 17 |
+
def test_cross_validation_stability():
|
| 18 |
+
"""Test CV results show acceptable stability (low variance)"""
|
| 19 |
+
# Load CV results
|
| 20 |
+
try:
|
| 21 |
+
from path_config import path_manager
|
| 22 |
+
cv_results_path = path_manager.get_logs_path("cv_results.json")
|
| 23 |
+
|
| 24 |
+
if not cv_results_path.exists():
|
| 25 |
+
pytest.skip("No CV results available")
|
| 26 |
+
|
| 27 |
+
with open(cv_results_path, 'r') as f:
|
| 28 |
+
cv_data = json.load(f)
|
| 29 |
+
|
| 30 |
+
# Test CV stability - standard deviation should be reasonable
|
| 31 |
+
test_scores = cv_data.get('test_scores', {})
|
| 32 |
+
|
| 33 |
+
if 'f1' in test_scores:
|
| 34 |
+
f1_std = test_scores['f1'].get('std', 0)
|
| 35 |
+
f1_mean = test_scores['f1'].get('mean', 0)
|
| 36 |
+
|
| 37 |
+
# CV coefficient of variation should be < 0.15 (15%)
|
| 38 |
+
cv_coefficient = f1_std / f1_mean if f1_mean > 0 else 1
|
| 39 |
+
assert cv_coefficient < 0.15, f"CV results too unstable: CV={cv_coefficient:.3f}"
|
| 40 |
+
|
| 41 |
+
if 'accuracy' in test_scores:
|
| 42 |
+
acc_std = test_scores['accuracy'].get('std', 0)
|
| 43 |
+
acc_mean = test_scores['accuracy'].get('mean', 0)
|
| 44 |
+
|
| 45 |
+
cv_coefficient = acc_std / acc_mean if acc_mean > 0 else 1
|
| 46 |
+
assert cv_coefficient < 0.15, f"Accuracy CV too unstable: CV={cv_coefficient:.3f}"
|
| 47 |
+
|
| 48 |
+
except FileNotFoundError:
|
| 49 |
+
pytest.skip("CV results file not found")
|
| 50 |
+
|
| 51 |
+
def test_model_overfitting_indicators():
|
| 52 |
+
"""Test that model doesn't show signs of severe overfitting"""
|
| 53 |
+
try:
|
| 54 |
+
from path_config import path_manager
|
| 55 |
+
cv_results_path = path_manager.get_logs_path("cv_results.json")
|
| 56 |
+
|
| 57 |
+
if not cv_results_path.exists():
|
| 58 |
+
pytest.skip("No CV results available")
|
| 59 |
+
|
| 60 |
+
with open(cv_results_path, 'r') as f:
|
| 61 |
+
cv_data = json.load(f)
|
| 62 |
+
|
| 63 |
+
# Check overfitting score if available
|
| 64 |
+
perf_indicators = cv_data.get('performance_indicators', {})
|
| 65 |
+
overfitting_score = perf_indicators.get('overfitting_score')
|
| 66 |
+
|
| 67 |
+
if overfitting_score is not None and overfitting_score != 'Unknown':
|
| 68 |
+
# Overfitting score should be reasonable (< 0.1 difference)
|
| 69 |
+
assert overfitting_score < 0.1, f"Model shows overfitting: {overfitting_score}"
|
| 70 |
+
|
| 71 |
+
except (FileNotFoundError, KeyError):
|
| 72 |
+
pytest.skip("Performance indicators not available")
|