Fix layer_outputs handling: transformers 5.x returns tensor, not tuple
Browse files- modeling_nvembed.py +20 -9
modeling_nvembed.py
CHANGED
|
@@ -181,13 +181,20 @@ class BidirectionalMistralModel(MistralModel):
|
|
| 181 |
**layer_kwargs,
|
| 182 |
)
|
| 183 |
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
hidden_states = self.norm(hidden_states)
|
| 193 |
|
|
@@ -196,8 +203,12 @@ class BidirectionalMistralModel(MistralModel):
|
|
| 196 |
all_hidden_states += (hidden_states,)
|
| 197 |
|
| 198 |
next_cache = None
|
| 199 |
-
if use_cache:
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
if not return_dict:
|
| 203 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
|
|
| 181 |
**layer_kwargs,
|
| 182 |
)
|
| 183 |
|
| 184 |
+
# Compatibility fix for transformers 5.x:
|
| 185 |
+
# In transformers 5.x, MistralDecoderLayer.forward returns a single tensor
|
| 186 |
+
# In transformers 4.x, it returns a tuple (hidden_states, present_key_value, ...)
|
| 187 |
+
if isinstance(layer_outputs, torch.Tensor):
|
| 188 |
+
# transformers 5.x: direct tensor output
|
| 189 |
+
hidden_states = layer_outputs
|
| 190 |
+
# Note: use_cache and output_attentions not supported in this code path
|
| 191 |
+
else:
|
| 192 |
+
# transformers 4.x: tuple output
|
| 193 |
+
hidden_states = layer_outputs[0]
|
| 194 |
+
if use_cache:
|
| 195 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 196 |
+
if output_attentions:
|
| 197 |
+
all_self_attns += (layer_outputs[1],)
|
| 198 |
|
| 199 |
hidden_states = self.norm(hidden_states)
|
| 200 |
|
|
|
|
| 203 |
all_hidden_states += (hidden_states,)
|
| 204 |
|
| 205 |
next_cache = None
|
| 206 |
+
if use_cache and next_decoder_cache is not None:
|
| 207 |
+
# Compatibility: to_legacy_cache may not exist in all versions
|
| 208 |
+
if use_legacy_cache and hasattr(next_decoder_cache, 'to_legacy_cache'):
|
| 209 |
+
next_cache = next_decoder_cache.to_legacy_cache()
|
| 210 |
+
else:
|
| 211 |
+
next_cache = next_decoder_cache
|
| 212 |
|
| 213 |
if not return_dict:
|
| 214 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|