Update src/nn/attn.py
Browse files- src/nn/attn.py +71 -250
src/nn/attn.py
CHANGED
|
@@ -6,7 +6,9 @@ from jaxtyping import Float, Bool
|
|
| 6 |
from torch import Tensor
|
| 7 |
from typing import Optional
|
| 8 |
from torch.nn.attention.flex_attention import flex_attention
|
|
|
|
| 9 |
|
|
|
|
| 10 |
|
| 11 |
class KVCache(nn.Module):
|
| 12 |
"""
|
|
@@ -19,7 +21,7 @@ class KVCache(nn.Module):
|
|
| 19 |
Call `extend(layer_idx, k, v)` once per layer for the *same* frame.
|
| 20 |
Call `update_global_location(n_frames)` once after all layers to commit the frame(s).
|
| 21 |
"""
|
| 22 |
-
def __init__(self, batch_size, n_layers, n_heads, d_head, toks_per_frame, n_window, *, dtype=None, device=None
|
| 23 |
super().__init__()
|
| 24 |
self.batch_size = batch_size
|
| 25 |
self.n_layers = n_layers
|
|
@@ -27,10 +29,9 @@ class KVCache(nn.Module):
|
|
| 27 |
self.d_head = d_head
|
| 28 |
self.toks_per_frame = toks_per_frame
|
| 29 |
self.n_window = n_window
|
| 30 |
-
self.size =
|
| 31 |
|
| 32 |
# Pointers / counters
|
| 33 |
-
self.curr_layer = 0 # which layer are we writing for this frame
|
| 34 |
self.global_loc = 0 # total tokens ever committed
|
| 35 |
self.local_loc = 0 # valid tokens in buffer (<= size)
|
| 36 |
self._write_ptr = 0 # ring-buffer write pointer (index of next commit position)
|
|
@@ -40,93 +41,67 @@ class KVCache(nn.Module):
|
|
| 40 |
self.register_buffer('keys', t.zeros(n_layers, batch_size, self.size, n_heads, d_head, dtype=dtype, device=device))
|
| 41 |
self.register_buffer('values', t.zeros(n_layers, batch_size, self.size, n_heads, d_head, dtype=dtype, device=device))
|
| 42 |
|
| 43 |
-
# Misc
|
| 44 |
-
self.enforce_layer_order = enforce_layer_order
|
| 45 |
|
| 46 |
-
|
| 47 |
-
def get(self, layer_idx):
|
| 48 |
"""Return (K, V) for given layer in chronological order: shape (B, L, H, D) where L = local_loc."""
|
| 49 |
-
self._check_layer(layer_idx)
|
| 50 |
if self.local_loc == 0:
|
| 51 |
# return empty views
|
| 52 |
-
empty = self.keys[
|
| 53 |
return empty, empty
|
| 54 |
|
| 55 |
start = (self._write_ptr - self.local_loc) % self.size
|
| 56 |
if start + self.local_loc <= self.size:
|
| 57 |
# contiguous slice
|
| 58 |
-
k = self.keys[
|
| 59 |
-
v = self.values[
|
| 60 |
else:
|
| 61 |
# wrap: concatenate two slices to maintain chronological order
|
| 62 |
first = self.size - start
|
| 63 |
k = t.cat([
|
| 64 |
-
self.keys[
|
| 65 |
-
self.keys[
|
| 66 |
-
], dim=
|
| 67 |
v = t.cat([
|
| 68 |
-
self.values[
|
| 69 |
-
self.values[
|
| 70 |
-
], dim=
|
| 71 |
return k, v
|
| 72 |
|
| 73 |
@t.no_grad()
|
| 74 |
-
def extend(self,
|
| 75 |
"""
|
| 76 |
Stage (but do not commit) tokens for the current frame for the given layer.
|
| 77 |
Call update_global_location(n_frames) to commit after all layers wrote.
|
| 78 |
"""
|
| 79 |
assert keys.shape == values.shape, f"keys and values shapes must match, got {keys.shape} vs {values.shape}"
|
| 80 |
-
self._check_layer(layer_idx)
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
assert B == self.batch_size, f"batch mismatch: expected {self.batch_size}, got {B}"
|
| 85 |
assert H == self.n_heads and D == self.d_head, f"heads/d_head mismatch: expected {(self.n_heads, self.d_head)}, got {(H, D)}"
|
| 86 |
assert T > 0 and T <= self.size, f"T must be in 1..{self.size}, got {T}"
|
| 87 |
-
# Optional: if you only ever append whole frames:
|
| 88 |
-
# assert T == self.toks_per_frame, f"T must equal toks_per_frame ({self.toks_per_frame}), got {T}"
|
| 89 |
|
| 90 |
-
# Cast to buffer dtype/device if needed
|
| 91 |
if keys.dtype != self.keys.dtype or keys.device != self.keys.device:
|
| 92 |
keys = keys.to(dtype=self.keys.dtype, device=self.keys.device)
|
| 93 |
if values.dtype != self.values.dtype or values.device != self.values.device:
|
| 94 |
values = values.to(dtype=self.values.dtype, device=self.values.device)
|
| 95 |
|
| 96 |
-
# Write into the ring at the *current* write_ptr (uncommitted until update_global_location)
|
| 97 |
i0 = self._write_ptr
|
| 98 |
i1 = (self._write_ptr + T) % self.size
|
| 99 |
if i0 < i1:
|
| 100 |
-
self.keys[
|
| 101 |
-
self.values[
|
| 102 |
else:
|
| 103 |
-
#
|
| 104 |
split = self.size - i0
|
| 105 |
-
self.keys[
|
| 106 |
-
self.values[
|
| 107 |
-
self.keys[
|
| 108 |
-
self.values[
|
| 109 |
|
| 110 |
-
|
| 111 |
-
self.
|
| 112 |
-
|
| 113 |
-
@t.no_grad()
|
| 114 |
-
def update_global_location(self, n_frames):
|
| 115 |
-
"""
|
| 116 |
-
Commit staged writes for n_frames (advances the ring write pointer once per frame).
|
| 117 |
-
Keep calling extend(layer_idx, ...) for each layer before you call this.
|
| 118 |
-
"""
|
| 119 |
-
assert n_frames >= 0, f"n_frames must be >= 0, got {n_frames}"
|
| 120 |
-
tokens = n_frames * self.toks_per_frame
|
| 121 |
-
if tokens == 0:
|
| 122 |
-
return
|
| 123 |
-
assert tokens <= self.size, f"Cannot commit {tokens} tokens (> buffer size {self.size})."
|
| 124 |
-
|
| 125 |
-
self.global_loc += tokens
|
| 126 |
-
# Update valid length (never exceeds capacity)
|
| 127 |
-
self.local_loc = min(self.size, self.local_loc + tokens)
|
| 128 |
-
# Advance write pointer
|
| 129 |
-
self._write_ptr = (self._write_ptr + tokens) % self.size
|
| 130 |
|
| 131 |
@t.no_grad()
|
| 132 |
def reset(self, zero_memory: bool = True):
|
|
@@ -138,7 +113,6 @@ class KVCache(nn.Module):
|
|
| 138 |
self.keys.zero_()
|
| 139 |
self.values.zero_()
|
| 140 |
|
| 141 |
-
# -------------- Convenience / Introspection --------------
|
| 142 |
@property
|
| 143 |
def local_location(self):
|
| 144 |
return self.local_loc
|
|
@@ -155,33 +129,10 @@ class KVCache(nn.Module):
|
|
| 155 |
def dtype(self):
|
| 156 |
return self.keys.dtype
|
| 157 |
|
| 158 |
-
def get_recent(self, layer_idx, last_T):
|
| 159 |
-
"""Return the most recent last_T tokens for a layer (chronological)."""
|
| 160 |
-
self._check_layer(layer_idx, allow_any=True)
|
| 161 |
-
last_T = min(last_T, self.local_loc)
|
| 162 |
-
if last_T == 0:
|
| 163 |
-
empty = self.keys[layer_idx, :, :0]
|
| 164 |
-
return empty, empty
|
| 165 |
-
start = (self._write_ptr - last_T) % self.size
|
| 166 |
-
if start + last_T <= self.size:
|
| 167 |
-
k = self.keys[layer_idx, :, start:start + last_T]
|
| 168 |
-
v = self.values[layer_idx, :, start:start + last_T]
|
| 169 |
-
else:
|
| 170 |
-
first = self.size - start
|
| 171 |
-
k = t.cat([self.keys[layer_idx, :, start:self.size], self.keys[layer_idx, :, 0:(last_T - first)]], dim=1)
|
| 172 |
-
v = t.cat([self.values[layer_idx, :, start:self.size], self.values[layer_idx, :, 0:(last_T - first)]], dim=1)
|
| 173 |
-
return k, v
|
| 174 |
-
|
| 175 |
-
# -------------- Internal checks --------------
|
| 176 |
-
def _check_layer(self, layer_idx, allow_any=False):
|
| 177 |
-
assert 0 <= layer_idx < self.n_layers, f"layer_idx out of range: 0..{self.n_layers-1}, got {layer_idx}"
|
| 178 |
-
if self.enforce_layer_order and not allow_any:
|
| 179 |
-
assert layer_idx == (self.curr_layer % self.n_layers), \
|
| 180 |
-
f"Layer order mismatch: expected {self.curr_layer % self.n_layers}, got {layer_idx}"
|
| 181 |
|
| 182 |
|
| 183 |
-
class
|
| 184 |
-
def __init__(self, batch_size, n_layers, n_heads, d_head, toks_per_frame, n_window):
|
| 185 |
"""
|
| 186 |
This is a rolling KVCache
|
| 187 |
"""
|
|
@@ -191,42 +142,40 @@ class KVCacheMine(nn.Module): # this does not work because it destroys the cache
|
|
| 191 |
self.d_head = d_head
|
| 192 |
self.toks_per_frame = toks_per_frame
|
| 193 |
self.n_window = n_window
|
| 194 |
-
self.size = toks_per_frame *
|
| 195 |
self.n_layers = n_layers
|
| 196 |
-
self.curr_layer = 0
|
| 197 |
self.global_loc = 0
|
| 198 |
self.local_loc = 0
|
| 199 |
-
|
| 200 |
-
self.register_buffer('
|
|
|
|
| 201 |
|
| 202 |
-
def get(self
|
| 203 |
-
|
| 204 |
-
return self.keys[layer_idx, :, :self.local_loc], self.values[layer_idx, :, :self.local_loc]
|
| 205 |
|
| 206 |
-
def extend(self,
|
|
|
|
|
|
|
|
|
|
| 207 |
assert keys.shape == values.shape, f"keys and values shapes must match {self.keys.shape} != {self.values.shape}"
|
| 208 |
-
assert layer_idx == self.curr_layer, f"layer idx should be the same as our internal counter but we got {layer_idx} and internal is {self.curr_layer}."
|
| 209 |
assert self.local_loc <= self.size, f"the cache size should be between 0 and {self.size}"
|
| 210 |
local_loc = self.local_loc
|
| 211 |
if local_loc == self.size:
|
| 212 |
# move to the left
|
| 213 |
-
local_loc -= keys.shape[
|
| 214 |
-
assert local_loc >= 0, f"the cache update {keys.shape[
|
| 215 |
assert local_loc % self.toks_per_frame == 0, f"the number of elements in the cache {local_loc} must be a multiple of the number of tokens per frame {self.toks_per_frame}"
|
| 216 |
-
self.keys[
|
| 217 |
-
self.values[
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
self.keys[layer_idx, :, local_loc:local_loc + keys.shape[1]] = keys
|
| 223 |
-
self.values[layer_idx, :, local_loc:local_loc + keys.shape[1]] = values
|
| 224 |
self.curr_layer = (self.curr_layer + 1) % self.n_layers
|
| 225 |
|
| 226 |
-
|
| 227 |
-
self.global_loc += n_frames * self.toks_per_frame
|
| 228 |
if self.local_loc < self.size:
|
| 229 |
-
self.local_loc +=
|
| 230 |
assert self.local_loc <= self.size, f"the local loc {self.local_loc} should never be bigger than {self.size}, something went wrong."
|
| 231 |
|
| 232 |
def reset(self):
|
|
@@ -256,10 +205,13 @@ class KVCacheMine(nn.Module): # this does not work because it destroys the cache
|
|
| 256 |
class AttentionEinOps(nn.Module):
|
| 257 |
IGNORE: Float[Tensor, ""]
|
| 258 |
|
| 259 |
-
def __init__(self, d_model, n_heads, rope=None):
|
| 260 |
super().__init__()
|
| 261 |
assert d_model % n_heads == 0, f"{d_model} must be divisble by {n_heads}"
|
| 262 |
self.d_head = d_model // n_heads
|
|
|
|
|
|
|
|
|
|
| 263 |
d_head = self.d_head
|
| 264 |
self.W_Q = nn.Parameter(t.empty((n_heads, d_model, d_head)))
|
| 265 |
self.W_K = nn.Parameter(t.empty((n_heads, d_model, d_head)))
|
|
@@ -289,7 +241,6 @@ class AttentionEinOps(nn.Module):
|
|
| 289 |
offset: int = 0
|
| 290 |
) -> Float[Tensor, "batch posq d_model"]:
|
| 291 |
assert (k_cache is None and v_cache is None) or (k_cache is not None and v_cache is not None), "k_cache and v_cache go together."
|
| 292 |
-
d_head = self.d_head
|
| 293 |
if k_cache is not None and v_cache is not None:
|
| 294 |
q = einops.einsum(x_q, self.W_Q, 'b s d, n d h -> b s n h') + self.b_Q
|
| 295 |
k_new = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h') + self.b_K
|
|
@@ -297,177 +248,47 @@ class AttentionEinOps(nn.Module):
|
|
| 297 |
|
| 298 |
k = t.cat([k_cache, k_new], dim=1)
|
| 299 |
v = t.cat([v_cache, v_new], dim=1)
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
if self.rope is not None:
|
| 302 |
-
q = self.rope(q, offset=k_cache.shape[1])
|
| 303 |
k = self.rope(k, offset=0)
|
| 304 |
-
|
| 305 |
-
|
|
|
|
|
|
|
| 306 |
mask = None
|
| 307 |
else:
|
| 308 |
q = einops.einsum(x_q, self.W_Q, 'b s d, n d h -> b s n h') + self.b_Q
|
| 309 |
k = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h') + self.b_K
|
| 310 |
v = einops.einsum(x_kv, self.W_V, 'b s d, n d h -> b s n h') + self.b_V
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
if self.rope is not None:
|
| 312 |
q = self.rope(q)
|
| 313 |
k = self.rope(k)
|
| 314 |
-
|
| 315 |
-
|
|
|
|
|
|
|
| 316 |
k_new = k
|
| 317 |
v_new = v
|
| 318 |
|
| 319 |
attention = einops.einsum(q, k, 'b sq n h, b sk n h -> b n sq sk')
|
| 320 |
if mask is not None and k_cache is not None:
|
| 321 |
-
attention = t.where(mask[k_cache.shape[1]:k_cache.shape[1]+q.shape[1], :k.shape[1]], self.IGNORE
|
| 322 |
elif mask is not None:
|
| 323 |
if attention.shape[-1] != mask.shape[-1] or attention.shape[-2] != mask.shape[-2]:
|
| 324 |
-
#print(f"Warning: attention shape {attention.shape} does not match mask shape {mask.shape}")
|
| 325 |
mask = mask[:attention.shape[-1], :attention.shape[-2]]
|
| 326 |
-
attention = t.where(mask, self.IGNORE
|
| 327 |
probas = attention.softmax(dim=3)
|
| 328 |
-
#plt.imshow(probas[0, 0].cpu().numpy())
|
| 329 |
-
#plt.show()
|
| 330 |
z = einops.einsum(probas, v, 'b n sq sk, b sk n h -> b sq n h')
|
| 331 |
out = einops.einsum(z, self.W_O, 'b s n h, n h d -> b s n d')
|
| 332 |
out = out.sum(dim=2) + self.b_O
|
| 333 |
return out, k_new, v_new
|
| 334 |
|
| 335 |
|
| 336 |
-
class Attention(nn.Module):
|
| 337 |
-
IGNORE: Float[Tensor, ""]
|
| 338 |
-
|
| 339 |
-
def __init__(self, d_model, n_heads, rope=None, use_flex_attention=False):
|
| 340 |
-
raise NotImplementedError("Attention is not implemented yet")
|
| 341 |
-
super().__init__()
|
| 342 |
-
assert d_model % n_heads == 0, f"{d_model} must be divisble by {n_heads}"
|
| 343 |
-
self.d_head = d_model // n_heads
|
| 344 |
-
d_head = self.d_head
|
| 345 |
-
self.W_Q = nn.Parameter(t.empty((n_heads, d_model, d_head)))
|
| 346 |
-
self.W_K = nn.Parameter(t.empty((n_heads, d_model, d_head)))
|
| 347 |
-
self.W_V = nn.Parameter(t.empty((n_heads, d_model, d_head)))
|
| 348 |
-
self.W_O = nn.Parameter(t.empty((n_heads, d_head, d_model)))
|
| 349 |
-
#self.b_Q = nn.Parameter(t.zeros((n_heads, d_head)))
|
| 350 |
-
#self.b_K = nn.Parameter(t.zeros((n_heads, d_head)))
|
| 351 |
-
#self.b_V = nn.Parameter(t.zeros((n_heads, d_head)))
|
| 352 |
-
#self.b_O = nn.Parameter(t.zeros((d_model)))
|
| 353 |
-
nn.init.normal_(self.W_Q, 1/d_model**0.5)
|
| 354 |
-
nn.init.normal_(self.W_K, 1/d_model**0.5)
|
| 355 |
-
nn.init.normal_(self.W_V, 1/d_model**0.5)
|
| 356 |
-
nn.init.normal_(self.W_O, 1/d_head**0.5)
|
| 357 |
-
self.register_buffer("IGNORE", t.tensor(float('-inf'), dtype=t.float32))
|
| 358 |
-
self.rope = rope
|
| 359 |
-
self.use_flex_attention = use_flex_attention
|
| 360 |
-
self.ln1 = nn.LayerNorm(d_head)
|
| 361 |
-
self.ln2 = nn.LayerNorm(d_head)
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
def forward(
|
| 365 |
-
self,
|
| 366 |
-
x_q: Float[Tensor, "batch posq d_model"],
|
| 367 |
-
x_kv: Float[Tensor, "batch posk d_model"],
|
| 368 |
-
mask: Bool[Tensor, "posq posk"] = None, # the 1s are removed
|
| 369 |
-
k_cache: Optional[Float[Tensor, "batch posk n_head d_head"]] = None,
|
| 370 |
-
v_cache: Optional[Float[Tensor, "batch posk n_head d_head"]] = None,
|
| 371 |
-
) -> Float[Tensor, "batch posq d_model"]:
|
| 372 |
-
assert (k_cache is None and v_cache is None) or (k_cache is not None and v_cache is not None), "k_cache and v_cache go together."
|
| 373 |
-
d_head = self.d_head
|
| 374 |
-
if k_cache is not None and v_cache is not None:
|
| 375 |
-
raise NotImplementedError("kv cache not implemented yet")
|
| 376 |
-
q = einops.einsum(x, self.W_Q, 'b s d, n d h -> b s n h')
|
| 377 |
-
k_new = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h')
|
| 378 |
-
v_new = einops.einsum(x_kv, self.W_V, 'b s d, n d h -> b s n h')
|
| 379 |
-
k = t.cat([k_cache, k_new], dim=1)
|
| 380 |
-
v = t.cat([v_cache, v_new], dim=1)
|
| 381 |
-
else:
|
| 382 |
-
q = einops.einsum(x_q, self.W_Q, 'b s d, n d h -> b s n h')
|
| 383 |
-
k = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h')
|
| 384 |
-
v = einops.einsum(x_kv, self.W_V, 'b s d, n d h -> b s n h')
|
| 385 |
-
|
| 386 |
-
q = self.ln1(q)
|
| 387 |
-
k = self.ln2(k)
|
| 388 |
-
if self.rope is not None:
|
| 389 |
-
q = self.rope(q)
|
| 390 |
-
k = self.rope(k)
|
| 391 |
-
|
| 392 |
-
# Convert to (batch, num_heads, seq_len, head_dim) format for flex_attention
|
| 393 |
-
q_perm = q.permute(0, 2, 1, 3) # (batch, n_heads, posq, d_head)
|
| 394 |
-
k_perm = k.permute(0, 2, 1, 3) # (batch, n_heads, posk, d_head)
|
| 395 |
-
v_perm = v.permute(0, 2, 1, 3) # (batch, n_heads, posk, d_head)
|
| 396 |
-
|
| 397 |
-
# Ensure tensors are contiguous to avoid flex_attention indexing bugs
|
| 398 |
-
q_perm = q_perm.contiguous()
|
| 399 |
-
k_perm = k_perm.contiguous()
|
| 400 |
-
v_perm = v_perm.contiguous()
|
| 401 |
-
|
| 402 |
-
if self.use_flex_attention:
|
| 403 |
-
# Handle mask using score_mod if needed
|
| 404 |
-
if mask is not None:
|
| 405 |
-
# Store mask and IGNORE for use in score_mod closure
|
| 406 |
-
mask_tensor = mask # (posq, posk)
|
| 407 |
-
ignore_val = self.IGNORE
|
| 408 |
-
def score_mod(score, b, h, q_idx, kv_idx):
|
| 409 |
-
# score_mod operates on individual scalar scores
|
| 410 |
-
# Apply mask: where mask is True, set to -inf
|
| 411 |
-
# Use torch ops that work in compiled context
|
| 412 |
-
mask_val = mask_tensor[q_idx, kv_idx]
|
| 413 |
-
return t.where(mask_val, ignore_val, score)
|
| 414 |
-
z = flex_attention(q_perm, k_perm, v_perm, score_mod=score_mod)
|
| 415 |
-
else:
|
| 416 |
-
z = flex_attention(q_perm, k_perm, v_perm)
|
| 417 |
-
else:
|
| 418 |
-
condi = mask is None and not self.dtype == t.float32
|
| 419 |
-
with t.backends.cuda.sdp_kernel(
|
| 420 |
-
enable_flash=condi,
|
| 421 |
-
enable_math=not condi,
|
| 422 |
-
enable_mem_efficient=not condi
|
| 423 |
-
):
|
| 424 |
-
z = F.scaled_dot_product_attention(
|
| 425 |
-
q_perm, k_perm, v_perm,
|
| 426 |
-
attn_mask = mask.logical_not() if mask is not None else None,
|
| 427 |
-
dropout_p = 0.0,
|
| 428 |
-
is_causal = False,
|
| 429 |
-
scale = 1.0
|
| 430 |
-
)
|
| 431 |
-
z = z.permute(0, 2, 1, 3) # Back to (batch, posq, n_heads, d_head)
|
| 432 |
-
out = einops.einsum(z, self.W_O, 'b s n h, n h d -> b s n d')
|
| 433 |
-
out = out.sum(dim=2)
|
| 434 |
-
#print(f"out {out.shape}, attention {probas.shape}, q {q.shape}, k {k.shape}, v {v.shape}")
|
| 435 |
-
return out, z, None
|
| 436 |
-
|
| 437 |
-
@property
|
| 438 |
-
def dtype(self):
|
| 439 |
-
return self.parameters().__next__().dtype
|
| 440 |
-
|
| 441 |
-
@property
|
| 442 |
-
def device(self):
|
| 443 |
-
return self.parameters().__next__().device
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
if __name__ == "__main__":
|
| 447 |
-
from .pe import RoPE
|
| 448 |
-
import inspect
|
| 449 |
-
rope = RoPE(256//8, 10000)
|
| 450 |
-
dtype = t.float32
|
| 451 |
-
rope = rope.to(dtype)
|
| 452 |
-
attn_slow = AttentionSlow(d_model=256, n_heads=8, rope=rope)
|
| 453 |
-
attn = Attention(d_model=256, n_heads=8, rope=rope)
|
| 454 |
-
attn.load_state_dict(attn_slow.state_dict(), strict=False)
|
| 455 |
-
attn.to(dtype)
|
| 456 |
-
attn_slow.to(dtype)
|
| 457 |
-
x = t.randn(1, 1000, 256, dtype=dtype)*10
|
| 458 |
-
xkv = t.randn(1, 1000, 256, dtype=dtype)*10
|
| 459 |
-
mask = t.randint(0, 2, (1000, 1000), dtype=t.bool)
|
| 460 |
-
y, z, _ = attn(x, xkv, mask=mask)
|
| 461 |
-
y_slow, z_slow, _ = attn_slow(x, xkv, mask=mask)
|
| 462 |
-
#assert t.allclose(z, z_slow, atol=1e-5), f"Attention and AttentionSlow should be the same: {(z - z_slow).abs().max()}"
|
| 463 |
-
#assert t.allclose(y, y_slow, atol=1e-5), f"Attention and AttentionSlow should be the same: {(y - y_slow).abs().max()}"
|
| 464 |
-
print("Attention and AttentionSlow are the same")
|
| 465 |
-
|
| 466 |
-
loss = t.nn.functional.mse_loss(y, y_slow)
|
| 467 |
-
loss.backward()
|
| 468 |
-
print("-"*100)
|
| 469 |
-
for n, p in attn.named_parameters():
|
| 470 |
-
print(n, p.grad.shape, p.grad.max(), p.grad.min())
|
| 471 |
-
print("-"*100)
|
| 472 |
-
for n, p in attn_slow.named_parameters():
|
| 473 |
-
print(n, p.grad.shape, p.grad.max(), p.grad.min())
|
|
|
|
| 6 |
from torch import Tensor
|
| 7 |
from typing import Optional
|
| 8 |
from torch.nn.attention.flex_attention import flex_attention
|
| 9 |
+
from matplotlib import pyplot as plt
|
| 10 |
|
| 11 |
+
from pdb import set_trace
|
| 12 |
|
| 13 |
class KVCache(nn.Module):
|
| 14 |
"""
|
|
|
|
| 21 |
Call `extend(layer_idx, k, v)` once per layer for the *same* frame.
|
| 22 |
Call `update_global_location(n_frames)` once after all layers to commit the frame(s).
|
| 23 |
"""
|
| 24 |
+
def __init__(self, batch_size, n_layers, n_heads, d_head, toks_per_frame, n_window, *, dtype=None, device=None):
|
| 25 |
super().__init__()
|
| 26 |
self.batch_size = batch_size
|
| 27 |
self.n_layers = n_layers
|
|
|
|
| 29 |
self.d_head = d_head
|
| 30 |
self.toks_per_frame = toks_per_frame
|
| 31 |
self.n_window = n_window
|
| 32 |
+
self.size = toks_per_frame * (n_window-1) #toks_per_frame # (toks_per_frame * n_window)
|
| 33 |
|
| 34 |
# Pointers / counters
|
|
|
|
| 35 |
self.global_loc = 0 # total tokens ever committed
|
| 36 |
self.local_loc = 0 # valid tokens in buffer (<= size)
|
| 37 |
self._write_ptr = 0 # ring-buffer write pointer (index of next commit position)
|
|
|
|
| 41 |
self.register_buffer('keys', t.zeros(n_layers, batch_size, self.size, n_heads, d_head, dtype=dtype, device=device))
|
| 42 |
self.register_buffer('values', t.zeros(n_layers, batch_size, self.size, n_heads, d_head, dtype=dtype, device=device))
|
| 43 |
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
def get(self):
|
|
|
|
| 46 |
"""Return (K, V) for given layer in chronological order: shape (B, L, H, D) where L = local_loc."""
|
|
|
|
| 47 |
if self.local_loc == 0:
|
| 48 |
# return empty views
|
| 49 |
+
empty = self.keys[:, :, :0]
|
| 50 |
return empty, empty
|
| 51 |
|
| 52 |
start = (self._write_ptr - self.local_loc) % self.size
|
| 53 |
if start + self.local_loc <= self.size:
|
| 54 |
# contiguous slice
|
| 55 |
+
k = self.keys[:, :, start:start + self.local_loc]
|
| 56 |
+
v = self.values[:, :, start:start + self.local_loc]
|
| 57 |
else:
|
| 58 |
# wrap: concatenate two slices to maintain chronological order
|
| 59 |
first = self.size - start
|
| 60 |
k = t.cat([
|
| 61 |
+
self.keys[:, :, start:self.size],
|
| 62 |
+
self.keys[:, :, 0:(self.local_loc - first)]
|
| 63 |
+
], dim=2)
|
| 64 |
v = t.cat([
|
| 65 |
+
self.values[:, :, start:self.size],
|
| 66 |
+
self.values[:, :, 0:(self.local_loc - first)]
|
| 67 |
+
], dim=2)
|
| 68 |
return k, v
|
| 69 |
|
| 70 |
@t.no_grad()
|
| 71 |
+
def extend(self, keys, values):
|
| 72 |
"""
|
| 73 |
Stage (but do not commit) tokens for the current frame for the given layer.
|
| 74 |
Call update_global_location(n_frames) to commit after all layers wrote.
|
| 75 |
"""
|
| 76 |
assert keys.shape == values.shape, f"keys and values shapes must match, got {keys.shape} vs {values.shape}"
|
|
|
|
| 77 |
|
| 78 |
+
L, B, T, H, D = keys.shape
|
| 79 |
+
assert L == self.n_layers, f"nlayers mismatch: expected {self.n_layers}, got {L}"
|
| 80 |
assert B == self.batch_size, f"batch mismatch: expected {self.batch_size}, got {B}"
|
| 81 |
assert H == self.n_heads and D == self.d_head, f"heads/d_head mismatch: expected {(self.n_heads, self.d_head)}, got {(H, D)}"
|
| 82 |
assert T > 0 and T <= self.size, f"T must be in 1..{self.size}, got {T}"
|
|
|
|
|
|
|
| 83 |
|
|
|
|
| 84 |
if keys.dtype != self.keys.dtype or keys.device != self.keys.device:
|
| 85 |
keys = keys.to(dtype=self.keys.dtype, device=self.keys.device)
|
| 86 |
if values.dtype != self.values.dtype or values.device != self.values.device:
|
| 87 |
values = values.to(dtype=self.values.dtype, device=self.values.device)
|
| 88 |
|
|
|
|
| 89 |
i0 = self._write_ptr
|
| 90 |
i1 = (self._write_ptr + T) % self.size
|
| 91 |
if i0 < i1:
|
| 92 |
+
self.keys[:, :, i0:i1] = keys
|
| 93 |
+
self.values[:, :, i0:i1] = values
|
| 94 |
else:
|
| 95 |
+
# wrap
|
| 96 |
split = self.size - i0
|
| 97 |
+
self.keys[:, :, i0:self.size] = keys[:, :, :split]
|
| 98 |
+
self.values[:, :, i0:self.size] = values[:, :, :split]
|
| 99 |
+
self.keys[:, :, 0:i1] = keys[:, :, split:]
|
| 100 |
+
self.values[:, :, 0:i1] = values[:, :, split:]
|
| 101 |
|
| 102 |
+
self.global_loc += keys.shape[2]
|
| 103 |
+
self.local_loc = min(self.size, self.local_loc + keys.shape[2])
|
| 104 |
+
self._write_ptr = (self._write_ptr + keys.shape[2]) % self.size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
@t.no_grad()
|
| 107 |
def reset(self, zero_memory: bool = True):
|
|
|
|
| 113 |
self.keys.zero_()
|
| 114 |
self.values.zero_()
|
| 115 |
|
|
|
|
| 116 |
@property
|
| 117 |
def local_location(self):
|
| 118 |
return self.local_loc
|
|
|
|
| 129 |
def dtype(self):
|
| 130 |
return self.keys.dtype
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
|
| 134 |
+
class KVCacheNaive(nn.Module):
|
| 135 |
+
def __init__(self, batch_size, n_layers, n_heads, d_head, toks_per_frame, n_window, dtype=t.float32, device='cuda'):
|
| 136 |
"""
|
| 137 |
This is a rolling KVCache
|
| 138 |
"""
|
|
|
|
| 142 |
self.d_head = d_head
|
| 143 |
self.toks_per_frame = toks_per_frame
|
| 144 |
self.n_window = n_window
|
| 145 |
+
self.size = toks_per_frame * (n_window - 1)
|
| 146 |
self.n_layers = n_layers
|
|
|
|
| 147 |
self.global_loc = 0
|
| 148 |
self.local_loc = 0
|
| 149 |
+
|
| 150 |
+
self.register_buffer('keys', t.zeros(n_layers, batch_size, self.size, n_heads, d_head, dtype=dtype, device=device))
|
| 151 |
+
self.register_buffer('values', t.zeros(n_layers, batch_size, self.size, n_heads, d_head, dtype=dtype, device=device))
|
| 152 |
|
| 153 |
+
def get(self):
|
| 154 |
+
return self.keys[:, :, :self.local_loc], self.values[:, :, :self.local_loc]
|
|
|
|
| 155 |
|
| 156 |
+
def extend(self, keys, values):
|
| 157 |
+
"""
|
| 158 |
+
this should only be called on the last denoising step respectively.
|
| 159 |
+
"""
|
| 160 |
assert keys.shape == values.shape, f"keys and values shapes must match {self.keys.shape} != {self.values.shape}"
|
|
|
|
| 161 |
assert self.local_loc <= self.size, f"the cache size should be between 0 and {self.size}"
|
| 162 |
local_loc = self.local_loc
|
| 163 |
if local_loc == self.size:
|
| 164 |
# move to the left
|
| 165 |
+
local_loc -= keys.shape[2]
|
| 166 |
+
assert local_loc >= 0, f"the cache update {keys.shape[2]} was larger than the cache {self.size}, that's not supported for now."
|
| 167 |
assert local_loc % self.toks_per_frame == 0, f"the number of elements in the cache {local_loc} must be a multiple of the number of tokens per frame {self.toks_per_frame}"
|
| 168 |
+
self.keys[:, :, :local_loc] = self.keys[:, :, self.toks_per_frame:local_loc+self.toks_per_frame].clone()
|
| 169 |
+
self.values[:, :, :local_loc] = self.values[:, :, self.toks_per_frame:local_loc+self.toks_per_frame].clone()
|
| 170 |
+
|
| 171 |
+
assert local_loc + keys.shape[2] <= self.size, f"{local_loc + keys.shape[2]} out of bounds {self.size}"
|
| 172 |
+
self.keys[:, :, local_loc:local_loc + keys.shape[2]] = keys
|
| 173 |
+
self.values[:, :, local_loc:local_loc + keys.shape[2]] = values
|
|
|
|
|
|
|
| 174 |
self.curr_layer = (self.curr_layer + 1) % self.n_layers
|
| 175 |
|
| 176 |
+
self.global_loc += keys.shape[2]
|
|
|
|
| 177 |
if self.local_loc < self.size:
|
| 178 |
+
self.local_loc += keys.shape[2]
|
| 179 |
assert self.local_loc <= self.size, f"the local loc {self.local_loc} should never be bigger than {self.size}, something went wrong."
|
| 180 |
|
| 181 |
def reset(self):
|
|
|
|
| 205 |
class AttentionEinOps(nn.Module):
|
| 206 |
IGNORE: Float[Tensor, ""]
|
| 207 |
|
| 208 |
+
def __init__(self, d_model, n_heads, rope=None, ln_first=False):
|
| 209 |
super().__init__()
|
| 210 |
assert d_model % n_heads == 0, f"{d_model} must be divisble by {n_heads}"
|
| 211 |
self.d_head = d_model // n_heads
|
| 212 |
+
self.d_model = d_model
|
| 213 |
+
self.n_heads = n_heads
|
| 214 |
+
self.ln_first = ln_first
|
| 215 |
d_head = self.d_head
|
| 216 |
self.W_Q = nn.Parameter(t.empty((n_heads, d_model, d_head)))
|
| 217 |
self.W_K = nn.Parameter(t.empty((n_heads, d_model, d_head)))
|
|
|
|
| 241 |
offset: int = 0
|
| 242 |
) -> Float[Tensor, "batch posq d_model"]:
|
| 243 |
assert (k_cache is None and v_cache is None) or (k_cache is not None and v_cache is not None), "k_cache and v_cache go together."
|
|
|
|
| 244 |
if k_cache is not None and v_cache is not None:
|
| 245 |
q = einops.einsum(x_q, self.W_Q, 'b s d, n d h -> b s n h') + self.b_Q
|
| 246 |
k_new = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h') + self.b_K
|
|
|
|
| 248 |
|
| 249 |
k = t.cat([k_cache, k_new], dim=1)
|
| 250 |
v = t.cat([v_cache, v_new], dim=1)
|
| 251 |
+
if self.ln_first:
|
| 252 |
+
q = self.ln1(q)
|
| 253 |
+
k = self.ln2(k)
|
| 254 |
|
| 255 |
if self.rope is not None:
|
| 256 |
+
q = self.rope(q, offset=k_cache.shape[1])
|
| 257 |
k = self.rope(k, offset=0)
|
| 258 |
+
|
| 259 |
+
if not self.ln_first:
|
| 260 |
+
q = self.ln1(q) # ppl usually do this before rope but our best checkpoint has it after rope, so this is for bwd compatibility; but in quick test on singleframe this did not make a big difference
|
| 261 |
+
k = self.ln2(k)
|
| 262 |
mask = None
|
| 263 |
else:
|
| 264 |
q = einops.einsum(x_q, self.W_Q, 'b s d, n d h -> b s n h') + self.b_Q
|
| 265 |
k = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h') + self.b_K
|
| 266 |
v = einops.einsum(x_kv, self.W_V, 'b s d, n d h -> b s n h') + self.b_V
|
| 267 |
+
if self.ln_first:
|
| 268 |
+
q = self.ln1(q)
|
| 269 |
+
k = self.ln2(k)
|
| 270 |
+
|
| 271 |
if self.rope is not None:
|
| 272 |
q = self.rope(q)
|
| 273 |
k = self.rope(k)
|
| 274 |
+
|
| 275 |
+
if not self.ln_first:
|
| 276 |
+
q = self.ln1(q)
|
| 277 |
+
k = self.ln2(k)
|
| 278 |
k_new = k
|
| 279 |
v_new = v
|
| 280 |
|
| 281 |
attention = einops.einsum(q, k, 'b sq n h, b sk n h -> b n sq sk')
|
| 282 |
if mask is not None and k_cache is not None:
|
| 283 |
+
attention = t.where(mask[k_cache.shape[1]:k_cache.shape[1]+q.shape[1], :k.shape[1]], attention, self.IGNORE)
|
| 284 |
elif mask is not None:
|
| 285 |
if attention.shape[-1] != mask.shape[-1] or attention.shape[-2] != mask.shape[-2]:
|
|
|
|
| 286 |
mask = mask[:attention.shape[-1], :attention.shape[-2]]
|
| 287 |
+
attention = t.where(mask, attention, self.IGNORE)
|
| 288 |
probas = attention.softmax(dim=3)
|
|
|
|
|
|
|
| 289 |
z = einops.einsum(probas, v, 'b n sq sk, b sk n h -> b sq n h')
|
| 290 |
out = einops.einsum(z, self.W_O, 'b s n h, n h d -> b s n d')
|
| 291 |
out = out.sum(dim=2) + self.b_O
|
| 292 |
return out, k_new, v_new
|
| 293 |
|
| 294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|