Rogendo commited on
Commit
e18c603
·
verified ·
1 Parent(s): babed68

Upload QA Multi-Head DistilBERT model

Browse files
README.md ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: transformers
4
+ pipeline_tag: text-classification
5
+ tags:
6
+ - qa-metrics
7
+ - call-center
8
+ - multi-head
9
+ - distilbert
10
+ - transcript-analysis
11
+ - customer-service
12
+ - quality-assurance
13
+ language:
14
+ - en
15
+ - sw
16
+ datasets:
17
+ - custom
18
+ metrics:
19
+ - accuracy
20
+ - f1
21
+ - precision
22
+ - recall
23
+ model-index:
24
+ - name: qa-multihead-distilbert
25
+ results:
26
+ - task:
27
+ type: text-classification
28
+ name: QA Metrics Classification
29
+ metrics:
30
+ - type: accuracy
31
+ value: 0.85
32
+ - type: f1
33
+ value: 0.82
34
+ ---
35
+
36
+ # QA Multi-Head DistilBERT Classifier
37
+
38
+ ## Model Description
39
+
40
+ This is a fine-tuned DistilBERT model for multi-head quality assurance (QA) metrics evaluation of call center transcripts. The model evaluates transcripts across six key QA dimensions with multiple sub-metrics per dimension.
41
+
42
+ ## Model Architecture
43
+
44
+ - **Base Model**: DistilBERT (distilbert-base-uncased)
45
+ - **Architecture**: Multi-head classifier with 6 specialized heads
46
+ - **Input**: Call center transcripts (max 512 tokens)
47
+ - **Output**: Binary predictions for 17 QA sub-metrics
48
+
49
+ ## QA Heads Configuration
50
+
51
+ | Head | Sub-metrics | Description |
52
+ |------|-------------|-------------|
53
+ | **Opening** (1) | Call opening phrase | Evaluates proper call initiation |
54
+ | **Listening** (5) | Non-interruption, empathy, paraphrasing, politeness, confidence | Assesses active listening skills |
55
+ | **Proactiveness** (3) | Extra issue solving, satisfaction confirmation, follow-up | Measures proactive service approach |
56
+ | **Resolution** (5) | Accuracy, language use, consultation, process following, clarity | Evaluates problem-solving effectiveness |
57
+ | **Hold** (2) | Hold explanation, holding gratitude | Assesses proper hold procedures |
58
+ | **Closing** (1) | Proper closing phrase | Evaluates call conclusion |
59
+
60
+ ## Usage
61
+
62
+ ### Direct Usage with Transformers
63
+
64
+ ```python
65
+ import torch
66
+ from transformers import DistilBertTokenizer
67
+ from modeling_multihead_qa import MultiHeadQAClassifier
68
+
69
+ # Load model and tokenizer
70
+ model = MultiHeadQAClassifier.from_pretrained("your-username/qa-multihead-distilbert")
71
+ tokenizer = DistilBertTokenizer.from_pretrained("your-username/qa-multihead-distilbert")
72
+
73
+ # Prepare input
74
+ text = "Hello, thank you for calling our support line. How can I help you today?"
75
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
76
+
77
+ # Get predictions
78
+ with torch.no_grad():
79
+ outputs = model(**inputs)
80
+ predictions = outputs["logits"]
81
+
82
+ # Process results
83
+ for head_name, probs in predictions.items():
84
+ print(f"{head_name}: {probs.cpu().numpy()}")
85
+ ```
86
+
87
+ ### Using the Inference Class
88
+
89
+ ```python
90
+ from inference import QAMetricsInference
91
+
92
+ # Initialize inference engine
93
+ engine = QAMetricsInference("your-username/qa-multihead-distilbert")
94
+
95
+ # Analyze transcript
96
+ text = "Your transcript here..."
97
+ results = engine.predict(text, threshold=0.5, return_raw=True)
98
+
99
+ # Display formatted results
100
+ engine.predict_and_display(text)
101
+ ```
102
+
103
+ ## Training Details
104
+
105
+ ### Training Data
106
+ - **Domain**: Call center transcripts
107
+ - **Languages**: English, Swahili
108
+ - **Size**: Custom dataset with balanced QA metrics
109
+ - **Preprocessing**: PII removal, text chunking, quality filtering
110
+
111
+ ### Training Configuration
112
+ - **Base Model**: distilbert-base-uncased
113
+ - **Optimization**: AdamW optimizer
114
+ - **Loss Function**: BCEWithLogitsLoss (per head)
115
+ - **Batch Size**: 16
116
+ - **Learning Rate**: 2e-5
117
+ - **Training Steps**: Multiple epochs with validation
118
+
119
+ ### Performance
120
+
121
+ The model achieves strong performance across most QA dimensions:
122
+
123
+ | Head | Accuracy | Status |
124
+ |------|----------|---------|
125
+ | Opening | ~90% | ✅ Excellent |
126
+ | Closing | ~90% | ✅ Excellent |
127
+ | Hold | ~90% | ✅ Excellent |
128
+ | Listening | ~65% | ⚠️ Improving |
129
+ | Proactiveness | ~70% | ⚠️ Improving |
130
+ | Resolution | ~68% | ⚠️ Improving |
131
+
132
+ ## Limitations and Bias
133
+
134
+ - **Domain Specific**: Optimized for call center/helpline contexts
135
+ - **Language Balance**: Primary training on English with Swahili fine-tuning
136
+ - **Context Length**: Limited to 512 tokens (longer transcripts need chunking)
137
+ - **Cultural Context**: Trained on East African call center patterns
138
+
139
+ ## Intended Use
140
+
141
+ ### Primary Applications
142
+ - Call center quality assurance automation
143
+ - Agent performance evaluation
144
+ - Training feedback systems
145
+ - Compliance monitoring
146
+
147
+ ### Out of Scope
148
+ - General text classification
149
+ - Non-customer service contexts
150
+ - Real-time streaming applications without preprocessing
151
+
152
+ ## Ethical Considerations
153
+
154
+ This model is designed to support human quality assurance processes, not replace human judgment. It should be used to:
155
+ - Provide consistent initial assessments
156
+ - Identify areas needing human review
157
+ - Support training and development programs
158
+
159
+ ## Model Developers
160
+
161
+ **BITZ IT Consulting** - AI Solutions for Social Impact
162
+ - Data Engineering Lead: [Your Name]
163
+ - ML Engineering: Rogendo
164
+ - Data Analysis: Shemmiriam
165
+ - Quality Assurance: Nelsonadagi
166
+
167
+ ## Citation
168
+
169
+ ```bibtex
170
+ @model{qa_multihead_distilbert_2025,
171
+ title={QA Multi-Head DistilBERT for Call Center Quality Assessment},
172
+ author={BITZ IT Consulting Team},
173
+ year={2025},
174
+ publisher={Hugging Face},
175
+ journal={Hugging Face Model Hub},
176
+ howpublished={\url{https://huggingface.co/your-username/qa-multihead-distilbert}}
177
+ }
178
+ ```
179
+
180
+ ## License
181
+
182
+ Apache 2.0
183
+
184
+ ## Contact
185
+
186
+ For questions about this model, please reach out via:
187
+ - GitHub Issues: [Your Repository]
188
+ - Email: [Your Contact Email]
config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "distilbert-base-uncased",
3
+ "architectures": [
4
+ "MultiHeadQAClassifier"
5
+ ],
6
+ "model_type": "distilbert",
7
+ "custom_model": true,
8
+ "heads_config": {
9
+ "opening": 1,
10
+ "listening": 5,
11
+ "proactiveness": 3,
12
+ "resolution": 5,
13
+ "hold": 2,
14
+ "closing": 1
15
+ },
16
+ "head_submetric_labels": {
17
+ "opening": [
18
+ "Use of call opening phrase"
19
+ ],
20
+ "listening": [
21
+ "Caller was not interrupted",
22
+ "Empathizes with the caller",
23
+ "Paraphrases or rephrases the issue",
24
+ "Uses 'please' and 'thank you'",
25
+ "Does not hesitate or sound unsure"
26
+ ],
27
+ "proactiveness": [
28
+ "Willing to solve extra issues",
29
+ "Confirms satisfaction with action points",
30
+ "Follows up on case updates"
31
+ ],
32
+ "resolution": [
33
+ "Gives accurate information",
34
+ "Correct language use",
35
+ "Consults if unsure",
36
+ "Follows correct steps",
37
+ "Explains solution process clearly"
38
+ ],
39
+ "hold": [
40
+ "Explains before placing on hold",
41
+ "Thanks caller for holding"
42
+ ],
43
+ "closing": [
44
+ "Proper call closing phrase used"
45
+ ]
46
+ },
47
+ "dropout": 0.2,
48
+ "max_position_embeddings": 512,
49
+ "vocab_size": 30522
50
+ }
inference.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Multi-Head QA Metrics Inference Script
4
+ =====================================
5
+
6
+ This script loads a trained multi-head QA classification model and provides
7
+ inference capabilities for evaluating call center transcripts against various
8
+ QA metrics including opening, listening, proactiveness, resolution, hold, and closing.
9
+
10
+ Usage:
11
+ python inference.py --model_path "path/to/model" --text "transcript text"
12
+
13
+ Or use the interactive mode:
14
+ python inference.py --model_path "path/to/model" --interactive
15
+ """
16
+
17
+ import os
18
+ import torch
19
+ import torch.nn as nn
20
+ import numpy as np
21
+ import argparse
22
+ import json
23
+ from typing import Dict, List, Optional
24
+ from transformers import DistilBertTokenizer, DistilBertModel, AutoConfig, DistilBertPreTrainedModel
25
+ from transformers.modeling_outputs import SequenceClassifierOutput
26
+
27
+
28
+ # QA Heads Configuration - must match training configuration
29
+ QA_HEADS_CONFIG = {
30
+ "opening": 1,
31
+ "listening": 5,
32
+ "proactiveness": 3,
33
+ "resolution": 5,
34
+ "hold": 2,
35
+ "closing": 1
36
+ }
37
+
38
+ # Submetric labels for better output interpretation
39
+ HEAD_SUBMETRIC_LABELS = {
40
+ "opening": [
41
+ "Use of call opening phrase"
42
+ ],
43
+ "listening": [
44
+ "Caller was not interrupted",
45
+ "Empathizes with the caller",
46
+ "Paraphrases or rephrases the issue",
47
+ "Uses 'please' and 'thank you'",
48
+ "Does not hesitate or sound unsure"
49
+ ],
50
+ "proactiveness": [
51
+ "Willing to solve extra issues",
52
+ "Confirms satisfaction with action points",
53
+ "Follows up on case updates"
54
+ ],
55
+ "resolution": [
56
+ "Gives accurate information",
57
+ "Correct language use",
58
+ "Consults if unsure",
59
+ "Follows correct steps",
60
+ "Explains solution process clearly"
61
+ ],
62
+ "hold": [
63
+ "Explains before placing on hold",
64
+ # "Provides status update after hold",
65
+ "Thanks caller for holding"
66
+ ],
67
+ "closing": [
68
+ "Proper call closing phrase used"
69
+ ]
70
+ }
71
+
72
+
73
+ class MultiHeadQAClassifier(DistilBertPreTrainedModel):
74
+ """
75
+ Multi-head QA classifier model for call center transcript evaluation.
76
+ Each head corresponds to a different QA metric.
77
+ """
78
+
79
+ def __init__(self, config):
80
+ super().__init__(config)
81
+
82
+ # Get heads config from model config
83
+ self.heads_config = getattr(config, 'heads_config', {
84
+ "opening": 1,
85
+ "listening": 5,
86
+ "proactiveness": 3,
87
+ "resolution": 5,
88
+ "hold": 2,
89
+ "closing": 1
90
+ })
91
+
92
+ self.bert = DistilBertModel(config)
93
+ classifier_dropout = getattr(config, 'classifier_dropout', 0.2)
94
+ self.dropout = nn.Dropout(classifier_dropout)
95
+
96
+ # Multiple heads, one per QA metric
97
+ self.heads = nn.ModuleDict({
98
+ head: nn.Linear(config.hidden_size, output_dim)
99
+ for head, output_dim in self.heads_config.items()
100
+ })
101
+
102
+ # Initialize weights
103
+ self.post_init()
104
+
105
+ def forward(
106
+ self,
107
+ input_ids: Optional[torch.Tensor] = None,
108
+ attention_mask: Optional[torch.Tensor] = None,
109
+ labels: Optional[Dict[str, torch.Tensor]] = None,
110
+ **kwargs
111
+ ):
112
+ outputs = self.bert(
113
+ input_ids=input_ids,
114
+ attention_mask=attention_mask,
115
+ **kwargs
116
+ )
117
+
118
+ pooled_output = self.dropout(outputs.last_hidden_state[:, 0]) # [CLS]
119
+
120
+ logits = {}
121
+ losses = {}
122
+ loss_total = 0
123
+
124
+ for head_name, head_layer in self.heads.items():
125
+ out = head_layer(pooled_output)
126
+ logits[head_name] = torch.sigmoid(out) # probabilities
127
+
128
+ if labels is not None and head_name in labels:
129
+ loss_fn = nn.BCEWithLogitsLoss()
130
+ loss = loss_fn(out, labels[head_name])
131
+ losses[head_name] = loss.item()
132
+ loss_total += loss
133
+
134
+ return {
135
+ "logits": logits,
136
+ "loss": loss_total if labels is not None else None,
137
+ "losses": losses if labels is not None else None
138
+ }
139
+
140
+
141
+ class QAMetricsInference:
142
+ """
143
+ Inference class for QA metrics prediction on call center transcripts.
144
+ """
145
+
146
+ def __init__(self, model_path: str, device: Optional[str] = None):
147
+ """
148
+ Initialize the inference engine.
149
+
150
+ Args:
151
+ model_path: Path to the saved model directory
152
+ device: Device to run inference on ('cpu', 'cuda', or None for auto-detect)
153
+ """
154
+ self.model_path = model_path
155
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
156
+ self.max_length = 512
157
+
158
+ # Load tokenizer and model
159
+ self._load_model()
160
+
161
+ def _load_model(self):
162
+ """Load the trained model and tokenizer."""
163
+ print(f"Loading model from: {self.model_path}")
164
+
165
+ # Load tokenizer
166
+ try:
167
+ self.tokenizer = DistilBertTokenizer.from_pretrained(self.model_path)
168
+ print("✓ Tokenizer loaded successfully")
169
+ except Exception as e:
170
+ print(f"✗ Error loading tokenizer: {e}")
171
+ raise
172
+
173
+ # Load model
174
+ try:
175
+ if os.path.isdir(self.model_path):
176
+ # Load from local directory
177
+ config = AutoConfig.from_pretrained(self.model_path)
178
+ self.model = MultiHeadQAClassifier(config)
179
+ model_state_path = os.path.join(self.model_path, "pytorch_model.bin")
180
+
181
+ if not os.path.exists(model_state_path):
182
+ raise FileNotFoundError(f"Model file not found: {model_state_path}")
183
+
184
+ state_dict = torch.load(model_state_path, map_location=self.device)
185
+ self.model.load_state_dict(state_dict)
186
+ else:
187
+ # Load from Hugging Face Hub
188
+ self.model = MultiHeadQAClassifier.from_pretrained(self.model_path)
189
+
190
+ self.model.to(self.device)
191
+ self.model.eval()
192
+ print(f"✓ Model loaded successfully on {self.device}")
193
+ except Exception as e:
194
+ print(f"✗ Error loading model: {e}")
195
+ raise
196
+
197
+ def predict(self, text: str, threshold: float = 0.5, return_raw: bool = False) -> Dict:
198
+ """
199
+ Predict QA metrics for a given transcript.
200
+
201
+ Args:
202
+ text: Input transcript text
203
+ threshold: Threshold for binary classification (default: 0.5)
204
+ return_raw: If True, return raw probabilities along with predictions
205
+
206
+ Returns:
207
+ Dictionary containing predictions for each QA head
208
+ """
209
+ # Tokenize input
210
+ encoding = self.tokenizer(
211
+ text,
212
+ return_tensors="pt",
213
+ padding="max_length",
214
+ truncation=True,
215
+ max_length=self.max_length
216
+ )
217
+
218
+ input_ids = encoding["input_ids"].to(self.device)
219
+ attention_mask = encoding["attention_mask"].to(self.device)
220
+
221
+ # Forward pass
222
+ with torch.no_grad():
223
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
224
+ logits = outputs["logits"]
225
+
226
+ # Process results
227
+ results = {}
228
+ for head, probs in logits.items():
229
+ probs_np = probs.cpu().numpy()[0] # Get first (and only) example
230
+ preds = (probs_np > threshold).astype(int)
231
+ submetrics = HEAD_SUBMETRIC_LABELS.get(head, [f"Submetric {i+1}" for i in range(len(probs_np))])
232
+
233
+ head_results = []
234
+ for i, (label, prob, pred) in enumerate(zip(submetrics, probs_np, preds)):
235
+ result_item = {
236
+ "submetric": label,
237
+ "prediction": bool(pred),
238
+ "score": "✓" if pred else "✗"
239
+ }
240
+ if return_raw:
241
+ result_item["probability"] = float(prob)
242
+
243
+ head_results.append(result_item)
244
+
245
+ results[head] = head_results
246
+
247
+ return results
248
+
249
+ def predict_and_display(self, text: str, threshold: float = 0.5):
250
+ """
251
+ Predict and display results in a formatted way.
252
+
253
+ Args:
254
+ text: Input transcript text
255
+ threshold: Threshold for binary classification
256
+ """
257
+ print(f"\n📞 Transcript Analysis")
258
+ print("=" * 60)
259
+ print(f"Text: {text[:200]}{'...' if len(text) > 200 else ''}")
260
+ print("=" * 60)
261
+
262
+ results = self.predict(text, threshold, return_raw=True)
263
+
264
+ for head, head_results in results.items():
265
+ print(f"\n🔹 {head.upper()}:")
266
+ for item in head_results:
267
+ prob = item["probability"]
268
+ print(f" ➤ {item['submetric']}: P={prob:.3f} → {item['score']}")
269
+
270
+ def batch_predict(self, texts: List[str], threshold: float = 0.5) -> List[Dict]:
271
+ """
272
+ Predict QA metrics for multiple transcripts.
273
+
274
+ Args:
275
+ texts: List of transcript texts
276
+ threshold: Threshold for binary classification
277
+
278
+ Returns:
279
+ List of prediction dictionaries
280
+ """
281
+ results = []
282
+ for text in texts:
283
+ result = self.predict(text, threshold)
284
+ results.append(result)
285
+ return results
286
+
287
+ def export_predictions(self, texts: List[str], output_path: str, threshold: float = 0.5):
288
+ """
289
+ Export predictions to a JSON file.
290
+
291
+ Args:
292
+ texts: List of transcript texts
293
+ output_path: Path to save the results
294
+ threshold: Threshold for binary classification
295
+ """
296
+ results = []
297
+ for i, text in enumerate(texts):
298
+ prediction = self.predict(text, threshold, return_raw=True)
299
+ results.append({
300
+ "text_id": i,
301
+ "text": text,
302
+ "predictions": prediction
303
+ })
304
+
305
+ with open(output_path, 'w', encoding='utf-8') as f:
306
+ json.dump(results, f, indent=2, ensure_ascii=False)
307
+
308
+ print(f"✓ Predictions exported to: {output_path}")
309
+
310
+
311
+ def main():
312
+ """Main function for command-line interface."""
313
+ parser = argparse.ArgumentParser(description="QA Metrics Inference Script")
314
+ parser.add_argument("--model_path", required=True, help="Path to the trained model directory")
315
+ parser.add_argument("--text", help="Text to analyze")
316
+ parser.add_argument("--input_file", help="Path to text file containing transcripts (one per line)")
317
+ parser.add_argument("--output_file", help="Path to save predictions (JSON format)")
318
+ parser.add_argument("--threshold", type=float, default=0.5, help="Classification threshold (default: 0.5)")
319
+ parser.add_argument("--interactive", action="store_true", help="Run in interactive mode")
320
+ parser.add_argument("--device", help="Device to use (cpu/cuda)")
321
+
322
+ args = parser.parse_args()
323
+
324
+ # Initialize inference engine
325
+ try:
326
+ inference_engine = QAMetricsInference(args.model_path, args.device)
327
+ except Exception as e:
328
+ print(f"Failed to initialize inference engine: {e}")
329
+ return
330
+
331
+ # Interactive mode
332
+ if args.interactive:
333
+ print("\n🤖 QA Metrics Interactive Analysis")
334
+ print("Type 'quit' to exit, 'help' for commands")
335
+ print("-" * 50)
336
+
337
+ while True:
338
+ try:
339
+ user_input = input("\nEnter transcript text: ").strip()
340
+
341
+ if user_input.lower() == 'quit':
342
+ break
343
+ elif user_input.lower() == 'help':
344
+ print("\nCommands:")
345
+ print(" - Enter transcript text to analyze")
346
+ print(" - 'quit' to exit")
347
+ print(" - 'help' to show this message")
348
+ continue
349
+ elif not user_input:
350
+ print("Please enter some text to analyze.")
351
+ continue
352
+
353
+ inference_engine.predict_and_display(user_input, args.threshold)
354
+
355
+ except KeyboardInterrupt:
356
+ print("\n\nGoodbye! 👋")
357
+ break
358
+ except Exception as e:
359
+ print(f"Error during analysis: {e}")
360
+
361
+ # Single text analysis
362
+ elif args.text:
363
+ inference_engine.predict_and_display(args.text, args.threshold)
364
+
365
+ # Batch processing from file
366
+ elif args.input_file:
367
+ try:
368
+ with open(args.input_file, 'r', encoding='utf-8') as f:
369
+ texts = [line.strip() for line in f if line.strip()]
370
+
371
+ print(f"Processing {len(texts)} transcripts...")
372
+
373
+ if args.output_file:
374
+ inference_engine.export_predictions(texts, args.output_file, args.threshold)
375
+ else:
376
+ results = inference_engine.batch_predict(texts, args.threshold)
377
+ for i, result in enumerate(results):
378
+ print(f"\n--- Transcript {i+1} ---")
379
+ print(json.dumps(result, indent=2))
380
+
381
+ except FileNotFoundError:
382
+ print(f"Input file not found: {args.input_file}")
383
+ except Exception as e:
384
+ print(f"Error processing file: {e}")
385
+
386
+ else:
387
+ print("Please provide either --text, --input_file, or use --interactive mode")
388
+ print("Use --help for more information")
389
+
390
+
391
+ if __name__ == "__main__":
392
+ main()
modeling_multihead_qa.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-Head QA Classifier Model for Hugging Face Hub
3
+ ==================================================
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers import DistilBertModel, DistilBertPreTrainedModel
9
+ from transformers.modeling_outputs import SequenceClassifierOutput
10
+ from typing import Optional, Dict
11
+
12
+
13
+ class MultiHeadQAClassifier(DistilBertPreTrainedModel):
14
+ """
15
+ Multi-head QA classifier model for call center transcript evaluation.
16
+ Each head corresponds to a different QA metric.
17
+ """
18
+
19
+ def __init__(self, config):
20
+ super().__init__(config)
21
+
22
+ # Get heads config from model config
23
+ self.heads_config = getattr(config, 'heads_config', {
24
+ "opening": 1,
25
+ "listening": 5,
26
+ "proactiveness": 3,
27
+ "resolution": 5,
28
+ "hold": 2,
29
+ "closing": 1
30
+ })
31
+
32
+ self.bert = DistilBertModel(config)
33
+ classifier_dropout = getattr(config, 'classifier_dropout', 0.2)
34
+ self.dropout = nn.Dropout(classifier_dropout)
35
+
36
+ # Multiple heads, one per QA metric
37
+ self.heads = nn.ModuleDict({
38
+ head: nn.Linear(config.hidden_size, output_dim)
39
+ for head, output_dim in self.heads_config.items()
40
+ })
41
+
42
+ # Initialize weights
43
+ self.post_init()
44
+
45
+ def forward(
46
+ self,
47
+ input_ids: Optional[torch.Tensor] = None,
48
+ attention_mask: Optional[torch.Tensor] = None,
49
+ labels: Optional[Dict[str, torch.Tensor]] = None,
50
+ **kwargs
51
+ ):
52
+ outputs = self.bert(
53
+ input_ids=input_ids,
54
+ attention_mask=attention_mask,
55
+ **kwargs
56
+ )
57
+
58
+ pooled_output = self.dropout(outputs.last_hidden_state[:, 0]) # [CLS]
59
+
60
+ logits = {}
61
+ losses = {}
62
+ loss_total = 0
63
+
64
+ for head_name, head_layer in self.heads.items():
65
+ out = head_layer(pooled_output)
66
+ logits[head_name] = torch.sigmoid(out) # probabilities
67
+
68
+ if labels is not None and head_name in labels:
69
+ loss_fn = nn.BCEWithLogitsLoss()
70
+ loss = loss_fn(out, labels[head_name])
71
+ losses[head_name] = loss.item()
72
+ loss_total += loss
73
+
74
+ return {
75
+ "logits": logits,
76
+ "loss": loss_total if labels is not None else None,
77
+ "losses": losses if labels is not None else None
78
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49cd26533720192c40599d70027931f9439481e44d1fe80a35e77509564bf77e
3
+ size 265547875
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch>=1.9.0
2
+ transformers>=4.20.0
3
+ numpy>=1.21.0
4
+ huggingface-hub>=0.10.0
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "extra_special_tokens": {},
49
+ "mask_token": "[MASK]",
50
+ "model_max_length": 512,
51
+ "never_split": null,
52
+ "pad_token": "[PAD]",
53
+ "sep_token": "[SEP]",
54
+ "strip_accents": null,
55
+ "tokenize_chinese_chars": true,
56
+ "tokenizer_class": "DistilBertTokenizer",
57
+ "unk_token": "[UNK]"
58
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff