Hyukkyu commited on
Commit
7d27cd8
·
verified ·
1 Parent(s): b73d528

Fix layer_outputs handling: transformers 5.x returns tensor, not tuple

Browse files
Files changed (1) hide show
  1. modeling_nvembed.py +20 -9
modeling_nvembed.py CHANGED
@@ -181,13 +181,20 @@ class BidirectionalMistralModel(MistralModel):
181
  **layer_kwargs,
182
  )
183
 
184
- hidden_states = layer_outputs[0]
185
-
186
- if use_cache:
187
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
188
-
189
- if output_attentions:
190
- all_self_attns += (layer_outputs[1],)
 
 
 
 
 
 
 
191
 
192
  hidden_states = self.norm(hidden_states)
193
 
@@ -196,8 +203,12 @@ class BidirectionalMistralModel(MistralModel):
196
  all_hidden_states += (hidden_states,)
197
 
198
  next_cache = None
199
- if use_cache:
200
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
 
 
 
 
201
 
202
  if not return_dict:
203
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
181
  **layer_kwargs,
182
  )
183
 
184
+ # Compatibility fix for transformers 5.x:
185
+ # In transformers 5.x, MistralDecoderLayer.forward returns a single tensor
186
+ # In transformers 4.x, it returns a tuple (hidden_states, present_key_value, ...)
187
+ if isinstance(layer_outputs, torch.Tensor):
188
+ # transformers 5.x: direct tensor output
189
+ hidden_states = layer_outputs
190
+ # Note: use_cache and output_attentions not supported in this code path
191
+ else:
192
+ # transformers 4.x: tuple output
193
+ hidden_states = layer_outputs[0]
194
+ if use_cache:
195
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
196
+ if output_attentions:
197
+ all_self_attns += (layer_outputs[1],)
198
 
199
  hidden_states = self.norm(hidden_states)
200
 
 
203
  all_hidden_states += (hidden_states,)
204
 
205
  next_cache = None
206
+ if use_cache and next_decoder_cache is not None:
207
+ # Compatibility: to_legacy_cache may not exist in all versions
208
+ if use_legacy_cache and hasattr(next_decoder_cache, 'to_legacy_cache'):
209
+ next_cache = next_decoder_cache.to_legacy_cache()
210
+ else:
211
+ next_cache = next_decoder_cache
212
 
213
  if not return_dict:
214
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)