Upload FastSLMForCausalLM
Browse files- delta_net.py +2 -4
delta_net.py
CHANGED
|
@@ -10,7 +10,7 @@ import torch.nn as nn
|
|
| 10 |
from einops import rearrange
|
| 11 |
from torch.nn import functional as F
|
| 12 |
|
| 13 |
-
from fla.modules import
|
| 14 |
from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
|
| 15 |
|
| 16 |
from typing import Any, Dict, List, Optional, Tuple
|
|
@@ -159,7 +159,7 @@ class DeltaNet(nn.Module):
|
|
| 159 |
)
|
| 160 |
if use_gate:
|
| 161 |
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
| 162 |
-
self.o_norm =
|
| 163 |
else:
|
| 164 |
self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
|
| 165 |
|
|
@@ -280,7 +280,6 @@ class DeltaNet(nn.Module):
|
|
| 280 |
initial_state=recurrent_state,
|
| 281 |
output_final_state=use_cache,
|
| 282 |
cu_seqlens=cu_seqlens,
|
| 283 |
-
head_first=False,
|
| 284 |
use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
|
| 285 |
)
|
| 286 |
elif mode == 'chunk':
|
|
@@ -292,7 +291,6 @@ class DeltaNet(nn.Module):
|
|
| 292 |
initial_state=recurrent_state,
|
| 293 |
output_final_state=use_cache,
|
| 294 |
cu_seqlens=cu_seqlens,
|
| 295 |
-
head_first=False,
|
| 296 |
use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
|
| 297 |
)
|
| 298 |
else:
|
|
|
|
| 10 |
from einops import rearrange
|
| 11 |
from torch.nn import functional as F
|
| 12 |
|
| 13 |
+
from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
|
| 14 |
from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule
|
| 15 |
|
| 16 |
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
| 159 |
)
|
| 160 |
if use_gate:
|
| 161 |
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
|
| 162 |
+
self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
|
| 163 |
else:
|
| 164 |
self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
|
| 165 |
|
|
|
|
| 280 |
initial_state=recurrent_state,
|
| 281 |
output_final_state=use_cache,
|
| 282 |
cu_seqlens=cu_seqlens,
|
|
|
|
| 283 |
use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
|
| 284 |
)
|
| 285 |
elif mode == 'chunk':
|
|
|
|
| 291 |
initial_state=recurrent_state,
|
| 292 |
output_final_state=use_cache,
|
| 293 |
cu_seqlens=cu_seqlens,
|
|
|
|
| 294 |
use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False
|
| 295 |
)
|
| 296 |
else:
|