feat: add mean_seq_mha pooling option
Browse filesAdd updated MaxPoolBERT paper's MeanSeq+MHA pooling option
- README.md +3 -1
- configuration_modchembert.py +25 -4
- modeling_modchembert.py +11 -6
README.md
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
|
|
|
| 3 |
library_name: transformers
|
| 4 |
tags:
|
| 5 |
- modernbert
|
|
@@ -64,13 +65,14 @@ This base model includes configurable pooling strategies for downstream fine-tun
|
|
| 64 |
- `max_cls`: Max over last k layers of [CLS]
|
| 65 |
- `cls_mha`: MHA with [CLS] as query
|
| 66 |
- `max_seq_mha`: MHA with max pooled sequence as KV and max pooled [CLS] as query
|
|
|
|
| 67 |
- `sum_mean`: Sum over all layers then mean tokens
|
| 68 |
- `sum_sum`: Sum over all layers then sum tokens
|
| 69 |
- `mean_mean`: Mean over all layers then mean tokens
|
| 70 |
- `mean_sum`: Mean over all layers then sum tokens
|
| 71 |
- `max_seq_mean`: Max over last k layers then mean tokens
|
| 72 |
|
| 73 |
-
Note: ModChemBERT's `max_seq_mha`
|
| 74 |
On ChemBERTa-3 benchmarks this variant produced stronger validation metrics and avoided the training instabilities (sporadic zero / NaN losses and gradient norms) seen with `nn.MultiheadAttention`. Training instability with ModernBERT has been reported in the past ([discussion 1](https://huggingface.co/answerdotai/ModernBERT-base/discussions/59) and [discussion 2](https://huggingface.co/answerdotai/ModernBERT-base/discussions/63)).
|
| 75 |
|
| 76 |
## Intended Use
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
+
base_model: Derify/ModChemBERT-IR-BASE
|
| 4 |
library_name: transformers
|
| 5 |
tags:
|
| 6 |
- modernbert
|
|
|
|
| 65 |
- `max_cls`: Max over last k layers of [CLS]
|
| 66 |
- `cls_mha`: MHA with [CLS] as query
|
| 67 |
- `max_seq_mha`: MHA with max pooled sequence as KV and max pooled [CLS] as query
|
| 68 |
+
- `mean_seq_mha`: MHA with mean pooled sequence as KV and mean pooled [CLS] as query
|
| 69 |
- `sum_mean`: Sum over all layers then mean tokens
|
| 70 |
- `sum_sum`: Sum over all layers then sum tokens
|
| 71 |
- `mean_mean`: Mean over all layers then mean tokens
|
| 72 |
- `mean_sum`: Mean over all layers then sum tokens
|
| 73 |
- `max_seq_mean`: Max over last k layers then mean tokens
|
| 74 |
|
| 75 |
+
Note: ModChemBERT's `cls_mha`, `max_seq_mha`, and `mean_seq_mha` differ from MaxPoolBERT [3]. MaxPoolBERT uses PyTorch `nn.MultiheadAttention`, whereas ModChemBERT's `ModChemBertPoolingAttention` adapts ModernBERT's `ModernBertAttention`.
|
| 76 |
On ChemBERTa-3 benchmarks this variant produced stronger validation metrics and avoided the training instabilities (sporadic zero / NaN losses and gradient norms) seen with `nn.MultiheadAttention`. Training instability with ModernBERT has been reported in the past ([discussion 1](https://huggingface.co/answerdotai/ModernBERT-base/discussions/59) and [discussion 2](https://huggingface.co/answerdotai/ModernBERT-base/discussions/63)).
|
| 77 |
|
| 78 |
## Intended Use
|
configuration_modchembert.py
CHANGED
|
@@ -37,14 +37,15 @@ class ModChemBertConfig(ModernBertConfig):
|
|
| 37 |
- "max_cls": Element-wise max pooling over last k hidden states, then take CLS token
|
| 38 |
- "cls_mha": Multi-head attention with CLS token as query and full sequence as keys/values
|
| 39 |
- "max_seq_mha": Max pooling over last k states + multi-head attention with CLS as query
|
|
|
|
| 40 |
- "max_seq_mean": Max pooling over last k hidden states, then mean pooling over sequence
|
| 41 |
Defaults to "sum_mean".
|
| 42 |
classifier_pooling_num_attention_heads (int, optional): Number of attention heads for multi-head attention
|
| 43 |
-
pooling strategies (cls_mha, max_seq_mha). Defaults to 4.
|
| 44 |
classifier_pooling_attention_dropout (float, optional): Dropout probability for multi-head attention
|
| 45 |
-
pooling strategies (cls_mha, max_seq_mha). Defaults to 0.0.
|
| 46 |
-
classifier_pooling_last_k (int, optional): Number of last hidden layers to use for max pooling
|
| 47 |
-
strategies (max_cls, max_seq_mha, max_seq_mean). Defaults to 8.
|
| 48 |
*args: Variable length argument list passed to ModernBertConfig.
|
| 49 |
**kwargs: Arbitrary keyword arguments passed to ModernBertConfig.
|
| 50 |
|
|
@@ -68,6 +69,7 @@ class ModChemBertConfig(ModernBertConfig):
|
|
| 68 |
"max_cls",
|
| 69 |
"cls_mha",
|
| 70 |
"max_seq_mha",
|
|
|
|
| 71 |
"max_seq_mean",
|
| 72 |
] = "max_seq_mha",
|
| 73 |
classifier_pooling_num_attention_heads: int = 4,
|
|
@@ -75,6 +77,25 @@ class ModChemBertConfig(ModernBertConfig):
|
|
| 75 |
classifier_pooling_last_k: int = 8,
|
| 76 |
**kwargs,
|
| 77 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
# Pass classifier_pooling="cls" to circumvent ValueError in ModernBertConfig init
|
| 79 |
super().__init__(*args, classifier_pooling="cls", **kwargs)
|
| 80 |
# Override with custom value
|
|
|
|
| 37 |
- "max_cls": Element-wise max pooling over last k hidden states, then take CLS token
|
| 38 |
- "cls_mha": Multi-head attention with CLS token as query and full sequence as keys/values
|
| 39 |
- "max_seq_mha": Max pooling over last k states + multi-head attention with CLS as query
|
| 40 |
+
- "mean_seq_mha": Mean pooling over last k states + multi-head attention with CLS as query
|
| 41 |
- "max_seq_mean": Max pooling over last k hidden states, then mean pooling over sequence
|
| 42 |
Defaults to "sum_mean".
|
| 43 |
classifier_pooling_num_attention_heads (int, optional): Number of attention heads for multi-head attention
|
| 44 |
+
pooling strategies (cls_mha, max_seq_mha, mean_seq_mha). Defaults to 4.
|
| 45 |
classifier_pooling_attention_dropout (float, optional): Dropout probability for multi-head attention
|
| 46 |
+
pooling strategies (cls_mha, max_seq_mha, mean_seq_mha). Defaults to 0.0.
|
| 47 |
+
classifier_pooling_last_k (int, optional): Number of last hidden layers to use for max/mean pooling
|
| 48 |
+
strategies (max_cls, max_seq_mha, mean_seq_mha, max_seq_mean). Defaults to 8.
|
| 49 |
*args: Variable length argument list passed to ModernBertConfig.
|
| 50 |
**kwargs: Arbitrary keyword arguments passed to ModernBertConfig.
|
| 51 |
|
|
|
|
| 69 |
"max_cls",
|
| 70 |
"cls_mha",
|
| 71 |
"max_seq_mha",
|
| 72 |
+
"mean_seq_mha",
|
| 73 |
"max_seq_mean",
|
| 74 |
] = "max_seq_mha",
|
| 75 |
classifier_pooling_num_attention_heads: int = 4,
|
|
|
|
| 77 |
classifier_pooling_last_k: int = 8,
|
| 78 |
**kwargs,
|
| 79 |
):
|
| 80 |
+
valid_classifier_pooling_options = [
|
| 81 |
+
"cls",
|
| 82 |
+
"mean",
|
| 83 |
+
"sum_mean",
|
| 84 |
+
"sum_sum",
|
| 85 |
+
"mean_mean",
|
| 86 |
+
"mean_sum",
|
| 87 |
+
"max_cls",
|
| 88 |
+
"cls_mha",
|
| 89 |
+
"max_seq_mha",
|
| 90 |
+
"mean_seq_mha",
|
| 91 |
+
"max_seq_mean",
|
| 92 |
+
]
|
| 93 |
+
if classifier_pooling not in valid_classifier_pooling_options:
|
| 94 |
+
raise ValueError(
|
| 95 |
+
f"Invalid value for `classifier_pooling`, should be one of {valid_classifier_pooling_options}, "
|
| 96 |
+
f"but is {classifier_pooling}."
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
# Pass classifier_pooling="cls" to circumvent ValueError in ModernBertConfig init
|
| 100 |
super().__init__(*args, classifier_pooling="cls", **kwargs)
|
| 101 |
# Override with custom value
|
modeling_modchembert.py
CHANGED
|
@@ -19,9 +19,9 @@
|
|
| 19 |
# Modifications include:
|
| 20 |
# - Additional classifier_pooling options for ModChemBertForSequenceClassification
|
| 21 |
# - sum_mean, sum_sum, mean_sum, mean_mean: from ChemLM (utilizes all hidden states)
|
| 22 |
-
# - max_cls, cls_mha, max_seq_mha: from MaxPoolBERT (utilizes last k hidden states)
|
| 23 |
# - max_seq_mean: a merge between sum_mean and max_cls (utilizes last k hidden states)
|
| 24 |
-
# - Addition of ModChemBertPoolingAttention for cls_mha and
|
| 25 |
|
| 26 |
import copy
|
| 27 |
import math
|
|
@@ -499,7 +499,7 @@ class ModChemBertForSequenceClassification(InitWeightsMixin, ModernBertPreTraine
|
|
| 499 |
self.config = config
|
| 500 |
|
| 501 |
self.model = ModernBertModel(config)
|
| 502 |
-
if self.config.classifier_pooling in {"cls_mha", "max_seq_mha"}:
|
| 503 |
self.pooling_attn = ModChemBertPoolingAttention(config=self.config)
|
| 504 |
else:
|
| 505 |
self.pooling_attn = None
|
|
@@ -649,6 +649,7 @@ def _pool_modchembert_output(
|
|
| 649 |
- max_cls: Element-wise max pooling over the last k hidden states, then take CLS token
|
| 650 |
- cls_mha: Multi-head attention with CLS token as query and full sequence as keys/values
|
| 651 |
- max_seq_mha: Max pooling over last k states + multi-head attention with CLS as query
|
|
|
|
| 652 |
- max_seq_mean: Max pooling over last k hidden states, then mean pooling over sequence
|
| 653 |
- sum_mean: Sum all hidden states across layers, then mean pool over sequence
|
| 654 |
- sum_sum: Sum all hidden states across layers, then sum pool over sequence
|
|
@@ -665,7 +666,7 @@ def _pool_modchembert_output(
|
|
| 665 |
torch.Tensor: Pooled representation of shape (batch_size, hidden_size)
|
| 666 |
|
| 667 |
Note:
|
| 668 |
-
Some pooling strategies (cls_mha, max_seq_mha) require the module to have a pooling_attn
|
| 669 |
attribute containing a ModChemBertPoolingAttention instance.
|
| 670 |
"""
|
| 671 |
config = typing.cast(ModChemBertConfig, module.config)
|
|
@@ -689,10 +690,13 @@ def _pool_modchembert_output(
|
|
| 689 |
q=q, kv=last_hidden_state, attention_mask=attention_mask
|
| 690 |
) # (batch, seq_len, hidden)
|
| 691 |
last_hidden_state = torch.mean(attn_out, dim=1)
|
| 692 |
-
elif config.classifier_pooling
|
| 693 |
k_hidden_states = hidden_states[-config.classifier_pooling_last_k :]
|
| 694 |
theta = torch.stack(k_hidden_states, dim=1) # (batch, k, seq_len, hidden)
|
| 695 |
-
|
|
|
|
|
|
|
|
|
|
| 696 |
# Query is pooled CLS token (position 0); Keys/Values are pooled sequence
|
| 697 |
q = pooled_seq[:, 0, :].unsqueeze(1) # (batch, 1, hidden)
|
| 698 |
q = q.expand(-1, pooled_seq.shape[1], -1) # (batch, seq_len, hidden)
|
|
@@ -729,6 +733,7 @@ def _pool_modchembert_output(
|
|
| 729 |
|
| 730 |
|
| 731 |
__all__ = [
|
|
|
|
| 732 |
"ModChemBertForMaskedLM",
|
| 733 |
"ModChemBertForSequenceClassification",
|
| 734 |
]
|
|
|
|
| 19 |
# Modifications include:
|
| 20 |
# - Additional classifier_pooling options for ModChemBertForSequenceClassification
|
| 21 |
# - sum_mean, sum_sum, mean_sum, mean_mean: from ChemLM (utilizes all hidden states)
|
| 22 |
+
# - max_cls, cls_mha, max_seq_mha, mean_seq_mha: from MaxPoolBERT (utilizes last k hidden states)
|
| 23 |
# - max_seq_mean: a merge between sum_mean and max_cls (utilizes last k hidden states)
|
| 24 |
+
# - Addition of ModChemBertPoolingAttention for cls_mha, max_seq_mha, and mean_seq_mha pooling options
|
| 25 |
|
| 26 |
import copy
|
| 27 |
import math
|
|
|
|
| 499 |
self.config = config
|
| 500 |
|
| 501 |
self.model = ModernBertModel(config)
|
| 502 |
+
if self.config.classifier_pooling in {"cls_mha", "max_seq_mha", "mean_seq_mha"}:
|
| 503 |
self.pooling_attn = ModChemBertPoolingAttention(config=self.config)
|
| 504 |
else:
|
| 505 |
self.pooling_attn = None
|
|
|
|
| 649 |
- max_cls: Element-wise max pooling over the last k hidden states, then take CLS token
|
| 650 |
- cls_mha: Multi-head attention with CLS token as query and full sequence as keys/values
|
| 651 |
- max_seq_mha: Max pooling over last k states + multi-head attention with CLS as query
|
| 652 |
+
- mean_seq_mha: Mean pooling over last k states + multi-head attention with CLS as query
|
| 653 |
- max_seq_mean: Max pooling over last k hidden states, then mean pooling over sequence
|
| 654 |
- sum_mean: Sum all hidden states across layers, then mean pool over sequence
|
| 655 |
- sum_sum: Sum all hidden states across layers, then sum pool over sequence
|
|
|
|
| 666 |
torch.Tensor: Pooled representation of shape (batch_size, hidden_size)
|
| 667 |
|
| 668 |
Note:
|
| 669 |
+
Some pooling strategies (cls_mha, max_seq_mha, mean_seq_mha) require the module to have a pooling_attn
|
| 670 |
attribute containing a ModChemBertPoolingAttention instance.
|
| 671 |
"""
|
| 672 |
config = typing.cast(ModChemBertConfig, module.config)
|
|
|
|
| 690 |
q=q, kv=last_hidden_state, attention_mask=attention_mask
|
| 691 |
) # (batch, seq_len, hidden)
|
| 692 |
last_hidden_state = torch.mean(attn_out, dim=1)
|
| 693 |
+
elif config.classifier_pooling in {"max_seq_mha", "mean_seq_mha"}:
|
| 694 |
k_hidden_states = hidden_states[-config.classifier_pooling_last_k :]
|
| 695 |
theta = torch.stack(k_hidden_states, dim=1) # (batch, k, seq_len, hidden)
|
| 696 |
+
if config.classifier_pooling == "max_seq_mha":
|
| 697 |
+
pooled_seq = torch.max(theta, dim=1).values # Element-wise max over k -> (batch, seq_len, hidden)
|
| 698 |
+
else:
|
| 699 |
+
pooled_seq = torch.mean(theta, dim=1) # Element-wise mean over k -> (batch, seq_len, hidden)
|
| 700 |
# Query is pooled CLS token (position 0); Keys/Values are pooled sequence
|
| 701 |
q = pooled_seq[:, 0, :].unsqueeze(1) # (batch, 1, hidden)
|
| 702 |
q = q.expand(-1, pooled_seq.shape[1], -1) # (batch, seq_len, hidden)
|
|
|
|
| 733 |
|
| 734 |
|
| 735 |
__all__ = [
|
| 736 |
+
"ModChemBertModel",
|
| 737 |
"ModChemBertForMaskedLM",
|
| 738 |
"ModChemBertForSequenceClassification",
|
| 739 |
]
|