Upload modeling_nemotron_h.py
Browse files- modeling_nemotron_h.py +15 -10
modeling_nemotron_h.py
CHANGED
|
@@ -24,21 +24,21 @@ import torch.utils.checkpoint
|
|
| 24 |
from torch import nn
|
| 25 |
from torch.nn import CrossEntropyLoss
|
| 26 |
|
| 27 |
-
from
|
| 28 |
-
from
|
| 29 |
-
from
|
| 30 |
-
from
|
| 31 |
AttentionMaskConverter,
|
| 32 |
)
|
| 33 |
-
from
|
| 34 |
-
from
|
| 35 |
ModelOutput,
|
| 36 |
add_code_sample_docstrings,
|
| 37 |
add_start_docstrings,
|
| 38 |
add_start_docstrings_to_model_forward,
|
| 39 |
logging,
|
| 40 |
)
|
| 41 |
-
from
|
| 42 |
is_causal_conv1d_available,
|
| 43 |
is_flash_attn_2_available,
|
| 44 |
is_flash_attn_greater_or_equal_2_10,
|
|
@@ -70,7 +70,7 @@ else:
|
|
| 70 |
causal_conv1d_update, causal_conv1d_fn = None, None
|
| 71 |
|
| 72 |
if is_flash_attn_2_available():
|
| 73 |
-
from
|
| 74 |
|
| 75 |
is_fast_path_available = all(
|
| 76 |
(
|
|
@@ -844,8 +844,8 @@ class NemotronHAttention(nn.Module):
|
|
| 844 |
self.attention_dropout = config.attention_dropout
|
| 845 |
self.hidden_size = config.hidden_size
|
| 846 |
self.num_heads = config.num_attention_heads
|
| 847 |
-
if config.
|
| 848 |
-
self.head_dim = config.
|
| 849 |
else:
|
| 850 |
self.head_dim = config.hidden_size // config.num_attention_heads
|
| 851 |
self.num_key_value_heads = config.num_key_value_heads
|
|
@@ -1542,6 +1542,11 @@ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
|
|
| 1542 |
|
| 1543 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 1544 |
if inputs_embeds is not None and empty_past_kv:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1545 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 1546 |
else:
|
| 1547 |
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|
|
|
|
| 24 |
from torch import nn
|
| 25 |
from torch.nn import CrossEntropyLoss
|
| 26 |
|
| 27 |
+
from ...activations import ACT2FN
|
| 28 |
+
from ...cache_utils import DynamicCache # we need __iter__ and __len__ of pkv
|
| 29 |
+
from ...generation import GenerationMixin
|
| 30 |
+
from ...modeling_attn_mask_utils import (
|
| 31 |
AttentionMaskConverter,
|
| 32 |
)
|
| 33 |
+
from ...modeling_utils import PreTrainedModel
|
| 34 |
+
from ...utils import (
|
| 35 |
ModelOutput,
|
| 36 |
add_code_sample_docstrings,
|
| 37 |
add_start_docstrings,
|
| 38 |
add_start_docstrings_to_model_forward,
|
| 39 |
logging,
|
| 40 |
)
|
| 41 |
+
from ...utils.import_utils import (
|
| 42 |
is_causal_conv1d_available,
|
| 43 |
is_flash_attn_2_available,
|
| 44 |
is_flash_attn_greater_or_equal_2_10,
|
|
|
|
| 70 |
causal_conv1d_update, causal_conv1d_fn = None, None
|
| 71 |
|
| 72 |
if is_flash_attn_2_available():
|
| 73 |
+
from ...modeling_flash_attention_utils import _flash_attention_forward
|
| 74 |
|
| 75 |
is_fast_path_available = all(
|
| 76 |
(
|
|
|
|
| 844 |
self.attention_dropout = config.attention_dropout
|
| 845 |
self.hidden_size = config.hidden_size
|
| 846 |
self.num_heads = config.num_attention_heads
|
| 847 |
+
if config.attention_head_dim is not None:
|
| 848 |
+
self.head_dim = config.attention_head_dim
|
| 849 |
else:
|
| 850 |
self.head_dim = config.hidden_size // config.num_attention_heads
|
| 851 |
self.num_key_value_heads = config.num_key_value_heads
|
|
|
|
| 1542 |
|
| 1543 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 1544 |
if inputs_embeds is not None and empty_past_kv:
|
| 1545 |
+
# TODO(pjin): workaround fix for properly extending inputs_embeds;
|
| 1546 |
+
# longer term, may be better handled elsewhere in .generate().
|
| 1547 |
+
if input_ids is not None and inputs_embeds.shape[1] < input_ids.shape[1]:
|
| 1548 |
+
new_token_embeds = self.get_input_embeddings()(input_ids[:,inputs_embeds.shape[1]:])
|
| 1549 |
+
inputs_embeds = torch.cat([inputs_embeds, new_token_embeds], dim=1)
|
| 1550 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 1551 |
else:
|
| 1552 |
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
|