Ahmedik95316 commited on
Commit
719d51e
Β·
1 Parent(s): c29bcf3

Update model/train.py

Browse files
Files changed (1) hide show
  1. model/train.py +201 -109
model/train.py CHANGED
@@ -22,7 +22,10 @@ import logging
22
  import json
23
  import joblib
24
  import hashlib
25
- from datetime import datetime
 
 
 
26
  from typing import Dict, Tuple, Optional, Any
27
  import warnings
28
  import re
@@ -75,6 +78,114 @@ def preprocess_text_function(texts):
75
  return processed
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  class RobustModelTrainer:
79
  """Production-ready model trainer with comprehensive evaluation and validation"""
80
 
@@ -82,6 +193,7 @@ class RobustModelTrainer:
82
  self.setup_paths()
83
  self.setup_training_config()
84
  self.setup_models()
 
85
 
86
  def setup_paths(self):
87
  """Setup all necessary paths"""
@@ -107,14 +219,14 @@ class RobustModelTrainer:
107
  self.test_size = 0.2
108
  self.validation_size = 0.1
109
  self.random_state = 42
110
- self.cv_folds = 5
111
- self.max_features = 10000
112
- self.min_df = 2
113
  self.max_df = 0.95
114
- self.ngram_range = (1, 3)
115
- self.max_iter = 1000
116
  self.class_weight = 'balanced'
117
- self.feature_selection_k = 5000
118
 
119
  def setup_models(self):
120
  """Setup model configurations for comparison"""
@@ -123,24 +235,24 @@ class RobustModelTrainer:
123
  'model': LogisticRegression(
124
  max_iter=self.max_iter,
125
  class_weight=self.class_weight,
126
- random_state=self.random_state
 
127
  ),
128
  'param_grid': {
129
- 'model__C': [0.1, 1, 10, 100],
130
- 'model__penalty': ['l2'],
131
- 'model__solver': ['liblinear', 'lbfgs']
132
  }
133
  },
134
  'random_forest': {
135
  'model': RandomForestClassifier(
136
- n_estimators=100,
137
  class_weight=self.class_weight,
138
- random_state=self.random_state
 
139
  ),
140
  'param_grid': {
141
- 'model__n_estimators': [50, 100, 200],
142
- 'model__max_depth': [10, 20, None],
143
- 'model__min_samples_split': [2, 5, 10]
144
  }
145
  }
146
  }
@@ -149,6 +261,8 @@ class RobustModelTrainer:
149
  """Load and validate training data"""
150
  try:
151
  logger.info("Loading training data...")
 
 
152
 
153
  if not self.data_path.exists():
154
  return False, None, f"Data file not found: {self.data_path}"
@@ -182,7 +296,7 @@ class RobustModelTrainer:
182
  return False, None, f"Need at least 2 classes, found: {unique_labels}"
183
 
184
  # Check minimum sample size
185
- if len(df) < 100:
186
  return False, None, f"Insufficient samples for training: {len(df)}"
187
 
188
  # Check class balance
@@ -204,15 +318,18 @@ class RobustModelTrainer:
204
  return False, None, error_msg
205
 
206
  def create_preprocessing_pipeline(self) -> Pipeline:
207
- """Create advanced preprocessing pipeline - pickle-safe"""
 
 
 
208
 
209
  # Use the standalone function instead of lambda
210
  text_preprocessor = FunctionTransformer(
211
- func=preprocess_text_function, # βœ… Pickle-safe function reference
212
  validate=False
213
  )
214
 
215
- # TF-IDF vectorization
216
  vectorizer = TfidfVectorizer(
217
  max_features=self.max_features,
218
  min_df=self.min_df,
@@ -226,7 +343,7 @@ class RobustModelTrainer:
226
  # Feature selection
227
  feature_selector = SelectKBest(
228
  score_func=chi2,
229
- k=self.feature_selection_k
230
  )
231
 
232
  # Create pipeline
@@ -241,8 +358,10 @@ class RobustModelTrainer:
241
 
242
  def comprehensive_evaluation(self, model, X_test, y_test, X_train=None, y_train=None) -> Dict:
243
  """Comprehensive model evaluation with multiple metrics"""
244
- logger.info("Starting comprehensive model evaluation...")
245
-
 
 
246
  # Predictions
247
  y_pred = model.predict(X_test)
248
  y_pred_proba = model.predict_proba(X_test)[:, 1]
@@ -260,18 +379,18 @@ class RobustModelTrainer:
260
  cm = confusion_matrix(y_test, y_pred)
261
  metrics['confusion_matrix'] = cm.tolist()
262
 
263
- # Classification report
264
- class_report = classification_report(y_test, y_pred, output_dict=True)
265
- metrics['classification_report'] = class_report
266
-
267
  # Cross-validation scores if training data provided
268
- if X_train is not None and y_train is not None:
269
  try:
270
  cv_scores = cross_val_score(
271
  model, X_train, y_train,
272
  cv=StratifiedKFold(
273
- n_splits=self.cv_folds, shuffle=True, random_state=self.random_state),
274
- scoring='f1_weighted'
 
 
 
 
275
  )
276
  metrics['cv_scores'] = {
277
  'mean': float(cv_scores.mean()),
@@ -281,30 +400,11 @@ class RobustModelTrainer:
281
  except Exception as e:
282
  logger.warning(f"Cross-validation failed: {e}")
283
  metrics['cv_scores'] = None
 
 
284
 
285
- # Feature importance (if available)
286
- try:
287
- if hasattr(model, 'feature_importances_'):
288
- feature_importance = model.feature_importances_
289
- metrics['feature_importance_stats'] = {
290
- 'mean': float(feature_importance.mean()),
291
- 'std': float(feature_importance.std()),
292
- 'top_features': feature_importance.argsort()[-10:][::-1].tolist()
293
- }
294
- elif hasattr(model, 'coef_'):
295
- coefficients = model.coef_[0]
296
- metrics['coefficient_stats'] = {
297
- 'mean': float(coefficients.mean()),
298
- 'std': float(coefficients.std()),
299
- 'top_positive': coefficients.argsort()[-10:][::-1].tolist(),
300
- 'top_negative': coefficients.argsort()[:10].tolist()
301
- }
302
- except Exception as e:
303
- logger.warning(f"Feature importance extraction failed: {e}")
304
-
305
- # Model complexity metrics
306
  try:
307
- # Training accuracy for overfitting detection
308
  if X_train is not None and y_train is not None:
309
  y_train_pred = model.predict(X_train)
310
  train_accuracy = accuracy_score(y_train, y_train_pred)
@@ -318,7 +418,9 @@ class RobustModelTrainer:
318
 
319
  def hyperparameter_tuning(self, pipeline, X_train, y_train, model_name: str) -> Tuple[Any, Dict]:
320
  """Perform hyperparameter tuning with cross-validation"""
321
- logger.info(f"Starting hyperparameter tuning for {model_name}...")
 
 
322
 
323
  try:
324
  # Set the model in the pipeline
@@ -327,15 +429,18 @@ class RobustModelTrainer:
327
  # Get parameter grid
328
  param_grid = self.models[model_name]['param_grid']
329
 
 
 
 
330
  # Create GridSearchCV
331
  grid_search = GridSearchCV(
332
  pipeline,
333
  param_grid,
334
- cv=StratifiedKFold(n_splits=self.cv_folds,
335
  shuffle=True, random_state=self.random_state),
336
  scoring='f1_weighted',
337
- n_jobs=-1,
338
- verbose=1
339
  )
340
 
341
  # Fit grid search
@@ -369,8 +474,7 @@ class RobustModelTrainer:
369
 
370
  def train_and_evaluate_models(self, X_train, X_test, y_train, y_test) -> Dict:
371
  """Train and evaluate multiple models"""
372
- logger.info("Starting model training and evaluation...")
373
-
374
  results = {}
375
 
376
  for model_name in self.models.keys():
@@ -409,7 +513,9 @@ class RobustModelTrainer:
409
 
410
  def select_best_model(self, results: Dict) -> Tuple[str, Any, Dict]:
411
  """Select the best performing model"""
412
- logger.info("Selecting best model...")
 
 
413
 
414
  best_model_name = None
415
  best_model = None
@@ -439,7 +545,8 @@ class RobustModelTrainer:
439
  def save_model_artifacts(self, model, model_name: str, metrics: Dict) -> bool:
440
  """Save model artifacts and metadata"""
441
  try:
442
- logger.info("Saving model artifacts...")
 
443
 
444
  # Save the full pipeline
445
  joblib.dump(model, self.pipeline_path)
@@ -449,14 +556,10 @@ class RobustModelTrainer:
449
  if hasattr(model, 'named_steps') and 'model' in model.named_steps:
450
  joblib.dump(model.named_steps['model'], self.model_path)
451
  logger.info(f"βœ… Saved model to {self.model_path}")
452
- else:
453
- logger.warning("❌ Could not extract model component")
454
 
455
  if hasattr(model, 'named_steps') and 'vectorize' in model.named_steps:
456
  joblib.dump(model.named_steps['vectorize'], self.vectorizer_path)
457
  logger.info(f"βœ… Saved vectorizer to {self.vectorizer_path}")
458
- else:
459
- logger.warning("❌ Could not extract vectorizer component")
460
 
461
  # Generate data hash
462
  data_hash = hashlib.md5(str(datetime.now()).encode()).hexdigest()
@@ -466,8 +569,6 @@ class RobustModelTrainer:
466
  'model_version': f"v1.0_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
467
  'model_type': model_name,
468
  'data_version': data_hash,
469
- 'train_size': metrics.get('train_accuracy', 'Unknown'),
470
- 'test_size': len(metrics.get('confusion_matrix', [[0]])[0]) if 'confusion_matrix' in metrics else 'Unknown',
471
  'test_accuracy': metrics['accuracy'],
472
  'test_f1': metrics['f1'],
473
  'test_precision': metrics['precision'],
@@ -479,7 +580,6 @@ class RobustModelTrainer:
479
  'timestamp': datetime.now().isoformat(),
480
  'training_config': {
481
  'test_size': self.test_size,
482
- 'validation_size': self.validation_size,
483
  'cv_folds': self.cv_folds,
484
  'max_features': self.max_features,
485
  'ngram_range': self.ngram_range,
@@ -492,46 +592,12 @@ class RobustModelTrainer:
492
  json.dump(metadata, f, indent=2)
493
 
494
  logger.info(f"βœ… Model artifacts saved successfully")
495
- logger.info(f"Model path: {self.model_path}")
496
- logger.info(f"Vectorizer path: {self.vectorizer_path}")
497
- logger.info(f"Pipeline path: {self.pipeline_path}")
498
- logger.info(f"Metadata path: {self.metadata_path}")
499
-
500
  return True
501
 
502
  except Exception as e:
503
  logger.error(f"Failed to save model artifacts: {str(e)}")
504
  return False
505
 
506
- def save_evaluation_results(self, results: Dict) -> bool:
507
- """Save comprehensive evaluation results"""
508
- try:
509
- # Clean results for JSON serialization
510
- clean_results = {}
511
- for model_name, result in results.items():
512
- if 'error' in result:
513
- clean_results[model_name] = result
514
- else:
515
- clean_results[model_name] = {
516
- 'tuning_results': {
517
- k: v for k, v in result['tuning_results'].items()
518
- if k != 'best_estimator'
519
- },
520
- 'evaluation_metrics': result['evaluation_metrics'],
521
- 'training_time': result['training_time']
522
- }
523
-
524
- # Save results
525
- with open(self.evaluation_path, 'w') as f:
526
- json.dump(clean_results, f, indent=2, default=str)
527
-
528
- logger.info(f"Evaluation results saved to {self.evaluation_path}")
529
- return True
530
-
531
- except Exception as e:
532
- logger.error(f"Failed to save evaluation results: {str(e)}")
533
- return False
534
-
535
  def train_model(self, data_path: str = None) -> Tuple[bool, str]:
536
  """Main training function with comprehensive pipeline"""
537
  try:
@@ -546,35 +612,52 @@ class RobustModelTrainer:
546
  if not success:
547
  return False, message
548
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  # Prepare data
550
  X = df['text'].values
551
  y = df['label'].values
552
 
553
  # Train-test split
 
554
  X_train, X_test, y_train, y_test = train_test_split(
555
  X, y,
556
  test_size=self.test_size,
557
- stratify=y,
558
  random_state=self.random_state
559
  )
560
 
561
- logger.info(
562
- f"Data split: {len(X_train)} train, {len(X_test)} test")
563
 
564
  # Train and evaluate models
565
  results = self.train_and_evaluate_models(
566
  X_train, X_test, y_train, y_test)
567
 
568
  # Select best model
569
- best_model_name, best_model, best_metrics = self.select_best_model(
570
- results)
571
 
572
  # Save model artifacts
573
  if not self.save_model_artifacts(best_model, best_model_name, best_metrics):
574
  return False, "Failed to save model artifacts"
575
 
576
- # Save evaluation results
577
- self.save_evaluation_results(results)
578
 
579
  success_message = (
580
  f"Model training completed successfully. "
@@ -586,6 +669,8 @@ class RobustModelTrainer:
586
  return True, success_message
587
 
588
  except Exception as e:
 
 
589
  error_message = f"Model training failed: {str(e)}"
590
  logger.error(error_message)
591
  return False, error_message
@@ -593,8 +678,15 @@ class RobustModelTrainer:
593
 
594
  def main():
595
  """Main execution function"""
 
 
 
 
 
 
 
596
  trainer = RobustModelTrainer()
597
- success, message = trainer.train_model()
598
 
599
  if success:
600
  print(f"βœ… {message}")
 
22
  import json
23
  import joblib
24
  import hashlib
25
+ import sys
26
+ import os
27
+ import time
28
+ from datetime import datetime, timedelta
29
  from typing import Dict, Tuple, Optional, Any
30
  import warnings
31
  import re
 
78
  return processed
79
 
80
 
81
+ class ProgressTracker:
82
+ """Progress tracking with time estimation"""
83
+
84
+ def __init__(self, total_steps: int, description: str = "Training"):
85
+ self.total_steps = total_steps
86
+ self.current_step = 0
87
+ self.start_time = time.time()
88
+ self.description = description
89
+ self.step_times = []
90
+
91
+ def update(self, step_name: str = ""):
92
+ """Update progress and print status"""
93
+ self.current_step += 1
94
+ current_time = time.time()
95
+ elapsed = current_time - self.start_time
96
+
97
+ # Calculate progress percentage
98
+ progress_pct = (self.current_step / self.total_steps) * 100
99
+
100
+ # Estimate remaining time
101
+ if self.current_step > 0:
102
+ avg_time_per_step = elapsed / self.current_step
103
+ remaining_steps = self.total_steps - self.current_step
104
+ eta_seconds = avg_time_per_step * remaining_steps
105
+ eta = timedelta(seconds=int(eta_seconds))
106
+ else:
107
+ eta = "calculating..."
108
+
109
+ # Create progress bar
110
+ bar_length = 30
111
+ filled_length = int(bar_length * self.current_step // self.total_steps)
112
+ bar = 'β–ˆ' * filled_length + 'β–‘' * (bar_length - filled_length)
113
+
114
+ # Print progress
115
+ status_msg = f"\r{self.description}: [{bar}] {progress_pct:.1f}% | Step {self.current_step}/{self.total_steps}"
116
+ if step_name:
117
+ status_msg += f" | {step_name}"
118
+ if eta != "calculating...":
119
+ status_msg += f" | ETA: {eta}"
120
+
121
+ print(status_msg, end='', flush=True)
122
+
123
+ # Store step time for better estimation
124
+ if len(self.step_times) >= 3: # Keep last 3 step times for moving average
125
+ self.step_times.pop(0)
126
+ self.step_times.append(current_time - (self.start_time + sum(self.step_times)))
127
+
128
+ def finish(self):
129
+ """Complete progress tracking"""
130
+ total_time = time.time() - self.start_time
131
+ print(f"\n{self.description} completed in {timedelta(seconds=int(total_time))}")
132
+
133
+
134
+ def estimate_training_time(dataset_size: int, enable_tuning: bool = True, cv_folds: int = 3) -> Dict:
135
+ """Estimate training time based on dataset characteristics"""
136
+
137
+ # Base time estimates (in seconds) based on empirical testing
138
+ base_times = {
139
+ 'preprocessing': max(0.1, dataset_size * 0.001), # ~1ms per sample
140
+ 'vectorization': max(0.5, dataset_size * 0.01), # ~10ms per sample
141
+ 'feature_selection': max(0.2, dataset_size * 0.005), # ~5ms per sample
142
+ 'simple_training': max(1.0, dataset_size * 0.02), # ~20ms per sample
143
+ 'evaluation': max(0.5, dataset_size * 0.01), # ~10ms per sample
144
+ }
145
+
146
+ # Hyperparameter tuning multipliers
147
+ tuning_multipliers = {
148
+ 'logistic_regression': 8 if enable_tuning else 1, # 8 param combinations
149
+ 'random_forest': 12 if enable_tuning else 1, # 12 param combinations
150
+ }
151
+
152
+ # Cross-validation multiplier
153
+ cv_multiplier = cv_folds if dataset_size > 100 else 1
154
+
155
+ # Calculate estimates
156
+ estimates = {}
157
+
158
+ # Preprocessing steps
159
+ estimates['data_loading'] = 0.5
160
+ estimates['preprocessing'] = base_times['preprocessing']
161
+ estimates['vectorization'] = base_times['vectorization']
162
+ estimates['feature_selection'] = base_times['feature_selection']
163
+
164
+ # Model training
165
+ for model_name, multiplier in tuning_multipliers.items():
166
+ model_time = base_times['simple_training'] * multiplier * cv_multiplier
167
+ estimates[f'{model_name}_training'] = model_time
168
+ estimates[f'{model_name}_evaluation'] = base_times['evaluation']
169
+
170
+ # Model saving
171
+ estimates['model_saving'] = 1.0
172
+
173
+ # Total estimate
174
+ total_estimate = sum(estimates.values())
175
+
176
+ # Add 20% buffer for overhead
177
+ total_estimate *= 1.2
178
+
179
+ return {
180
+ 'detailed_estimates': estimates,
181
+ 'total_seconds': total_estimate,
182
+ 'total_formatted': str(timedelta(seconds=int(total_estimate))),
183
+ 'dataset_size': dataset_size,
184
+ 'enable_tuning': enable_tuning,
185
+ 'cv_folds': cv_folds
186
+ }
187
+
188
+
189
  class RobustModelTrainer:
190
  """Production-ready model trainer with comprehensive evaluation and validation"""
191
 
 
193
  self.setup_paths()
194
  self.setup_training_config()
195
  self.setup_models()
196
+ self.progress_tracker = None
197
 
198
  def setup_paths(self):
199
  """Setup all necessary paths"""
 
219
  self.test_size = 0.2
220
  self.validation_size = 0.1
221
  self.random_state = 42
222
+ self.cv_folds = 3
223
+ self.max_features = 5000 # Reduced for speed
224
+ self.min_df = 1 # More lenient for small datasets
225
  self.max_df = 0.95
226
+ self.ngram_range = (1, 2) # Reduced for speed
227
+ self.max_iter = 500 # Reduced for speed
228
  self.class_weight = 'balanced'
229
+ self.feature_selection_k = 2000 # Reduced for speed
230
 
231
  def setup_models(self):
232
  """Setup model configurations for comparison"""
 
235
  'model': LogisticRegression(
236
  max_iter=self.max_iter,
237
  class_weight=self.class_weight,
238
+ random_state=self.random_state,
239
+ n_jobs=-1 # Use all cores
240
  ),
241
  'param_grid': {
242
+ 'model__C': [0.1, 1, 10], # Reduced grid
243
+ 'model__penalty': ['l2']
 
244
  }
245
  },
246
  'random_forest': {
247
  'model': RandomForestClassifier(
248
+ n_estimators=50, # Reduced for speed
249
  class_weight=self.class_weight,
250
+ random_state=self.random_state,
251
+ n_jobs=-1 # Use all cores
252
  ),
253
  'param_grid': {
254
+ 'model__n_estimators': [50, 100], # Reduced grid
255
+ 'model__max_depth': [10, None]
 
256
  }
257
  }
258
  }
 
261
  """Load and validate training data"""
262
  try:
263
  logger.info("Loading training data...")
264
+ if self.progress_tracker:
265
+ self.progress_tracker.update("Loading data")
266
 
267
  if not self.data_path.exists():
268
  return False, None, f"Data file not found: {self.data_path}"
 
296
  return False, None, f"Need at least 2 classes, found: {unique_labels}"
297
 
298
  # Check minimum sample size
299
+ if len(df) < 10:
300
  return False, None, f"Insufficient samples for training: {len(df)}"
301
 
302
  # Check class balance
 
318
  return False, None, error_msg
319
 
320
  def create_preprocessing_pipeline(self) -> Pipeline:
321
+ """Create preprocessing pipeline"""
322
+
323
+ if self.progress_tracker:
324
+ self.progress_tracker.update("Creating pipeline")
325
 
326
  # Use the standalone function instead of lambda
327
  text_preprocessor = FunctionTransformer(
328
+ func=preprocess_text_function,
329
  validate=False
330
  )
331
 
332
+ # TF-IDF vectorization with optimized parameters
333
  vectorizer = TfidfVectorizer(
334
  max_features=self.max_features,
335
  min_df=self.min_df,
 
343
  # Feature selection
344
  feature_selector = SelectKBest(
345
  score_func=chi2,
346
+ k=min(self.feature_selection_k, self.max_features)
347
  )
348
 
349
  # Create pipeline
 
358
 
359
  def comprehensive_evaluation(self, model, X_test, y_test, X_train=None, y_train=None) -> Dict:
360
  """Comprehensive model evaluation with multiple metrics"""
361
+
362
+ if self.progress_tracker:
363
+ self.progress_tracker.update("Evaluating model")
364
+
365
  # Predictions
366
  y_pred = model.predict(X_test)
367
  y_pred_proba = model.predict_proba(X_test)[:, 1]
 
379
  cm = confusion_matrix(y_test, y_pred)
380
  metrics['confusion_matrix'] = cm.tolist()
381
 
 
 
 
 
382
  # Cross-validation scores if training data provided
383
+ if X_train is not None and y_train is not None and len(X_train) >= 50:
384
  try:
385
  cv_scores = cross_val_score(
386
  model, X_train, y_train,
387
  cv=StratifiedKFold(
388
+ n_splits=min(self.cv_folds, len(X_train) // 10),
389
+ shuffle=True,
390
+ random_state=self.random_state
391
+ ),
392
+ scoring='f1_weighted',
393
+ n_jobs=-1 # Parallel CV
394
  )
395
  metrics['cv_scores'] = {
396
  'mean': float(cv_scores.mean()),
 
400
  except Exception as e:
401
  logger.warning(f"Cross-validation failed: {e}")
402
  metrics['cv_scores'] = None
403
+ else:
404
+ metrics['cv_scores'] = {'note': 'Skipped for small dataset'}
405
 
406
+ # Training accuracy for overfitting detection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  try:
 
408
  if X_train is not None and y_train is not None:
409
  y_train_pred = model.predict(X_train)
410
  train_accuracy = accuracy_score(y_train, y_train_pred)
 
418
 
419
  def hyperparameter_tuning(self, pipeline, X_train, y_train, model_name: str) -> Tuple[Any, Dict]:
420
  """Perform hyperparameter tuning with cross-validation"""
421
+
422
+ if self.progress_tracker:
423
+ self.progress_tracker.update(f"Tuning {model_name}")
424
 
425
  try:
426
  # Set the model in the pipeline
 
429
  # Get parameter grid
430
  param_grid = self.models[model_name]['param_grid']
431
 
432
+ # Adaptive CV folds based on dataset size
433
+ cv_folds = min(self.cv_folds, len(X_train) // 10, 5)
434
+
435
  # Create GridSearchCV
436
  grid_search = GridSearchCV(
437
  pipeline,
438
  param_grid,
439
+ cv=StratifiedKFold(n_splits=cv_folds,
440
  shuffle=True, random_state=self.random_state),
441
  scoring='f1_weighted',
442
+ n_jobs=-1, # Use all cores
443
+ verbose=0 # Reduce verbosity for speed
444
  )
445
 
446
  # Fit grid search
 
474
 
475
  def train_and_evaluate_models(self, X_train, X_test, y_train, y_test) -> Dict:
476
  """Train and evaluate multiple models"""
477
+
 
478
  results = {}
479
 
480
  for model_name in self.models.keys():
 
513
 
514
  def select_best_model(self, results: Dict) -> Tuple[str, Any, Dict]:
515
  """Select the best performing model"""
516
+
517
+ if self.progress_tracker:
518
+ self.progress_tracker.update("Selecting best model")
519
 
520
  best_model_name = None
521
  best_model = None
 
545
  def save_model_artifacts(self, model, model_name: str, metrics: Dict) -> bool:
546
  """Save model artifacts and metadata"""
547
  try:
548
+ if self.progress_tracker:
549
+ self.progress_tracker.update("Saving model")
550
 
551
  # Save the full pipeline
552
  joblib.dump(model, self.pipeline_path)
 
556
  if hasattr(model, 'named_steps') and 'model' in model.named_steps:
557
  joblib.dump(model.named_steps['model'], self.model_path)
558
  logger.info(f"βœ… Saved model to {self.model_path}")
 
 
559
 
560
  if hasattr(model, 'named_steps') and 'vectorize' in model.named_steps:
561
  joblib.dump(model.named_steps['vectorize'], self.vectorizer_path)
562
  logger.info(f"βœ… Saved vectorizer to {self.vectorizer_path}")
 
 
563
 
564
  # Generate data hash
565
  data_hash = hashlib.md5(str(datetime.now()).encode()).hexdigest()
 
569
  'model_version': f"v1.0_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
570
  'model_type': model_name,
571
  'data_version': data_hash,
 
 
572
  'test_accuracy': metrics['accuracy'],
573
  'test_f1': metrics['f1'],
574
  'test_precision': metrics['precision'],
 
580
  'timestamp': datetime.now().isoformat(),
581
  'training_config': {
582
  'test_size': self.test_size,
 
583
  'cv_folds': self.cv_folds,
584
  'max_features': self.max_features,
585
  'ngram_range': self.ngram_range,
 
592
  json.dump(metadata, f, indent=2)
593
 
594
  logger.info(f"βœ… Model artifacts saved successfully")
 
 
 
 
 
595
  return True
596
 
597
  except Exception as e:
598
  logger.error(f"Failed to save model artifacts: {str(e)}")
599
  return False
600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601
  def train_model(self, data_path: str = None) -> Tuple[bool, str]:
602
  """Main training function with comprehensive pipeline"""
603
  try:
 
612
  if not success:
613
  return False, message
614
 
615
+ # Estimate training time and setup progress tracker
616
+ time_estimate = estimate_training_time(
617
+ len(df),
618
+ enable_tuning=True,
619
+ cv_folds=self.cv_folds
620
+ )
621
+
622
+ print(f"\nπŸ“Š Training Configuration:")
623
+ print(f"Dataset size: {len(df)} samples")
624
+ print(f"Estimated time: {time_estimate['total_formatted']}")
625
+ print(f"Models to train: {len(self.models)}")
626
+ print(f"Cross-validation folds: {self.cv_folds}")
627
+ print()
628
+
629
+ # Setup progress tracker
630
+ total_steps = 4 + (len(self.models) * 2) + 1 # Load, split, 2*models, select, save
631
+ self.progress_tracker = ProgressTracker(total_steps, "Training Progress")
632
+
633
  # Prepare data
634
  X = df['text'].values
635
  y = df['label'].values
636
 
637
  # Train-test split
638
+ self.progress_tracker.update("Splitting data")
639
  X_train, X_test, y_train, y_test = train_test_split(
640
  X, y,
641
  test_size=self.test_size,
642
+ stratify=y if len(np.unique(y)) > 1 and len(y) > 10 else None,
643
  random_state=self.random_state
644
  )
645
 
646
+ logger.info(f"Data split: {len(X_train)} train, {len(X_test)} test")
 
647
 
648
  # Train and evaluate models
649
  results = self.train_and_evaluate_models(
650
  X_train, X_test, y_train, y_test)
651
 
652
  # Select best model
653
+ best_model_name, best_model, best_metrics = self.select_best_model(results)
 
654
 
655
  # Save model artifacts
656
  if not self.save_model_artifacts(best_model, best_model_name, best_metrics):
657
  return False, "Failed to save model artifacts"
658
 
659
+ # Finish progress tracking
660
+ self.progress_tracker.finish()
661
 
662
  success_message = (
663
  f"Model training completed successfully. "
 
669
  return True, success_message
670
 
671
  except Exception as e:
672
+ if self.progress_tracker:
673
+ print() # New line after progress bar
674
  error_message = f"Model training failed: {str(e)}"
675
  logger.error(error_message)
676
  return False, error_message
 
678
 
679
  def main():
680
  """Main execution function"""
681
+ import argparse
682
+
683
+ # Parse command line arguments
684
+ parser = argparse.ArgumentParser(description='Train fake news detection model')
685
+ parser.add_argument('--data_path', type=str, help='Path to training data CSV file')
686
+ args = parser.parse_args()
687
+
688
  trainer = RobustModelTrainer()
689
+ success, message = trainer.train_model(data_path=args.data_path)
690
 
691
  if success:
692
  print(f"βœ… {message}")