Alessandro Piana commited on
Commit
ad3a64a
Β·
1 Parent(s): d9e7f33

dockerfile con logging 61

Browse files
Files changed (2) hide show
  1. life_coach_v1.py +343 -1184
  2. life_coach_v1_old.py +1222 -0
life_coach_v1.py CHANGED
@@ -1,1222 +1,381 @@
1
  #!/usr/bin/env python3
2
  """
3
- Life Coach v1 - Phi-4 Fine-tuned Life Coaching Assistant
4
-
5
- A simple command-line life coaching assistant using Microsoft's Phi-4 model.
6
- Fine-tunes on life coaching conversations and provides interactive chat sessions.
7
  """
8
 
9
- import torch
10
- import json
11
  import os
 
 
 
 
12
  import gc
13
- import argparse
 
 
 
14
  from pathlib import Path
15
- from typing import Optional
16
- from tqdm import tqdm
17
 
18
- # Set PyTorch CUDA memory allocation config to reduce fragmentation
19
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
 
 
 
 
 
20
 
21
- from transformers import (
22
- AutoTokenizer,
23
- AutoModelForCausalLM,
24
- TrainingArguments,
25
- Trainer,
26
- DataCollatorForSeq2Seq
27
- )
28
- from datasets import Dataset, load_dataset, concatenate_datasets
29
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
30
- import logging
31
- import random
32
- import shutil
33
- import gzip
34
- from typing import List, Dict
35
-
36
- # Configure logging
37
  logging.basicConfig(
38
- level=logging.INFO,
39
- format='%(asctime)s - %(levelname)s - %(message)s'
40
  )
41
  logger = logging.getLogger(__name__)
42
 
43
-
44
- def cleanup_gpu_memory():
45
- """
46
- Clean up GPU memory before starting the program.
47
- Clears PyTorch cache and runs garbage collection.
48
- """
49
- logger.info("=" * 80)
50
- logger.info("GPU MEMORY CLEANUP")
51
- logger.info("=" * 80)
52
-
 
 
 
 
 
 
53
  if torch.cuda.is_available():
54
- # Clear PyTorch CUDA cache
55
- torch.cuda.empty_cache()
56
-
57
- # Run garbage collection
58
- gc.collect()
59
-
60
- # Get GPU memory stats
61
- for i in range(torch.cuda.device_count()):
62
- total = torch.cuda.get_device_properties(i).total_memory / 1024**3
63
- reserved = torch.cuda.memory_reserved(i) / 1024**3
64
- allocated = torch.cuda.memory_allocated(i) / 1024**3
65
- free = total - reserved
66
-
67
- logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
68
- logger.info(f" Total memory: {total:.2f} GB")
69
- logger.info(f" Reserved: {reserved:.2f} GB")
70
- logger.info(f" Allocated: {allocated:.2f} GB")
71
- logger.info(f" Free: {free:.2f} GB")
72
-
73
- if reserved > 1.0: # More than 1GB reserved
74
- logger.warning(f" ⚠️ GPU {i} has {reserved:.2f} GB reserved!")
75
- logger.warning(f" ⚠️ This might be from a previous run.")
76
- logger.warning(f" ⚠️ If you encounter OOM errors, kill other processes using:")
77
- logger.warning(f" ⚠️ nvidia-smi | grep python")
78
- else:
79
- logger.warning("No CUDA GPUs available! Running on CPU (very slow).")
80
-
81
- logger.info("=" * 80)
82
-
83
-
84
- def clear_hf_cache():
85
- """Clear Hugging Face datasets cache to save disk space."""
86
- try:
87
- from datasets import config
88
- cache_dir = config.HF_DATASETS_CACHE
89
- if os.path.exists(cache_dir):
90
- # Get size before clearing
91
- size_mb = sum(os.path.getsize(os.path.join(dirpath,filename))
92
- for dirpath, _, filenames in os.walk(cache_dir)
93
- for filename in filenames) / (1024 * 1024)
94
-
95
- logger.info(f"Clearing HF cache ({size_mb:.1f} MB)...")
96
- shutil.rmtree(cache_dir, ignore_errors=True)
97
- os.makedirs(cache_dir, exist_ok=True)
98
- logger.info("βœ“ Cache cleared")
99
- except Exception as e:
100
- logger.warning(f"Failed to clear cache: {e}")
101
-
102
-
103
- def load_mental_health_counseling() -> List[Dict]:
104
- """Load Amod/mental_health_counseling_conversations dataset - ALL samples."""
105
- logger.info(f"Loading mental health counseling dataset...")
106
- try:
107
- dataset = load_dataset("Amod/mental_health_counseling_conversations", split="train")
108
- logger.info(f" Dataset has {len(dataset)} samples available")
109
-
110
- conversations = []
111
- for item in dataset:
112
- # Format: Context (user) -> Response (assistant)
113
- conversations.append({
114
- "messages": [
115
- {"role": "user", "content": item.get("Context", "").strip()},
116
- {"role": "assistant", "content": item.get("Response", "").strip()}
117
- ]
118
- })
119
-
120
- logger.info(f"βœ“ Loaded {len(conversations)} mental health counseling conversations")
121
- return conversations
122
- except Exception as e:
123
- logger.warning(f"Failed to load mental health counseling dataset: {e}")
124
- return []
125
-
126
-
127
- def load_counsel_chat() -> List[Dict]:
128
- """Load nbertagnolli/counsel-chat dataset - ALL samples."""
129
- logger.info(f"Loading CounselChat (nbertagnolli) dataset...")
130
- try:
131
- dataset = load_dataset("nbertagnolli/counsel-chat", split="train")
132
- logger.info(f" Dataset has {len(dataset)} samples available")
133
-
134
- conversations = []
135
- for item in dataset:
136
- # Try different possible field names
137
- question = None
138
- answer = None
139
-
140
- # Common field patterns
141
- for q_field in ["questionText", "question", "query", "input", "user_message"]:
142
- if q_field in item and item.get(q_field):
143
- question = item[q_field].strip()
144
- break
145
-
146
- for a_field in ["answerText", "answer", "response", "output", "counselor_message"]:
147
- if a_field in item and item.get(a_field):
148
- answer = item[a_field].strip()
149
- break
150
-
151
- if question and answer:
152
- conversations.append({
153
- "messages": [
154
- {"role": "user", "content": question},
155
- {"role": "assistant", "content": answer}
156
- ]
157
- })
158
-
159
- logger.info(f"βœ“ Loaded {len(conversations)} CounselChat conversations")
160
- return conversations
161
- except Exception as e:
162
- logger.warning(f"Failed to load CounselChat dataset: {e}")
163
- return []
164
-
165
-
166
- def load_cbt_cognitive_distortions() -> List[Dict]:
167
- """Load epsilon3/cbt-cognitive-distortions-analysis dataset - ALL samples."""
168
- logger.info(f"Loading CBT Cognitive Distortions dataset...")
169
- try:
170
- dataset = load_dataset("epsilon3/cbt-cognitive-distortions-analysis", split="train")
171
- logger.info(f" Dataset has {len(dataset)} samples available")
172
-
173
- conversations = []
174
- for item in dataset:
175
- # Try different field patterns
176
- user_msg = None
177
- assistant_msg = None
178
-
179
- for u_field in ["input", "text", "thought", "statement", "user_input"]:
180
- if u_field in item and item.get(u_field):
181
- user_msg = item[u_field].strip()
182
- break
183
-
184
- for a_field in ["output", "analysis", "reframe", "response", "cbt_response"]:
185
- if a_field in item and item.get(a_field):
186
- assistant_msg = item[a_field].strip()
187
- break
188
-
189
- if user_msg and assistant_msg:
190
- conversations.append({
191
- "messages": [
192
- {"role": "user", "content": user_msg},
193
- {"role": "assistant", "content": assistant_msg}
194
- ]
195
- })
196
-
197
- logger.info(f"βœ“ Loaded {len(conversations)} CBT Cognitive Distortions conversations")
198
- return conversations
199
- except Exception as e:
200
- logger.warning(f"Failed to load CBT Cognitive Distortions dataset: {e}")
201
- return []
202
-
203
-
204
- def load_peer_counseling_reflections() -> List[Dict]:
205
- """Load emoneil/reflections-in-peer-counseling dataset - ALL samples."""
206
- logger.info(f"Loading Peer Counseling Reflections dataset...")
207
- try:
208
- dataset = load_dataset("emoneil/reflections-in-peer-counseling", split="train")
209
- logger.info(f" Dataset has {len(dataset)} samples available")
210
-
211
- conversations = []
212
- for item in dataset:
213
- # Try different field patterns
214
- user_msg = None
215
- assistant_msg = None
216
-
217
- for u_field in ["question", "statement", "input", "user_message", "counselee"]:
218
- if u_field in item and item.get(u_field):
219
- user_msg = item[u_field].strip()
220
- break
221
-
222
- for a_field in ["reflection", "response", "output", "counselor_response", "counselor"]:
223
- if a_field in item and item.get(a_field):
224
- assistant_msg = item[a_field].strip()
225
- break
226
-
227
- if user_msg and assistant_msg:
228
- conversations.append({
229
- "messages": [
230
- {"role": "user", "content": user_msg},
231
- {"role": "assistant", "content": assistant_msg}
232
- ]
233
- })
234
-
235
- logger.info(f"βœ“ Loaded {len(conversations)} Peer Counseling Reflections conversations")
236
- return conversations
237
- except Exception as e:
238
- logger.warning(f"Failed to load Peer Counseling Reflections dataset: {e}")
239
- return []
240
-
241
-
242
- def load_dolly_dataset() -> List[Dict]:
243
- """Load databricks-dolly-15k dataset (instruction-following) - ALL relevant samples."""
244
- logger.info(f"Loading Dolly instruction dataset...")
245
- try:
246
- dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
247
- logger.info(f" Dataset has {len(dataset)} samples available")
248
-
249
- # Filter for relevant categories (brainstorming, open_qa, creative_writing)
250
- relevant_categories = {"brainstorming", "open_qa", "creative_writing", "general_qa"}
251
-
252
- conversations = []
253
- for item in dataset:
254
- if item.get("category", "") in relevant_categories:
255
- instruction = item.get("instruction", "").strip()
256
- context = item.get("context", "").strip()
257
- response = item.get("response", "").strip()
258
-
259
- # Combine instruction and context if both exist
260
- user_message = f"{instruction}\n\n{context}" if context else instruction
261
-
262
- if user_message and response:
263
- conversations.append({
264
- "messages": [
265
- {"role": "user", "content": user_message},
266
- {"role": "assistant", "content": response}
267
- ]
268
- })
269
-
270
- logger.info(f"βœ“ Loaded {len(conversations)} Dolly instruction conversations (filtered from {len(dataset)} total)")
271
- return conversations
272
- except Exception as e:
273
- logger.warning(f"Failed to load Dolly dataset: {e}")
274
- return []
275
-
276
-
277
- def load_mentalchat16k() -> List[Dict]:
278
- """Load ShenLab/MentalChat16K dataset - ALL samples."""
279
- logger.info(f"Loading MentalChat16K dataset...")
280
- try:
281
- dataset = load_dataset("ShenLab/MentalChat16K", split="train")
282
- logger.info(f" Dataset has {len(dataset)} samples available")
283
-
284
- conversations = []
285
- for item in dataset:
286
- # Try different possible field names
287
- user_msg = None
288
- assistant_msg = None
289
-
290
- # Common field name patterns
291
- for user_field in ["query", "question", "input", "user", "prompt", "instruction"]:
292
- if user_field in item and item.get(user_field):
293
- user_msg = item[user_field].strip()
294
- break
295
-
296
- for assistant_field in ["response", "answer", "output", "assistant", "reply"]:
297
- if assistant_field in item and item.get(assistant_field):
298
- assistant_msg = item[assistant_field].strip()
299
- break
300
-
301
- if user_msg and assistant_msg:
302
- conversations.append({
303
- "messages": [
304
- {"role": "user", "content": user_msg},
305
- {"role": "assistant", "content": assistant_msg}
306
- ]
307
- })
308
-
309
- logger.info(f"βœ“ Loaded {len(conversations)} MentalChat16K conversations")
310
- return conversations
311
- except Exception as e:
312
- logger.warning(f"Failed to load MentalChat16K dataset: {e}")
313
- return []
314
-
315
-
316
- def load_additional_mental_health_datasets() -> List[Dict]:
317
- """Load additional mental health datasets - ALL samples."""
318
- logger.info(f"Loading additional mental health datasets...")
319
-
320
- all_conversations = []
321
-
322
- # List of additional datasets to try
323
- additional_datasets = [
324
- ("heliosbrahma/mental_health_chatbot_dataset", ["prompt", "question"], ["response", "answer"]),
325
- ("mpingale/mental-health-chat-dataset", ["question", "query"], ["answer", "response"]),
326
- ("sauravjoshi23/psychology-dataset", ["input", "question"], ["output", "answer"]),
327
- ]
328
-
329
- for dataset_name, user_fields, assistant_fields in additional_datasets:
330
- try:
331
- logger.info(f" Loading {dataset_name}...")
332
- dataset = load_dataset(dataset_name, split="train")
333
- logger.info(f" Has {len(dataset)} samples available")
334
-
335
- for item in dataset:
336
- # Try different field names
337
- user_msg = None
338
- assistant_msg = None
339
-
340
- for field in user_fields:
341
- if field in item and item.get(field):
342
- user_msg = item[field].strip()
343
- break
344
-
345
- for field in assistant_fields:
346
- if field in item and item.get(field):
347
- assistant_msg = item[field].strip()
348
- break
349
-
350
- if user_msg and assistant_msg:
351
- all_conversations.append({
352
- "messages": [
353
- {"role": "user", "content": user_msg},
354
- {"role": "assistant", "content": assistant_msg}
355
- ]
356
- })
357
-
358
- logger.info(f" βœ“ Loaded {len([c for c in all_conversations if c])} from this dataset")
359
-
360
- except Exception as e:
361
- logger.warning(f" Failed: {e}")
362
- continue
363
-
364
- logger.info(f"βœ“ Loaded {len(all_conversations)} additional mental health conversations total")
365
- return all_conversations
366
-
367
-
368
- def quality_filter_conversation(conv: Dict, min_response_length: int = 50, max_total_length: int = 2048) -> bool:
369
- """Filter conversation based on quality criteria."""
370
- try:
371
- messages = conv.get("messages", [])
372
- if len(messages) < 2:
373
- return False
374
-
375
- # Check response length
376
- assistant_msg = [m for m in messages if m.get("role") == "assistant"]
377
- if not assistant_msg:
378
- return False
379
-
380
- response = assistant_msg[0].get("content", "")
381
- if len(response) < min_response_length:
382
- return False
383
-
384
- # Check total length
385
- total_length = sum(len(m.get("content", "")) for m in messages)
386
- if total_length > max_total_length:
387
- return False
388
-
389
- # Check for empty messages
390
- if any(not m.get("content", "").strip() for m in messages):
391
- return False
392
-
393
- return True
394
- except:
395
- return False
396
-
397
-
398
- def load_mixed_dataset(
399
- total_samples: int = 100000,
400
- cache_file: str = "mixed_lifecoach_dataset_100k.jsonl.gz", # Now compressed by default
401
- use_cache: bool = True
402
- ) -> List[Dict]:
403
- """
404
- Load and mix multiple datasets for comprehensive life coaching training.
405
- Saves compressed cache to save disk space.
406
-
407
- Datasets loaded (ALL available samples):
408
- 1. Mental Health Counseling (Amod/mental_health_counseling_conversations)
409
- 2. CounselChat (nbertagnolli/counsel-chat)
410
- 3. CBT Cognitive Distortions (epsilon3/cbt-cognitive-distortions-analysis)
411
- 4. Peer Counseling Reflections (emoneil/reflections-in-peer-counseling)
412
- 5. MentalChat16K (ShenLab/MentalChat16K)
413
- 6. Dolly Instructions (databricks/databricks-dolly-15k - filtered categories)
414
- 7-8. Additional mental health datasets (heliosbrahma, mpingale, sauravjoshi23)
415
- """
416
- cache_path = Path(cache_file)
417
- cache_path_uncompressed = Path(cache_file.replace('.gz', ''))
418
-
419
- # Try to load from compressed cache first
420
- if use_cache and cache_path.exists():
421
- logger.info(f"Loading cached dataset from {cache_file} (compressed)...")
422
- try:
423
- conversations = []
424
- with gzip.open(cache_path, 'rt', encoding='utf-8') as f:
425
- for line in f:
426
- conversations.append(json.loads(line.strip()))
427
- logger.info(f"βœ“ Loaded {len(conversations)} conversations from compressed cache")
428
- return conversations
429
- except Exception as e:
430
- logger.warning(f"Failed to load compressed cache: {e}. Trying uncompressed...")
431
-
432
- # Try uncompressed cache (backward compatibility)
433
- if use_cache and cache_path_uncompressed.exists():
434
- logger.info(f"Loading cached dataset from {cache_path_uncompressed} (uncompressed)...")
435
  try:
436
- conversations = []
437
- with open(cache_path_uncompressed, 'r', encoding='utf-8') as f:
438
- for line in f:
439
- conversations.append(json.loads(line.strip()))
440
- logger.info(f"βœ“ Loaded {len(conversations)} conversations from uncompressed cache")
441
- return conversations
442
  except Exception as e:
443
- logger.warning(f"Failed to load cache: {e}. Rebuilding dataset...")
444
-
445
- # Load ALL available samples from each dataset
446
- logger.info("=" * 80)
447
- logger.info(f"LOADING MIXED DATASET (Target: ~{total_samples} samples)")
448
- logger.info("Loading ALL available samples from each dataset")
449
- logger.info("=" * 80)
450
-
451
- all_conversations = []
452
-
453
- # Load each dataset ONE AT A TIME and clear cache after each
454
- # This saves disk space by not keeping all downloads simultaneously
455
-
456
- logger.info("Dataset 1/8: Mental Health Counseling (Amod)")
457
- all_conversations.extend(load_mental_health_counseling())
458
- logger.info(f" Running total: {len(all_conversations)} conversations")
459
- clear_hf_cache()
460
- gc.collect()
461
-
462
- # Stop early if we've reached target
463
- if len(all_conversations) >= total_samples:
464
- logger.info(f"βœ“ Reached target of {total_samples} samples, stopping dataset loading")
465
- else:
466
- logger.info("Dataset 2/8: CounselChat (nbertagnolli)")
467
- all_conversations.extend(load_counsel_chat())
468
- logger.info(f" Running total: {len(all_conversations)} conversations")
469
- clear_hf_cache()
470
- gc.collect()
471
-
472
- if len(all_conversations) >= total_samples:
473
- logger.info(f"βœ“ Reached target of {total_samples} samples, stopping dataset loading")
474
- else:
475
- logger.info("Dataset 3/8: CBT Cognitive Distortions (epsilon3)")
476
- all_conversations.extend(load_cbt_cognitive_distortions())
477
- logger.info(f" Running total: {len(all_conversations)} conversations")
478
- clear_hf_cache()
479
- gc.collect()
480
-
481
- if len(all_conversations) >= total_samples:
482
- logger.info(f"βœ“ Reached target of {total_samples} samples, stopping dataset loading")
483
- else:
484
- logger.info("Dataset 4/8: Peer Counseling Reflections (emoneil)")
485
- all_conversations.extend(load_peer_counseling_reflections())
486
- logger.info(f" Running total: {len(all_conversations)} conversations")
487
- clear_hf_cache()
488
- gc.collect()
489
-
490
- if len(all_conversations) >= total_samples:
491
- logger.info(f"βœ“ Reached target of {total_samples} samples, stopping dataset loading")
492
- else:
493
- logger.info("Dataset 5/8: MentalChat16K (ShenLab)")
494
- all_conversations.extend(load_mentalchat16k())
495
- logger.info(f" Running total: {len(all_conversations)} conversations")
496
- clear_hf_cache()
497
- gc.collect()
498
-
499
- if len(all_conversations) >= total_samples:
500
- logger.info(f"βœ“ Reached target of {total_samples} samples, stopping dataset loading")
501
- else:
502
- logger.info("Dataset 6/8: Dolly Instructions (databricks)")
503
- all_conversations.extend(load_dolly_dataset())
504
- logger.info(f" Running total: {len(all_conversations)} conversations")
505
- clear_hf_cache()
506
- gc.collect()
507
-
508
- if len(all_conversations) >= total_samples:
509
- logger.info(f"βœ“ Reached target of {total_samples} samples, stopping dataset loading")
510
- else:
511
- logger.info("Datasets 7-8: Additional Mental Health Datasets")
512
- all_conversations.extend(load_additional_mental_health_datasets())
513
- logger.info(f" Running total: {len(all_conversations)} conversations")
514
- clear_hf_cache()
515
- gc.collect()
516
-
517
- logger.info("=" * 80)
518
- logger.info(f"Total conversations loaded: {len(all_conversations)}")
519
-
520
- # Apply quality filtering
521
- logger.info("Applying quality filters...")
522
- filtered_conversations = [conv for conv in all_conversations if quality_filter_conversation(conv)]
523
- logger.info(f"βœ“ After filtering: {len(filtered_conversations)} conversations")
524
-
525
- # Shuffle to mix datasets
526
- random.shuffle(filtered_conversations)
527
-
528
- # Trim to target size
529
- if len(filtered_conversations) > total_samples:
530
- filtered_conversations = filtered_conversations[:total_samples]
531
-
532
- logger.info(f"Final dataset size: {len(filtered_conversations)} conversations")
533
-
534
- # Save compressed cache to save disk space
535
- if use_cache:
536
- logger.info(f"Saving compressed cache to {cache_file}...")
537
- try:
538
- with gzip.open(cache_path, 'wt', encoding='utf-8') as f:
539
- for conv in filtered_conversations:
540
- f.write(json.dumps(conv, ensure_ascii=False) + '\n')
541
-
542
- # Get file sizes for comparison
543
- compressed_size_mb = cache_path.stat().st_size / (1024 * 1024)
544
- logger.info(f"βœ“ Compressed cache saved successfully ({compressed_size_mb:.1f} MB)")
545
- except Exception as e:
546
- logger.warning(f"Failed to save compressed cache: {e}")
547
-
548
- logger.info("=" * 80)
549
- return filtered_conversations
550
-
551
 
552
  class LifeCoachModel:
553
- """Life coaching assistant using Phi-4 model."""
554
-
555
- def __init__(
556
- self,
557
- model_name: str = "microsoft/Phi-4",
558
- model_save_path: str = "/data/life_coach_model",
559
- train_file: str = "cbt_life_coach_improved_50000.jsonl",
560
- max_length: int = 2048
561
- ):
562
- """
563
- Initialize the Life Coach model.
564
-
565
- Args:
566
- model_name: Hugging Face model identifier
567
- model_save_path: Path to save/load fine-tuned model
568
- train_file: Path to training data file (JSONL format)
569
- max_length: Maximum sequence length for training
570
- """
571
  self.model_name = model_name
572
-
573
- # Check if /data is writable, otherwise use local directory
574
- save_path = Path(model_save_path)
575
- if str(save_path).startswith("/data"):
576
- try:
577
- Path("/data").mkdir(parents=True, exist_ok=True)
578
- # Test write permissions
579
- test_file = Path("/data/.test_write")
580
- test_file.touch()
581
- test_file.unlink()
582
- self.model_save_path = save_path
583
- logger.info(f"Using /data directory for model storage: {save_path}")
584
- except (PermissionError, OSError) as e:
585
- # Fall back to local directory
586
- local_path = Path("./data/life_coach_model")
587
- logger.warning(f"/data directory not writable ({e}), using local directory: {local_path}")
588
- self.model_save_path = local_path
589
  else:
590
- self.model_save_path = save_path
591
-
592
- self.train_file = Path(train_file)
593
- self.max_length = max_length
594
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
595
-
596
- logger.info(f"Device: {self.device}")
597
- logger.info(f"Model: {model_name}")
598
- logger.info(f"Save path: {self.model_save_path}")
599
- logger.info(f"Training file: {self.train_file}")
600
-
601
  self.tokenizer = None
602
  self.model = None
603
-
 
 
 
 
 
 
604
  def load_tokenizer(self):
605
- """Carica il tokenizer da /data/hf_cache (persistente) o scaricalo una volta."""
606
- logger.info("Loading tokenizer...")
 
607
 
608
- cache_dir = "/data/hf_cache"
609
- os.makedirs(cache_dir, exist_ok=True)
610
-
611
  try:
 
 
612
  self.tokenizer = AutoTokenizer.from_pretrained(
613
  self.model_name,
614
- cache_dir=cache_dir,
615
- local_files_only=False, # Permette download solo se non esiste
616
  trust_remote_code=True,
617
- use_fast=True
618
  )
619
- logger.info(f"Tokenizer caricato (cache: {cache_dir})")
 
 
 
 
 
 
 
 
 
 
 
 
 
620
  except Exception as e:
621
- logger.error(f"Errore critico nel caricamento tokenizer: {e}")
 
622
  raise
623
- def load_model(self, fine_tuned=True):
624
- """Load the fine-tuned model with safe settings for HF Spaces."""
625
- logger.info(f"Loading {'fine-tuned' if fine_tuned else 'base'} model from {self.model_save_path}")
626
-
627
- # Forza impostazioni sicure
628
- import torch
629
- from transformers import AutoModelForCausalLM
630
- from peft import PeftModel
631
 
632
- base_model_name = self.model_name
633
-
634
- # Carica modello base con device_map e offload
635
- base_model = AutoModelForCausalLM.from_pretrained(
636
- base_model_name,
637
- torch_dtype=torch.float16,
638
- device_map="auto",
639
- trust_remote_code=True,
640
- low_cpu_mem_usage=True,
641
- offload_folder="/tmp/offload", # Usa /tmp per offload
642
- cache_dir="/data/hf_cache"
643
- )
644
-
645
  if fine_tuned:
646
- logger.info(f"Loading adapter from {self.model_save_path}")
647
- self.model = PeftModel.from_pretrained(
648
- base_model,
649
- self.model_save_path,
650
- device_map="auto",
651
- offload_folder="/tmp/offload",
652
- torch_dtype=torch.float16
653
- )
654
- else:
655
- self.model = base_model
656
-
657
- self.model.eval()
658
- logger.info("Model loaded successfully!")
659
-
660
- def load_training_data(self, num_samples: Optional[int] = None) -> Dataset:
661
- """
662
- Load training data from mixed datasets or JSONL file.
663
-
664
- Args:
665
- num_samples: Number of samples to load (None for 100,000 default)
666
-
667
- Returns:
668
- Dataset object
669
- """
670
- # Try to load from mixed datasets first (new method)
671
- # If train_file doesn't exist or is the old one, use mixed datasets
672
- use_mixed_datasets = True
673
-
674
- if self.train_file.exists():
675
- # Check if it's the old single dataset file
676
- if "cbt_life_coach" in str(self.train_file):
677
- logger.info("Found old training file. Using new mixed datasets instead...")
678
- use_mixed_datasets = True
679
  else:
680
- # It might be a cached mixed dataset
681
- logger.info(f"Found training file at {self.train_file}")
682
- use_mixed_datasets = False
683
-
684
- if use_mixed_datasets:
685
- # Load mixed datasets from Hugging Face
686
- logger.info("Loading mixed datasets from Hugging Face...")
687
- if num_samples is None:
688
- num_samples = 100000 # Default to 100k samples
689
-
690
- # Load mixed dataset (will use cache if available)
691
- cache_file = f"mixed_lifecoach_dataset_{num_samples}.jsonl.gz" # Compressed format
692
- data = load_mixed_dataset(
693
- total_samples=num_samples,
694
- cache_file=cache_file,
695
- use_cache=True
696
  )
697
- else:
698
- # Fall back to loading from JSONL file
699
- logger.info(f"Loading training data from {self.train_file}")
700
- data = []
701
- with open(self.train_file, 'r', encoding='utf-8') as f:
702
- for i, line in enumerate(f):
703
- if num_samples and i >= num_samples:
704
- break
705
- try:
706
- data.append(json.loads(line.strip()))
707
- except json.JSONDecodeError:
708
- logger.warning(f"Skipping invalid JSON at line {i+1}")
709
-
710
- logger.info(f"Loaded {len(data)} training examples")
711
-
712
- # Convert to Hugging Face Dataset
713
- dataset = Dataset.from_list(data)
714
-
715
- # Preprocess for Phi-4 format
716
- logger.info("Preprocessing data for Phi-4 format...")
717
- dataset = dataset.map(
718
- self._preprocess_function,
719
- batched=True,
720
- remove_columns=dataset.column_names,
721
- desc="Tokenizing"
722
- )
723
-
724
- return dataset
725
-
726
- def _preprocess_function(self, examples):
727
- """
728
- Preprocess data into Phi-4 chat format.
729
-
730
- Phi-4 uses:
731
- <|system|>
732
- {system message}<|end|>
733
- <|user|>
734
- {user message}<|end|>
735
- <|assistant|>
736
- {assistant response}<|end|>
737
- """
738
- texts = []
739
-
740
- # Handle both 'conversations' (our format) and 'messages' (standard format)
741
- conversations_key = 'conversations' if 'conversations' in examples else 'messages'
742
-
743
- for conversation in examples[conversations_key]:
744
- text = ""
745
- for message in conversation:
746
- # Handle both 'from'/'value' and 'role'/'content' formats
747
- if 'from' in message:
748
- role = message['from']
749
- content = message['value']
750
- else:
751
- role = message['role']
752
- content = message['content']
753
-
754
- # Convert to Phi-4 format
755
- if role == 'system':
756
- text += f"<|system|>\n{content}<|end|>\n"
757
- elif role == 'user':
758
- text += f"<|user|>\n{content}<|end|>\n"
759
- elif role == 'assistant':
760
- text += f"<|assistant|>\n{content}<|end|>\n"
761
-
762
- texts.append(text)
763
-
764
- # Tokenize with dynamic padding (like quantum server)
765
- # Don't pad here - let DataCollatorForSeq2Seq handle it dynamically per batch
766
- model_inputs = self.tokenizer(
767
- texts,
768
- max_length=self.max_length,
769
- truncation=True,
770
- padding=False, # Dynamic padding - saves massive memory!
771
- return_tensors=None # Don't convert to tensors yet
772
- )
773
-
774
- # Set labels (for causal language modeling, labels = input_ids)
775
- # Note: .copy() instead of .clone() since we're not using tensors yet
776
- model_inputs["labels"] = model_inputs["input_ids"].copy()
777
-
778
- return model_inputs
779
-
780
- def setup_lora(self):
781
- """Setup LoRA (Low-Rank Adaptation) for efficient fine-tuning."""
782
- logger.info("Setting up LoRA adapters...")
783
-
784
- # Prepare model for k-bit training (critical for load_in_8bit=True)
785
- logger.info("Preparing model for 8-bit training...")
786
- self.model = prepare_model_for_kbit_training(self.model)
787
-
788
- # Enable gradient checkpointing to save GPU memory
789
- # This reduces memory usage by 20-30 GB with minimal performance impact
790
- if hasattr(self.model, 'gradient_checkpointing_enable'):
791
- self.model.gradient_checkpointing_enable()
792
- logger.info("βœ“ Gradient checkpointing enabled (saves 20-30 GB GPU memory)")
793
-
794
- # LoRA configuration
795
- lora_config = LoraConfig(
796
- task_type=TaskType.CAUSAL_LM,
797
- r=16, # Rank
798
- lora_alpha=32,
799
- lora_dropout=0.1,
800
- bias="none",
801
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj"] # Attention layers
802
- )
803
-
804
- # Apply LoRA
805
- self.model = get_peft_model(self.model, lora_config)
806
-
807
- # Print trainable parameters
808
- trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
809
- total_params = sum(p.numel() for p in self.model.parameters())
810
-
811
- logger.info(f"Trainable parameters: {trainable_params:,} / {total_params:,} "
812
- f"({100 * trainable_params / total_params:.2f}%)")
813
-
814
- def fine_tune(
815
- self,
816
- num_samples: Optional[int] = 5000,
817
- epochs: int = 3,
818
- batch_size: int = 8,
819
- learning_rate: float = 5e-5,
820
- gradient_accumulation_steps: int = 2
821
- ):
822
- """
823
- Fine-tune the model on life coaching data.
824
-
825
- Args:
826
- num_samples: Number of training samples (None for all)
827
- epochs: Number of training epochs
828
- batch_size: Training batch size
829
- learning_rate: Learning rate
830
- gradient_accumulation_steps: Gradient accumulation steps (for memory efficiency)
831
- """
832
- logger.info("=" * 80)
833
- logger.info("STARTING FINE-TUNING")
834
- logger.info("=" * 80)
835
-
836
- # Load data
837
- dataset = self.load_training_data(num_samples)
838
-
839
- # Setup LoRA
840
- self.setup_lora()
841
-
842
- # Training arguments
843
- training_args = TrainingArguments(
844
- output_dir="./training_output",
845
- num_train_epochs=epochs,
846
- per_device_train_batch_size=batch_size,
847
- gradient_accumulation_steps=gradient_accumulation_steps,
848
- learning_rate=learning_rate,
849
- fp16=True, # Mixed precision training
850
- logging_steps=10,
851
- save_strategy="epoch",
852
- save_total_limit=2,
853
- warmup_steps=100,
854
- weight_decay=0.01,
855
- report_to="none", # Disable wandb/tensorboard
856
- )
857
-
858
- # Data collator
859
- data_collator = DataCollatorForSeq2Seq(
860
- tokenizer=self.tokenizer,
861
- model=self.model,
862
- padding=True
863
- )
864
-
865
- # Trainer
866
- trainer = Trainer(
867
- model=self.model,
868
- args=training_args,
869
- train_dataset=dataset,
870
- data_collator=data_collator,
871
- )
872
-
873
- # Train
874
- logger.info("Training started...")
875
- trainer.train()
876
-
877
- logger.info("=" * 80)
878
- logger.info("TRAINING COMPLETED")
879
- logger.info("=" * 80)
880
-
881
- # Save model
882
- self.save_model()
883
-
884
- def save_model(self):
885
- """Save the fine-tuned model to disk."""
886
- logger.info(f"Saving model to {self.model_save_path}")
887
-
888
- self.model_save_path.mkdir(parents=True, exist_ok=True)
889
-
890
- # Save model and tokenizer
891
- self.model.save_pretrained(str(self.model_save_path))
892
- self.tokenizer.save_pretrained(str(self.model_save_path))
893
-
894
- logger.info("Model saved successfully")
895
-
896
- def generate_response(self, prompt: str, max_new_tokens: int = 128, conversation_history: list = None) -> str:
897
- """
898
- Generate a response to a user prompt.
899
-
900
- Args:
901
- prompt: User's input message
902
- max_new_tokens: Maximum tokens to generate
903
- conversation_history: List of previous messages for context
904
-
905
- Returns:
906
- Generated response
907
- """
908
- # Build full conversation context with system prompt
909
- formatted_prompt = ""
910
-
911
- # Add system prompt to guide the model's behavior
912
- system_prompt = """You are Robert, a friendly and experienced life coach. Here's your background:
913
-
914
- About You:
915
- - Name: Robert (Bob to friends)
916
- - Age: 42 years old
917
- - Experience: 15 years as a certified life coach and motivational speaker
918
- - Education: Master's degree in Psychology from UC Berkeley
919
- - Specialties: Personal growth, career transitions, work-life balance, goal setting, stress management
920
- - Personal: Married with two kids, enjoy hiking and meditation in your free time
921
- - Approach: Warm, empathetic, practical, and solution-focused
922
-
923
- Your Coaching Style:
924
- - Respond ONLY to what the user actually tells you - never make assumptions about their problems
925
- - Start conversations in a welcoming, open manner
926
- - Ask clarifying questions to understand their situation better
927
- - Provide practical, actionable advice based on what they share
928
- - Be encouraging and positive, but also honest and realistic
929
- - Keep responses concise and focused (2-4 sentences usually)
930
- - Share brief personal insights when relevant, but keep the focus on the client
931
-
932
- Important: Never assume clients have problems they haven't mentioned. Let them guide the conversation and share what's on their mind."""
933
-
934
- formatted_prompt += f"<|system|>\n{system_prompt}<|end|>\n"
935
-
936
- # Add conversation history if provided
937
- if conversation_history:
938
- for msg in conversation_history:
939
- if msg["role"] == "user":
940
- formatted_prompt += f"<|user|>\n{msg['content']}<|end|>\n"
941
- elif msg["role"] == "assistant":
942
- formatted_prompt += f"<|assistant|>\n{msg['content']}<|end|>\n"
943
-
944
- # Add current prompt
945
- formatted_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
946
-
947
- # DEBUG: Print the full prompt being sent to the model
948
- logger.info("=" * 80)
949
- logger.info("FULL PROMPT SENT TO MODEL:")
950
- logger.info(formatted_prompt)
951
- logger.info("=" * 80)
952
-
953
- # Tokenize
954
- inputs = self.tokenizer(
955
- formatted_prompt,
956
- return_tensors="pt",
957
- truncation=True,
958
- max_length=self.max_length
959
- ).to(self.device)
960
-
961
- # Get input length to extract only new tokens
962
- input_length = inputs['input_ids'].shape[1]
963
-
964
- # Get the token ID for <|end|> to use as a stopping token
965
- end_token_id = self.tokenizer.convert_tokens_to_ids("<|end|>")
966
-
967
- # Build list of EOS token IDs (stop generation at <|end|> or EOS)
968
- eos_token_ids = [self.tokenizer.eos_token_id]
969
- if end_token_id is not None and end_token_id != self.tokenizer.unk_token_id:
970
- eos_token_ids.append(end_token_id)
971
-
972
- # Generate
973
- with torch.no_grad():
974
- outputs = self.model.generate(
975
- **inputs,
976
- max_new_tokens=max_new_tokens,
977
- temperature=0.7, # Balanced - coherent but still creative
978
- top_p=0.9, # Standard setting for focused responses
979
- top_k=50, # Add top-k sampling
980
- do_sample=True,
981
- pad_token_id=self.tokenizer.pad_token_id,
982
- eos_token_id=eos_token_ids, # Stop at <|end|> or EOS
983
- repetition_penalty=1.15 # Stronger penalty to prevent repetition
984
  )
985
-
986
- # Decode ONLY the newly generated tokens (not the input)
987
- generated_tokens = outputs[0][input_length:]
988
-
989
- # Decode without skipping special tokens first to find the end marker
990
- response_with_tokens = self.tokenizer.decode(generated_tokens, skip_special_tokens=False)
991
-
992
- # Extract only up to the first <|end|> token (model may generate multi-turn conversations)
993
- if "<|end|>" in response_with_tokens:
994
- response_text = response_with_tokens.split("<|end|>")[0]
995
- else:
996
- response_text = response_with_tokens
997
-
998
- # Clean up any remaining special tokens
999
- response_text = response_text.replace("<|assistant|>", "").replace("<|user|>", "").replace("<|system|>", "")
1000
-
1001
- # Remove any remaining special tokens using the tokenizer
1002
- response_text = response_text.strip()
1003
-
1004
- return response_text
1005
-
1006
- def interactive_chat(self):
1007
- """Start an interactive chat session."""
1008
- logger.info("=" * 80)
1009
- logger.info("LIFE COACH V1 - Interactive Chat Session")
1010
- logger.info("=" * 80)
1011
- print("\nWelcome to Life Coach v1!")
1012
- print("I'm here to help you with life coaching, goal setting, motivation, and personal growth.")
1013
- print("\nCommands:")
1014
- print(" - Type your question or concern to get coaching advice")
1015
- print(" - Type 'quit' or 'exit' to end the session")
1016
- print(" - Type 'clear' to clear conversation history")
1017
- print("=" * 80)
1018
- print()
1019
-
1020
- conversation_history = []
1021
-
1022
- while True:
1023
- try:
1024
- # Get user input
1025
- user_input = input("\nπŸ§‘ You: ").strip()
1026
-
1027
- if not user_input:
1028
- continue
1029
-
1030
- # Check for exit commands
1031
- if user_input.lower() in ['quit', 'exit', 'q']:
1032
- print("\nπŸ‘‹ Thank you for using Life Coach v1. Take care!")
1033
- break
1034
-
1035
- # Check for clear command
1036
- if user_input.lower() == 'clear':
1037
- conversation_history = []
1038
- print("βœ… Conversation history cleared.")
1039
- continue
1040
-
1041
- # Generate response with conversation context
1042
- print("\nπŸ€– Life Coach: ", end="", flush=True)
1043
- response = self.generate_response(user_input, conversation_history=conversation_history)
1044
- print(response)
1045
-
1046
- # Update conversation history
1047
- conversation_history.append({
1048
- "role": "user",
1049
- "content": user_input
1050
- })
1051
- conversation_history.append({
1052
- "role": "assistant",
1053
- "content": response
1054
- })
1055
-
1056
- except KeyboardInterrupt:
1057
- print("\n\nπŸ‘‹ Session interrupted. Goodbye!")
1058
- break
1059
- except Exception as e:
1060
- logger.error(f"Error during chat: {e}")
1061
- print(f"\n❌ Error: {e}")
1062
-
1063
-
1064
- def main():
1065
- """Main entry point."""
1066
- parser = argparse.ArgumentParser(
1067
- description="Life Coach v1 - Phi-4 based life coaching assistant"
1068
- )
1069
-
1070
- parser.add_argument(
1071
- "--mode",
1072
- type=str,
1073
- choices=["train", "chat", "both"],
1074
- default="both",
1075
- help="Mode: train (fine-tune only), chat (chat only), both (train then chat)"
1076
- )
1077
-
1078
- parser.add_argument(
1079
- "--model-name",
1080
- type=str,
1081
- default="microsoft/Phi-4",
1082
- help="Hugging Face model name"
1083
- )
1084
-
1085
- parser.add_argument(
1086
- "--model-path",
1087
- type=str,
1088
- default="/data/life_coach_model",
1089
- help="Path to save/load fine-tuned model"
1090
- )
1091
-
1092
- parser.add_argument(
1093
- "--train-file",
1094
- type=str,
1095
- default="cbt_life_coach_improved_50000.jsonl",
1096
- help="Path to training data file (JSONL format)"
1097
- )
1098
-
1099
- parser.add_argument(
1100
- "--num-samples",
1101
- type=int,
1102
- default=-1,
1103
- help="Number of training samples (default: -1 for all 100,000 from mixed datasets)"
1104
- )
1105
-
1106
- parser.add_argument(
1107
- "--epochs",
1108
- type=int,
1109
- default=3,
1110
- help="Number of training epochs"
1111
- )
1112
-
1113
- parser.add_argument(
1114
- "--batch-size",
1115
- type=int,
1116
- default=4,
1117
- help="Training batch size (default: 4 for memory safety)"
1118
- )
1119
-
1120
- parser.add_argument(
1121
- "--learning-rate",
1122
- type=float,
1123
- default=5e-5,
1124
- help="Learning rate (default: 5e-5, matching quantum server)"
1125
- )
1126
-
1127
- parser.add_argument(
1128
- "--gradient-accumulation",
1129
- type=int,
1130
- default=4,
1131
- help="Gradient accumulation steps (default: 4, effective batch=16)"
1132
- )
1133
-
1134
- parser.add_argument(
1135
- "--force-retrain",
1136
- action="store_true",
1137
- help="Force retraining even if fine-tuned model exists"
1138
- )
1139
-
1140
- args = parser.parse_args()
1141
-
1142
- # Clean up GPU memory before starting
1143
- cleanup_gpu_memory()
1144
-
1145
- # Initialize model
1146
- coach = LifeCoachModel(
1147
- model_name=args.model_name,
1148
- model_save_path=args.model_path,
1149
- train_file=args.train_file
1150
- )
1151
-
1152
- # Load tokenizer
1153
- coach.load_tokenizer()
1154
-
1155
- # Check if fine-tuned model already exists
1156
- model_exists = coach.model_save_path.exists() and (coach.model_save_path / "adapter_model.safetensors").exists()
1157
-
1158
- # Training mode
1159
- if args.mode in ["train", "both"]:
1160
- # Check if we should skip training
1161
- if model_exists and not args.force_retrain:
1162
- logger.info("=" * 80)
1163
- logger.info("FINE-TUNED MODEL ALREADY EXISTS")
1164
- logger.info("=" * 80)
1165
- logger.info(f"Found existing model at: {coach.model_save_path}")
1166
- logger.info("Skipping training. Loading existing model...")
1167
- logger.info("(Use --force-retrain to retrain from scratch)")
1168
- logger.info("=" * 80)
1169
-
1170
- # Load the existing fine-tuned model
1171
- coach.load_model(fine_tuned=True)
1172
- else:
1173
- if args.force_retrain and model_exists:
1174
- logger.info("=" * 80)
1175
- logger.info("FORCING RETRAINING (--force-retrain flag set)")
1176
- logger.info("=" * 80)
1177
-
1178
- # Load base model for training
1179
- coach.load_model(fine_tuned=False)
1180
-
1181
- # Fine-tune
1182
- num_samples = None if args.num_samples == -1 else args.num_samples
1183
- coach.fine_tune(
1184
- num_samples=num_samples,
1185
- epochs=args.epochs,
1186
- batch_size=args.batch_size,
1187
- learning_rate=args.learning_rate,
1188
- gradient_accumulation_steps=args.gradient_accumulation
1189
  )
1190
-
1191
- # For "both" mode, reload the fine-tuned model for chat
1192
- if args.mode == "both":
1193
- logger.info("Reloading fine-tuned model for chat...")
1194
- coach.load_model(fine_tuned=True)
1195
-
1196
- # If only training mode, exit
1197
- if args.mode == "train":
1198
- logger.info("Training complete. Use --mode chat to start chatting.")
1199
- return
1200
-
1201
- # Chat mode
1202
- elif args.mode == "chat":
1203
- if not model_exists:
1204
- logger.error("=" * 80)
1205
- logger.error("ERROR: No fine-tuned model found!")
1206
- logger.error("=" * 80)
1207
- logger.error(f"Expected location: {coach.model_save_path}")
1208
- logger.error("Please train the model first using:")
1209
- logger.error(" python3 life_coach_v1.py --mode train")
1210
- logger.error("=" * 80)
1211
- return
1212
-
1213
- # Load fine-tuned model
1214
- logger.info(f"Loading fine-tuned model from {coach.model_save_path}")
1215
- coach.load_model(fine_tuned=True)
1216
-
1217
- # Start interactive chat
1218
- coach.interactive_chat()
1219
-
1220
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1221
  if __name__ == "__main__":
1222
- main()
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ Life Coach Model - DEBUG VERSION
4
+ Versione con logging estensivo per diagnosticare blocchi su HF Spaces
 
 
5
  """
6
 
 
 
7
  import os
8
+ import torch
9
+ import logging
10
+ import time
11
+ import traceback
12
  import gc
13
+ import threading
14
+ from datetime import datetime
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
16
+ from peft import PeftModel
17
  from pathlib import Path
 
 
18
 
19
+ # Installa psutil se non presente (per HF Spaces)
20
+ try:
21
+ import psutil
22
+ except ImportError:
23
+ import subprocess
24
+ subprocess.check_call(["pip", "install", "psutil", "--break-system-packages"])
25
+ import psutil
26
 
27
+ # Setup logging ultra-dettagliato
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  logging.basicConfig(
29
+ level=logging.DEBUG,
30
+ format='%(asctime)s - [PID:%(process)d] - %(levelname)s - %(message)s'
31
  )
32
  logger = logging.getLogger(__name__)
33
 
34
+ def log_system_status(prefix=""):
35
+ """Log dettagliato dello stato del sistema"""
36
+ logger.info(f"{'='*60}")
37
+ logger.info(f"{prefix} SYSTEM STATUS CHECK")
38
+ logger.info(f"PID: {os.getpid()}")
39
+ logger.info(f"Thread ID: {threading.get_ident()}")
40
+
41
+ # CPU info
42
+ cpu_percent = psutil.cpu_percent(interval=0.1)
43
+ logger.info(f"CPU Usage: {cpu_percent}%")
44
+
45
+ # Memory info
46
+ mem = psutil.virtual_memory()
47
+ logger.info(f"RAM: {mem.used/1e9:.2f}GB used / {mem.total/1e9:.2f}GB total ({mem.percent}%)")
48
+
49
+ # GPU info if available
50
  if torch.cuda.is_available():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  try:
52
+ gpu_mem = torch.cuda.mem_get_info()
53
+ logger.info(f"GPU Memory: {gpu_mem[0]/1e9:.2f}GB free / {gpu_mem[1]/1e9:.2f}GB total")
54
+ logger.info(f"GPU Allocated: {torch.cuda.memory_allocated()/1e9:.2f}GB")
55
+ logger.info(f"GPU Reserved: {torch.cuda.memory_reserved()/1e9:.2f}GB")
56
+ logger.info(f"CUDA Device: {torch.cuda.get_device_name()}")
 
57
  except Exception as e:
58
+ logger.error(f"Error getting GPU info: {e}")
59
+
60
+ logger.info(f"{'='*60}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  class LifeCoachModel:
63
+ def __init__(self, model_name="microsoft/Phi-4", model_save_path="data/life_coach_model",
64
+ train_file=None):
65
+ """Initialize the Life Coach model with extensive logging."""
66
+ logger.info(f"[INIT] Starting LifeCoachModel initialization")
67
+ logger.info(f"[INIT] Model name: {model_name}")
68
+ logger.info(f"[INIT] Save path: {model_save_path}")
69
+
70
+ log_system_status("[INIT-START]")
71
+
 
 
 
 
 
 
 
 
 
72
  self.model_name = model_name
73
+ self.model_save_path = model_save_path
74
+ self.train_file = train_file
75
+
76
+ # Device detection con logging
77
+ logger.info(f"[INIT] Checking CUDA availability...")
78
+ if torch.cuda.is_available():
79
+ self.device = torch.device("cuda")
80
+ logger.info(f"[INIT] βœ… CUDA is available")
81
+ logger.info(f"[INIT] CUDA version: {torch.version.cuda}")
82
+ logger.info(f"[INIT] PyTorch version: {torch.__version__}")
83
+
84
+ # Clear GPU memory
85
+ logger.info(f"[INIT] Clearing GPU cache...")
86
+ torch.cuda.empty_cache()
87
+ gc.collect()
88
+ logger.info(f"[INIT] GPU cache cleared")
 
89
  else:
90
+ self.device = torch.device("cpu")
91
+ logger.warning(f"[INIT] ⚠️ CUDA not available, using CPU")
92
+
93
+ logger.info(f"[INIT] Device set to: {self.device}")
94
+
 
 
 
 
 
 
95
  self.tokenizer = None
96
  self.model = None
97
+
98
+ # System prompt
99
+ self.system_prompt = """You are Robert, a friendly and experienced life coach. Keep responses concise."""
100
+
101
+ logger.info(f"[INIT] LifeCoachModel initialization complete")
102
+ log_system_status("[INIT-END]")
103
+
104
  def load_tokenizer(self):
105
+ """Load tokenizer with detailed logging."""
106
+ logger.info(f"[TOKENIZER] Starting tokenizer loading...")
107
+ logger.info(f"[TOKENIZER] Loading from: {self.model_name}")
108
 
 
 
 
109
  try:
110
+ start_time = time.time()
111
+
112
  self.tokenizer = AutoTokenizer.from_pretrained(
113
  self.model_name,
 
 
114
  trust_remote_code=True,
115
+ cache_dir=os.environ.get('HF_HOME', None)
116
  )
117
+
118
+ load_time = time.time() - start_time
119
+ logger.info(f"[TOKENIZER] βœ… Tokenizer loaded in {load_time:.2f} seconds")
120
+ logger.info(f"[TOKENIZER] Vocab size: {self.tokenizer.vocab_size}")
121
+ logger.info(f"[TOKENIZER] Pad token: {self.tokenizer.pad_token}")
122
+
123
+ # Set padding
124
+ if self.tokenizer.pad_token is None:
125
+ logger.info(f"[TOKENIZER] Setting pad token to eos token")
126
+ self.tokenizer.pad_token = self.tokenizer.eos_token
127
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
128
+
129
+ logger.info(f"[TOKENIZER] Tokenizer ready")
130
+
131
  except Exception as e:
132
+ logger.error(f"[TOKENIZER] ❌ Error loading tokenizer: {e}")
133
+ logger.error(f"[TOKENIZER] Traceback: {traceback.format_exc()}")
134
  raise
 
 
 
 
 
 
 
 
135
 
136
+ def load_model(self, fine_tuned=True):
137
+ """Load model with EXTENSIVE logging at every step."""
138
+ logger.info(f"[MODEL] Starting model loading process...")
139
+ logger.info(f"[MODEL] Fine-tuned: {fine_tuned}")
140
+ log_system_status("[MODEL-LOAD-START]")
141
+
 
 
 
 
 
 
 
142
  if fine_tuned:
143
+ adapter_path = Path(self.model_save_path)
144
+ alternate_path = Path(f"./{self.model_save_path}")
145
+
146
+ logger.info(f"[MODEL] Checking for adapter at: {adapter_path}")
147
+ logger.info(f"[MODEL] Alternate path: {alternate_path}")
148
+
149
+ if alternate_path.exists() and (alternate_path / "adapter_model.safetensors").exists():
150
+ model_path = str(alternate_path)
151
+ logger.info(f"[MODEL] βœ… Found adapter at alternate path: {model_path}")
152
+ elif adapter_path.exists() and (adapter_path / "adapter_model.safetensors").exists():
153
+ model_path = str(adapter_path)
154
+ logger.info(f"[MODEL] βœ… Found adapter at primary path: {model_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  else:
156
+ logger.error(f"[MODEL] ❌ No adapter found, loading base model")
157
+ fine_tuned = False
158
+
159
+ try:
160
+ # Quantization config con logging
161
+ logger.info(f"[MODEL] Setting up quantization config...")
162
+ quantization_config = BitsAndBytesConfig(
163
+ load_in_4bit=True,
164
+ bnb_4bit_compute_dtype=torch.float16,
165
+ bnb_4bit_quant_type="nf4",
166
+ bnb_4bit_use_double_quant=False
 
 
 
 
 
167
  )
168
+ logger.info(f"[MODEL] Quantization config created")
169
+
170
+ # Load base model
171
+ logger.info(f"[MODEL] Loading base model from: {self.model_name}")
172
+ logger.info(f"[MODEL] This may take several minutes...")
173
+
174
+ start_time = time.time()
175
+ checkpoint_counter = 0
176
+
177
+ # Hook per monitorare il caricamento dei checkpoint
178
+ original_print = print
179
+ def counting_print(*args, **kwargs):
180
+ nonlocal checkpoint_counter
181
+ msg = ' '.join(str(arg) for arg in args)
182
+ if 'Loading checkpoint' in msg:
183
+ checkpoint_counter += 1
184
+ logger.info(f"[MODEL] Checkpoint {checkpoint_counter} - {msg}")
185
+ original_print(*args, **kwargs)
186
+
187
+ # Temporaneamente sostituisci print
188
+ import builtins
189
+ builtins.print = counting_print
190
+
191
+ logger.info(f"[MODEL] Calling AutoModelForCausalLM.from_pretrained...")
192
+
193
+ self.model = AutoModelForCausalLM.from_pretrained(
194
+ self.model_name,
195
+ quantization_config=quantization_config,
196
+ device_map="auto",
197
+ trust_remote_code=True,
198
+ torch_dtype=torch.float16,
199
+ cache_dir=os.environ.get('HF_HOME', None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  )
201
+
202
+ # Ripristina print originale
203
+ builtins.print = original_print
204
+
205
+ load_time = time.time() - start_time
206
+ logger.info(f"[MODEL] βœ… Base model loaded in {load_time:.2f} seconds")
207
+
208
+ log_system_status("[MODEL-AFTER-BASE-LOAD]")
209
+
210
+ # Load adapter if fine-tuned
211
+ if fine_tuned:
212
+ logger.info(f"[MODEL] Loading adapter from: {model_path}")
213
+ start_time = time.time()
214
+
215
+ self.model = PeftModel.from_pretrained(
216
+ self.model,
217
+ model_path,
218
+ device_map="auto"
219
+ )
220
+
221
+ adapter_time = time.time() - start_time
222
+ logger.info(f"[MODEL] βœ… Adapter loaded in {adapter_time:.2f} seconds")
223
+
224
+ logger.info(f"[MODEL] Merging adapter with base model...")
225
+ self.model = self.model.merge_and_unload()
226
+ logger.info(f"[MODEL] βœ… Model merged")
227
+
228
+ # Set eval mode
229
+ logger.info(f"[MODEL] Setting model to eval mode...")
230
+ self.model.eval()
231
+
232
+ logger.info(f"[MODEL] Model configuration:")
233
+ logger.info(f"[MODEL] - Parameters: {sum(p.numel() for p in self.model.parameters())/1e9:.2f}B")
234
+ logger.info(f"[MODEL] - Device map: {getattr(self.model, 'hf_device_map', 'Not available')}")
235
+
236
+ log_system_status("[MODEL-LOAD-COMPLETE]")
237
+ logger.info(f"[MODEL] βœ…βœ…βœ… Model loading COMPLETE")
238
+
239
+ except Exception as e:
240
+ logger.error(f"[MODEL] ❌❌❌ CRITICAL ERROR during model loading")
241
+ logger.error(f"[MODEL] Error type: {type(e).__name__}")
242
+ logger.error(f"[MODEL] Error message: {str(e)}")
243
+ logger.error(f"[MODEL] Full traceback:\n{traceback.format_exc()}")
244
+ log_system_status("[MODEL-LOAD-ERROR]")
245
+ raise
246
+
247
+ def generate_response(self, prompt, max_new_tokens=256, conversation_history=None):
248
+ """Generate response with DETAILED logging at every step."""
249
+ logger.info(f"{'='*80}")
250
+ logger.info(f"[GENERATE] STARTING GENERATION PROCESS")
251
+ logger.info(f"[GENERATE] Timestamp: {datetime.now().isoformat()}")
252
+ logger.info(f"[GENERATE] Prompt length: {len(prompt)} chars")
253
+ logger.info(f"[GENERATE] Max new tokens: {max_new_tokens}")
254
+ logger.info(f"[GENERATE] History items: {len(conversation_history) if conversation_history else 0}")
255
+
256
+ log_system_status("[GENERATE-START]")
257
+
258
+ try:
259
+ # Step 1: Build prompt
260
+ logger.info(f"[GENERATE-1] Building full prompt...")
261
+ full_prompt = f"<|system|>\n{self.system_prompt}<|end|>\n"
262
+
263
+ if conversation_history:
264
+ for msg in conversation_history:
265
+ role = msg.get('role', 'user')
266
+ content = msg.get('content', '')
267
+ full_prompt += f"<|{role}|>\n{content}<|end|>\n"
268
+ logger.info(f"[GENERATE-1] Added {role} message: {len(content)} chars")
269
+
270
+ full_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
271
+ logger.info(f"[GENERATE-1] Full prompt built: {len(full_prompt)} chars")
272
+
273
+ # Step 2: Tokenize
274
+ logger.info(f"[GENERATE-2] Starting tokenization...")
275
+ start_time = time.time()
276
+
277
+ inputs = self.tokenizer(
278
+ full_prompt,
279
+ return_tensors="pt",
280
+ truncation=True,
281
+ max_length=2048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  )
283
+
284
+ tokenize_time = time.time() - start_time
285
+ logger.info(f"[GENERATE-2] Tokenization complete in {tokenize_time:.3f}s")
286
+ logger.info(f"[GENERATE-2] Input shape: {inputs['input_ids'].shape}")
287
+ logger.info(f"[GENERATE-2] Number of tokens: {inputs['input_ids'].shape[-1]}")
288
+
289
+ # Step 3: Move to device
290
+ logger.info(f"[GENERATE-3] Moving tensors to device: {self.device}")
291
+ start_time = time.time()
292
+
293
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
294
+
295
+ move_time = time.time() - start_time
296
+ logger.info(f"[GENERATE-3] Tensors moved in {move_time:.3f}s")
297
+
298
+ log_system_status("[GENERATE-BEFORE-MODEL]")
299
+
300
+ # Step 4: Generate
301
+ logger.info(f"[GENERATE-4] ⚑ CALLING MODEL.GENERATE()...")
302
+ logger.info(f"[GENERATE-4] Generation parameters:")
303
+ logger.info(f"[GENERATE-4] - max_new_tokens: {max_new_tokens}")
304
+ logger.info(f"[GENERATE-4] - temperature: 0.7")
305
+ logger.info(f"[GENERATE-4] - do_sample: True")
306
+
307
+ start_time = time.time()
308
+
309
+ # CRITICAL POINT - This is where it might hang
310
+ logger.info(f"[GENERATE-4] >>> ENTERING model.generate() at {datetime.now().isoformat()}")
311
+
312
+ with torch.no_grad():
313
+ outputs = self.model.generate(
314
+ **inputs,
315
+ max_new_tokens=max_new_tokens,
316
+ temperature=0.7,
317
+ do_sample=True,
318
+ top_p=0.9,
319
+ pad_token_id=self.tokenizer.pad_token_id,
320
+ eos_token_id=self.tokenizer.eos_token_id
321
+ )
322
+
323
+ logger.info(f"[GENERATE-4] <<< EXITED model.generate() at {datetime.now().isoformat()}")
324
+
325
+ generate_time = time.time() - start_time
326
+ logger.info(f"[GENERATE-4] βœ… Generation complete in {generate_time:.2f}s")
327
+ logger.info(f"[GENERATE-4] Output shape: {outputs.shape}")
328
+ logger.info(f"[GENERATE-4] Generated {outputs.shape[-1] - inputs['input_ids'].shape[-1]} new tokens")
329
+
330
+ log_system_status("[GENERATE-AFTER-MODEL]")
331
+
332
+ # Step 5: Decode
333
+ logger.info(f"[GENERATE-5] Decoding output...")
334
+ start_time = time.time()
335
+
336
+ response = self.tokenizer.decode(
337
+ outputs[0][inputs['input_ids'].shape[-1]:],
338
+ skip_special_tokens=True
339
+ )
340
+
341
+ decode_time = time.time() - start_time
342
+ logger.info(f"[GENERATE-5] Decoding complete in {decode_time:.3f}s")
343
+ logger.info(f"[GENERATE-5] Response length: {len(response)} chars")
344
+ logger.info(f"[GENERATE-5] Response preview: {response[:100]}...")
345
+
346
+ # Step 6: Cleanup
347
+ logger.info(f"[GENERATE-6] Cleaning up GPU memory...")
348
+ del inputs, outputs
349
+ torch.cuda.empty_cache()
350
+ gc.collect()
351
+ logger.info(f"[GENERATE-6] Cleanup complete")
352
+
353
+ log_system_status("[GENERATE-COMPLETE]")
354
+
355
+ logger.info(f"[GENERATE] βœ…βœ…βœ… GENERATION SUCCESSFUL")
356
+ logger.info(f"[GENERATE] Total time: {time.time() - start_time:.2f}s")
357
+ logger.info(f"{'='*80}")
358
+
359
+ return response
360
+
361
+ except Exception as e:
362
+ logger.error(f"[GENERATE] ❌❌❌ ERROR DURING GENERATION")
363
+ logger.error(f"[GENERATE] Error type: {type(e).__name__}")
364
+ logger.error(f"[GENERATE] Error message: {str(e)}")
365
+ logger.error(f"[GENERATE] Full traceback:\n{traceback.format_exc()}")
366
+ log_system_status("[GENERATE-ERROR]")
367
+
368
+ # Return fallback message
369
+ return "I apologize, but I encountered an error while generating a response. Please try again."
370
+
371
+ # Test if this file is run directly
372
  if __name__ == "__main__":
373
+ import threading
374
+ logger.info("Running test...")
375
+
376
+ model = LifeCoachModel()
377
+ model.load_tokenizer()
378
+ model.load_model(fine_tuned=True)
379
+
380
+ response = model.generate_response("Hello, how are you?", max_new_tokens=50)
381
+ logger.info(f"Test response: {response}")
life_coach_v1_old.py ADDED
@@ -0,0 +1,1222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Life Coach v1 - Phi-4 Fine-tuned Life Coaching Assistant
4
+
5
+ A simple command-line life coaching assistant using Microsoft's Phi-4 model.
6
+ Fine-tunes on life coaching conversations and provides interactive chat sessions.
7
+ """
8
+
9
+ import torch
10
+ import json
11
+ import os
12
+ import gc
13
+ import argparse
14
+ from pathlib import Path
15
+ from typing import Optional
16
+ from tqdm import tqdm
17
+
18
+ # Set PyTorch CUDA memory allocation config to reduce fragmentation
19
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
20
+
21
+ from transformers import (
22
+ AutoTokenizer,
23
+ AutoModelForCausalLM,
24
+ TrainingArguments,
25
+ Trainer,
26
+ DataCollatorForSeq2Seq
27
+ )
28
+ from datasets import Dataset, load_dataset, concatenate_datasets
29
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType
30
+ import logging
31
+ import random
32
+ import shutil
33
+ import gzip
34
+ from typing import List, Dict
35
+
36
+ # Configure logging
37
+ logging.basicConfig(
38
+ level=logging.INFO,
39
+ format='%(asctime)s - %(levelname)s - %(message)s'
40
+ )
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ def cleanup_gpu_memory():
45
+ """
46
+ Clean up GPU memory before starting the program.
47
+ Clears PyTorch cache and runs garbage collection.
48
+ """
49
+ logger.info("=" * 80)
50
+ logger.info("GPU MEMORY CLEANUP")
51
+ logger.info("=" * 80)
52
+
53
+ if torch.cuda.is_available():
54
+ # Clear PyTorch CUDA cache
55
+ torch.cuda.empty_cache()
56
+
57
+ # Run garbage collection
58
+ gc.collect()
59
+
60
+ # Get GPU memory stats
61
+ for i in range(torch.cuda.device_count()):
62
+ total = torch.cuda.get_device_properties(i).total_memory / 1024**3
63
+ reserved = torch.cuda.memory_reserved(i) / 1024**3
64
+ allocated = torch.cuda.memory_allocated(i) / 1024**3
65
+ free = total - reserved
66
+
67
+ logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
68
+ logger.info(f" Total memory: {total:.2f} GB")
69
+ logger.info(f" Reserved: {reserved:.2f} GB")
70
+ logger.info(f" Allocated: {allocated:.2f} GB")
71
+ logger.info(f" Free: {free:.2f} GB")
72
+
73
+ if reserved > 1.0: # More than 1GB reserved
74
+ logger.warning(f" ⚠️ GPU {i} has {reserved:.2f} GB reserved!")
75
+ logger.warning(f" ⚠️ This might be from a previous run.")
76
+ logger.warning(f" ⚠️ If you encounter OOM errors, kill other processes using:")
77
+ logger.warning(f" ⚠️ nvidia-smi | grep python")
78
+ else:
79
+ logger.warning("No CUDA GPUs available! Running on CPU (very slow).")
80
+
81
+ logger.info("=" * 80)
82
+
83
+
84
+ def clear_hf_cache():
85
+ """Clear Hugging Face datasets cache to save disk space."""
86
+ try:
87
+ from datasets import config
88
+ cache_dir = config.HF_DATASETS_CACHE
89
+ if os.path.exists(cache_dir):
90
+ # Get size before clearing
91
+ size_mb = sum(os.path.getsize(os.path.join(dirpath,filename))
92
+ for dirpath, _, filenames in os.walk(cache_dir)
93
+ for filename in filenames) / (1024 * 1024)
94
+
95
+ logger.info(f"Clearing HF cache ({size_mb:.1f} MB)...")
96
+ shutil.rmtree(cache_dir, ignore_errors=True)
97
+ os.makedirs(cache_dir, exist_ok=True)
98
+ logger.info("βœ“ Cache cleared")
99
+ except Exception as e:
100
+ logger.warning(f"Failed to clear cache: {e}")
101
+
102
+
103
+ def load_mental_health_counseling() -> List[Dict]:
104
+ """Load Amod/mental_health_counseling_conversations dataset - ALL samples."""
105
+ logger.info(f"Loading mental health counseling dataset...")
106
+ try:
107
+ dataset = load_dataset("Amod/mental_health_counseling_conversations", split="train")
108
+ logger.info(f" Dataset has {len(dataset)} samples available")
109
+
110
+ conversations = []
111
+ for item in dataset:
112
+ # Format: Context (user) -> Response (assistant)
113
+ conversations.append({
114
+ "messages": [
115
+ {"role": "user", "content": item.get("Context", "").strip()},
116
+ {"role": "assistant", "content": item.get("Response", "").strip()}
117
+ ]
118
+ })
119
+
120
+ logger.info(f"βœ“ Loaded {len(conversations)} mental health counseling conversations")
121
+ return conversations
122
+ except Exception as e:
123
+ logger.warning(f"Failed to load mental health counseling dataset: {e}")
124
+ return []
125
+
126
+
127
+ def load_counsel_chat() -> List[Dict]:
128
+ """Load nbertagnolli/counsel-chat dataset - ALL samples."""
129
+ logger.info(f"Loading CounselChat (nbertagnolli) dataset...")
130
+ try:
131
+ dataset = load_dataset("nbertagnolli/counsel-chat", split="train")
132
+ logger.info(f" Dataset has {len(dataset)} samples available")
133
+
134
+ conversations = []
135
+ for item in dataset:
136
+ # Try different possible field names
137
+ question = None
138
+ answer = None
139
+
140
+ # Common field patterns
141
+ for q_field in ["questionText", "question", "query", "input", "user_message"]:
142
+ if q_field in item and item.get(q_field):
143
+ question = item[q_field].strip()
144
+ break
145
+
146
+ for a_field in ["answerText", "answer", "response", "output", "counselor_message"]:
147
+ if a_field in item and item.get(a_field):
148
+ answer = item[a_field].strip()
149
+ break
150
+
151
+ if question and answer:
152
+ conversations.append({
153
+ "messages": [
154
+ {"role": "user", "content": question},
155
+ {"role": "assistant", "content": answer}
156
+ ]
157
+ })
158
+
159
+ logger.info(f"βœ“ Loaded {len(conversations)} CounselChat conversations")
160
+ return conversations
161
+ except Exception as e:
162
+ logger.warning(f"Failed to load CounselChat dataset: {e}")
163
+ return []
164
+
165
+
166
+ def load_cbt_cognitive_distortions() -> List[Dict]:
167
+ """Load epsilon3/cbt-cognitive-distortions-analysis dataset - ALL samples."""
168
+ logger.info(f"Loading CBT Cognitive Distortions dataset...")
169
+ try:
170
+ dataset = load_dataset("epsilon3/cbt-cognitive-distortions-analysis", split="train")
171
+ logger.info(f" Dataset has {len(dataset)} samples available")
172
+
173
+ conversations = []
174
+ for item in dataset:
175
+ # Try different field patterns
176
+ user_msg = None
177
+ assistant_msg = None
178
+
179
+ for u_field in ["input", "text", "thought", "statement", "user_input"]:
180
+ if u_field in item and item.get(u_field):
181
+ user_msg = item[u_field].strip()
182
+ break
183
+
184
+ for a_field in ["output", "analysis", "reframe", "response", "cbt_response"]:
185
+ if a_field in item and item.get(a_field):
186
+ assistant_msg = item[a_field].strip()
187
+ break
188
+
189
+ if user_msg and assistant_msg:
190
+ conversations.append({
191
+ "messages": [
192
+ {"role": "user", "content": user_msg},
193
+ {"role": "assistant", "content": assistant_msg}
194
+ ]
195
+ })
196
+
197
+ logger.info(f"βœ“ Loaded {len(conversations)} CBT Cognitive Distortions conversations")
198
+ return conversations
199
+ except Exception as e:
200
+ logger.warning(f"Failed to load CBT Cognitive Distortions dataset: {e}")
201
+ return []
202
+
203
+
204
+ def load_peer_counseling_reflections() -> List[Dict]:
205
+ """Load emoneil/reflections-in-peer-counseling dataset - ALL samples."""
206
+ logger.info(f"Loading Peer Counseling Reflections dataset...")
207
+ try:
208
+ dataset = load_dataset("emoneil/reflections-in-peer-counseling", split="train")
209
+ logger.info(f" Dataset has {len(dataset)} samples available")
210
+
211
+ conversations = []
212
+ for item in dataset:
213
+ # Try different field patterns
214
+ user_msg = None
215
+ assistant_msg = None
216
+
217
+ for u_field in ["question", "statement", "input", "user_message", "counselee"]:
218
+ if u_field in item and item.get(u_field):
219
+ user_msg = item[u_field].strip()
220
+ break
221
+
222
+ for a_field in ["reflection", "response", "output", "counselor_response", "counselor"]:
223
+ if a_field in item and item.get(a_field):
224
+ assistant_msg = item[a_field].strip()
225
+ break
226
+
227
+ if user_msg and assistant_msg:
228
+ conversations.append({
229
+ "messages": [
230
+ {"role": "user", "content": user_msg},
231
+ {"role": "assistant", "content": assistant_msg}
232
+ ]
233
+ })
234
+
235
+ logger.info(f"βœ“ Loaded {len(conversations)} Peer Counseling Reflections conversations")
236
+ return conversations
237
+ except Exception as e:
238
+ logger.warning(f"Failed to load Peer Counseling Reflections dataset: {e}")
239
+ return []
240
+
241
+
242
+ def load_dolly_dataset() -> List[Dict]:
243
+ """Load databricks-dolly-15k dataset (instruction-following) - ALL relevant samples."""
244
+ logger.info(f"Loading Dolly instruction dataset...")
245
+ try:
246
+ dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
247
+ logger.info(f" Dataset has {len(dataset)} samples available")
248
+
249
+ # Filter for relevant categories (brainstorming, open_qa, creative_writing)
250
+ relevant_categories = {"brainstorming", "open_qa", "creative_writing", "general_qa"}
251
+
252
+ conversations = []
253
+ for item in dataset:
254
+ if item.get("category", "") in relevant_categories:
255
+ instruction = item.get("instruction", "").strip()
256
+ context = item.get("context", "").strip()
257
+ response = item.get("response", "").strip()
258
+
259
+ # Combine instruction and context if both exist
260
+ user_message = f"{instruction}\n\n{context}" if context else instruction
261
+
262
+ if user_message and response:
263
+ conversations.append({
264
+ "messages": [
265
+ {"role": "user", "content": user_message},
266
+ {"role": "assistant", "content": response}
267
+ ]
268
+ })
269
+
270
+ logger.info(f"βœ“ Loaded {len(conversations)} Dolly instruction conversations (filtered from {len(dataset)} total)")
271
+ return conversations
272
+ except Exception as e:
273
+ logger.warning(f"Failed to load Dolly dataset: {e}")
274
+ return []
275
+
276
+
277
+ def load_mentalchat16k() -> List[Dict]:
278
+ """Load ShenLab/MentalChat16K dataset - ALL samples."""
279
+ logger.info(f"Loading MentalChat16K dataset...")
280
+ try:
281
+ dataset = load_dataset("ShenLab/MentalChat16K", split="train")
282
+ logger.info(f" Dataset has {len(dataset)} samples available")
283
+
284
+ conversations = []
285
+ for item in dataset:
286
+ # Try different possible field names
287
+ user_msg = None
288
+ assistant_msg = None
289
+
290
+ # Common field name patterns
291
+ for user_field in ["query", "question", "input", "user", "prompt", "instruction"]:
292
+ if user_field in item and item.get(user_field):
293
+ user_msg = item[user_field].strip()
294
+ break
295
+
296
+ for assistant_field in ["response", "answer", "output", "assistant", "reply"]:
297
+ if assistant_field in item and item.get(assistant_field):
298
+ assistant_msg = item[assistant_field].strip()
299
+ break
300
+
301
+ if user_msg and assistant_msg:
302
+ conversations.append({
303
+ "messages": [
304
+ {"role": "user", "content": user_msg},
305
+ {"role": "assistant", "content": assistant_msg}
306
+ ]
307
+ })
308
+
309
+ logger.info(f"βœ“ Loaded {len(conversations)} MentalChat16K conversations")
310
+ return conversations
311
+ except Exception as e:
312
+ logger.warning(f"Failed to load MentalChat16K dataset: {e}")
313
+ return []
314
+
315
+
316
+ def load_additional_mental_health_datasets() -> List[Dict]:
317
+ """Load additional mental health datasets - ALL samples."""
318
+ logger.info(f"Loading additional mental health datasets...")
319
+
320
+ all_conversations = []
321
+
322
+ # List of additional datasets to try
323
+ additional_datasets = [
324
+ ("heliosbrahma/mental_health_chatbot_dataset", ["prompt", "question"], ["response", "answer"]),
325
+ ("mpingale/mental-health-chat-dataset", ["question", "query"], ["answer", "response"]),
326
+ ("sauravjoshi23/psychology-dataset", ["input", "question"], ["output", "answer"]),
327
+ ]
328
+
329
+ for dataset_name, user_fields, assistant_fields in additional_datasets:
330
+ try:
331
+ logger.info(f" Loading {dataset_name}...")
332
+ dataset = load_dataset(dataset_name, split="train")
333
+ logger.info(f" Has {len(dataset)} samples available")
334
+
335
+ for item in dataset:
336
+ # Try different field names
337
+ user_msg = None
338
+ assistant_msg = None
339
+
340
+ for field in user_fields:
341
+ if field in item and item.get(field):
342
+ user_msg = item[field].strip()
343
+ break
344
+
345
+ for field in assistant_fields:
346
+ if field in item and item.get(field):
347
+ assistant_msg = item[field].strip()
348
+ break
349
+
350
+ if user_msg and assistant_msg:
351
+ all_conversations.append({
352
+ "messages": [
353
+ {"role": "user", "content": user_msg},
354
+ {"role": "assistant", "content": assistant_msg}
355
+ ]
356
+ })
357
+
358
+ logger.info(f" βœ“ Loaded {len([c for c in all_conversations if c])} from this dataset")
359
+
360
+ except Exception as e:
361
+ logger.warning(f" Failed: {e}")
362
+ continue
363
+
364
+ logger.info(f"βœ“ Loaded {len(all_conversations)} additional mental health conversations total")
365
+ return all_conversations
366
+
367
+
368
+ def quality_filter_conversation(conv: Dict, min_response_length: int = 50, max_total_length: int = 2048) -> bool:
369
+ """Filter conversation based on quality criteria."""
370
+ try:
371
+ messages = conv.get("messages", [])
372
+ if len(messages) < 2:
373
+ return False
374
+
375
+ # Check response length
376
+ assistant_msg = [m for m in messages if m.get("role") == "assistant"]
377
+ if not assistant_msg:
378
+ return False
379
+
380
+ response = assistant_msg[0].get("content", "")
381
+ if len(response) < min_response_length:
382
+ return False
383
+
384
+ # Check total length
385
+ total_length = sum(len(m.get("content", "")) for m in messages)
386
+ if total_length > max_total_length:
387
+ return False
388
+
389
+ # Check for empty messages
390
+ if any(not m.get("content", "").strip() for m in messages):
391
+ return False
392
+
393
+ return True
394
+ except:
395
+ return False
396
+
397
+
398
+ def load_mixed_dataset(
399
+ total_samples: int = 100000,
400
+ cache_file: str = "mixed_lifecoach_dataset_100k.jsonl.gz", # Now compressed by default
401
+ use_cache: bool = True
402
+ ) -> List[Dict]:
403
+ """
404
+ Load and mix multiple datasets for comprehensive life coaching training.
405
+ Saves compressed cache to save disk space.
406
+
407
+ Datasets loaded (ALL available samples):
408
+ 1. Mental Health Counseling (Amod/mental_health_counseling_conversations)
409
+ 2. CounselChat (nbertagnolli/counsel-chat)
410
+ 3. CBT Cognitive Distortions (epsilon3/cbt-cognitive-distortions-analysis)
411
+ 4. Peer Counseling Reflections (emoneil/reflections-in-peer-counseling)
412
+ 5. MentalChat16K (ShenLab/MentalChat16K)
413
+ 6. Dolly Instructions (databricks/databricks-dolly-15k - filtered categories)
414
+ 7-8. Additional mental health datasets (heliosbrahma, mpingale, sauravjoshi23)
415
+ """
416
+ cache_path = Path(cache_file)
417
+ cache_path_uncompressed = Path(cache_file.replace('.gz', ''))
418
+
419
+ # Try to load from compressed cache first
420
+ if use_cache and cache_path.exists():
421
+ logger.info(f"Loading cached dataset from {cache_file} (compressed)...")
422
+ try:
423
+ conversations = []
424
+ with gzip.open(cache_path, 'rt', encoding='utf-8') as f:
425
+ for line in f:
426
+ conversations.append(json.loads(line.strip()))
427
+ logger.info(f"βœ“ Loaded {len(conversations)} conversations from compressed cache")
428
+ return conversations
429
+ except Exception as e:
430
+ logger.warning(f"Failed to load compressed cache: {e}. Trying uncompressed...")
431
+
432
+ # Try uncompressed cache (backward compatibility)
433
+ if use_cache and cache_path_uncompressed.exists():
434
+ logger.info(f"Loading cached dataset from {cache_path_uncompressed} (uncompressed)...")
435
+ try:
436
+ conversations = []
437
+ with open(cache_path_uncompressed, 'r', encoding='utf-8') as f:
438
+ for line in f:
439
+ conversations.append(json.loads(line.strip()))
440
+ logger.info(f"βœ“ Loaded {len(conversations)} conversations from uncompressed cache")
441
+ return conversations
442
+ except Exception as e:
443
+ logger.warning(f"Failed to load cache: {e}. Rebuilding dataset...")
444
+
445
+ # Load ALL available samples from each dataset
446
+ logger.info("=" * 80)
447
+ logger.info(f"LOADING MIXED DATASET (Target: ~{total_samples} samples)")
448
+ logger.info("Loading ALL available samples from each dataset")
449
+ logger.info("=" * 80)
450
+
451
+ all_conversations = []
452
+
453
+ # Load each dataset ONE AT A TIME and clear cache after each
454
+ # This saves disk space by not keeping all downloads simultaneously
455
+
456
+ logger.info("Dataset 1/8: Mental Health Counseling (Amod)")
457
+ all_conversations.extend(load_mental_health_counseling())
458
+ logger.info(f" Running total: {len(all_conversations)} conversations")
459
+ clear_hf_cache()
460
+ gc.collect()
461
+
462
+ # Stop early if we've reached target
463
+ if len(all_conversations) >= total_samples:
464
+ logger.info(f"βœ“ Reached target of {total_samples} samples, stopping dataset loading")
465
+ else:
466
+ logger.info("Dataset 2/8: CounselChat (nbertagnolli)")
467
+ all_conversations.extend(load_counsel_chat())
468
+ logger.info(f" Running total: {len(all_conversations)} conversations")
469
+ clear_hf_cache()
470
+ gc.collect()
471
+
472
+ if len(all_conversations) >= total_samples:
473
+ logger.info(f"βœ“ Reached target of {total_samples} samples, stopping dataset loading")
474
+ else:
475
+ logger.info("Dataset 3/8: CBT Cognitive Distortions (epsilon3)")
476
+ all_conversations.extend(load_cbt_cognitive_distortions())
477
+ logger.info(f" Running total: {len(all_conversations)} conversations")
478
+ clear_hf_cache()
479
+ gc.collect()
480
+
481
+ if len(all_conversations) >= total_samples:
482
+ logger.info(f"βœ“ Reached target of {total_samples} samples, stopping dataset loading")
483
+ else:
484
+ logger.info("Dataset 4/8: Peer Counseling Reflections (emoneil)")
485
+ all_conversations.extend(load_peer_counseling_reflections())
486
+ logger.info(f" Running total: {len(all_conversations)} conversations")
487
+ clear_hf_cache()
488
+ gc.collect()
489
+
490
+ if len(all_conversations) >= total_samples:
491
+ logger.info(f"βœ“ Reached target of {total_samples} samples, stopping dataset loading")
492
+ else:
493
+ logger.info("Dataset 5/8: MentalChat16K (ShenLab)")
494
+ all_conversations.extend(load_mentalchat16k())
495
+ logger.info(f" Running total: {len(all_conversations)} conversations")
496
+ clear_hf_cache()
497
+ gc.collect()
498
+
499
+ if len(all_conversations) >= total_samples:
500
+ logger.info(f"βœ“ Reached target of {total_samples} samples, stopping dataset loading")
501
+ else:
502
+ logger.info("Dataset 6/8: Dolly Instructions (databricks)")
503
+ all_conversations.extend(load_dolly_dataset())
504
+ logger.info(f" Running total: {len(all_conversations)} conversations")
505
+ clear_hf_cache()
506
+ gc.collect()
507
+
508
+ if len(all_conversations) >= total_samples:
509
+ logger.info(f"βœ“ Reached target of {total_samples} samples, stopping dataset loading")
510
+ else:
511
+ logger.info("Datasets 7-8: Additional Mental Health Datasets")
512
+ all_conversations.extend(load_additional_mental_health_datasets())
513
+ logger.info(f" Running total: {len(all_conversations)} conversations")
514
+ clear_hf_cache()
515
+ gc.collect()
516
+
517
+ logger.info("=" * 80)
518
+ logger.info(f"Total conversations loaded: {len(all_conversations)}")
519
+
520
+ # Apply quality filtering
521
+ logger.info("Applying quality filters...")
522
+ filtered_conversations = [conv for conv in all_conversations if quality_filter_conversation(conv)]
523
+ logger.info(f"βœ“ After filtering: {len(filtered_conversations)} conversations")
524
+
525
+ # Shuffle to mix datasets
526
+ random.shuffle(filtered_conversations)
527
+
528
+ # Trim to target size
529
+ if len(filtered_conversations) > total_samples:
530
+ filtered_conversations = filtered_conversations[:total_samples]
531
+
532
+ logger.info(f"Final dataset size: {len(filtered_conversations)} conversations")
533
+
534
+ # Save compressed cache to save disk space
535
+ if use_cache:
536
+ logger.info(f"Saving compressed cache to {cache_file}...")
537
+ try:
538
+ with gzip.open(cache_path, 'wt', encoding='utf-8') as f:
539
+ for conv in filtered_conversations:
540
+ f.write(json.dumps(conv, ensure_ascii=False) + '\n')
541
+
542
+ # Get file sizes for comparison
543
+ compressed_size_mb = cache_path.stat().st_size / (1024 * 1024)
544
+ logger.info(f"βœ“ Compressed cache saved successfully ({compressed_size_mb:.1f} MB)")
545
+ except Exception as e:
546
+ logger.warning(f"Failed to save compressed cache: {e}")
547
+
548
+ logger.info("=" * 80)
549
+ return filtered_conversations
550
+
551
+
552
+ class LifeCoachModel:
553
+ """Life coaching assistant using Phi-4 model."""
554
+
555
+ def __init__(
556
+ self,
557
+ model_name: str = "microsoft/Phi-4",
558
+ model_save_path: str = "/data/life_coach_model",
559
+ train_file: str = "cbt_life_coach_improved_50000.jsonl",
560
+ max_length: int = 2048
561
+ ):
562
+ """
563
+ Initialize the Life Coach model.
564
+
565
+ Args:
566
+ model_name: Hugging Face model identifier
567
+ model_save_path: Path to save/load fine-tuned model
568
+ train_file: Path to training data file (JSONL format)
569
+ max_length: Maximum sequence length for training
570
+ """
571
+ self.model_name = model_name
572
+
573
+ # Check if /data is writable, otherwise use local directory
574
+ save_path = Path(model_save_path)
575
+ if str(save_path).startswith("/data"):
576
+ try:
577
+ Path("/data").mkdir(parents=True, exist_ok=True)
578
+ # Test write permissions
579
+ test_file = Path("/data/.test_write")
580
+ test_file.touch()
581
+ test_file.unlink()
582
+ self.model_save_path = save_path
583
+ logger.info(f"Using /data directory for model storage: {save_path}")
584
+ except (PermissionError, OSError) as e:
585
+ # Fall back to local directory
586
+ local_path = Path("./data/life_coach_model")
587
+ logger.warning(f"/data directory not writable ({e}), using local directory: {local_path}")
588
+ self.model_save_path = local_path
589
+ else:
590
+ self.model_save_path = save_path
591
+
592
+ self.train_file = Path(train_file)
593
+ self.max_length = max_length
594
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
595
+
596
+ logger.info(f"Device: {self.device}")
597
+ logger.info(f"Model: {model_name}")
598
+ logger.info(f"Save path: {self.model_save_path}")
599
+ logger.info(f"Training file: {self.train_file}")
600
+
601
+ self.tokenizer = None
602
+ self.model = None
603
+
604
+ def load_tokenizer(self):
605
+ """Carica il tokenizer da /data/hf_cache (persistente) o scaricalo una volta."""
606
+ logger.info("Loading tokenizer...")
607
+
608
+ cache_dir = "/data/hf_cache"
609
+ os.makedirs(cache_dir, exist_ok=True)
610
+
611
+ try:
612
+ self.tokenizer = AutoTokenizer.from_pretrained(
613
+ self.model_name,
614
+ cache_dir=cache_dir,
615
+ local_files_only=False, # Permette download solo se non esiste
616
+ trust_remote_code=True,
617
+ use_fast=True
618
+ )
619
+ logger.info(f"Tokenizer caricato (cache: {cache_dir})")
620
+ except Exception as e:
621
+ logger.error(f"Errore critico nel caricamento tokenizer: {e}")
622
+ raise
623
+ def load_model(self, fine_tuned=True):
624
+ """Load the fine-tuned model with safe settings for HF Spaces."""
625
+ logger.info(f"Loading {'fine-tuned' if fine_tuned else 'base'} model from {self.model_save_path}")
626
+
627
+ # Forza impostazioni sicure
628
+ import torch
629
+ from transformers import AutoModelForCausalLM
630
+ from peft import PeftModel
631
+
632
+ base_model_name = self.model_name
633
+
634
+ # Carica modello base con device_map e offload
635
+ base_model = AutoModelForCausalLM.from_pretrained(
636
+ base_model_name,
637
+ torch_dtype=torch.float16,
638
+ device_map="auto",
639
+ trust_remote_code=True,
640
+ low_cpu_mem_usage=True,
641
+ offload_folder="/tmp/offload", # Usa /tmp per offload
642
+ cache_dir="/data/hf_cache"
643
+ )
644
+
645
+ if fine_tuned:
646
+ logger.info(f"Loading adapter from {self.model_save_path}")
647
+ self.model = PeftModel.from_pretrained(
648
+ base_model,
649
+ self.model_save_path,
650
+ device_map="auto",
651
+ offload_folder="/tmp/offload",
652
+ torch_dtype=torch.float16
653
+ )
654
+ else:
655
+ self.model = base_model
656
+
657
+ self.model.eval()
658
+ logger.info("Model loaded successfully!")
659
+
660
+ def load_training_data(self, num_samples: Optional[int] = None) -> Dataset:
661
+ """
662
+ Load training data from mixed datasets or JSONL file.
663
+
664
+ Args:
665
+ num_samples: Number of samples to load (None for 100,000 default)
666
+
667
+ Returns:
668
+ Dataset object
669
+ """
670
+ # Try to load from mixed datasets first (new method)
671
+ # If train_file doesn't exist or is the old one, use mixed datasets
672
+ use_mixed_datasets = True
673
+
674
+ if self.train_file.exists():
675
+ # Check if it's the old single dataset file
676
+ if "cbt_life_coach" in str(self.train_file):
677
+ logger.info("Found old training file. Using new mixed datasets instead...")
678
+ use_mixed_datasets = True
679
+ else:
680
+ # It might be a cached mixed dataset
681
+ logger.info(f"Found training file at {self.train_file}")
682
+ use_mixed_datasets = False
683
+
684
+ if use_mixed_datasets:
685
+ # Load mixed datasets from Hugging Face
686
+ logger.info("Loading mixed datasets from Hugging Face...")
687
+ if num_samples is None:
688
+ num_samples = 100000 # Default to 100k samples
689
+
690
+ # Load mixed dataset (will use cache if available)
691
+ cache_file = f"mixed_lifecoach_dataset_{num_samples}.jsonl.gz" # Compressed format
692
+ data = load_mixed_dataset(
693
+ total_samples=num_samples,
694
+ cache_file=cache_file,
695
+ use_cache=True
696
+ )
697
+ else:
698
+ # Fall back to loading from JSONL file
699
+ logger.info(f"Loading training data from {self.train_file}")
700
+ data = []
701
+ with open(self.train_file, 'r', encoding='utf-8') as f:
702
+ for i, line in enumerate(f):
703
+ if num_samples and i >= num_samples:
704
+ break
705
+ try:
706
+ data.append(json.loads(line.strip()))
707
+ except json.JSONDecodeError:
708
+ logger.warning(f"Skipping invalid JSON at line {i+1}")
709
+
710
+ logger.info(f"Loaded {len(data)} training examples")
711
+
712
+ # Convert to Hugging Face Dataset
713
+ dataset = Dataset.from_list(data)
714
+
715
+ # Preprocess for Phi-4 format
716
+ logger.info("Preprocessing data for Phi-4 format...")
717
+ dataset = dataset.map(
718
+ self._preprocess_function,
719
+ batched=True,
720
+ remove_columns=dataset.column_names,
721
+ desc="Tokenizing"
722
+ )
723
+
724
+ return dataset
725
+
726
+ def _preprocess_function(self, examples):
727
+ """
728
+ Preprocess data into Phi-4 chat format.
729
+
730
+ Phi-4 uses:
731
+ <|system|>
732
+ {system message}<|end|>
733
+ <|user|>
734
+ {user message}<|end|>
735
+ <|assistant|>
736
+ {assistant response}<|end|>
737
+ """
738
+ texts = []
739
+
740
+ # Handle both 'conversations' (our format) and 'messages' (standard format)
741
+ conversations_key = 'conversations' if 'conversations' in examples else 'messages'
742
+
743
+ for conversation in examples[conversations_key]:
744
+ text = ""
745
+ for message in conversation:
746
+ # Handle both 'from'/'value' and 'role'/'content' formats
747
+ if 'from' in message:
748
+ role = message['from']
749
+ content = message['value']
750
+ else:
751
+ role = message['role']
752
+ content = message['content']
753
+
754
+ # Convert to Phi-4 format
755
+ if role == 'system':
756
+ text += f"<|system|>\n{content}<|end|>\n"
757
+ elif role == 'user':
758
+ text += f"<|user|>\n{content}<|end|>\n"
759
+ elif role == 'assistant':
760
+ text += f"<|assistant|>\n{content}<|end|>\n"
761
+
762
+ texts.append(text)
763
+
764
+ # Tokenize with dynamic padding (like quantum server)
765
+ # Don't pad here - let DataCollatorForSeq2Seq handle it dynamically per batch
766
+ model_inputs = self.tokenizer(
767
+ texts,
768
+ max_length=self.max_length,
769
+ truncation=True,
770
+ padding=False, # Dynamic padding - saves massive memory!
771
+ return_tensors=None # Don't convert to tensors yet
772
+ )
773
+
774
+ # Set labels (for causal language modeling, labels = input_ids)
775
+ # Note: .copy() instead of .clone() since we're not using tensors yet
776
+ model_inputs["labels"] = model_inputs["input_ids"].copy()
777
+
778
+ return model_inputs
779
+
780
+ def setup_lora(self):
781
+ """Setup LoRA (Low-Rank Adaptation) for efficient fine-tuning."""
782
+ logger.info("Setting up LoRA adapters...")
783
+
784
+ # Prepare model for k-bit training (critical for load_in_8bit=True)
785
+ logger.info("Preparing model for 8-bit training...")
786
+ self.model = prepare_model_for_kbit_training(self.model)
787
+
788
+ # Enable gradient checkpointing to save GPU memory
789
+ # This reduces memory usage by 20-30 GB with minimal performance impact
790
+ if hasattr(self.model, 'gradient_checkpointing_enable'):
791
+ self.model.gradient_checkpointing_enable()
792
+ logger.info("βœ“ Gradient checkpointing enabled (saves 20-30 GB GPU memory)")
793
+
794
+ # LoRA configuration
795
+ lora_config = LoraConfig(
796
+ task_type=TaskType.CAUSAL_LM,
797
+ r=16, # Rank
798
+ lora_alpha=32,
799
+ lora_dropout=0.1,
800
+ bias="none",
801
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"] # Attention layers
802
+ )
803
+
804
+ # Apply LoRA
805
+ self.model = get_peft_model(self.model, lora_config)
806
+
807
+ # Print trainable parameters
808
+ trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
809
+ total_params = sum(p.numel() for p in self.model.parameters())
810
+
811
+ logger.info(f"Trainable parameters: {trainable_params:,} / {total_params:,} "
812
+ f"({100 * trainable_params / total_params:.2f}%)")
813
+
814
+ def fine_tune(
815
+ self,
816
+ num_samples: Optional[int] = 5000,
817
+ epochs: int = 3,
818
+ batch_size: int = 8,
819
+ learning_rate: float = 5e-5,
820
+ gradient_accumulation_steps: int = 2
821
+ ):
822
+ """
823
+ Fine-tune the model on life coaching data.
824
+
825
+ Args:
826
+ num_samples: Number of training samples (None for all)
827
+ epochs: Number of training epochs
828
+ batch_size: Training batch size
829
+ learning_rate: Learning rate
830
+ gradient_accumulation_steps: Gradient accumulation steps (for memory efficiency)
831
+ """
832
+ logger.info("=" * 80)
833
+ logger.info("STARTING FINE-TUNING")
834
+ logger.info("=" * 80)
835
+
836
+ # Load data
837
+ dataset = self.load_training_data(num_samples)
838
+
839
+ # Setup LoRA
840
+ self.setup_lora()
841
+
842
+ # Training arguments
843
+ training_args = TrainingArguments(
844
+ output_dir="./training_output",
845
+ num_train_epochs=epochs,
846
+ per_device_train_batch_size=batch_size,
847
+ gradient_accumulation_steps=gradient_accumulation_steps,
848
+ learning_rate=learning_rate,
849
+ fp16=True, # Mixed precision training
850
+ logging_steps=10,
851
+ save_strategy="epoch",
852
+ save_total_limit=2,
853
+ warmup_steps=100,
854
+ weight_decay=0.01,
855
+ report_to="none", # Disable wandb/tensorboard
856
+ )
857
+
858
+ # Data collator
859
+ data_collator = DataCollatorForSeq2Seq(
860
+ tokenizer=self.tokenizer,
861
+ model=self.model,
862
+ padding=True
863
+ )
864
+
865
+ # Trainer
866
+ trainer = Trainer(
867
+ model=self.model,
868
+ args=training_args,
869
+ train_dataset=dataset,
870
+ data_collator=data_collator,
871
+ )
872
+
873
+ # Train
874
+ logger.info("Training started...")
875
+ trainer.train()
876
+
877
+ logger.info("=" * 80)
878
+ logger.info("TRAINING COMPLETED")
879
+ logger.info("=" * 80)
880
+
881
+ # Save model
882
+ self.save_model()
883
+
884
+ def save_model(self):
885
+ """Save the fine-tuned model to disk."""
886
+ logger.info(f"Saving model to {self.model_save_path}")
887
+
888
+ self.model_save_path.mkdir(parents=True, exist_ok=True)
889
+
890
+ # Save model and tokenizer
891
+ self.model.save_pretrained(str(self.model_save_path))
892
+ self.tokenizer.save_pretrained(str(self.model_save_path))
893
+
894
+ logger.info("Model saved successfully")
895
+
896
+ def generate_response(self, prompt: str, max_new_tokens: int = 128, conversation_history: list = None) -> str:
897
+ """
898
+ Generate a response to a user prompt.
899
+
900
+ Args:
901
+ prompt: User's input message
902
+ max_new_tokens: Maximum tokens to generate
903
+ conversation_history: List of previous messages for context
904
+
905
+ Returns:
906
+ Generated response
907
+ """
908
+ # Build full conversation context with system prompt
909
+ formatted_prompt = ""
910
+
911
+ # Add system prompt to guide the model's behavior
912
+ system_prompt = """You are Robert, a friendly and experienced life coach. Here's your background:
913
+
914
+ About You:
915
+ - Name: Robert (Bob to friends)
916
+ - Age: 42 years old
917
+ - Experience: 15 years as a certified life coach and motivational speaker
918
+ - Education: Master's degree in Psychology from UC Berkeley
919
+ - Specialties: Personal growth, career transitions, work-life balance, goal setting, stress management
920
+ - Personal: Married with two kids, enjoy hiking and meditation in your free time
921
+ - Approach: Warm, empathetic, practical, and solution-focused
922
+
923
+ Your Coaching Style:
924
+ - Respond ONLY to what the user actually tells you - never make assumptions about their problems
925
+ - Start conversations in a welcoming, open manner
926
+ - Ask clarifying questions to understand their situation better
927
+ - Provide practical, actionable advice based on what they share
928
+ - Be encouraging and positive, but also honest and realistic
929
+ - Keep responses concise and focused (2-4 sentences usually)
930
+ - Share brief personal insights when relevant, but keep the focus on the client
931
+
932
+ Important: Never assume clients have problems they haven't mentioned. Let them guide the conversation and share what's on their mind."""
933
+
934
+ formatted_prompt += f"<|system|>\n{system_prompt}<|end|>\n"
935
+
936
+ # Add conversation history if provided
937
+ if conversation_history:
938
+ for msg in conversation_history:
939
+ if msg["role"] == "user":
940
+ formatted_prompt += f"<|user|>\n{msg['content']}<|end|>\n"
941
+ elif msg["role"] == "assistant":
942
+ formatted_prompt += f"<|assistant|>\n{msg['content']}<|end|>\n"
943
+
944
+ # Add current prompt
945
+ formatted_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
946
+
947
+ # DEBUG: Print the full prompt being sent to the model
948
+ logger.info("=" * 80)
949
+ logger.info("FULL PROMPT SENT TO MODEL:")
950
+ logger.info(formatted_prompt)
951
+ logger.info("=" * 80)
952
+
953
+ # Tokenize
954
+ inputs = self.tokenizer(
955
+ formatted_prompt,
956
+ return_tensors="pt",
957
+ truncation=True,
958
+ max_length=self.max_length
959
+ ).to(self.device)
960
+
961
+ # Get input length to extract only new tokens
962
+ input_length = inputs['input_ids'].shape[1]
963
+
964
+ # Get the token ID for <|end|> to use as a stopping token
965
+ end_token_id = self.tokenizer.convert_tokens_to_ids("<|end|>")
966
+
967
+ # Build list of EOS token IDs (stop generation at <|end|> or EOS)
968
+ eos_token_ids = [self.tokenizer.eos_token_id]
969
+ if end_token_id is not None and end_token_id != self.tokenizer.unk_token_id:
970
+ eos_token_ids.append(end_token_id)
971
+
972
+ # Generate
973
+ with torch.no_grad():
974
+ outputs = self.model.generate(
975
+ **inputs,
976
+ max_new_tokens=max_new_tokens,
977
+ temperature=0.7, # Balanced - coherent but still creative
978
+ top_p=0.9, # Standard setting for focused responses
979
+ top_k=50, # Add top-k sampling
980
+ do_sample=True,
981
+ pad_token_id=self.tokenizer.pad_token_id,
982
+ eos_token_id=eos_token_ids, # Stop at <|end|> or EOS
983
+ repetition_penalty=1.15 # Stronger penalty to prevent repetition
984
+ )
985
+
986
+ # Decode ONLY the newly generated tokens (not the input)
987
+ generated_tokens = outputs[0][input_length:]
988
+
989
+ # Decode without skipping special tokens first to find the end marker
990
+ response_with_tokens = self.tokenizer.decode(generated_tokens, skip_special_tokens=False)
991
+
992
+ # Extract only up to the first <|end|> token (model may generate multi-turn conversations)
993
+ if "<|end|>" in response_with_tokens:
994
+ response_text = response_with_tokens.split("<|end|>")[0]
995
+ else:
996
+ response_text = response_with_tokens
997
+
998
+ # Clean up any remaining special tokens
999
+ response_text = response_text.replace("<|assistant|>", "").replace("<|user|>", "").replace("<|system|>", "")
1000
+
1001
+ # Remove any remaining special tokens using the tokenizer
1002
+ response_text = response_text.strip()
1003
+
1004
+ return response_text
1005
+
1006
+ def interactive_chat(self):
1007
+ """Start an interactive chat session."""
1008
+ logger.info("=" * 80)
1009
+ logger.info("LIFE COACH V1 - Interactive Chat Session")
1010
+ logger.info("=" * 80)
1011
+ print("\nWelcome to Life Coach v1!")
1012
+ print("I'm here to help you with life coaching, goal setting, motivation, and personal growth.")
1013
+ print("\nCommands:")
1014
+ print(" - Type your question or concern to get coaching advice")
1015
+ print(" - Type 'quit' or 'exit' to end the session")
1016
+ print(" - Type 'clear' to clear conversation history")
1017
+ print("=" * 80)
1018
+ print()
1019
+
1020
+ conversation_history = []
1021
+
1022
+ while True:
1023
+ try:
1024
+ # Get user input
1025
+ user_input = input("\nπŸ§‘ You: ").strip()
1026
+
1027
+ if not user_input:
1028
+ continue
1029
+
1030
+ # Check for exit commands
1031
+ if user_input.lower() in ['quit', 'exit', 'q']:
1032
+ print("\nπŸ‘‹ Thank you for using Life Coach v1. Take care!")
1033
+ break
1034
+
1035
+ # Check for clear command
1036
+ if user_input.lower() == 'clear':
1037
+ conversation_history = []
1038
+ print("βœ… Conversation history cleared.")
1039
+ continue
1040
+
1041
+ # Generate response with conversation context
1042
+ print("\nπŸ€– Life Coach: ", end="", flush=True)
1043
+ response = self.generate_response(user_input, conversation_history=conversation_history)
1044
+ print(response)
1045
+
1046
+ # Update conversation history
1047
+ conversation_history.append({
1048
+ "role": "user",
1049
+ "content": user_input
1050
+ })
1051
+ conversation_history.append({
1052
+ "role": "assistant",
1053
+ "content": response
1054
+ })
1055
+
1056
+ except KeyboardInterrupt:
1057
+ print("\n\nπŸ‘‹ Session interrupted. Goodbye!")
1058
+ break
1059
+ except Exception as e:
1060
+ logger.error(f"Error during chat: {e}")
1061
+ print(f"\n❌ Error: {e}")
1062
+
1063
+
1064
+ def main():
1065
+ """Main entry point."""
1066
+ parser = argparse.ArgumentParser(
1067
+ description="Life Coach v1 - Phi-4 based life coaching assistant"
1068
+ )
1069
+
1070
+ parser.add_argument(
1071
+ "--mode",
1072
+ type=str,
1073
+ choices=["train", "chat", "both"],
1074
+ default="both",
1075
+ help="Mode: train (fine-tune only), chat (chat only), both (train then chat)"
1076
+ )
1077
+
1078
+ parser.add_argument(
1079
+ "--model-name",
1080
+ type=str,
1081
+ default="microsoft/Phi-4",
1082
+ help="Hugging Face model name"
1083
+ )
1084
+
1085
+ parser.add_argument(
1086
+ "--model-path",
1087
+ type=str,
1088
+ default="/data/life_coach_model",
1089
+ help="Path to save/load fine-tuned model"
1090
+ )
1091
+
1092
+ parser.add_argument(
1093
+ "--train-file",
1094
+ type=str,
1095
+ default="cbt_life_coach_improved_50000.jsonl",
1096
+ help="Path to training data file (JSONL format)"
1097
+ )
1098
+
1099
+ parser.add_argument(
1100
+ "--num-samples",
1101
+ type=int,
1102
+ default=-1,
1103
+ help="Number of training samples (default: -1 for all 100,000 from mixed datasets)"
1104
+ )
1105
+
1106
+ parser.add_argument(
1107
+ "--epochs",
1108
+ type=int,
1109
+ default=3,
1110
+ help="Number of training epochs"
1111
+ )
1112
+
1113
+ parser.add_argument(
1114
+ "--batch-size",
1115
+ type=int,
1116
+ default=4,
1117
+ help="Training batch size (default: 4 for memory safety)"
1118
+ )
1119
+
1120
+ parser.add_argument(
1121
+ "--learning-rate",
1122
+ type=float,
1123
+ default=5e-5,
1124
+ help="Learning rate (default: 5e-5, matching quantum server)"
1125
+ )
1126
+
1127
+ parser.add_argument(
1128
+ "--gradient-accumulation",
1129
+ type=int,
1130
+ default=4,
1131
+ help="Gradient accumulation steps (default: 4, effective batch=16)"
1132
+ )
1133
+
1134
+ parser.add_argument(
1135
+ "--force-retrain",
1136
+ action="store_true",
1137
+ help="Force retraining even if fine-tuned model exists"
1138
+ )
1139
+
1140
+ args = parser.parse_args()
1141
+
1142
+ # Clean up GPU memory before starting
1143
+ cleanup_gpu_memory()
1144
+
1145
+ # Initialize model
1146
+ coach = LifeCoachModel(
1147
+ model_name=args.model_name,
1148
+ model_save_path=args.model_path,
1149
+ train_file=args.train_file
1150
+ )
1151
+
1152
+ # Load tokenizer
1153
+ coach.load_tokenizer()
1154
+
1155
+ # Check if fine-tuned model already exists
1156
+ model_exists = coach.model_save_path.exists() and (coach.model_save_path / "adapter_model.safetensors").exists()
1157
+
1158
+ # Training mode
1159
+ if args.mode in ["train", "both"]:
1160
+ # Check if we should skip training
1161
+ if model_exists and not args.force_retrain:
1162
+ logger.info("=" * 80)
1163
+ logger.info("FINE-TUNED MODEL ALREADY EXISTS")
1164
+ logger.info("=" * 80)
1165
+ logger.info(f"Found existing model at: {coach.model_save_path}")
1166
+ logger.info("Skipping training. Loading existing model...")
1167
+ logger.info("(Use --force-retrain to retrain from scratch)")
1168
+ logger.info("=" * 80)
1169
+
1170
+ # Load the existing fine-tuned model
1171
+ coach.load_model(fine_tuned=True)
1172
+ else:
1173
+ if args.force_retrain and model_exists:
1174
+ logger.info("=" * 80)
1175
+ logger.info("FORCING RETRAINING (--force-retrain flag set)")
1176
+ logger.info("=" * 80)
1177
+
1178
+ # Load base model for training
1179
+ coach.load_model(fine_tuned=False)
1180
+
1181
+ # Fine-tune
1182
+ num_samples = None if args.num_samples == -1 else args.num_samples
1183
+ coach.fine_tune(
1184
+ num_samples=num_samples,
1185
+ epochs=args.epochs,
1186
+ batch_size=args.batch_size,
1187
+ learning_rate=args.learning_rate,
1188
+ gradient_accumulation_steps=args.gradient_accumulation
1189
+ )
1190
+
1191
+ # For "both" mode, reload the fine-tuned model for chat
1192
+ if args.mode == "both":
1193
+ logger.info("Reloading fine-tuned model for chat...")
1194
+ coach.load_model(fine_tuned=True)
1195
+
1196
+ # If only training mode, exit
1197
+ if args.mode == "train":
1198
+ logger.info("Training complete. Use --mode chat to start chatting.")
1199
+ return
1200
+
1201
+ # Chat mode
1202
+ elif args.mode == "chat":
1203
+ if not model_exists:
1204
+ logger.error("=" * 80)
1205
+ logger.error("ERROR: No fine-tuned model found!")
1206
+ logger.error("=" * 80)
1207
+ logger.error(f"Expected location: {coach.model_save_path}")
1208
+ logger.error("Please train the model first using:")
1209
+ logger.error(" python3 life_coach_v1.py --mode train")
1210
+ logger.error("=" * 80)
1211
+ return
1212
+
1213
+ # Load fine-tuned model
1214
+ logger.info(f"Loading fine-tuned model from {coach.model_save_path}")
1215
+ coach.load_model(fine_tuned=True)
1216
+
1217
+ # Start interactive chat
1218
+ coach.interactive_chat()
1219
+
1220
+
1221
+ if __name__ == "__main__":
1222
+ main()