YongganFu commited on
Commit
b4672f4
·
verified ·
1 Parent(s): e77795a

Upload NemotronFlashForCausalLM

Browse files
model-00001-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6cea1b75a41cb4881552e82a27b03b380185cd2afc5604623ec7a6f9e7b71eb5
3
- size 4987938992
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:816ac0d20ac6856b93bbcd72e700d87175cb42685508cd43362ab36e3ee1db54
3
+ size 4987920224
model.safetensors.index.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "metadata": {
3
- "total_parameters": 2750007840,
4
- "total_size": 5500015680
5
  },
6
  "weight_map": {
7
  "model.embed_tokens.weight": "model-00001-of-00002.safetensors",
@@ -34,7 +34,6 @@
34
  "model.layers.11.ffn.up_proj.weight": "model-00001-of-00002.safetensors",
35
  "model.layers.11.pre_ffn_layernorm.weight": "model-00001-of-00002.safetensors",
36
  "model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
37
- "model.layers.12.pre_ffn_layernorm.weight": "model-00001-of-00002.safetensors",
38
  "model.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
39
  "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
40
  "model.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
@@ -93,7 +92,6 @@
93
  "model.layers.2.mamba.norm.weight": "model-00001-of-00002.safetensors",
94
  "model.layers.2.mamba.out_proj.weight": "model-00001-of-00002.safetensors",
95
  "model.layers.20.input_layernorm.weight": "model-00001-of-00002.safetensors",
96
- "model.layers.20.pre_ffn_layernorm.weight": "model-00001-of-00002.safetensors",
97
  "model.layers.20.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
98
  "model.layers.20.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
99
  "model.layers.20.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
@@ -201,7 +199,6 @@
201
  "model.layers.35.ffn.up_proj.weight": "model-00002-of-00002.safetensors",
202
  "model.layers.35.pre_ffn_layernorm.weight": "model-00002-of-00002.safetensors",
203
  "model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
204
- "model.layers.4.pre_ffn_layernorm.weight": "model-00001-of-00002.safetensors",
205
  "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
206
  "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
207
  "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
 
1
  {
2
  "metadata": {
3
+ "total_parameters": 2749998624,
4
+ "total_size": 5499997248
5
  },
6
  "weight_map": {
7
  "model.embed_tokens.weight": "model-00001-of-00002.safetensors",
 
34
  "model.layers.11.ffn.up_proj.weight": "model-00001-of-00002.safetensors",
35
  "model.layers.11.pre_ffn_layernorm.weight": "model-00001-of-00002.safetensors",
36
  "model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
37
  "model.layers.12.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
38
  "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
39
  "model.layers.12.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
 
92
  "model.layers.2.mamba.norm.weight": "model-00001-of-00002.safetensors",
93
  "model.layers.2.mamba.out_proj.weight": "model-00001-of-00002.safetensors",
94
  "model.layers.20.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
95
  "model.layers.20.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
96
  "model.layers.20.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
97
  "model.layers.20.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
 
199
  "model.layers.35.ffn.up_proj.weight": "model-00002-of-00002.safetensors",
200
  "model.layers.35.pre_ffn_layernorm.weight": "model-00002-of-00002.safetensors",
201
  "model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
 
202
  "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00002.safetensors",
203
  "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
204
  "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00002.safetensors",
modeling_nemotron_flash.py CHANGED
@@ -918,11 +918,12 @@ class NemotronFlashAttentionDecoderLayer(nn.Module):
918
 
919
  if self.config.intermediate_size > 0:
920
  self.ffn = NemotronFlashMLP(config, layer_idx=layer_idx)
 
921
  else:
922
  self.ffn = None
 
923
 
924
  self.input_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
925
- self.pre_ffn_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
926
 
927
  def forward(
928
  self,
@@ -1037,13 +1038,12 @@ class NemotronFlashMambaDecoderLayer(nn.Module):
1037
  self.intermediate_size = config.intermediate_size
1038
  if self.intermediate_size > 0:
1039
  self.ffn = NemotronFlashMLP(config, layer_idx=layer_idx)
1040
-
1041
- self.input_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1042
-
1043
- if self.intermediate_size > 0:
1044
  self.pre_ffn_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1045
  else:
 
1046
  self.pre_ffn_layernorm = None
 
 
1047
 
1048
 
1049
  def forward(
 
918
 
919
  if self.config.intermediate_size > 0:
920
  self.ffn = NemotronFlashMLP(config, layer_idx=layer_idx)
921
+ self.pre_ffn_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
922
  else:
923
  self.ffn = None
924
+ self.pre_ffn_layernorm = None
925
 
926
  self.input_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
927
 
928
  def forward(
929
  self,
 
1038
  self.intermediate_size = config.intermediate_size
1039
  if self.intermediate_size > 0:
1040
  self.ffn = NemotronFlashMLP(config, layer_idx=layer_idx)
 
 
 
 
1041
  self.pre_ffn_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1042
  else:
1043
+ self.ffn = None
1044
  self.pre_ffn_layernorm = None
1045
+
1046
+ self.input_layernorm = NemotronFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1047
 
1048
 
1049
  def forward(