qa-helpline-distilbert-v1 / modeling_multihead_qa.py
Rogendo's picture
Upload QA Multi-Head DistilBERT model
e18c603 verified
"""
Multi-Head QA Classifier Model for Hugging Face Hub
==================================================
"""
import torch
import torch.nn as nn
from transformers import DistilBertModel, DistilBertPreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from typing import Optional, Dict
class MultiHeadQAClassifier(DistilBertPreTrainedModel):
"""
Multi-head QA classifier model for call center transcript evaluation.
Each head corresponds to a different QA metric.
"""
def __init__(self, config):
super().__init__(config)
# Get heads config from model config
self.heads_config = getattr(config, 'heads_config', {
"opening": 1,
"listening": 5,
"proactiveness": 3,
"resolution": 5,
"hold": 2,
"closing": 1
})
self.bert = DistilBertModel(config)
classifier_dropout = getattr(config, 'classifier_dropout', 0.2)
self.dropout = nn.Dropout(classifier_dropout)
# Multiple heads, one per QA metric
self.heads = nn.ModuleDict({
head: nn.Linear(config.hidden_size, output_dim)
for head, output_dim in self.heads_config.items()
})
# Initialize weights
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[Dict[str, torch.Tensor]] = None,
**kwargs
):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs
)
pooled_output = self.dropout(outputs.last_hidden_state[:, 0]) # [CLS]
logits = {}
losses = {}
loss_total = 0
for head_name, head_layer in self.heads.items():
out = head_layer(pooled_output)
logits[head_name] = torch.sigmoid(out) # probabilities
if labels is not None and head_name in labels:
loss_fn = nn.BCEWithLogitsLoss()
loss = loss_fn(out, labels[head_name])
losses[head_name] = loss.item()
loss_total += loss
return {
"logits": logits,
"loss": loss_total if labels is not None else None,
"losses": losses if labels is not None else None
}