YongganFu commited on
Commit
a20abb6
·
verified ·
1 Parent(s): 9e67537

Upload FastSLMForCausalLM

Browse files
Files changed (1) hide show
  1. delta_net.py +2 -4
delta_net.py CHANGED
@@ -10,7 +10,7 @@ import torch.nn as nn
10
  from einops import rearrange
11
  from torch.nn import functional as F
12
 
13
- from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
14
  from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
15
 
16
  from typing import Any, Dict, List, Optional, Tuple
@@ -159,7 +159,7 @@ class DeltaNet(nn.Module):
159
  )
160
  if use_gate:
161
  self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
162
- self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps)
163
  else:
164
  self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
165
 
@@ -280,7 +280,6 @@ class DeltaNet(nn.Module):
280
  initial_state=recurrent_state,
281
  output_final_state=use_cache,
282
  cu_seqlens=cu_seqlens,
283
- head_first=False,
284
  use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
285
  )
286
  elif mode == 'chunk':
@@ -292,7 +291,6 @@ class DeltaNet(nn.Module):
292
  initial_state=recurrent_state,
293
  output_final_state=use_cache,
294
  cu_seqlens=cu_seqlens,
295
- head_first=False,
296
  use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
297
  )
298
  else:
 
10
  from einops import rearrange
11
  from torch.nn import functional as F
12
 
13
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
14
  from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
15
 
16
  from typing import Any, Dict, List, Optional, Tuple
 
159
  )
160
  if use_gate:
161
  self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
162
+ self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
163
  else:
164
  self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
165
 
 
280
  initial_state=recurrent_state,
281
  output_final_state=use_cache,
282
  cu_seqlens=cu_seqlens,
 
283
  use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
284
  )
285
  elif mode == 'chunk':
 
291
  initial_state=recurrent_state,
292
  output_final_state=use_cache,
293
  cu_seqlens=cu_seqlens,
 
294
  use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
295
  )
296
  else: