|
|
""" |
|
|
Multi-Head QA Classifier Model for Hugging Face Hub |
|
|
================================================== |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import DistilBertModel, DistilBertPreTrainedModel |
|
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
from typing import Optional, Dict |
|
|
|
|
|
|
|
|
class MultiHeadQAClassifier(DistilBertPreTrainedModel): |
|
|
""" |
|
|
Multi-head QA classifier model for call center transcript evaluation. |
|
|
Each head corresponds to a different QA metric. |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self.heads_config = getattr(config, 'heads_config', { |
|
|
"opening": 1, |
|
|
"listening": 5, |
|
|
"proactiveness": 3, |
|
|
"resolution": 5, |
|
|
"hold": 2, |
|
|
"closing": 1 |
|
|
}) |
|
|
|
|
|
self.bert = DistilBertModel(config) |
|
|
classifier_dropout = getattr(config, 'classifier_dropout', 0.2) |
|
|
self.dropout = nn.Dropout(classifier_dropout) |
|
|
|
|
|
|
|
|
self.heads = nn.ModuleDict({ |
|
|
head: nn.Linear(config.hidden_size, output_dim) |
|
|
for head, output_dim in self.heads_config.items() |
|
|
}) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[Dict[str, torch.Tensor]] = None, |
|
|
**kwargs |
|
|
): |
|
|
outputs = self.bert( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
pooled_output = self.dropout(outputs.last_hidden_state[:, 0]) |
|
|
|
|
|
logits = {} |
|
|
losses = {} |
|
|
loss_total = 0 |
|
|
|
|
|
for head_name, head_layer in self.heads.items(): |
|
|
out = head_layer(pooled_output) |
|
|
logits[head_name] = torch.sigmoid(out) |
|
|
|
|
|
if labels is not None and head_name in labels: |
|
|
loss_fn = nn.BCEWithLogitsLoss() |
|
|
loss = loss_fn(out, labels[head_name]) |
|
|
losses[head_name] = loss.item() |
|
|
loss_total += loss |
|
|
|
|
|
return { |
|
|
"logits": logits, |
|
|
"loss": loss_total if labels is not None else None, |
|
|
"losses": losses if labels is not None else None |
|
|
} |
|
|
|