eacortes commited on
Commit
f18795d
·
verified ·
1 Parent(s): fde8c1e

feat: add mean_seq_mha pooling option

Browse files

Add updated MaxPoolBERT paper's MeanSeq+MHA pooling option

README.md CHANGED
@@ -1,5 +1,6 @@
1
  ---
2
  license: apache-2.0
 
3
  library_name: transformers
4
  tags:
5
  - modernbert
@@ -64,13 +65,14 @@ This base model includes configurable pooling strategies for downstream fine-tun
64
  - `max_cls`: Max over last k layers of [CLS]
65
  - `cls_mha`: MHA with [CLS] as query
66
  - `max_seq_mha`: MHA with max pooled sequence as KV and max pooled [CLS] as query
 
67
  - `sum_mean`: Sum over all layers then mean tokens
68
  - `sum_sum`: Sum over all layers then sum tokens
69
  - `mean_mean`: Mean over all layers then mean tokens
70
  - `mean_sum`: Mean over all layers then sum tokens
71
  - `max_seq_mean`: Max over last k layers then mean tokens
72
 
73
- Note: ModChemBERT's `max_seq_mha` differs from MaxPoolBERT [3]. MaxPoolBERT uses PyTorch `nn.MultiheadAttention`, whereas ModChemBERT's `ModChemBertPoolingAttention` adapts ModernBERT's `ModernBertAttention`.
74
  On ChemBERTa-3 benchmarks this variant produced stronger validation metrics and avoided the training instabilities (sporadic zero / NaN losses and gradient norms) seen with `nn.MultiheadAttention`. Training instability with ModernBERT has been reported in the past ([discussion 1](https://huggingface.co/answerdotai/ModernBERT-base/discussions/59) and [discussion 2](https://huggingface.co/answerdotai/ModernBERT-base/discussions/63)).
75
 
76
  ## Intended Use
 
1
  ---
2
  license: apache-2.0
3
+ base_model: Derify/ModChemBERT-IR-BASE
4
  library_name: transformers
5
  tags:
6
  - modernbert
 
65
  - `max_cls`: Max over last k layers of [CLS]
66
  - `cls_mha`: MHA with [CLS] as query
67
  - `max_seq_mha`: MHA with max pooled sequence as KV and max pooled [CLS] as query
68
+ - `mean_seq_mha`: MHA with mean pooled sequence as KV and mean pooled [CLS] as query
69
  - `sum_mean`: Sum over all layers then mean tokens
70
  - `sum_sum`: Sum over all layers then sum tokens
71
  - `mean_mean`: Mean over all layers then mean tokens
72
  - `mean_sum`: Mean over all layers then sum tokens
73
  - `max_seq_mean`: Max over last k layers then mean tokens
74
 
75
+ Note: ModChemBERT's `cls_mha`, `max_seq_mha`, and `mean_seq_mha` differ from MaxPoolBERT [3]. MaxPoolBERT uses PyTorch `nn.MultiheadAttention`, whereas ModChemBERT's `ModChemBertPoolingAttention` adapts ModernBERT's `ModernBertAttention`.
76
  On ChemBERTa-3 benchmarks this variant produced stronger validation metrics and avoided the training instabilities (sporadic zero / NaN losses and gradient norms) seen with `nn.MultiheadAttention`. Training instability with ModernBERT has been reported in the past ([discussion 1](https://huggingface.co/answerdotai/ModernBERT-base/discussions/59) and [discussion 2](https://huggingface.co/answerdotai/ModernBERT-base/discussions/63)).
77
 
78
  ## Intended Use
configuration_modchembert.py CHANGED
@@ -37,14 +37,15 @@ class ModChemBertConfig(ModernBertConfig):
37
  - "max_cls": Element-wise max pooling over last k hidden states, then take CLS token
38
  - "cls_mha": Multi-head attention with CLS token as query and full sequence as keys/values
39
  - "max_seq_mha": Max pooling over last k states + multi-head attention with CLS as query
 
40
  - "max_seq_mean": Max pooling over last k hidden states, then mean pooling over sequence
41
  Defaults to "sum_mean".
42
  classifier_pooling_num_attention_heads (int, optional): Number of attention heads for multi-head attention
43
- pooling strategies (cls_mha, max_seq_mha). Defaults to 4.
44
  classifier_pooling_attention_dropout (float, optional): Dropout probability for multi-head attention
45
- pooling strategies (cls_mha, max_seq_mha). Defaults to 0.0.
46
- classifier_pooling_last_k (int, optional): Number of last hidden layers to use for max pooling
47
- strategies (max_cls, max_seq_mha, max_seq_mean). Defaults to 8.
48
  *args: Variable length argument list passed to ModernBertConfig.
49
  **kwargs: Arbitrary keyword arguments passed to ModernBertConfig.
50
 
@@ -68,6 +69,7 @@ class ModChemBertConfig(ModernBertConfig):
68
  "max_cls",
69
  "cls_mha",
70
  "max_seq_mha",
 
71
  "max_seq_mean",
72
  ] = "max_seq_mha",
73
  classifier_pooling_num_attention_heads: int = 4,
@@ -75,6 +77,25 @@ class ModChemBertConfig(ModernBertConfig):
75
  classifier_pooling_last_k: int = 8,
76
  **kwargs,
77
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  # Pass classifier_pooling="cls" to circumvent ValueError in ModernBertConfig init
79
  super().__init__(*args, classifier_pooling="cls", **kwargs)
80
  # Override with custom value
 
37
  - "max_cls": Element-wise max pooling over last k hidden states, then take CLS token
38
  - "cls_mha": Multi-head attention with CLS token as query and full sequence as keys/values
39
  - "max_seq_mha": Max pooling over last k states + multi-head attention with CLS as query
40
+ - "mean_seq_mha": Mean pooling over last k states + multi-head attention with CLS as query
41
  - "max_seq_mean": Max pooling over last k hidden states, then mean pooling over sequence
42
  Defaults to "sum_mean".
43
  classifier_pooling_num_attention_heads (int, optional): Number of attention heads for multi-head attention
44
+ pooling strategies (cls_mha, max_seq_mha, mean_seq_mha). Defaults to 4.
45
  classifier_pooling_attention_dropout (float, optional): Dropout probability for multi-head attention
46
+ pooling strategies (cls_mha, max_seq_mha, mean_seq_mha). Defaults to 0.0.
47
+ classifier_pooling_last_k (int, optional): Number of last hidden layers to use for max/mean pooling
48
+ strategies (max_cls, max_seq_mha, mean_seq_mha, max_seq_mean). Defaults to 8.
49
  *args: Variable length argument list passed to ModernBertConfig.
50
  **kwargs: Arbitrary keyword arguments passed to ModernBertConfig.
51
 
 
69
  "max_cls",
70
  "cls_mha",
71
  "max_seq_mha",
72
+ "mean_seq_mha",
73
  "max_seq_mean",
74
  ] = "max_seq_mha",
75
  classifier_pooling_num_attention_heads: int = 4,
 
77
  classifier_pooling_last_k: int = 8,
78
  **kwargs,
79
  ):
80
+ valid_classifier_pooling_options = [
81
+ "cls",
82
+ "mean",
83
+ "sum_mean",
84
+ "sum_sum",
85
+ "mean_mean",
86
+ "mean_sum",
87
+ "max_cls",
88
+ "cls_mha",
89
+ "max_seq_mha",
90
+ "mean_seq_mha",
91
+ "max_seq_mean",
92
+ ]
93
+ if classifier_pooling not in valid_classifier_pooling_options:
94
+ raise ValueError(
95
+ f"Invalid value for `classifier_pooling`, should be one of {valid_classifier_pooling_options}, "
96
+ f"but is {classifier_pooling}."
97
+ )
98
+
99
  # Pass classifier_pooling="cls" to circumvent ValueError in ModernBertConfig init
100
  super().__init__(*args, classifier_pooling="cls", **kwargs)
101
  # Override with custom value
modeling_modchembert.py CHANGED
@@ -19,9 +19,9 @@
19
  # Modifications include:
20
  # - Additional classifier_pooling options for ModChemBertForSequenceClassification
21
  # - sum_mean, sum_sum, mean_sum, mean_mean: from ChemLM (utilizes all hidden states)
22
- # - max_cls, cls_mha, max_seq_mha: from MaxPoolBERT (utilizes last k hidden states)
23
  # - max_seq_mean: a merge between sum_mean and max_cls (utilizes last k hidden states)
24
- # - Addition of ModChemBertPoolingAttention for cls_mha and max_seq_mha pooling options
25
 
26
  import copy
27
  import math
@@ -499,7 +499,7 @@ class ModChemBertForSequenceClassification(InitWeightsMixin, ModernBertPreTraine
499
  self.config = config
500
 
501
  self.model = ModernBertModel(config)
502
- if self.config.classifier_pooling in {"cls_mha", "max_seq_mha"}:
503
  self.pooling_attn = ModChemBertPoolingAttention(config=self.config)
504
  else:
505
  self.pooling_attn = None
@@ -649,6 +649,7 @@ def _pool_modchembert_output(
649
  - max_cls: Element-wise max pooling over the last k hidden states, then take CLS token
650
  - cls_mha: Multi-head attention with CLS token as query and full sequence as keys/values
651
  - max_seq_mha: Max pooling over last k states + multi-head attention with CLS as query
 
652
  - max_seq_mean: Max pooling over last k hidden states, then mean pooling over sequence
653
  - sum_mean: Sum all hidden states across layers, then mean pool over sequence
654
  - sum_sum: Sum all hidden states across layers, then sum pool over sequence
@@ -665,7 +666,7 @@ def _pool_modchembert_output(
665
  torch.Tensor: Pooled representation of shape (batch_size, hidden_size)
666
 
667
  Note:
668
- Some pooling strategies (cls_mha, max_seq_mha) require the module to have a pooling_attn
669
  attribute containing a ModChemBertPoolingAttention instance.
670
  """
671
  config = typing.cast(ModChemBertConfig, module.config)
@@ -689,10 +690,13 @@ def _pool_modchembert_output(
689
  q=q, kv=last_hidden_state, attention_mask=attention_mask
690
  ) # (batch, seq_len, hidden)
691
  last_hidden_state = torch.mean(attn_out, dim=1)
692
- elif config.classifier_pooling == "max_seq_mha":
693
  k_hidden_states = hidden_states[-config.classifier_pooling_last_k :]
694
  theta = torch.stack(k_hidden_states, dim=1) # (batch, k, seq_len, hidden)
695
- pooled_seq = torch.max(theta, dim=1).values # Element-wise max over k -> (batch, seq_len, hidden)
 
 
 
696
  # Query is pooled CLS token (position 0); Keys/Values are pooled sequence
697
  q = pooled_seq[:, 0, :].unsqueeze(1) # (batch, 1, hidden)
698
  q = q.expand(-1, pooled_seq.shape[1], -1) # (batch, seq_len, hidden)
@@ -729,6 +733,7 @@ def _pool_modchembert_output(
729
 
730
 
731
  __all__ = [
 
732
  "ModChemBertForMaskedLM",
733
  "ModChemBertForSequenceClassification",
734
  ]
 
19
  # Modifications include:
20
  # - Additional classifier_pooling options for ModChemBertForSequenceClassification
21
  # - sum_mean, sum_sum, mean_sum, mean_mean: from ChemLM (utilizes all hidden states)
22
+ # - max_cls, cls_mha, max_seq_mha, mean_seq_mha: from MaxPoolBERT (utilizes last k hidden states)
23
  # - max_seq_mean: a merge between sum_mean and max_cls (utilizes last k hidden states)
24
+ # - Addition of ModChemBertPoolingAttention for cls_mha, max_seq_mha, and mean_seq_mha pooling options
25
 
26
  import copy
27
  import math
 
499
  self.config = config
500
 
501
  self.model = ModernBertModel(config)
502
+ if self.config.classifier_pooling in {"cls_mha", "max_seq_mha", "mean_seq_mha"}:
503
  self.pooling_attn = ModChemBertPoolingAttention(config=self.config)
504
  else:
505
  self.pooling_attn = None
 
649
  - max_cls: Element-wise max pooling over the last k hidden states, then take CLS token
650
  - cls_mha: Multi-head attention with CLS token as query and full sequence as keys/values
651
  - max_seq_mha: Max pooling over last k states + multi-head attention with CLS as query
652
+ - mean_seq_mha: Mean pooling over last k states + multi-head attention with CLS as query
653
  - max_seq_mean: Max pooling over last k hidden states, then mean pooling over sequence
654
  - sum_mean: Sum all hidden states across layers, then mean pool over sequence
655
  - sum_sum: Sum all hidden states across layers, then sum pool over sequence
 
666
  torch.Tensor: Pooled representation of shape (batch_size, hidden_size)
667
 
668
  Note:
669
+ Some pooling strategies (cls_mha, max_seq_mha, mean_seq_mha) require the module to have a pooling_attn
670
  attribute containing a ModChemBertPoolingAttention instance.
671
  """
672
  config = typing.cast(ModChemBertConfig, module.config)
 
690
  q=q, kv=last_hidden_state, attention_mask=attention_mask
691
  ) # (batch, seq_len, hidden)
692
  last_hidden_state = torch.mean(attn_out, dim=1)
693
+ elif config.classifier_pooling in {"max_seq_mha", "mean_seq_mha"}:
694
  k_hidden_states = hidden_states[-config.classifier_pooling_last_k :]
695
  theta = torch.stack(k_hidden_states, dim=1) # (batch, k, seq_len, hidden)
696
+ if config.classifier_pooling == "max_seq_mha":
697
+ pooled_seq = torch.max(theta, dim=1).values # Element-wise max over k -> (batch, seq_len, hidden)
698
+ else:
699
+ pooled_seq = torch.mean(theta, dim=1) # Element-wise mean over k -> (batch, seq_len, hidden)
700
  # Query is pooled CLS token (position 0); Keys/Values are pooled sequence
701
  q = pooled_seq[:, 0, :].unsqueeze(1) # (batch, 1, hidden)
702
  q = q.expand(-1, pooled_seq.shape[1], -1) # (batch, seq_len, hidden)
 
733
 
734
 
735
  __all__ = [
736
+ "ModChemBertModel",
737
  "ModChemBertForMaskedLM",
738
  "ModChemBertForSequenceClassification",
739
  ]