Commit
Β·
719d51e
1
Parent(s):
c29bcf3
Update model/train.py
Browse files- model/train.py +201 -109
model/train.py
CHANGED
|
@@ -22,7 +22,10 @@ import logging
|
|
| 22 |
import json
|
| 23 |
import joblib
|
| 24 |
import hashlib
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
| 26 |
from typing import Dict, Tuple, Optional, Any
|
| 27 |
import warnings
|
| 28 |
import re
|
|
@@ -75,6 +78,114 @@ def preprocess_text_function(texts):
|
|
| 75 |
return processed
|
| 76 |
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
class RobustModelTrainer:
|
| 79 |
"""Production-ready model trainer with comprehensive evaluation and validation"""
|
| 80 |
|
|
@@ -82,6 +193,7 @@ class RobustModelTrainer:
|
|
| 82 |
self.setup_paths()
|
| 83 |
self.setup_training_config()
|
| 84 |
self.setup_models()
|
|
|
|
| 85 |
|
| 86 |
def setup_paths(self):
|
| 87 |
"""Setup all necessary paths"""
|
|
@@ -107,14 +219,14 @@ class RobustModelTrainer:
|
|
| 107 |
self.test_size = 0.2
|
| 108 |
self.validation_size = 0.1
|
| 109 |
self.random_state = 42
|
| 110 |
-
self.cv_folds =
|
| 111 |
-
self.max_features =
|
| 112 |
-
self.min_df =
|
| 113 |
self.max_df = 0.95
|
| 114 |
-
self.ngram_range = (1,
|
| 115 |
-
self.max_iter =
|
| 116 |
self.class_weight = 'balanced'
|
| 117 |
-
self.feature_selection_k =
|
| 118 |
|
| 119 |
def setup_models(self):
|
| 120 |
"""Setup model configurations for comparison"""
|
|
@@ -123,24 +235,24 @@ class RobustModelTrainer:
|
|
| 123 |
'model': LogisticRegression(
|
| 124 |
max_iter=self.max_iter,
|
| 125 |
class_weight=self.class_weight,
|
| 126 |
-
random_state=self.random_state
|
|
|
|
| 127 |
),
|
| 128 |
'param_grid': {
|
| 129 |
-
'model__C': [0.1, 1, 10,
|
| 130 |
-
'model__penalty': ['l2']
|
| 131 |
-
'model__solver': ['liblinear', 'lbfgs']
|
| 132 |
}
|
| 133 |
},
|
| 134 |
'random_forest': {
|
| 135 |
'model': RandomForestClassifier(
|
| 136 |
-
n_estimators=
|
| 137 |
class_weight=self.class_weight,
|
| 138 |
-
random_state=self.random_state
|
|
|
|
| 139 |
),
|
| 140 |
'param_grid': {
|
| 141 |
-
'model__n_estimators': [50, 100,
|
| 142 |
-
'model__max_depth': [10,
|
| 143 |
-
'model__min_samples_split': [2, 5, 10]
|
| 144 |
}
|
| 145 |
}
|
| 146 |
}
|
|
@@ -149,6 +261,8 @@ class RobustModelTrainer:
|
|
| 149 |
"""Load and validate training data"""
|
| 150 |
try:
|
| 151 |
logger.info("Loading training data...")
|
|
|
|
|
|
|
| 152 |
|
| 153 |
if not self.data_path.exists():
|
| 154 |
return False, None, f"Data file not found: {self.data_path}"
|
|
@@ -182,7 +296,7 @@ class RobustModelTrainer:
|
|
| 182 |
return False, None, f"Need at least 2 classes, found: {unique_labels}"
|
| 183 |
|
| 184 |
# Check minimum sample size
|
| 185 |
-
if len(df) <
|
| 186 |
return False, None, f"Insufficient samples for training: {len(df)}"
|
| 187 |
|
| 188 |
# Check class balance
|
|
@@ -204,15 +318,18 @@ class RobustModelTrainer:
|
|
| 204 |
return False, None, error_msg
|
| 205 |
|
| 206 |
def create_preprocessing_pipeline(self) -> Pipeline:
|
| 207 |
-
"""Create
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
# Use the standalone function instead of lambda
|
| 210 |
text_preprocessor = FunctionTransformer(
|
| 211 |
-
func=preprocess_text_function,
|
| 212 |
validate=False
|
| 213 |
)
|
| 214 |
|
| 215 |
-
# TF-IDF vectorization
|
| 216 |
vectorizer = TfidfVectorizer(
|
| 217 |
max_features=self.max_features,
|
| 218 |
min_df=self.min_df,
|
|
@@ -226,7 +343,7 @@ class RobustModelTrainer:
|
|
| 226 |
# Feature selection
|
| 227 |
feature_selector = SelectKBest(
|
| 228 |
score_func=chi2,
|
| 229 |
-
k=self.feature_selection_k
|
| 230 |
)
|
| 231 |
|
| 232 |
# Create pipeline
|
|
@@ -241,8 +358,10 @@ class RobustModelTrainer:
|
|
| 241 |
|
| 242 |
def comprehensive_evaluation(self, model, X_test, y_test, X_train=None, y_train=None) -> Dict:
|
| 243 |
"""Comprehensive model evaluation with multiple metrics"""
|
| 244 |
-
|
| 245 |
-
|
|
|
|
|
|
|
| 246 |
# Predictions
|
| 247 |
y_pred = model.predict(X_test)
|
| 248 |
y_pred_proba = model.predict_proba(X_test)[:, 1]
|
|
@@ -260,18 +379,18 @@ class RobustModelTrainer:
|
|
| 260 |
cm = confusion_matrix(y_test, y_pred)
|
| 261 |
metrics['confusion_matrix'] = cm.tolist()
|
| 262 |
|
| 263 |
-
# Classification report
|
| 264 |
-
class_report = classification_report(y_test, y_pred, output_dict=True)
|
| 265 |
-
metrics['classification_report'] = class_report
|
| 266 |
-
|
| 267 |
# Cross-validation scores if training data provided
|
| 268 |
-
if X_train is not None and y_train is not None:
|
| 269 |
try:
|
| 270 |
cv_scores = cross_val_score(
|
| 271 |
model, X_train, y_train,
|
| 272 |
cv=StratifiedKFold(
|
| 273 |
-
n_splits=self.cv_folds,
|
| 274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
)
|
| 276 |
metrics['cv_scores'] = {
|
| 277 |
'mean': float(cv_scores.mean()),
|
|
@@ -281,30 +400,11 @@ class RobustModelTrainer:
|
|
| 281 |
except Exception as e:
|
| 282 |
logger.warning(f"Cross-validation failed: {e}")
|
| 283 |
metrics['cv_scores'] = None
|
|
|
|
|
|
|
| 284 |
|
| 285 |
-
#
|
| 286 |
-
try:
|
| 287 |
-
if hasattr(model, 'feature_importances_'):
|
| 288 |
-
feature_importance = model.feature_importances_
|
| 289 |
-
metrics['feature_importance_stats'] = {
|
| 290 |
-
'mean': float(feature_importance.mean()),
|
| 291 |
-
'std': float(feature_importance.std()),
|
| 292 |
-
'top_features': feature_importance.argsort()[-10:][::-1].tolist()
|
| 293 |
-
}
|
| 294 |
-
elif hasattr(model, 'coef_'):
|
| 295 |
-
coefficients = model.coef_[0]
|
| 296 |
-
metrics['coefficient_stats'] = {
|
| 297 |
-
'mean': float(coefficients.mean()),
|
| 298 |
-
'std': float(coefficients.std()),
|
| 299 |
-
'top_positive': coefficients.argsort()[-10:][::-1].tolist(),
|
| 300 |
-
'top_negative': coefficients.argsort()[:10].tolist()
|
| 301 |
-
}
|
| 302 |
-
except Exception as e:
|
| 303 |
-
logger.warning(f"Feature importance extraction failed: {e}")
|
| 304 |
-
|
| 305 |
-
# Model complexity metrics
|
| 306 |
try:
|
| 307 |
-
# Training accuracy for overfitting detection
|
| 308 |
if X_train is not None and y_train is not None:
|
| 309 |
y_train_pred = model.predict(X_train)
|
| 310 |
train_accuracy = accuracy_score(y_train, y_train_pred)
|
|
@@ -318,7 +418,9 @@ class RobustModelTrainer:
|
|
| 318 |
|
| 319 |
def hyperparameter_tuning(self, pipeline, X_train, y_train, model_name: str) -> Tuple[Any, Dict]:
|
| 320 |
"""Perform hyperparameter tuning with cross-validation"""
|
| 321 |
-
|
|
|
|
|
|
|
| 322 |
|
| 323 |
try:
|
| 324 |
# Set the model in the pipeline
|
|
@@ -327,15 +429,18 @@ class RobustModelTrainer:
|
|
| 327 |
# Get parameter grid
|
| 328 |
param_grid = self.models[model_name]['param_grid']
|
| 329 |
|
|
|
|
|
|
|
|
|
|
| 330 |
# Create GridSearchCV
|
| 331 |
grid_search = GridSearchCV(
|
| 332 |
pipeline,
|
| 333 |
param_grid,
|
| 334 |
-
cv=StratifiedKFold(n_splits=
|
| 335 |
shuffle=True, random_state=self.random_state),
|
| 336 |
scoring='f1_weighted',
|
| 337 |
-
n_jobs=-1,
|
| 338 |
-
verbose=
|
| 339 |
)
|
| 340 |
|
| 341 |
# Fit grid search
|
|
@@ -369,8 +474,7 @@ class RobustModelTrainer:
|
|
| 369 |
|
| 370 |
def train_and_evaluate_models(self, X_train, X_test, y_train, y_test) -> Dict:
|
| 371 |
"""Train and evaluate multiple models"""
|
| 372 |
-
|
| 373 |
-
|
| 374 |
results = {}
|
| 375 |
|
| 376 |
for model_name in self.models.keys():
|
|
@@ -409,7 +513,9 @@ class RobustModelTrainer:
|
|
| 409 |
|
| 410 |
def select_best_model(self, results: Dict) -> Tuple[str, Any, Dict]:
|
| 411 |
"""Select the best performing model"""
|
| 412 |
-
|
|
|
|
|
|
|
| 413 |
|
| 414 |
best_model_name = None
|
| 415 |
best_model = None
|
|
@@ -439,7 +545,8 @@ class RobustModelTrainer:
|
|
| 439 |
def save_model_artifacts(self, model, model_name: str, metrics: Dict) -> bool:
|
| 440 |
"""Save model artifacts and metadata"""
|
| 441 |
try:
|
| 442 |
-
|
|
|
|
| 443 |
|
| 444 |
# Save the full pipeline
|
| 445 |
joblib.dump(model, self.pipeline_path)
|
|
@@ -449,14 +556,10 @@ class RobustModelTrainer:
|
|
| 449 |
if hasattr(model, 'named_steps') and 'model' in model.named_steps:
|
| 450 |
joblib.dump(model.named_steps['model'], self.model_path)
|
| 451 |
logger.info(f"β
Saved model to {self.model_path}")
|
| 452 |
-
else:
|
| 453 |
-
logger.warning("β Could not extract model component")
|
| 454 |
|
| 455 |
if hasattr(model, 'named_steps') and 'vectorize' in model.named_steps:
|
| 456 |
joblib.dump(model.named_steps['vectorize'], self.vectorizer_path)
|
| 457 |
logger.info(f"β
Saved vectorizer to {self.vectorizer_path}")
|
| 458 |
-
else:
|
| 459 |
-
logger.warning("β Could not extract vectorizer component")
|
| 460 |
|
| 461 |
# Generate data hash
|
| 462 |
data_hash = hashlib.md5(str(datetime.now()).encode()).hexdigest()
|
|
@@ -466,8 +569,6 @@ class RobustModelTrainer:
|
|
| 466 |
'model_version': f"v1.0_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
| 467 |
'model_type': model_name,
|
| 468 |
'data_version': data_hash,
|
| 469 |
-
'train_size': metrics.get('train_accuracy', 'Unknown'),
|
| 470 |
-
'test_size': len(metrics.get('confusion_matrix', [[0]])[0]) if 'confusion_matrix' in metrics else 'Unknown',
|
| 471 |
'test_accuracy': metrics['accuracy'],
|
| 472 |
'test_f1': metrics['f1'],
|
| 473 |
'test_precision': metrics['precision'],
|
|
@@ -479,7 +580,6 @@ class RobustModelTrainer:
|
|
| 479 |
'timestamp': datetime.now().isoformat(),
|
| 480 |
'training_config': {
|
| 481 |
'test_size': self.test_size,
|
| 482 |
-
'validation_size': self.validation_size,
|
| 483 |
'cv_folds': self.cv_folds,
|
| 484 |
'max_features': self.max_features,
|
| 485 |
'ngram_range': self.ngram_range,
|
|
@@ -492,46 +592,12 @@ class RobustModelTrainer:
|
|
| 492 |
json.dump(metadata, f, indent=2)
|
| 493 |
|
| 494 |
logger.info(f"β
Model artifacts saved successfully")
|
| 495 |
-
logger.info(f"Model path: {self.model_path}")
|
| 496 |
-
logger.info(f"Vectorizer path: {self.vectorizer_path}")
|
| 497 |
-
logger.info(f"Pipeline path: {self.pipeline_path}")
|
| 498 |
-
logger.info(f"Metadata path: {self.metadata_path}")
|
| 499 |
-
|
| 500 |
return True
|
| 501 |
|
| 502 |
except Exception as e:
|
| 503 |
logger.error(f"Failed to save model artifacts: {str(e)}")
|
| 504 |
return False
|
| 505 |
|
| 506 |
-
def save_evaluation_results(self, results: Dict) -> bool:
|
| 507 |
-
"""Save comprehensive evaluation results"""
|
| 508 |
-
try:
|
| 509 |
-
# Clean results for JSON serialization
|
| 510 |
-
clean_results = {}
|
| 511 |
-
for model_name, result in results.items():
|
| 512 |
-
if 'error' in result:
|
| 513 |
-
clean_results[model_name] = result
|
| 514 |
-
else:
|
| 515 |
-
clean_results[model_name] = {
|
| 516 |
-
'tuning_results': {
|
| 517 |
-
k: v for k, v in result['tuning_results'].items()
|
| 518 |
-
if k != 'best_estimator'
|
| 519 |
-
},
|
| 520 |
-
'evaluation_metrics': result['evaluation_metrics'],
|
| 521 |
-
'training_time': result['training_time']
|
| 522 |
-
}
|
| 523 |
-
|
| 524 |
-
# Save results
|
| 525 |
-
with open(self.evaluation_path, 'w') as f:
|
| 526 |
-
json.dump(clean_results, f, indent=2, default=str)
|
| 527 |
-
|
| 528 |
-
logger.info(f"Evaluation results saved to {self.evaluation_path}")
|
| 529 |
-
return True
|
| 530 |
-
|
| 531 |
-
except Exception as e:
|
| 532 |
-
logger.error(f"Failed to save evaluation results: {str(e)}")
|
| 533 |
-
return False
|
| 534 |
-
|
| 535 |
def train_model(self, data_path: str = None) -> Tuple[bool, str]:
|
| 536 |
"""Main training function with comprehensive pipeline"""
|
| 537 |
try:
|
|
@@ -546,35 +612,52 @@ class RobustModelTrainer:
|
|
| 546 |
if not success:
|
| 547 |
return False, message
|
| 548 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
# Prepare data
|
| 550 |
X = df['text'].values
|
| 551 |
y = df['label'].values
|
| 552 |
|
| 553 |
# Train-test split
|
|
|
|
| 554 |
X_train, X_test, y_train, y_test = train_test_split(
|
| 555 |
X, y,
|
| 556 |
test_size=self.test_size,
|
| 557 |
-
stratify=y,
|
| 558 |
random_state=self.random_state
|
| 559 |
)
|
| 560 |
|
| 561 |
-
logger.info(
|
| 562 |
-
f"Data split: {len(X_train)} train, {len(X_test)} test")
|
| 563 |
|
| 564 |
# Train and evaluate models
|
| 565 |
results = self.train_and_evaluate_models(
|
| 566 |
X_train, X_test, y_train, y_test)
|
| 567 |
|
| 568 |
# Select best model
|
| 569 |
-
best_model_name, best_model, best_metrics = self.select_best_model(
|
| 570 |
-
results)
|
| 571 |
|
| 572 |
# Save model artifacts
|
| 573 |
if not self.save_model_artifacts(best_model, best_model_name, best_metrics):
|
| 574 |
return False, "Failed to save model artifacts"
|
| 575 |
|
| 576 |
-
#
|
| 577 |
-
self.
|
| 578 |
|
| 579 |
success_message = (
|
| 580 |
f"Model training completed successfully. "
|
|
@@ -586,6 +669,8 @@ class RobustModelTrainer:
|
|
| 586 |
return True, success_message
|
| 587 |
|
| 588 |
except Exception as e:
|
|
|
|
|
|
|
| 589 |
error_message = f"Model training failed: {str(e)}"
|
| 590 |
logger.error(error_message)
|
| 591 |
return False, error_message
|
|
@@ -593,8 +678,15 @@ class RobustModelTrainer:
|
|
| 593 |
|
| 594 |
def main():
|
| 595 |
"""Main execution function"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
trainer = RobustModelTrainer()
|
| 597 |
-
success, message = trainer.train_model()
|
| 598 |
|
| 599 |
if success:
|
| 600 |
print(f"β
{message}")
|
|
|
|
| 22 |
import json
|
| 23 |
import joblib
|
| 24 |
import hashlib
|
| 25 |
+
import sys
|
| 26 |
+
import os
|
| 27 |
+
import time
|
| 28 |
+
from datetime import datetime, timedelta
|
| 29 |
from typing import Dict, Tuple, Optional, Any
|
| 30 |
import warnings
|
| 31 |
import re
|
|
|
|
| 78 |
return processed
|
| 79 |
|
| 80 |
|
| 81 |
+
class ProgressTracker:
|
| 82 |
+
"""Progress tracking with time estimation"""
|
| 83 |
+
|
| 84 |
+
def __init__(self, total_steps: int, description: str = "Training"):
|
| 85 |
+
self.total_steps = total_steps
|
| 86 |
+
self.current_step = 0
|
| 87 |
+
self.start_time = time.time()
|
| 88 |
+
self.description = description
|
| 89 |
+
self.step_times = []
|
| 90 |
+
|
| 91 |
+
def update(self, step_name: str = ""):
|
| 92 |
+
"""Update progress and print status"""
|
| 93 |
+
self.current_step += 1
|
| 94 |
+
current_time = time.time()
|
| 95 |
+
elapsed = current_time - self.start_time
|
| 96 |
+
|
| 97 |
+
# Calculate progress percentage
|
| 98 |
+
progress_pct = (self.current_step / self.total_steps) * 100
|
| 99 |
+
|
| 100 |
+
# Estimate remaining time
|
| 101 |
+
if self.current_step > 0:
|
| 102 |
+
avg_time_per_step = elapsed / self.current_step
|
| 103 |
+
remaining_steps = self.total_steps - self.current_step
|
| 104 |
+
eta_seconds = avg_time_per_step * remaining_steps
|
| 105 |
+
eta = timedelta(seconds=int(eta_seconds))
|
| 106 |
+
else:
|
| 107 |
+
eta = "calculating..."
|
| 108 |
+
|
| 109 |
+
# Create progress bar
|
| 110 |
+
bar_length = 30
|
| 111 |
+
filled_length = int(bar_length * self.current_step // self.total_steps)
|
| 112 |
+
bar = 'β' * filled_length + 'β' * (bar_length - filled_length)
|
| 113 |
+
|
| 114 |
+
# Print progress
|
| 115 |
+
status_msg = f"\r{self.description}: [{bar}] {progress_pct:.1f}% | Step {self.current_step}/{self.total_steps}"
|
| 116 |
+
if step_name:
|
| 117 |
+
status_msg += f" | {step_name}"
|
| 118 |
+
if eta != "calculating...":
|
| 119 |
+
status_msg += f" | ETA: {eta}"
|
| 120 |
+
|
| 121 |
+
print(status_msg, end='', flush=True)
|
| 122 |
+
|
| 123 |
+
# Store step time for better estimation
|
| 124 |
+
if len(self.step_times) >= 3: # Keep last 3 step times for moving average
|
| 125 |
+
self.step_times.pop(0)
|
| 126 |
+
self.step_times.append(current_time - (self.start_time + sum(self.step_times)))
|
| 127 |
+
|
| 128 |
+
def finish(self):
|
| 129 |
+
"""Complete progress tracking"""
|
| 130 |
+
total_time = time.time() - self.start_time
|
| 131 |
+
print(f"\n{self.description} completed in {timedelta(seconds=int(total_time))}")
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def estimate_training_time(dataset_size: int, enable_tuning: bool = True, cv_folds: int = 3) -> Dict:
|
| 135 |
+
"""Estimate training time based on dataset characteristics"""
|
| 136 |
+
|
| 137 |
+
# Base time estimates (in seconds) based on empirical testing
|
| 138 |
+
base_times = {
|
| 139 |
+
'preprocessing': max(0.1, dataset_size * 0.001), # ~1ms per sample
|
| 140 |
+
'vectorization': max(0.5, dataset_size * 0.01), # ~10ms per sample
|
| 141 |
+
'feature_selection': max(0.2, dataset_size * 0.005), # ~5ms per sample
|
| 142 |
+
'simple_training': max(1.0, dataset_size * 0.02), # ~20ms per sample
|
| 143 |
+
'evaluation': max(0.5, dataset_size * 0.01), # ~10ms per sample
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
# Hyperparameter tuning multipliers
|
| 147 |
+
tuning_multipliers = {
|
| 148 |
+
'logistic_regression': 8 if enable_tuning else 1, # 8 param combinations
|
| 149 |
+
'random_forest': 12 if enable_tuning else 1, # 12 param combinations
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
# Cross-validation multiplier
|
| 153 |
+
cv_multiplier = cv_folds if dataset_size > 100 else 1
|
| 154 |
+
|
| 155 |
+
# Calculate estimates
|
| 156 |
+
estimates = {}
|
| 157 |
+
|
| 158 |
+
# Preprocessing steps
|
| 159 |
+
estimates['data_loading'] = 0.5
|
| 160 |
+
estimates['preprocessing'] = base_times['preprocessing']
|
| 161 |
+
estimates['vectorization'] = base_times['vectorization']
|
| 162 |
+
estimates['feature_selection'] = base_times['feature_selection']
|
| 163 |
+
|
| 164 |
+
# Model training
|
| 165 |
+
for model_name, multiplier in tuning_multipliers.items():
|
| 166 |
+
model_time = base_times['simple_training'] * multiplier * cv_multiplier
|
| 167 |
+
estimates[f'{model_name}_training'] = model_time
|
| 168 |
+
estimates[f'{model_name}_evaluation'] = base_times['evaluation']
|
| 169 |
+
|
| 170 |
+
# Model saving
|
| 171 |
+
estimates['model_saving'] = 1.0
|
| 172 |
+
|
| 173 |
+
# Total estimate
|
| 174 |
+
total_estimate = sum(estimates.values())
|
| 175 |
+
|
| 176 |
+
# Add 20% buffer for overhead
|
| 177 |
+
total_estimate *= 1.2
|
| 178 |
+
|
| 179 |
+
return {
|
| 180 |
+
'detailed_estimates': estimates,
|
| 181 |
+
'total_seconds': total_estimate,
|
| 182 |
+
'total_formatted': str(timedelta(seconds=int(total_estimate))),
|
| 183 |
+
'dataset_size': dataset_size,
|
| 184 |
+
'enable_tuning': enable_tuning,
|
| 185 |
+
'cv_folds': cv_folds
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
|
| 189 |
class RobustModelTrainer:
|
| 190 |
"""Production-ready model trainer with comprehensive evaluation and validation"""
|
| 191 |
|
|
|
|
| 193 |
self.setup_paths()
|
| 194 |
self.setup_training_config()
|
| 195 |
self.setup_models()
|
| 196 |
+
self.progress_tracker = None
|
| 197 |
|
| 198 |
def setup_paths(self):
|
| 199 |
"""Setup all necessary paths"""
|
|
|
|
| 219 |
self.test_size = 0.2
|
| 220 |
self.validation_size = 0.1
|
| 221 |
self.random_state = 42
|
| 222 |
+
self.cv_folds = 3
|
| 223 |
+
self.max_features = 5000 # Reduced for speed
|
| 224 |
+
self.min_df = 1 # More lenient for small datasets
|
| 225 |
self.max_df = 0.95
|
| 226 |
+
self.ngram_range = (1, 2) # Reduced for speed
|
| 227 |
+
self.max_iter = 500 # Reduced for speed
|
| 228 |
self.class_weight = 'balanced'
|
| 229 |
+
self.feature_selection_k = 2000 # Reduced for speed
|
| 230 |
|
| 231 |
def setup_models(self):
|
| 232 |
"""Setup model configurations for comparison"""
|
|
|
|
| 235 |
'model': LogisticRegression(
|
| 236 |
max_iter=self.max_iter,
|
| 237 |
class_weight=self.class_weight,
|
| 238 |
+
random_state=self.random_state,
|
| 239 |
+
n_jobs=-1 # Use all cores
|
| 240 |
),
|
| 241 |
'param_grid': {
|
| 242 |
+
'model__C': [0.1, 1, 10], # Reduced grid
|
| 243 |
+
'model__penalty': ['l2']
|
|
|
|
| 244 |
}
|
| 245 |
},
|
| 246 |
'random_forest': {
|
| 247 |
'model': RandomForestClassifier(
|
| 248 |
+
n_estimators=50, # Reduced for speed
|
| 249 |
class_weight=self.class_weight,
|
| 250 |
+
random_state=self.random_state,
|
| 251 |
+
n_jobs=-1 # Use all cores
|
| 252 |
),
|
| 253 |
'param_grid': {
|
| 254 |
+
'model__n_estimators': [50, 100], # Reduced grid
|
| 255 |
+
'model__max_depth': [10, None]
|
|
|
|
| 256 |
}
|
| 257 |
}
|
| 258 |
}
|
|
|
|
| 261 |
"""Load and validate training data"""
|
| 262 |
try:
|
| 263 |
logger.info("Loading training data...")
|
| 264 |
+
if self.progress_tracker:
|
| 265 |
+
self.progress_tracker.update("Loading data")
|
| 266 |
|
| 267 |
if not self.data_path.exists():
|
| 268 |
return False, None, f"Data file not found: {self.data_path}"
|
|
|
|
| 296 |
return False, None, f"Need at least 2 classes, found: {unique_labels}"
|
| 297 |
|
| 298 |
# Check minimum sample size
|
| 299 |
+
if len(df) < 10:
|
| 300 |
return False, None, f"Insufficient samples for training: {len(df)}"
|
| 301 |
|
| 302 |
# Check class balance
|
|
|
|
| 318 |
return False, None, error_msg
|
| 319 |
|
| 320 |
def create_preprocessing_pipeline(self) -> Pipeline:
|
| 321 |
+
"""Create preprocessing pipeline"""
|
| 322 |
+
|
| 323 |
+
if self.progress_tracker:
|
| 324 |
+
self.progress_tracker.update("Creating pipeline")
|
| 325 |
|
| 326 |
# Use the standalone function instead of lambda
|
| 327 |
text_preprocessor = FunctionTransformer(
|
| 328 |
+
func=preprocess_text_function,
|
| 329 |
validate=False
|
| 330 |
)
|
| 331 |
|
| 332 |
+
# TF-IDF vectorization with optimized parameters
|
| 333 |
vectorizer = TfidfVectorizer(
|
| 334 |
max_features=self.max_features,
|
| 335 |
min_df=self.min_df,
|
|
|
|
| 343 |
# Feature selection
|
| 344 |
feature_selector = SelectKBest(
|
| 345 |
score_func=chi2,
|
| 346 |
+
k=min(self.feature_selection_k, self.max_features)
|
| 347 |
)
|
| 348 |
|
| 349 |
# Create pipeline
|
|
|
|
| 358 |
|
| 359 |
def comprehensive_evaluation(self, model, X_test, y_test, X_train=None, y_train=None) -> Dict:
|
| 360 |
"""Comprehensive model evaluation with multiple metrics"""
|
| 361 |
+
|
| 362 |
+
if self.progress_tracker:
|
| 363 |
+
self.progress_tracker.update("Evaluating model")
|
| 364 |
+
|
| 365 |
# Predictions
|
| 366 |
y_pred = model.predict(X_test)
|
| 367 |
y_pred_proba = model.predict_proba(X_test)[:, 1]
|
|
|
|
| 379 |
cm = confusion_matrix(y_test, y_pred)
|
| 380 |
metrics['confusion_matrix'] = cm.tolist()
|
| 381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
# Cross-validation scores if training data provided
|
| 383 |
+
if X_train is not None and y_train is not None and len(X_train) >= 50:
|
| 384 |
try:
|
| 385 |
cv_scores = cross_val_score(
|
| 386 |
model, X_train, y_train,
|
| 387 |
cv=StratifiedKFold(
|
| 388 |
+
n_splits=min(self.cv_folds, len(X_train) // 10),
|
| 389 |
+
shuffle=True,
|
| 390 |
+
random_state=self.random_state
|
| 391 |
+
),
|
| 392 |
+
scoring='f1_weighted',
|
| 393 |
+
n_jobs=-1 # Parallel CV
|
| 394 |
)
|
| 395 |
metrics['cv_scores'] = {
|
| 396 |
'mean': float(cv_scores.mean()),
|
|
|
|
| 400 |
except Exception as e:
|
| 401 |
logger.warning(f"Cross-validation failed: {e}")
|
| 402 |
metrics['cv_scores'] = None
|
| 403 |
+
else:
|
| 404 |
+
metrics['cv_scores'] = {'note': 'Skipped for small dataset'}
|
| 405 |
|
| 406 |
+
# Training accuracy for overfitting detection
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
try:
|
|
|
|
| 408 |
if X_train is not None and y_train is not None:
|
| 409 |
y_train_pred = model.predict(X_train)
|
| 410 |
train_accuracy = accuracy_score(y_train, y_train_pred)
|
|
|
|
| 418 |
|
| 419 |
def hyperparameter_tuning(self, pipeline, X_train, y_train, model_name: str) -> Tuple[Any, Dict]:
|
| 420 |
"""Perform hyperparameter tuning with cross-validation"""
|
| 421 |
+
|
| 422 |
+
if self.progress_tracker:
|
| 423 |
+
self.progress_tracker.update(f"Tuning {model_name}")
|
| 424 |
|
| 425 |
try:
|
| 426 |
# Set the model in the pipeline
|
|
|
|
| 429 |
# Get parameter grid
|
| 430 |
param_grid = self.models[model_name]['param_grid']
|
| 431 |
|
| 432 |
+
# Adaptive CV folds based on dataset size
|
| 433 |
+
cv_folds = min(self.cv_folds, len(X_train) // 10, 5)
|
| 434 |
+
|
| 435 |
# Create GridSearchCV
|
| 436 |
grid_search = GridSearchCV(
|
| 437 |
pipeline,
|
| 438 |
param_grid,
|
| 439 |
+
cv=StratifiedKFold(n_splits=cv_folds,
|
| 440 |
shuffle=True, random_state=self.random_state),
|
| 441 |
scoring='f1_weighted',
|
| 442 |
+
n_jobs=-1, # Use all cores
|
| 443 |
+
verbose=0 # Reduce verbosity for speed
|
| 444 |
)
|
| 445 |
|
| 446 |
# Fit grid search
|
|
|
|
| 474 |
|
| 475 |
def train_and_evaluate_models(self, X_train, X_test, y_train, y_test) -> Dict:
|
| 476 |
"""Train and evaluate multiple models"""
|
| 477 |
+
|
|
|
|
| 478 |
results = {}
|
| 479 |
|
| 480 |
for model_name in self.models.keys():
|
|
|
|
| 513 |
|
| 514 |
def select_best_model(self, results: Dict) -> Tuple[str, Any, Dict]:
|
| 515 |
"""Select the best performing model"""
|
| 516 |
+
|
| 517 |
+
if self.progress_tracker:
|
| 518 |
+
self.progress_tracker.update("Selecting best model")
|
| 519 |
|
| 520 |
best_model_name = None
|
| 521 |
best_model = None
|
|
|
|
| 545 |
def save_model_artifacts(self, model, model_name: str, metrics: Dict) -> bool:
|
| 546 |
"""Save model artifacts and metadata"""
|
| 547 |
try:
|
| 548 |
+
if self.progress_tracker:
|
| 549 |
+
self.progress_tracker.update("Saving model")
|
| 550 |
|
| 551 |
# Save the full pipeline
|
| 552 |
joblib.dump(model, self.pipeline_path)
|
|
|
|
| 556 |
if hasattr(model, 'named_steps') and 'model' in model.named_steps:
|
| 557 |
joblib.dump(model.named_steps['model'], self.model_path)
|
| 558 |
logger.info(f"β
Saved model to {self.model_path}")
|
|
|
|
|
|
|
| 559 |
|
| 560 |
if hasattr(model, 'named_steps') and 'vectorize' in model.named_steps:
|
| 561 |
joblib.dump(model.named_steps['vectorize'], self.vectorizer_path)
|
| 562 |
logger.info(f"β
Saved vectorizer to {self.vectorizer_path}")
|
|
|
|
|
|
|
| 563 |
|
| 564 |
# Generate data hash
|
| 565 |
data_hash = hashlib.md5(str(datetime.now()).encode()).hexdigest()
|
|
|
|
| 569 |
'model_version': f"v1.0_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
| 570 |
'model_type': model_name,
|
| 571 |
'data_version': data_hash,
|
|
|
|
|
|
|
| 572 |
'test_accuracy': metrics['accuracy'],
|
| 573 |
'test_f1': metrics['f1'],
|
| 574 |
'test_precision': metrics['precision'],
|
|
|
|
| 580 |
'timestamp': datetime.now().isoformat(),
|
| 581 |
'training_config': {
|
| 582 |
'test_size': self.test_size,
|
|
|
|
| 583 |
'cv_folds': self.cv_folds,
|
| 584 |
'max_features': self.max_features,
|
| 585 |
'ngram_range': self.ngram_range,
|
|
|
|
| 592 |
json.dump(metadata, f, indent=2)
|
| 593 |
|
| 594 |
logger.info(f"β
Model artifacts saved successfully")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
return True
|
| 596 |
|
| 597 |
except Exception as e:
|
| 598 |
logger.error(f"Failed to save model artifacts: {str(e)}")
|
| 599 |
return False
|
| 600 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
def train_model(self, data_path: str = None) -> Tuple[bool, str]:
|
| 602 |
"""Main training function with comprehensive pipeline"""
|
| 603 |
try:
|
|
|
|
| 612 |
if not success:
|
| 613 |
return False, message
|
| 614 |
|
| 615 |
+
# Estimate training time and setup progress tracker
|
| 616 |
+
time_estimate = estimate_training_time(
|
| 617 |
+
len(df),
|
| 618 |
+
enable_tuning=True,
|
| 619 |
+
cv_folds=self.cv_folds
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
print(f"\nπ Training Configuration:")
|
| 623 |
+
print(f"Dataset size: {len(df)} samples")
|
| 624 |
+
print(f"Estimated time: {time_estimate['total_formatted']}")
|
| 625 |
+
print(f"Models to train: {len(self.models)}")
|
| 626 |
+
print(f"Cross-validation folds: {self.cv_folds}")
|
| 627 |
+
print()
|
| 628 |
+
|
| 629 |
+
# Setup progress tracker
|
| 630 |
+
total_steps = 4 + (len(self.models) * 2) + 1 # Load, split, 2*models, select, save
|
| 631 |
+
self.progress_tracker = ProgressTracker(total_steps, "Training Progress")
|
| 632 |
+
|
| 633 |
# Prepare data
|
| 634 |
X = df['text'].values
|
| 635 |
y = df['label'].values
|
| 636 |
|
| 637 |
# Train-test split
|
| 638 |
+
self.progress_tracker.update("Splitting data")
|
| 639 |
X_train, X_test, y_train, y_test = train_test_split(
|
| 640 |
X, y,
|
| 641 |
test_size=self.test_size,
|
| 642 |
+
stratify=y if len(np.unique(y)) > 1 and len(y) > 10 else None,
|
| 643 |
random_state=self.random_state
|
| 644 |
)
|
| 645 |
|
| 646 |
+
logger.info(f"Data split: {len(X_train)} train, {len(X_test)} test")
|
|
|
|
| 647 |
|
| 648 |
# Train and evaluate models
|
| 649 |
results = self.train_and_evaluate_models(
|
| 650 |
X_train, X_test, y_train, y_test)
|
| 651 |
|
| 652 |
# Select best model
|
| 653 |
+
best_model_name, best_model, best_metrics = self.select_best_model(results)
|
|
|
|
| 654 |
|
| 655 |
# Save model artifacts
|
| 656 |
if not self.save_model_artifacts(best_model, best_model_name, best_metrics):
|
| 657 |
return False, "Failed to save model artifacts"
|
| 658 |
|
| 659 |
+
# Finish progress tracking
|
| 660 |
+
self.progress_tracker.finish()
|
| 661 |
|
| 662 |
success_message = (
|
| 663 |
f"Model training completed successfully. "
|
|
|
|
| 669 |
return True, success_message
|
| 670 |
|
| 671 |
except Exception as e:
|
| 672 |
+
if self.progress_tracker:
|
| 673 |
+
print() # New line after progress bar
|
| 674 |
error_message = f"Model training failed: {str(e)}"
|
| 675 |
logger.error(error_message)
|
| 676 |
return False, error_message
|
|
|
|
| 678 |
|
| 679 |
def main():
|
| 680 |
"""Main execution function"""
|
| 681 |
+
import argparse
|
| 682 |
+
|
| 683 |
+
# Parse command line arguments
|
| 684 |
+
parser = argparse.ArgumentParser(description='Train fake news detection model')
|
| 685 |
+
parser.add_argument('--data_path', type=str, help='Path to training data CSV file')
|
| 686 |
+
args = parser.parse_args()
|
| 687 |
+
|
| 688 |
trainer = RobustModelTrainer()
|
| 689 |
+
success, message = trainer.train_model(data_path=args.data_path)
|
| 690 |
|
| 691 |
if success:
|
| 692 |
print(f"β
{message}")
|