Ahmedik95316 commited on
Commit
52d71b1
·
1 Parent(s): dce7bfe

Update model/train.py

Browse files

Restored previous working version

Files changed (1) hide show
  1. model/train.py +258 -629
model/train.py CHANGED
@@ -27,161 +27,46 @@ from typing import Dict, Tuple, Optional, Any
27
  import warnings
28
  warnings.filterwarnings('ignore')
29
 
 
30
 
31
- # =============================================================================
32
- # CENTRALIZED PATH CONFIGURATION - MATCHES FASTAPI SERVER
33
- # =============================================================================
34
- class PathConfig:
35
- """Centralized path management to ensure consistency across all components"""
36
-
37
- # Base directories
38
- BASE_DIR = Path("/tmp")
39
- DATA_DIR = BASE_DIR / "data"
40
- MODEL_DIR = BASE_DIR / "model" # CONSISTENT: /tmp/model/
41
- LOGS_DIR = BASE_DIR / "logs"
42
- RESULTS_DIR = BASE_DIR / "results"
43
-
44
- # Model files - CONSISTENT PATHS (matches fastapi_server.py)
45
- MODEL_FILE = MODEL_DIR / "model.pkl" # /tmp/model/model.pkl
46
- VECTORIZER_FILE = MODEL_DIR / "vectorizer.pkl" # /tmp/model/vectorizer.pkl
47
- PIPELINE_FILE = MODEL_DIR / "pipeline.pkl" # /tmp/model/pipeline.pkl
48
- METADATA_FILE = BASE_DIR / "metadata.json" # /tmp/metadata.json
49
-
50
- # Data files
51
- COMBINED_DATASET = DATA_DIR / "combined_dataset.csv"
52
- SCRAPED_DATA = DATA_DIR / "scraped_real.csv"
53
- GENERATED_DATA = DATA_DIR / "generated_fake.csv"
54
-
55
- # Log and result files
56
- TRAINING_LOG = LOGS_DIR / "model_training.log"
57
- EVALUATION_RESULTS = RESULTS_DIR / "evaluation_results.json"
58
-
59
- @classmethod
60
- def ensure_directories(cls):
61
- """Create all required directories with proper permissions"""
62
- for attr_name in dir(cls):
63
- attr = getattr(cls, attr_name)
64
- if isinstance(attr, Path) and attr_name.endswith('_DIR'):
65
- attr.mkdir(parents=True, exist_ok=True, mode=0o755)
66
-
67
- # Additional directory creation for safety
68
- for directory in [cls.BASE_DIR, cls.DATA_DIR, cls.MODEL_DIR, cls.LOGS_DIR, cls.RESULTS_DIR]:
69
- directory.mkdir(parents=True, exist_ok=True, mode=0o755)
70
-
71
-
72
- # Initialize directories at startup
73
- PathConfig.ensure_directories()
74
-
75
-
76
- # =============================================================================
77
- # ENHANCED LOGGING CONFIGURATION
78
- # =============================================================================
79
  logging.basicConfig(
80
  level=logging.INFO,
81
- format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
82
  handlers=[
83
- logging.FileHandler(PathConfig.TRAINING_LOG),
84
  logging.StreamHandler()
85
  ]
86
  )
87
  logger = logging.getLogger(__name__)
88
 
89
 
90
- # =============================================================================
91
- # DATA VALIDATION PIPELINE
92
- # =============================================================================
93
- class DataValidator:
94
- """Comprehensive data validation for training pipeline"""
95
-
96
- def __init__(self, min_text_length: int = 10, max_null_ratio: float = 0.1):
97
- self.min_text_length = min_text_length
98
- self.max_null_ratio = max_null_ratio
99
-
100
- def validate_schema(self, df: pd.DataFrame) -> Tuple[bool, list]:
101
- """Validate data schema"""
102
- errors = []
103
- required_columns = ['text', 'label']
104
-
105
- missing_cols = set(required_columns) - set(df.columns)
106
- if missing_cols:
107
- errors.append(f"Missing required columns: {missing_cols}")
108
-
109
- return len(errors) == 0, errors
110
-
111
- def validate_quality(self, df: pd.DataFrame) -> Tuple[bool, list]:
112
- """Validate data quality"""
113
- errors = []
114
-
115
- # Check null ratio
116
- null_ratio = df.isnull().sum().sum() / (len(df) * len(df.columns))
117
- if null_ratio > self.max_null_ratio:
118
- errors.append(f"Too many nulls: {null_ratio:.2%} > {self.max_null_ratio:.2%}")
119
-
120
- # Check text quality
121
- if 'text' in df.columns:
122
- short_texts = (df['text'].astype(str).str.len() < self.min_text_length).sum()
123
- if short_texts > 0:
124
- errors.append(f"{short_texts} texts below minimum length ({self.min_text_length} chars)")
125
-
126
- # Check minimum samples
127
- if len(df) < 100:
128
- errors.append(f"Insufficient samples for training: {len(df)} < 100")
129
-
130
- # Check class distribution
131
- if 'label' in df.columns:
132
- unique_labels = df['label'].unique()
133
- if len(unique_labels) < 2:
134
- errors.append(f"Need at least 2 classes, found: {unique_labels}")
135
-
136
- label_counts = df['label'].value_counts()
137
- min_class_ratio = label_counts.min() / label_counts.max()
138
- if min_class_ratio < 0.05:
139
- errors.append(f"Severe class imbalance: {min_class_ratio:.3f}")
140
- elif min_class_ratio < 0.1:
141
- logger.warning(f"Class imbalance detected: {min_class_ratio:.3f}")
142
-
143
- return len(errors) == 0, errors
144
-
145
- def validate(self, df: pd.DataFrame) -> Tuple[bool, Dict[str, list]]:
146
- """Complete data validation"""
147
- all_valid = True
148
- all_errors = {}
149
-
150
- # Schema validation
151
- schema_valid, schema_errors = self.validate_schema(df)
152
- if not schema_valid:
153
- all_valid = False
154
- all_errors['schema'] = schema_errors
155
-
156
- # Quality validation
157
- quality_valid, quality_errors = self.validate_quality(df)
158
- if not quality_valid:
159
- all_valid = False
160
- all_errors['quality'] = quality_errors
161
-
162
- return all_valid, all_errors
163
-
164
-
165
- # =============================================================================
166
- # ENHANCED MODEL TRAINER WITH FIXED PATHS
167
- # =============================================================================
168
  class RobustModelTrainer:
169
- """Production-ready model trainer with comprehensive evaluation and FIXED PATH MANAGEMENT"""
170
 
171
  def __init__(self):
172
- # Use centralized path configuration
173
- PathConfig.ensure_directories()
174
  self.setup_training_config()
175
  self.setup_models()
176
- self.data_validator = DataValidator()
177
-
178
- # Log path configuration for verification
179
- logger.info("🔧 Path Configuration:")
180
- logger.info(f"Model Directory: {PathConfig.MODEL_DIR}")
181
- logger.info(f"Pipeline File: {PathConfig.PIPELINE_FILE}")
182
- logger.info(f"Model File: {PathConfig.MODEL_FILE}")
183
- logger.info(f"Vectorizer File: {PathConfig.VECTORIZER_FILE}")
184
- logger.info(f"Metadata File: {PathConfig.METADATA_FILE}")
 
 
 
 
 
 
 
 
 
 
185
 
186
  def setup_training_config(self):
187
  """Setup training configuration"""
@@ -227,47 +112,57 @@ class RobustModelTrainer:
227
  }
228
 
229
  def load_and_validate_data(self) -> Tuple[bool, Optional[pd.DataFrame], str]:
230
- """Load and validate training data with enhanced validation"""
231
  try:
232
- logger.info("Loading and validating training data...")
233
 
234
- data_path = PathConfig.COMBINED_DATASET
235
-
236
- if not data_path.exists():
237
- return False, None, f"Data file not found: {data_path}"
238
 
239
  # Load data
240
- df = pd.read_csv(data_path)
241
- logger.info(f"Loaded dataset with {len(df)} samples")
242
-
243
- # Enhanced validation using DataValidator
244
- valid, validation_errors = self.data_validator.validate(df)
245
-
246
- if not valid:
247
- error_msg = "Data validation failed:\n" + "\n".join([
248
- f" {category}: {errors}" for category, errors in validation_errors.items()
249
- ])
250
- logger.error(error_msg)
251
- return False, None, error_msg
252
-
253
- # Clean data
254
- initial_count = len(df)
255
-
256
  # Remove missing values
257
- df = df.dropna(subset=['text', 'label'])
258
-
259
- # Remove short texts
260
- df = df[df['text'].astype(str).str.len() >= self.data_validator.min_text_length]
261
-
262
  if len(df) < initial_count:
263
- logger.info(f"🧹 Cleaned data: removed {initial_count - len(df)} invalid samples")
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
- # Log final statistics
266
  label_counts = df['label'].value_counts()
267
- logger.info(f"Data validation successful: {len(df)} samples")
 
 
 
 
 
 
268
  logger.info(f"Class distribution: {label_counts.to_dict()}")
269
 
270
- return True, df, "Data loaded and validated successfully"
271
 
272
  except Exception as e:
273
  error_msg = f"Error loading data: {str(e)}"
@@ -275,40 +170,33 @@ class RobustModelTrainer:
275
  return False, None, error_msg
276
 
277
  def preprocess_text(self, text):
278
- """Advanced text preprocessing with better error handling"""
279
  import re
280
 
281
- try:
282
- # Convert to string
283
- text = str(text)
284
 
285
- # Remove URLs
286
- text = re.sub(r'http\S+|www\S+|https\S+', '', text)
287
 
288
- # Remove email addresses
289
- text = re.sub(r'\S+@\S+', '', text)
290
 
291
- # Remove excessive punctuation
292
- text = re.sub(r'[!]{2,}', '!', text)
293
- text = re.sub(r'[?]{2,}', '?', text)
294
- text = re.sub(r'[.]{3,}', '...', text)
295
 
296
- # Remove non-alphabetic characters except spaces and basic punctuation
297
- text = re.sub(r'[^a-zA-Z\s.!?]', '', text)
298
 
299
- # Remove excessive whitespace
300
- text = re.sub(r'\s+', ' ', text)
301
 
302
- return text.strip().lower()
303
-
304
- except Exception as e:
305
- logger.warning(f"Text preprocessing failed for text, returning original: {e}")
306
- return str(text).lower()
307
 
308
  def create_preprocessing_pipeline(self) -> Pipeline:
309
- """Create advanced preprocessing pipeline with FIXED saving"""
310
- logger.info("🔧 Creating preprocessing pipeline...")
311
-
312
  # Text preprocessing
313
  text_preprocessor = FunctionTransformer(
314
  func=lambda x: [self.preprocess_text(text) for text in x],
@@ -340,197 +228,95 @@ class RobustModelTrainer:
340
  ('model', None) # Will be set during training
341
  ])
342
 
343
- logger.info("Preprocessing pipeline created successfully")
344
- return pipeline
345
-
346
- def save_model_artifacts(self, model, model_name: str, metrics: Dict) -> bool:
347
- """Save model artifacts with FIXED PATHS and comprehensive error handling"""
348
- try:
349
- logger.info("💾 Saving model artifacts with corrected paths...")
350
-
351
- # FIXED: Use centralized path configuration
352
- pipeline_path = PathConfig.PIPELINE_FILE # /tmp/model/pipeline.pkl
353
- model_path = PathConfig.MODEL_FILE # /tmp/model/model.pkl
354
- vectorizer_path = PathConfig.VECTORIZER_FILE # /tmp/model/vectorizer.pkl
355
- metadata_path = PathConfig.METADATA_FILE # /tmp/metadata.json
356
-
357
- logger.info(f"Saving to paths:")
358
- logger.info(f" Pipeline: {pipeline_path}")
359
- logger.info(f" Model: {model_path}")
360
- logger.info(f" Vectorizer: {vectorizer_path}")
361
- logger.info(f" Metadata: {metadata_path}")
362
-
363
- # Save the complete pipeline (FIXED PATH)
364
- joblib.dump(model, pipeline_path)
365
- logger.info("Saved complete pipeline")
366
-
367
- # Save individual components for backward compatibility (FIXED PATHS)
368
- try:
369
- if hasattr(model, 'named_steps'):
370
- # Save individual model
371
- if 'model' in model.named_steps and model.named_steps['model'] is not None:
372
- joblib.dump(model.named_steps['model'], model_path)
373
- logger.info("Saved individual model component")
374
-
375
- # Save individual vectorizer
376
- if 'vectorize' in model.named_steps and model.named_steps['vectorize'] is not None:
377
- joblib.dump(model.named_steps['vectorize'], vectorizer_path)
378
- logger.info("Saved individual vectorizer component")
379
- else:
380
- logger.warning("Model doesn't have named_steps, skipping individual component saves")
381
-
382
- except Exception as e:
383
- logger.warning(f"Could not save individual components: {e}")
384
-
385
- # Generate comprehensive metadata
386
- metadata = self.generate_metadata(model_name, metrics)
387
-
388
- # Save metadata (FIXED PATH)
389
- with open(metadata_path, 'w') as f:
390
- json.dump(metadata, f, indent=2)
391
- logger.info("Saved model metadata")
392
-
393
- # Verify all files were created
394
- verification_results = {
395
- 'pipeline': pipeline_path.exists(),
396
- 'model': model_path.exists(),
397
- 'vectorizer': vectorizer_path.exists(),
398
- 'metadata': metadata_path.exists()
399
- }
400
-
401
- logger.info("🔍 File verification results:")
402
- for file_type, exists in verification_results.items():
403
- status = "✅" if exists else "❌"
404
- logger.info(f" {status} {file_type}: {exists}")
405
-
406
- # Check if at least the pipeline was saved
407
- if not verification_results['pipeline']:
408
- raise Exception("Critical: Pipeline file was not created")
409
-
410
- logger.info("🎉 Model artifacts saved successfully!")
411
- return True
412
-
413
- except Exception as e:
414
- logger.error(f"❌ Failed to save model artifacts: {str(e)}")
415
- return False
416
 
417
- def generate_metadata(self, model_name: str, metrics: Dict) -> Dict:
418
- """Generate comprehensive metadata"""
419
- # Generate data hash for versioning
420
- data_hash = hashlib.md5(str(datetime.now()).encode()).hexdigest()[:8]
421
-
422
- metadata = {
423
- 'model_version': f"v1.0_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
424
- 'model_type': model_name,
425
- 'data_version': data_hash,
426
- 'training_metrics': {
427
- 'test_accuracy': metrics.get('accuracy', 'Unknown'),
428
- 'test_f1': metrics.get('f1', 'Unknown'),
429
- 'test_precision': metrics.get('precision', 'Unknown'),
430
- 'test_recall': metrics.get('recall', 'Unknown'),
431
- 'test_roc_auc': metrics.get('roc_auc', 'Unknown'),
432
- 'overfitting_score': metrics.get('overfitting_score', 'Unknown'),
433
- 'cv_score_mean': metrics.get('cv_scores', {}).get('mean', 'Unknown'),
434
- 'cv_score_std': metrics.get('cv_scores', {}).get('std', 'Unknown')
435
- },
436
- 'training_config': {
437
- 'test_size': self.test_size,
438
- 'validation_size': self.validation_size,
439
- 'cv_folds': self.cv_folds,
440
- 'max_features': self.max_features,
441
- 'ngram_range': self.ngram_range,
442
- 'feature_selection_k': self.feature_selection_k,
443
- 'class_weight': self.class_weight
444
- },
445
- 'paths': {
446
- 'pipeline_file': str(PathConfig.PIPELINE_FILE),
447
- 'model_file': str(PathConfig.MODEL_FILE),
448
- 'vectorizer_file': str(PathConfig.VECTORIZER_FILE)
449
- },
450
- 'timestamp': datetime.now().isoformat(),
451
- 'training_completed': True
452
- }
453
-
454
- return metadata
455
 
456
  def comprehensive_evaluation(self, model, X_test, y_test, X_train=None, y_train=None) -> Dict:
457
  """Comprehensive model evaluation with multiple metrics"""
458
  logger.info("Starting comprehensive model evaluation...")
459
 
460
- try:
461
- # Predictions
462
- y_pred = model.predict(X_test)
463
- y_pred_proba = model.predict_proba(X_test)[:, 1] if hasattr(model, 'predict_proba') else None
464
-
465
- # Basic metrics
466
- metrics = {
467
- 'accuracy': float(accuracy_score(y_test, y_pred)),
468
- 'precision': float(precision_score(y_test, y_pred, average='weighted', zero_division=0)),
469
- 'recall': float(recall_score(y_test, y_pred, average='weighted', zero_division=0)),
470
- 'f1': float(f1_score(y_test, y_pred, average='weighted', zero_division=0))
471
- }
 
 
 
 
 
 
 
 
472
 
473
- # ROC AUC if probabilities available
474
- if y_pred_proba is not None:
475
- try:
476
- metrics['roc_auc'] = float(roc_auc_score(y_test, y_pred_proba))
477
- except Exception as e:
478
- logger.warning(f"Could not calculate ROC AUC: {e}")
479
- metrics['roc_auc'] = 0.0
480
- else:
481
- metrics['roc_auc'] = 0.0
482
-
483
- # Confusion matrix
484
- cm = confusion_matrix(y_test, y_pred)
485
- metrics['confusion_matrix'] = cm.tolist()
486
-
487
- # Classification report
488
  try:
489
- class_report = classification_report(y_test, y_pred, output_dict=True, zero_division=0)
490
- metrics['classification_report'] = class_report
 
 
 
 
 
 
 
 
 
491
  except Exception as e:
492
- logger.warning(f"Could not generate classification report: {e}")
 
493
 
494
- # Cross-validation scores if training data provided
495
- if X_train is not None and y_train is not None:
496
- try:
497
- cv_scores = cross_val_score(
498
- model, X_train, y_train,
499
- cv=StratifiedKFold(n_splits=self.cv_folds, shuffle=True, random_state=self.random_state),
500
- scoring='f1_weighted'
501
- )
502
- metrics['cv_scores'] = {
503
- 'mean': float(cv_scores.mean()),
504
- 'std': float(cv_scores.std()),
505
- 'scores': cv_scores.tolist()
506
- }
507
- except Exception as e:
508
- logger.warning(f"Cross-validation failed: {e}")
509
- metrics['cv_scores'] = {'mean': 0.0, 'std': 0.0, 'scores': []}
 
 
 
510
 
 
 
511
  # Training accuracy for overfitting detection
512
  if X_train is not None and y_train is not None:
513
- try:
514
- y_train_pred = model.predict(X_train)
515
- train_accuracy = accuracy_score(y_train, y_train_pred)
516
- metrics['train_accuracy'] = float(train_accuracy)
517
- metrics['overfitting_score'] = float(train_accuracy - metrics['accuracy'])
518
- except Exception as e:
519
- logger.warning(f"Overfitting detection failed: {e}")
520
-
521
- logger.info(f"📈 Evaluation completed - F1: {metrics['f1']:.4f}, Accuracy: {metrics['accuracy']:.4f}")
522
- return metrics
523
-
524
  except Exception as e:
525
- logger.error(f" Evaluation failed: {e}")
526
- return {
527
- 'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0,
528
- 'f1': 0.0, 'roc_auc': 0.0, 'error': str(e)
529
- }
530
 
531
  def hyperparameter_tuning(self, pipeline, X_train, y_train, model_name: str) -> Tuple[Any, Dict]:
532
  """Perform hyperparameter tuning with cross-validation"""
533
- logger.info(f"🔧 Starting hyperparameter tuning for {model_name}...")
534
 
535
  try:
536
  # Set the model in the pipeline
@@ -543,7 +329,8 @@ class RobustModelTrainer:
543
  grid_search = GridSearchCV(
544
  pipeline,
545
  param_grid,
546
- cv=StratifiedKFold(n_splits=self.cv_folds, shuffle=True, random_state=self.random_state),
 
547
  scoring='f1_weighted',
548
  n_jobs=-1,
549
  verbose=1
@@ -560,7 +347,7 @@ class RobustModelTrainer:
560
  'cv_results': {
561
  'mean_test_scores': grid_search.cv_results_['mean_test_score'].tolist(),
562
  'std_test_scores': grid_search.cv_results_['std_test_score'].tolist(),
563
- 'params': [dict(p) for p in grid_search.cv_results_['params']]
564
  }
565
  }
566
 
@@ -571,19 +358,16 @@ class RobustModelTrainer:
571
  return grid_search.best_estimator_, tuning_results
572
 
573
  except Exception as e:
574
- logger.error(f"❌ Hyperparameter tuning failed for {model_name}: {str(e)}")
 
575
  # Return basic model if tuning fails
576
- try:
577
- pipeline.set_params(model=self.models[model_name]['model'])
578
- pipeline.fit(X_train, y_train)
579
- return pipeline, {'error': str(e), 'used_default_params': True}
580
- except Exception as e2:
581
- logger.error(f"❌ Even basic model training failed: {str(e2)}")
582
- raise e2
583
 
584
  def train_and_evaluate_models(self, X_train, X_test, y_train, y_test) -> Dict:
585
  """Train and evaluate multiple models"""
586
- logger.info("🚀 Starting model training and evaluation...")
587
 
588
  results = {}
589
 
@@ -591,7 +375,7 @@ class RobustModelTrainer:
591
  logger.info(f"Training {model_name}...")
592
 
593
  try:
594
- # Create fresh pipeline for each model
595
  pipeline = self.create_preprocessing_pipeline()
596
 
597
  # Hyperparameter tuning
@@ -612,18 +396,18 @@ class RobustModelTrainer:
612
  'training_time': datetime.now().isoformat()
613
  }
614
 
615
- logger.info(f"Model {model_name} - F1: {evaluation_metrics['f1']:.4f}, "
616
  f"Accuracy: {evaluation_metrics['accuracy']:.4f}")
617
 
618
  except Exception as e:
619
- logger.error(f"Training failed for {model_name}: {str(e)}")
620
  results[model_name] = {'error': str(e)}
621
 
622
  return results
623
 
624
  def select_best_model(self, results: Dict) -> Tuple[str, Any, Dict]:
625
  """Select the best performing model"""
626
- logger.info("🏆 Selecting best model...")
627
 
628
  best_model_name = None
629
  best_model = None
@@ -632,7 +416,6 @@ class RobustModelTrainer:
632
 
633
  for model_name, result in results.items():
634
  if 'error' in result:
635
- logger.warning(f"Skipping {model_name} due to error: {result['error']}")
636
  continue
637
 
638
  # Use F1 score as primary metric
@@ -645,11 +428,69 @@ class RobustModelTrainer:
645
  best_metrics = result['evaluation_metrics']
646
 
647
  if best_model_name is None:
648
- raise ValueError("No models trained successfully")
649
 
650
- logger.info(f"🏆 Best model: {best_model_name} with F1 score: {best_score:.4f}")
 
651
  return best_model_name, best_model, best_metrics
652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
  def save_evaluation_results(self, results: Dict) -> bool:
654
  """Save comprehensive evaluation results"""
655
  try:
@@ -662,32 +503,31 @@ class RobustModelTrainer:
662
  clean_results[model_name] = {
663
  'tuning_results': {
664
  k: v for k, v in result['tuning_results'].items()
665
- if k != 'best_estimator' # Can't serialize sklearn objects
666
  },
667
  'evaluation_metrics': result['evaluation_metrics'],
668
  'training_time': result['training_time']
669
  }
670
 
671
- # Save results to centralized path
672
- evaluation_path = PathConfig.EVALUATION_RESULTS
673
- with open(evaluation_path, 'w') as f:
674
  json.dump(clean_results, f, indent=2, default=str)
675
 
676
- logger.info(f"📊 Evaluation results saved to {evaluation_path}")
677
  return True
678
 
679
  except Exception as e:
680
- logger.error(f"Failed to save evaluation results: {str(e)}")
681
  return False
682
 
683
  def train_model(self, data_path: str = None) -> Tuple[bool, str]:
684
  """Main training function with comprehensive pipeline"""
685
  try:
686
- logger.info("🚀 Starting model training pipeline...")
687
-
688
- # Log system information
689
- logger.info(f"Training environment: {PathConfig.BASE_DIR}")
690
- PathConfig.ensure_directories()
691
 
692
  # Load and validate data
693
  success, df, message = self.load_and_validate_data()
@@ -706,261 +546,50 @@ class RobustModelTrainer:
706
  random_state=self.random_state
707
  )
708
 
709
- logger.info(f"Data split: {len(X_train)} train, {len(X_test)} test")
 
710
 
711
  # Train and evaluate models
712
- results = self.train_and_evaluate_models(X_train, X_test, y_train, y_test)
713
-
714
- # Check if any models were trained successfully
715
- successful_models = [name for name, result in results.items() if 'error' not in result]
716
- if not successful_models:
717
- return False, "❌ All model training attempts failed"
718
 
719
  # Select best model
720
- best_model_name, best_model, best_metrics = self.select_best_model(results)
 
721
 
722
- # Save model artifacts with FIXED paths
723
  if not self.save_model_artifacts(best_model, best_model_name, best_metrics):
724
- return False, "Failed to save model artifacts"
725
 
726
  # Save evaluation results
727
  self.save_evaluation_results(results)
728
 
729
  success_message = (
730
- f"Model training completed successfully!\n"
731
- f"Best model: {best_model_name}\n"
732
- f"Performance: F1={best_metrics['f1']:.4f}, Accuracy={best_metrics['accuracy']:.4f}\n"
733
- f"Artifacts saved to: {PathConfig.MODEL_DIR}"
734
  )
735
 
736
  logger.info(success_message)
737
  return True, success_message
738
 
739
  except Exception as e:
740
- error_message = f"Model training failed: {str(e)}"
741
  logger.error(error_message)
742
- logger.error(f"📍 Full traceback: {traceback.format_exc()}")
743
  return False, error_message
744
 
745
 
746
- # =============================================================================
747
- # TRAINING UTILITIES AND DIAGNOSTICS
748
- # =============================================================================
749
- class TrainingDiagnostics:
750
- """Diagnostic utilities for training pipeline"""
751
-
752
- @staticmethod
753
- def check_data_availability():
754
- """Check if training data is available"""
755
- data_path = PathConfig.COMBINED_DATASET
756
-
757
- if not data_path.exists():
758
- logger.error(f"❌ Training data not found at: {data_path}")
759
-
760
- # Check what files are available
761
- if PathConfig.DATA_DIR.exists():
762
- available_files = list(PathConfig.DATA_DIR.iterdir())
763
- logger.info(f"Available files in data directory: {[f.name for f in available_files]}")
764
- else:
765
- logger.error(f"❌ Data directory doesn't exist: {PathConfig.DATA_DIR}")
766
-
767
- return False
768
-
769
- logger.info(f"✅ Training data found at: {data_path}")
770
- return True
771
-
772
- @staticmethod
773
- def verify_model_output():
774
- """Verify that model files were created correctly"""
775
- files_to_check = {
776
- 'Pipeline': PathConfig.PIPELINE_FILE,
777
- 'Model': PathConfig.MODEL_FILE,
778
- 'Vectorizer': PathConfig.VECTORIZER_FILE,
779
- 'Metadata': PathConfig.METADATA_FILE
780
- }
781
-
782
- logger.info("🔍 Verifying model output files:")
783
- all_exist = True
784
-
785
- for file_type, file_path in files_to_check.items():
786
- exists = file_path.exists()
787
- size = file_path.stat().st_size if exists else 0
788
-
789
- status = "✅" if exists else "❌"
790
- logger.info(f" {status} {file_type}: {file_path} ({size} bytes)")
791
-
792
- if not exists:
793
- all_exist = False
794
-
795
- return all_exist
796
-
797
- @staticmethod
798
- def test_model_loading():
799
- """Test if the saved model can be loaded correctly"""
800
- try:
801
- logger.info("🧪 Testing model loading...")
802
-
803
- # Try loading pipeline
804
- if PathConfig.PIPELINE_FILE.exists():
805
- pipeline = joblib.load(PathConfig.PIPELINE_FILE)
806
- logger.info("✅ Pipeline loaded successfully")
807
-
808
- # Test prediction
809
- test_text = ["This is a test article for verification."]
810
- prediction = pipeline.predict(test_text)
811
- logger.info(f"✅ Test prediction successful: {prediction}")
812
-
813
- return True
814
- else:
815
- logger.error("❌ Pipeline file not found")
816
- return False
817
-
818
- except Exception as e:
819
- logger.error(f"❌ Model loading test failed: {e}")
820
- return False
821
-
822
-
823
- # ================================
824
- # ENHANCED MAIN EXECUTION FUNCTION
825
- # ================================
826
  def main():
827
- """Enhanced main execution function with comprehensive diagnostics"""
828
- import traceback
829
-
830
- logger.info("🚀 Starting Enhanced Model Training Pipeline")
831
- logger.info("=" * 60)
832
-
833
- try:
834
- # Step 1: Check data availability
835
- logger.info("📋 Step 1: Checking data availability...")
836
- if not TrainingDiagnostics.check_data_availability():
837
- logger.error("❌ Training aborted: No data available")
838
- print("❌ Training failed: Training data not found")
839
- print(f"📁 Expected data location: {PathConfig.COMBINED_DATASET}")
840
- print("💡 Please ensure the data preparation step has been completed")
841
- exit(1)
842
-
843
- # Step 2: Initialize trainer
844
- logger.info("📋 Step 2: Initializing trainer...")
845
- trainer = RobustModelTrainer()
846
-
847
- # Step 3: Train model
848
- logger.info("📋 Step 3: Training model...")
849
- success, message = trainer.train_model()
850
-
851
- if success:
852
- # Step 4: Verify output
853
- logger.info("📋 Step 4: Verifying model output...")
854
- if TrainingDiagnostics.verify_model_output():
855
- logger.info("✅ All model files created successfully")
856
- else:
857
- logger.warning("⚠️ Some model files may be missing")
858
-
859
- # Step 5: Test model loading
860
- logger.info("📋 Step 5: Testing model loading...")
861
- if TrainingDiagnostics.test_model_loading():
862
- logger.info("✅ Model loading verification successful")
863
- else:
864
- logger.warning("⚠️ Model loading verification failed")
865
-
866
- # Success summary
867
- logger.info("=" * 60)
868
- logger.info("TRAINING COMPLETED SUCCESSFULLY!")
869
- logger.info("=" * 60)
870
- print("✅ Training completed successfully!")
871
- print(f"{message}")
872
- print(f"Model files saved to: {PathConfig.MODEL_DIR}")
873
- print("Next steps:")
874
- print(" 1. Start the FastAPI server to test predictions")
875
- print(" 2. Run the monitoring dashboard")
876
- print(" 3. Perform model validation tests")
877
-
878
- else:
879
- logger.error("=" * 60)
880
- logger.error("❌ TRAINING FAILED!")
881
- logger.error("=" * 60)
882
- print("❌ Training failed!")
883
- print(f"📄 Error: {message}")
884
- print("\n🔧 Troubleshooting steps:")
885
- print(" 1. Check if training data exists and is properly formatted")
886
- print(" 2. Verify sufficient disk space and memory")
887
- print(" 3. Review the training logs for detailed error information")
888
- exit(1)
889
-
890
- except KeyboardInterrupt:
891
- logger.info("⏹️ Training interrupted by user")
892
- print("\n⏹️ Training interrupted by user")
893
- exit(1)
894
-
895
- except Exception as e:
896
- logger.error(f"Unexpected error during training: {str(e)}")
897
- logger.error(f"Full traceback: {traceback.format_exc()}")
898
- print(f"Unexpected error: {str(e)}")
899
- print("Check the training logs for more details")
900
- exit(1)
901
-
902
 
903
- # ============================
904
- # STANDALONE TESTING FUNCTIONS
905
- # ============================
906
- def test_path_configuration():
907
- """Test path configuration and directory creation"""
908
- print("🧪 Testing path configuration...")
909
-
910
- PathConfig.ensure_directories()
911
-
912
- directories = [
913
- PathConfig.BASE_DIR, PathConfig.DATA_DIR,
914
- PathConfig.MODEL_DIR, PathConfig.LOGS_DIR, PathConfig.RESULTS_DIR
915
- ]
916
-
917
- for directory in directories:
918
- if directory.exists():
919
- print(f"✅ {directory}")
920
- else:
921
- print(f"❌ {directory}")
922
-
923
- print("\n Expected file locations:")
924
- print(f" Pipeline: {PathConfig.PIPELINE_FILE}")
925
- print(f" Model: {PathConfig.MODEL_FILE}")
926
- print(f" Vectorizer: {PathConfig.VECTORIZER_FILE}")
927
- print(f" Metadata: {PathConfig.METADATA_FILE}")
928
-
929
-
930
- def quick_data_check():
931
- """Quick check of training data"""
932
- print("Quick data check...")
933
-
934
- data_path = PathConfig.COMBINED_DATASET
935
- if data_path.exists():
936
- try:
937
- df = pd.read_csv(data_path)
938
- print(f"Data loaded: {len(df)} rows, {len(df.columns)} columns")
939
- print(f"Columns: {list(df.columns)}")
940
- if 'label' in df.columns:
941
- print(f"Label distribution: {df['label'].value_counts().to_dict()}")
942
- except Exception as e:
943
- print(f"❌ Error reading data: {e}")
944
  else:
945
- print(f"❌ Data file not found: {data_path}")
 
946
 
947
 
948
  if __name__ == "__main__":
949
- import sys
950
-
951
- # Handle command line arguments for testing
952
- if len(sys.argv) > 1:
953
- if sys.argv[1] == "test-paths":
954
- test_path_configuration()
955
- elif sys.argv[1] == "test-data":
956
- quick_data_check()
957
- elif sys.argv[1] == "test-loading":
958
- TrainingDiagnostics.test_model_loading()
959
- else:
960
- print("Available test commands:")
961
- print(" python train.py test-paths # Test path configuration")
962
- print(" python train.py test-data # Quick data check")
963
- print(" python train.py test-loading # Test model loading")
964
- else:
965
- # Run main training
966
- main()
 
27
  import warnings
28
  warnings.filterwarnings('ignore')
29
 
30
+ # Scikit-learn imports
31
 
32
+ # Configure logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  logging.basicConfig(
34
  level=logging.INFO,
35
+ format='%(asctime)s - %(levelname)s - %(message)s',
36
  handlers=[
37
+ logging.FileHandler('/tmp/model_training.log'),
38
  logging.StreamHandler()
39
  ]
40
  )
41
  logger = logging.getLogger(__name__)
42
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  class RobustModelTrainer:
45
+ """Production-ready model trainer with comprehensive evaluation and validation"""
46
 
47
  def __init__(self):
48
+ self.setup_paths()
 
49
  self.setup_training_config()
50
  self.setup_models()
51
+
52
+ def setup_paths(self):
53
+ """Setup all necessary paths"""
54
+ self.base_dir = Path("/tmp")
55
+ self.data_dir = self.base_dir / "data"
56
+ self.model_dir = self.base_dir / "model"
57
+ self.results_dir = self.base_dir / "results"
58
+
59
+ # Create directories
60
+ for dir_path in [self.data_dir, self.model_dir, self.results_dir]:
61
+ dir_path.mkdir(parents=True, exist_ok=True)
62
+
63
+ # File paths
64
+ self.data_path = self.data_dir / "combined_dataset.csv"
65
+ self.model_path = self.model_dir / "model.pkl"
66
+ self.vectorizer_path = self.model_dir / "vectorizer.pkl"
67
+ self.pipeline_path = self.model_dir / "pipeline.pkl"
68
+ self.metadata_path = Path("/tmp/metadata.json")
69
+ self.evaluation_path = self.results_dir / "evaluation_results.json"
70
 
71
  def setup_training_config(self):
72
  """Setup training configuration"""
 
112
  }
113
 
114
  def load_and_validate_data(self) -> Tuple[bool, Optional[pd.DataFrame], str]:
115
+ """Load and validate training data"""
116
  try:
117
+ logger.info("Loading training data...")
118
 
119
+ if not self.data_path.exists():
120
+ return False, None, f"Data file not found: {self.data_path}"
 
 
121
 
122
  # Load data
123
+ df = pd.read_csv(self.data_path)
124
+
125
+ # Basic validation
126
+ if df.empty:
127
+ return False, None, "Dataset is empty"
128
+
129
+ required_columns = ['text', 'label']
130
+ missing_columns = [
131
+ col for col in required_columns if col not in df.columns]
132
+ if missing_columns:
133
+ return False, None, f"Missing required columns: {missing_columns}"
134
+
 
 
 
 
135
  # Remove missing values
136
+ initial_count = len(df)
137
+ df = df.dropna(subset=required_columns)
 
 
 
138
  if len(df) < initial_count:
139
+ logger.warning(
140
+ f"Removed {initial_count - len(df)} rows with missing values")
141
+
142
+ # Validate text content
143
+ df = df[df['text'].astype(str).str.len() > 10]
144
+
145
+ # Validate labels
146
+ unique_labels = df['label'].unique()
147
+ if len(unique_labels) < 2:
148
+ return False, None, f"Need at least 2 classes, found: {unique_labels}"
149
+
150
+ # Check minimum sample size
151
+ if len(df) < 100:
152
+ return False, None, f"Insufficient samples for training: {len(df)}"
153
 
154
+ # Check class balance
155
  label_counts = df['label'].value_counts()
156
+ min_class_ratio = label_counts.min() / label_counts.max()
157
+ if min_class_ratio < 0.1:
158
+ logger.warning(
159
+ f"Severe class imbalance detected: {min_class_ratio:.3f}")
160
+
161
+ logger.info(
162
+ f"Data validation successful: {len(df)} samples, {len(unique_labels)} classes")
163
  logger.info(f"Class distribution: {label_counts.to_dict()}")
164
 
165
+ return True, df, "Data loaded successfully"
166
 
167
  except Exception as e:
168
  error_msg = f"Error loading data: {str(e)}"
 
170
  return False, None, error_msg
171
 
172
  def preprocess_text(self, text):
173
+ """Advanced text preprocessing"""
174
  import re
175
 
176
+ # Convert to string
177
+ text = str(text)
 
178
 
179
+ # Remove URLs
180
+ text = re.sub(r'http\S+|www\S+|https\S+', '', text)
181
 
182
+ # Remove email addresses
183
+ text = re.sub(r'\S+@\S+', '', text)
184
 
185
+ # Remove excessive punctuation
186
+ text = re.sub(r'[!]{2,}', '!', text)
187
+ text = re.sub(r'[?]{2,}', '?', text)
188
+ text = re.sub(r'[.]{3,}', '...', text)
189
 
190
+ # Remove non-alphabetic characters except spaces and basic punctuation
191
+ text = re.sub(r'[^a-zA-Z\s.!?]', '', text)
192
 
193
+ # Remove excessive whitespace
194
+ text = re.sub(r'\s+', ' ', text)
195
 
196
+ return text.strip().lower()
 
 
 
 
197
 
198
  def create_preprocessing_pipeline(self) -> Pipeline:
199
+ """Create advanced preprocessing pipeline"""
 
 
200
  # Text preprocessing
201
  text_preprocessor = FunctionTransformer(
202
  func=lambda x: [self.preprocess_text(text) for text in x],
 
228
  ('model', None) # Will be set during training
229
  ])
230
 
231
+ # After creating the pipeline
232
+ joblib.dump(pipeline, "/tmp/pipeline.pkl") # Save complete pipeline
233
+ # Individual model
234
+ joblib.dump(pipeline.named_steps['model'], "/tmp/model.pkl")
235
+ # Individual vectorizer
236
+ joblib.dump(pipeline.named_steps['vectorize'], "/tmp/vectorizer.pkl")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
+ return pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  def comprehensive_evaluation(self, model, X_test, y_test, X_train=None, y_train=None) -> Dict:
241
  """Comprehensive model evaluation with multiple metrics"""
242
  logger.info("Starting comprehensive model evaluation...")
243
 
244
+ # Predictions
245
+ y_pred = model.predict(X_test)
246
+ y_pred_proba = model.predict_proba(X_test)[:, 1]
247
+
248
+ # Basic metrics
249
+ metrics = {
250
+ 'accuracy': float(accuracy_score(y_test, y_pred)),
251
+ 'precision': float(precision_score(y_test, y_pred, average='weighted')),
252
+ 'recall': float(recall_score(y_test, y_pred, average='weighted')),
253
+ 'f1': float(f1_score(y_test, y_pred, average='weighted')),
254
+ 'roc_auc': float(roc_auc_score(y_test, y_pred_proba))
255
+ }
256
+
257
+ # Confusion matrix
258
+ cm = confusion_matrix(y_test, y_pred)
259
+ metrics['confusion_matrix'] = cm.tolist()
260
+
261
+ # Classification report
262
+ class_report = classification_report(y_test, y_pred, output_dict=True)
263
+ metrics['classification_report'] = class_report
264
 
265
+ # Cross-validation scores if training data provided
266
+ if X_train is not None and y_train is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  try:
268
+ cv_scores = cross_val_score(
269
+ model, X_train, y_train,
270
+ cv=StratifiedKFold(
271
+ n_splits=self.cv_folds, shuffle=True, random_state=self.random_state),
272
+ scoring='f1_weighted'
273
+ )
274
+ metrics['cv_scores'] = {
275
+ 'mean': float(cv_scores.mean()),
276
+ 'std': float(cv_scores.std()),
277
+ 'scores': cv_scores.tolist()
278
+ }
279
  except Exception as e:
280
+ logger.warning(f"Cross-validation failed: {e}")
281
+ metrics['cv_scores'] = None
282
 
283
+ # Feature importance (if available)
284
+ try:
285
+ if hasattr(model, 'feature_importances_'):
286
+ feature_importance = model.feature_importances_
287
+ metrics['feature_importance_stats'] = {
288
+ 'mean': float(feature_importance.mean()),
289
+ 'std': float(feature_importance.std()),
290
+ 'top_features': feature_importance.argsort()[-10:][::-1].tolist()
291
+ }
292
+ elif hasattr(model, 'coef_'):
293
+ coefficients = model.coef_[0]
294
+ metrics['coefficient_stats'] = {
295
+ 'mean': float(coefficients.mean()),
296
+ 'std': float(coefficients.std()),
297
+ 'top_positive': coefficients.argsort()[-10:][::-1].tolist(),
298
+ 'top_negative': coefficients.argsort()[:10].tolist()
299
+ }
300
+ except Exception as e:
301
+ logger.warning(f"Feature importance extraction failed: {e}")
302
 
303
+ # Model complexity metrics
304
+ try:
305
  # Training accuracy for overfitting detection
306
  if X_train is not None and y_train is not None:
307
+ y_train_pred = model.predict(X_train)
308
+ train_accuracy = accuracy_score(y_train, y_train_pred)
309
+ metrics['train_accuracy'] = float(train_accuracy)
310
+ metrics['overfitting_score'] = float(
311
+ train_accuracy - metrics['accuracy'])
 
 
 
 
 
 
312
  except Exception as e:
313
+ logger.warning(f"Overfitting detection failed: {e}")
314
+
315
+ return metrics
 
 
316
 
317
  def hyperparameter_tuning(self, pipeline, X_train, y_train, model_name: str) -> Tuple[Any, Dict]:
318
  """Perform hyperparameter tuning with cross-validation"""
319
+ logger.info(f"Starting hyperparameter tuning for {model_name}...")
320
 
321
  try:
322
  # Set the model in the pipeline
 
329
  grid_search = GridSearchCV(
330
  pipeline,
331
  param_grid,
332
+ cv=StratifiedKFold(n_splits=self.cv_folds,
333
+ shuffle=True, random_state=self.random_state),
334
  scoring='f1_weighted',
335
  n_jobs=-1,
336
  verbose=1
 
347
  'cv_results': {
348
  'mean_test_scores': grid_search.cv_results_['mean_test_score'].tolist(),
349
  'std_test_scores': grid_search.cv_results_['std_test_score'].tolist(),
350
+ 'params': grid_search.cv_results_['params']
351
  }
352
  }
353
 
 
358
  return grid_search.best_estimator_, tuning_results
359
 
360
  except Exception as e:
361
+ logger.error(
362
+ f"Hyperparameter tuning failed for {model_name}: {str(e)}")
363
  # Return basic model if tuning fails
364
+ pipeline.set_params(model=self.models[model_name]['model'])
365
+ pipeline.fit(X_train, y_train)
366
+ return pipeline, {'error': str(e)}
 
 
 
 
367
 
368
  def train_and_evaluate_models(self, X_train, X_test, y_train, y_test) -> Dict:
369
  """Train and evaluate multiple models"""
370
+ logger.info("Starting model training and evaluation...")
371
 
372
  results = {}
373
 
 
375
  logger.info(f"Training {model_name}...")
376
 
377
  try:
378
+ # Create pipeline
379
  pipeline = self.create_preprocessing_pipeline()
380
 
381
  # Hyperparameter tuning
 
396
  'training_time': datetime.now().isoformat()
397
  }
398
 
399
+ logger.info(f"Model {model_name} - F1: {evaluation_metrics['f1']:.4f}, "
400
  f"Accuracy: {evaluation_metrics['accuracy']:.4f}")
401
 
402
  except Exception as e:
403
+ logger.error(f"Training failed for {model_name}: {str(e)}")
404
  results[model_name] = {'error': str(e)}
405
 
406
  return results
407
 
408
  def select_best_model(self, results: Dict) -> Tuple[str, Any, Dict]:
409
  """Select the best performing model"""
410
+ logger.info("Selecting best model...")
411
 
412
  best_model_name = None
413
  best_model = None
 
416
 
417
  for model_name, result in results.items():
418
  if 'error' in result:
 
419
  continue
420
 
421
  # Use F1 score as primary metric
 
428
  best_metrics = result['evaluation_metrics']
429
 
430
  if best_model_name is None:
431
+ raise ValueError("No models trained successfully")
432
 
433
+ logger.info(
434
+ f"Best model: {best_model_name} with F1 score: {best_score:.4f}")
435
  return best_model_name, best_model, best_metrics
436
 
437
+ def save_model_artifacts(self, model, model_name: str, metrics: Dict) -> bool:
438
+ """Save model artifacts and metadata"""
439
+ try:
440
+ logger.info("Saving model artifacts...")
441
+
442
+ # Save the full pipeline
443
+ joblib.dump(model, self.pipeline_path)
444
+
445
+ # Save individual components for backward compatibility
446
+ joblib.dump(model.named_steps['model'], self.model_path)
447
+ joblib.dump(model.named_steps['vectorize'], self.vectorizer_path)
448
+
449
+ # Generate data hash
450
+ data_hash = hashlib.md5(str(datetime.now()).encode()).hexdigest()
451
+
452
+ # Create metadata
453
+ metadata = {
454
+ 'model_version': f"v1.0_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
455
+ 'model_type': model_name,
456
+ 'data_version': data_hash,
457
+ 'train_size': metrics.get('train_accuracy', 'Unknown'),
458
+ 'test_size': len(metrics.get('confusion_matrix', [[0]])[0]) if 'confusion_matrix' in metrics else 'Unknown',
459
+ 'test_accuracy': metrics['accuracy'],
460
+ 'test_f1': metrics['f1'],
461
+ 'test_precision': metrics['precision'],
462
+ 'test_recall': metrics['recall'],
463
+ 'test_roc_auc': metrics['roc_auc'],
464
+ 'overfitting_score': metrics.get('overfitting_score', 'Unknown'),
465
+ 'cv_score_mean': metrics.get('cv_scores', {}).get('mean', 'Unknown'),
466
+ 'cv_score_std': metrics.get('cv_scores', {}).get('std', 'Unknown'),
467
+ 'timestamp': datetime.now().isoformat(),
468
+ 'training_config': {
469
+ 'test_size': self.test_size,
470
+ 'validation_size': self.validation_size,
471
+ 'cv_folds': self.cv_folds,
472
+ 'max_features': self.max_features,
473
+ 'ngram_range': self.ngram_range,
474
+ 'feature_selection_k': self.feature_selection_k
475
+ }
476
+ }
477
+
478
+ # Save metadata
479
+ with open(self.metadata_path, 'w') as f:
480
+ json.dump(metadata, f, indent=2)
481
+
482
+ logger.info(f"Model artifacts saved successfully")
483
+ logger.info(f"Model path: {self.model_path}")
484
+ logger.info(f"Vectorizer path: {self.vectorizer_path}")
485
+ logger.info(f"Pipeline path: {self.pipeline_path}")
486
+ logger.info(f"Metadata path: {self.metadata_path}")
487
+
488
+ return True
489
+
490
+ except Exception as e:
491
+ logger.error(f"Failed to save model artifacts: {str(e)}")
492
+ return False
493
+
494
  def save_evaluation_results(self, results: Dict) -> bool:
495
  """Save comprehensive evaluation results"""
496
  try:
 
503
  clean_results[model_name] = {
504
  'tuning_results': {
505
  k: v for k, v in result['tuning_results'].items()
506
+ if k != 'best_estimator'
507
  },
508
  'evaluation_metrics': result['evaluation_metrics'],
509
  'training_time': result['training_time']
510
  }
511
 
512
+ # Save results
513
+ with open(self.evaluation_path, 'w') as f:
 
514
  json.dump(clean_results, f, indent=2, default=str)
515
 
516
+ logger.info(f"Evaluation results saved to {self.evaluation_path}")
517
  return True
518
 
519
  except Exception as e:
520
+ logger.error(f"Failed to save evaluation results: {str(e)}")
521
  return False
522
 
523
  def train_model(self, data_path: str = None) -> Tuple[bool, str]:
524
  """Main training function with comprehensive pipeline"""
525
  try:
526
+ logger.info("Starting model training pipeline...")
527
+
528
+ # Override data path if provided
529
+ if data_path:
530
+ self.data_path = Path(data_path)
531
 
532
  # Load and validate data
533
  success, df, message = self.load_and_validate_data()
 
546
  random_state=self.random_state
547
  )
548
 
549
+ logger.info(
550
+ f"Data split: {len(X_train)} train, {len(X_test)} test")
551
 
552
  # Train and evaluate models
553
+ results = self.train_and_evaluate_models(
554
+ X_train, X_test, y_train, y_test)
 
 
 
 
555
 
556
  # Select best model
557
+ best_model_name, best_model, best_metrics = self.select_best_model(
558
+ results)
559
 
560
+ # Save model artifacts
561
  if not self.save_model_artifacts(best_model, best_model_name, best_metrics):
562
+ return False, "Failed to save model artifacts"
563
 
564
  # Save evaluation results
565
  self.save_evaluation_results(results)
566
 
567
  success_message = (
568
+ f"Model training completed successfully. "
569
+ f"Best model: {best_model_name} "
570
+ f"(F1: {best_metrics['f1']:.4f}, Accuracy: {best_metrics['accuracy']:.4f})"
 
571
  )
572
 
573
  logger.info(success_message)
574
  return True, success_message
575
 
576
  except Exception as e:
577
+ error_message = f"Model training failed: {str(e)}"
578
  logger.error(error_message)
 
579
  return False, error_message
580
 
581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
  def main():
583
+ """Main execution function"""
584
+ trainer = RobustModelTrainer()
585
+ success, message = trainer.train_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
 
587
+ if success:
588
+ print(f"✅ {message}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  else:
590
+ print(f"❌ {message}")
591
+ exit(1)
592
 
593
 
594
  if __name__ == "__main__":
595
+ main()