|
|
|
|
|
""" |
|
|
Multi-Head QA Metrics Inference Script |
|
|
===================================== |
|
|
|
|
|
This script loads a trained multi-head QA classification model and provides |
|
|
inference capabilities for evaluating call center transcripts against various |
|
|
QA metrics including opening, listening, proactiveness, resolution, hold, and closing. |
|
|
|
|
|
Usage: |
|
|
python inference.py --model_path "path/to/model" --text "transcript text" |
|
|
|
|
|
Or use the interactive mode: |
|
|
python inference.py --model_path "path/to/model" --interactive |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import argparse |
|
|
import json |
|
|
from typing import Dict, List, Optional |
|
|
from transformers import DistilBertTokenizer, DistilBertModel, AutoConfig, DistilBertPreTrainedModel |
|
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
|
|
|
|
|
|
|
|
QA_HEADS_CONFIG = { |
|
|
"opening": 1, |
|
|
"listening": 5, |
|
|
"proactiveness": 3, |
|
|
"resolution": 5, |
|
|
"hold": 2, |
|
|
"closing": 1 |
|
|
} |
|
|
|
|
|
|
|
|
HEAD_SUBMETRIC_LABELS = { |
|
|
"opening": [ |
|
|
"Use of call opening phrase" |
|
|
], |
|
|
"listening": [ |
|
|
"Caller was not interrupted", |
|
|
"Empathizes with the caller", |
|
|
"Paraphrases or rephrases the issue", |
|
|
"Uses 'please' and 'thank you'", |
|
|
"Does not hesitate or sound unsure" |
|
|
], |
|
|
"proactiveness": [ |
|
|
"Willing to solve extra issues", |
|
|
"Confirms satisfaction with action points", |
|
|
"Follows up on case updates" |
|
|
], |
|
|
"resolution": [ |
|
|
"Gives accurate information", |
|
|
"Correct language use", |
|
|
"Consults if unsure", |
|
|
"Follows correct steps", |
|
|
"Explains solution process clearly" |
|
|
], |
|
|
"hold": [ |
|
|
"Explains before placing on hold", |
|
|
|
|
|
"Thanks caller for holding" |
|
|
], |
|
|
"closing": [ |
|
|
"Proper call closing phrase used" |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
class MultiHeadQAClassifier(DistilBertPreTrainedModel): |
|
|
""" |
|
|
Multi-head QA classifier model for call center transcript evaluation. |
|
|
Each head corresponds to a different QA metric. |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self.heads_config = getattr(config, 'heads_config', { |
|
|
"opening": 1, |
|
|
"listening": 5, |
|
|
"proactiveness": 3, |
|
|
"resolution": 5, |
|
|
"hold": 2, |
|
|
"closing": 1 |
|
|
}) |
|
|
|
|
|
self.bert = DistilBertModel(config) |
|
|
classifier_dropout = getattr(config, 'classifier_dropout', 0.2) |
|
|
self.dropout = nn.Dropout(classifier_dropout) |
|
|
|
|
|
|
|
|
self.heads = nn.ModuleDict({ |
|
|
head: nn.Linear(config.hidden_size, output_dim) |
|
|
for head, output_dim in self.heads_config.items() |
|
|
}) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[Dict[str, torch.Tensor]] = None, |
|
|
**kwargs |
|
|
): |
|
|
outputs = self.bert( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
pooled_output = self.dropout(outputs.last_hidden_state[:, 0]) |
|
|
|
|
|
logits = {} |
|
|
losses = {} |
|
|
loss_total = 0 |
|
|
|
|
|
for head_name, head_layer in self.heads.items(): |
|
|
out = head_layer(pooled_output) |
|
|
logits[head_name] = torch.sigmoid(out) |
|
|
|
|
|
if labels is not None and head_name in labels: |
|
|
loss_fn = nn.BCEWithLogitsLoss() |
|
|
loss = loss_fn(out, labels[head_name]) |
|
|
losses[head_name] = loss.item() |
|
|
loss_total += loss |
|
|
|
|
|
return { |
|
|
"logits": logits, |
|
|
"loss": loss_total if labels is not None else None, |
|
|
"losses": losses if labels is not None else None |
|
|
} |
|
|
|
|
|
|
|
|
class QAMetricsInference: |
|
|
""" |
|
|
Inference class for QA metrics prediction on call center transcripts. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_path: str, device: Optional[str] = None): |
|
|
""" |
|
|
Initialize the inference engine. |
|
|
|
|
|
Args: |
|
|
model_path: Path to the saved model directory |
|
|
device: Device to run inference on ('cpu', 'cuda', or None for auto-detect) |
|
|
""" |
|
|
self.model_path = model_path |
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self.max_length = 512 |
|
|
|
|
|
|
|
|
self._load_model() |
|
|
|
|
|
def _load_model(self): |
|
|
"""Load the trained model and tokenizer.""" |
|
|
print(f"Loading model from: {self.model_path}") |
|
|
|
|
|
|
|
|
try: |
|
|
self.tokenizer = DistilBertTokenizer.from_pretrained(self.model_path) |
|
|
print("✓ Tokenizer loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"✗ Error loading tokenizer: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
try: |
|
|
if os.path.isdir(self.model_path): |
|
|
|
|
|
config = AutoConfig.from_pretrained(self.model_path) |
|
|
self.model = MultiHeadQAClassifier(config) |
|
|
model_state_path = os.path.join(self.model_path, "pytorch_model.bin") |
|
|
|
|
|
if not os.path.exists(model_state_path): |
|
|
raise FileNotFoundError(f"Model file not found: {model_state_path}") |
|
|
|
|
|
state_dict = torch.load(model_state_path, map_location=self.device) |
|
|
self.model.load_state_dict(state_dict) |
|
|
else: |
|
|
|
|
|
self.model = MultiHeadQAClassifier.from_pretrained(self.model_path) |
|
|
|
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
print(f"✓ Model loaded successfully on {self.device}") |
|
|
except Exception as e: |
|
|
print(f"✗ Error loading model: {e}") |
|
|
raise |
|
|
|
|
|
def predict(self, text: str, threshold: float = 0.5, return_raw: bool = False) -> Dict: |
|
|
""" |
|
|
Predict QA metrics for a given transcript. |
|
|
|
|
|
Args: |
|
|
text: Input transcript text |
|
|
threshold: Threshold for binary classification (default: 0.5) |
|
|
return_raw: If True, return raw probabilities along with predictions |
|
|
|
|
|
Returns: |
|
|
Dictionary containing predictions for each QA head |
|
|
""" |
|
|
|
|
|
encoding = self.tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=self.max_length |
|
|
) |
|
|
|
|
|
input_ids = encoding["input_ids"].to(self.device) |
|
|
attention_mask = encoding["attention_mask"].to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
logits = outputs["logits"] |
|
|
|
|
|
|
|
|
results = {} |
|
|
for head, probs in logits.items(): |
|
|
probs_np = probs.cpu().numpy()[0] |
|
|
preds = (probs_np > threshold).astype(int) |
|
|
submetrics = HEAD_SUBMETRIC_LABELS.get(head, [f"Submetric {i+1}" for i in range(len(probs_np))]) |
|
|
|
|
|
head_results = [] |
|
|
for i, (label, prob, pred) in enumerate(zip(submetrics, probs_np, preds)): |
|
|
result_item = { |
|
|
"submetric": label, |
|
|
"prediction": bool(pred), |
|
|
"score": "✓" if pred else "✗" |
|
|
} |
|
|
if return_raw: |
|
|
result_item["probability"] = float(prob) |
|
|
|
|
|
head_results.append(result_item) |
|
|
|
|
|
results[head] = head_results |
|
|
|
|
|
return results |
|
|
|
|
|
def predict_and_display(self, text: str, threshold: float = 0.5): |
|
|
""" |
|
|
Predict and display results in a formatted way. |
|
|
|
|
|
Args: |
|
|
text: Input transcript text |
|
|
threshold: Threshold for binary classification |
|
|
""" |
|
|
print(f"\n📞 Transcript Analysis") |
|
|
print("=" * 60) |
|
|
print(f"Text: {text[:200]}{'...' if len(text) > 200 else ''}") |
|
|
print("=" * 60) |
|
|
|
|
|
results = self.predict(text, threshold, return_raw=True) |
|
|
|
|
|
for head, head_results in results.items(): |
|
|
print(f"\n🔹 {head.upper()}:") |
|
|
for item in head_results: |
|
|
prob = item["probability"] |
|
|
print(f" ➤ {item['submetric']}: P={prob:.3f} → {item['score']}") |
|
|
|
|
|
def batch_predict(self, texts: List[str], threshold: float = 0.5) -> List[Dict]: |
|
|
""" |
|
|
Predict QA metrics for multiple transcripts. |
|
|
|
|
|
Args: |
|
|
texts: List of transcript texts |
|
|
threshold: Threshold for binary classification |
|
|
|
|
|
Returns: |
|
|
List of prediction dictionaries |
|
|
""" |
|
|
results = [] |
|
|
for text in texts: |
|
|
result = self.predict(text, threshold) |
|
|
results.append(result) |
|
|
return results |
|
|
|
|
|
def export_predictions(self, texts: List[str], output_path: str, threshold: float = 0.5): |
|
|
""" |
|
|
Export predictions to a JSON file. |
|
|
|
|
|
Args: |
|
|
texts: List of transcript texts |
|
|
output_path: Path to save the results |
|
|
threshold: Threshold for binary classification |
|
|
""" |
|
|
results = [] |
|
|
for i, text in enumerate(texts): |
|
|
prediction = self.predict(text, threshold, return_raw=True) |
|
|
results.append({ |
|
|
"text_id": i, |
|
|
"text": text, |
|
|
"predictions": prediction |
|
|
}) |
|
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(results, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
print(f"✓ Predictions exported to: {output_path}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function for command-line interface.""" |
|
|
parser = argparse.ArgumentParser(description="QA Metrics Inference Script") |
|
|
parser.add_argument("--model_path", required=True, help="Path to the trained model directory") |
|
|
parser.add_argument("--text", help="Text to analyze") |
|
|
parser.add_argument("--input_file", help="Path to text file containing transcripts (one per line)") |
|
|
parser.add_argument("--output_file", help="Path to save predictions (JSON format)") |
|
|
parser.add_argument("--threshold", type=float, default=0.5, help="Classification threshold (default: 0.5)") |
|
|
parser.add_argument("--interactive", action="store_true", help="Run in interactive mode") |
|
|
parser.add_argument("--device", help="Device to use (cpu/cuda)") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
try: |
|
|
inference_engine = QAMetricsInference(args.model_path, args.device) |
|
|
except Exception as e: |
|
|
print(f"Failed to initialize inference engine: {e}") |
|
|
return |
|
|
|
|
|
|
|
|
if args.interactive: |
|
|
print("\n🤖 QA Metrics Interactive Analysis") |
|
|
print("Type 'quit' to exit, 'help' for commands") |
|
|
print("-" * 50) |
|
|
|
|
|
while True: |
|
|
try: |
|
|
user_input = input("\nEnter transcript text: ").strip() |
|
|
|
|
|
if user_input.lower() == 'quit': |
|
|
break |
|
|
elif user_input.lower() == 'help': |
|
|
print("\nCommands:") |
|
|
print(" - Enter transcript text to analyze") |
|
|
print(" - 'quit' to exit") |
|
|
print(" - 'help' to show this message") |
|
|
continue |
|
|
elif not user_input: |
|
|
print("Please enter some text to analyze.") |
|
|
continue |
|
|
|
|
|
inference_engine.predict_and_display(user_input, args.threshold) |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\n\nGoodbye! 👋") |
|
|
break |
|
|
except Exception as e: |
|
|
print(f"Error during analysis: {e}") |
|
|
|
|
|
|
|
|
elif args.text: |
|
|
inference_engine.predict_and_display(args.text, args.threshold) |
|
|
|
|
|
|
|
|
elif args.input_file: |
|
|
try: |
|
|
with open(args.input_file, 'r', encoding='utf-8') as f: |
|
|
texts = [line.strip() for line in f if line.strip()] |
|
|
|
|
|
print(f"Processing {len(texts)} transcripts...") |
|
|
|
|
|
if args.output_file: |
|
|
inference_engine.export_predictions(texts, args.output_file, args.threshold) |
|
|
else: |
|
|
results = inference_engine.batch_predict(texts, args.threshold) |
|
|
for i, result in enumerate(results): |
|
|
print(f"\n--- Transcript {i+1} ---") |
|
|
print(json.dumps(result, indent=2)) |
|
|
|
|
|
except FileNotFoundError: |
|
|
print(f"Input file not found: {args.input_file}") |
|
|
except Exception as e: |
|
|
print(f"Error processing file: {e}") |
|
|
|
|
|
else: |
|
|
print("Please provide either --text, --input_file, or use --interactive mode") |
|
|
print("Use --help for more information") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |