chrisxx commited on
Commit
82ee7b0
·
verified ·
1 Parent(s): 8ef5ca9

Update src/nn/attn.py

Browse files
Files changed (1) hide show
  1. src/nn/attn.py +71 -250
src/nn/attn.py CHANGED
@@ -6,7 +6,9 @@ from jaxtyping import Float, Bool
6
  from torch import Tensor
7
  from typing import Optional
8
  from torch.nn.attention.flex_attention import flex_attention
 
9
 
 
10
 
11
  class KVCache(nn.Module):
12
  """
@@ -19,7 +21,7 @@ class KVCache(nn.Module):
19
  Call `extend(layer_idx, k, v)` once per layer for the *same* frame.
20
  Call `update_global_location(n_frames)` once after all layers to commit the frame(s).
21
  """
22
- def __init__(self, batch_size, n_layers, n_heads, d_head, toks_per_frame, n_window, *, dtype=None, device=None, enforce_layer_order=True):
23
  super().__init__()
24
  self.batch_size = batch_size
25
  self.n_layers = n_layers
@@ -27,10 +29,9 @@ class KVCache(nn.Module):
27
  self.d_head = d_head
28
  self.toks_per_frame = toks_per_frame
29
  self.n_window = n_window
30
- self.size = (toks_per_frame * n_window) #toks_per_frame # (toks_per_frame * n_window)
31
 
32
  # Pointers / counters
33
- self.curr_layer = 0 # which layer are we writing for this frame
34
  self.global_loc = 0 # total tokens ever committed
35
  self.local_loc = 0 # valid tokens in buffer (<= size)
36
  self._write_ptr = 0 # ring-buffer write pointer (index of next commit position)
@@ -40,93 +41,67 @@ class KVCache(nn.Module):
40
  self.register_buffer('keys', t.zeros(n_layers, batch_size, self.size, n_heads, d_head, dtype=dtype, device=device))
41
  self.register_buffer('values', t.zeros(n_layers, batch_size, self.size, n_heads, d_head, dtype=dtype, device=device))
42
 
43
- # Misc
44
- self.enforce_layer_order = enforce_layer_order
45
 
46
- # -------------- Public API --------------
47
- def get(self, layer_idx):
48
  """Return (K, V) for given layer in chronological order: shape (B, L, H, D) where L = local_loc."""
49
- self._check_layer(layer_idx)
50
  if self.local_loc == 0:
51
  # return empty views
52
- empty = self.keys[layer_idx, :, :0]
53
  return empty, empty
54
 
55
  start = (self._write_ptr - self.local_loc) % self.size
56
  if start + self.local_loc <= self.size:
57
  # contiguous slice
58
- k = self.keys[layer_idx, :, start:start + self.local_loc]
59
- v = self.values[layer_idx, :, start:start + self.local_loc]
60
  else:
61
  # wrap: concatenate two slices to maintain chronological order
62
  first = self.size - start
63
  k = t.cat([
64
- self.keys[layer_idx, :, start:self.size],
65
- self.keys[layer_idx, :, 0:(self.local_loc - first)]
66
- ], dim=1)
67
  v = t.cat([
68
- self.values[layer_idx, :, start:self.size],
69
- self.values[layer_idx, :, 0:(self.local_loc - first)]
70
- ], dim=1)
71
  return k, v
72
 
73
  @t.no_grad()
74
- def extend(self, layer_idx, keys, values):
75
  """
76
  Stage (but do not commit) tokens for the current frame for the given layer.
77
  Call update_global_location(n_frames) to commit after all layers wrote.
78
  """
79
  assert keys.shape == values.shape, f"keys and values shapes must match, got {keys.shape} vs {values.shape}"
80
- self._check_layer(layer_idx)
81
 
82
- # Expected shape: (B, T, H, D)
83
- B, T, H, D = keys.shape
84
  assert B == self.batch_size, f"batch mismatch: expected {self.batch_size}, got {B}"
85
  assert H == self.n_heads and D == self.d_head, f"heads/d_head mismatch: expected {(self.n_heads, self.d_head)}, got {(H, D)}"
86
  assert T > 0 and T <= self.size, f"T must be in 1..{self.size}, got {T}"
87
- # Optional: if you only ever append whole frames:
88
- # assert T == self.toks_per_frame, f"T must equal toks_per_frame ({self.toks_per_frame}), got {T}"
89
 
90
- # Cast to buffer dtype/device if needed
91
  if keys.dtype != self.keys.dtype or keys.device != self.keys.device:
92
  keys = keys.to(dtype=self.keys.dtype, device=self.keys.device)
93
  if values.dtype != self.values.dtype or values.device != self.values.device:
94
  values = values.to(dtype=self.values.dtype, device=self.values.device)
95
 
96
- # Write into the ring at the *current* write_ptr (uncommitted until update_global_location)
97
  i0 = self._write_ptr
98
  i1 = (self._write_ptr + T) % self.size
99
  if i0 < i1:
100
- self.keys[layer_idx, :, i0:i1] = keys
101
- self.values[layer_idx, :, i0:i1] = values
102
  else:
103
- # wraps: split write
104
  split = self.size - i0
105
- self.keys[layer_idx, :, i0:self.size] = keys[:, :split]
106
- self.values[layer_idx, :, i0:self.size] = values[:, :split]
107
- self.keys[layer_idx, :, 0:i1] = keys[:, split:]
108
- self.values[layer_idx, :, 0:i1] = values[:, split:]
109
 
110
- # Advance expected layer (but do *not* advance write_ptr/local_len here)
111
- self.curr_layer = (self.curr_layer + 1) % self.n_layers
112
-
113
- @t.no_grad()
114
- def update_global_location(self, n_frames):
115
- """
116
- Commit staged writes for n_frames (advances the ring write pointer once per frame).
117
- Keep calling extend(layer_idx, ...) for each layer before you call this.
118
- """
119
- assert n_frames >= 0, f"n_frames must be >= 0, got {n_frames}"
120
- tokens = n_frames * self.toks_per_frame
121
- if tokens == 0:
122
- return
123
- assert tokens <= self.size, f"Cannot commit {tokens} tokens (> buffer size {self.size})."
124
-
125
- self.global_loc += tokens
126
- # Update valid length (never exceeds capacity)
127
- self.local_loc = min(self.size, self.local_loc + tokens)
128
- # Advance write pointer
129
- self._write_ptr = (self._write_ptr + tokens) % self.size
130
 
131
  @t.no_grad()
132
  def reset(self, zero_memory: bool = True):
@@ -138,7 +113,6 @@ class KVCache(nn.Module):
138
  self.keys.zero_()
139
  self.values.zero_()
140
 
141
- # -------------- Convenience / Introspection --------------
142
  @property
143
  def local_location(self):
144
  return self.local_loc
@@ -155,33 +129,10 @@ class KVCache(nn.Module):
155
  def dtype(self):
156
  return self.keys.dtype
157
 
158
- def get_recent(self, layer_idx, last_T):
159
- """Return the most recent last_T tokens for a layer (chronological)."""
160
- self._check_layer(layer_idx, allow_any=True)
161
- last_T = min(last_T, self.local_loc)
162
- if last_T == 0:
163
- empty = self.keys[layer_idx, :, :0]
164
- return empty, empty
165
- start = (self._write_ptr - last_T) % self.size
166
- if start + last_T <= self.size:
167
- k = self.keys[layer_idx, :, start:start + last_T]
168
- v = self.values[layer_idx, :, start:start + last_T]
169
- else:
170
- first = self.size - start
171
- k = t.cat([self.keys[layer_idx, :, start:self.size], self.keys[layer_idx, :, 0:(last_T - first)]], dim=1)
172
- v = t.cat([self.values[layer_idx, :, start:self.size], self.values[layer_idx, :, 0:(last_T - first)]], dim=1)
173
- return k, v
174
-
175
- # -------------- Internal checks --------------
176
- def _check_layer(self, layer_idx, allow_any=False):
177
- assert 0 <= layer_idx < self.n_layers, f"layer_idx out of range: 0..{self.n_layers-1}, got {layer_idx}"
178
- if self.enforce_layer_order and not allow_any:
179
- assert layer_idx == (self.curr_layer % self.n_layers), \
180
- f"Layer order mismatch: expected {self.curr_layer % self.n_layers}, got {layer_idx}"
181
 
182
 
183
- class KVCacheMine(nn.Module): # this does not work because it destroys the cache of later timesteps when the earlier ones overflow and move to the left. --> fix as an exercise.
184
- def __init__(self, batch_size, n_layers, n_heads, d_head, toks_per_frame, n_window):
185
  """
186
  This is a rolling KVCache
187
  """
@@ -191,42 +142,40 @@ class KVCacheMine(nn.Module): # this does not work because it destroys the cache
191
  self.d_head = d_head
192
  self.toks_per_frame = toks_per_frame
193
  self.n_window = n_window
194
- self.size = toks_per_frame * n_window#5*n_window#(n_window + 1)
195
  self.n_layers = n_layers
196
- self.curr_layer = 0
197
  self.global_loc = 0
198
  self.local_loc = 0
199
- self.register_buffer('keys', t.zeros(n_layers, batch_size, self.size, n_heads, d_head))
200
- self.register_buffer('values', t.zeros(n_layers, batch_size, self.size, n_heads, d_head))
 
201
 
202
- def get(self, layer_idx):
203
- assert layer_idx == self.curr_layer, f"layer idx should be the same as our internal counter but we got {layer_idx} and internal is {self.curr_layer}."
204
- return self.keys[layer_idx, :, :self.local_loc], self.values[layer_idx, :, :self.local_loc]
205
 
206
- def extend(self, layer_idx, keys, values):
 
 
 
207
  assert keys.shape == values.shape, f"keys and values shapes must match {self.keys.shape} != {self.values.shape}"
208
- assert layer_idx == self.curr_layer, f"layer idx should be the same as our internal counter but we got {layer_idx} and internal is {self.curr_layer}."
209
  assert self.local_loc <= self.size, f"the cache size should be between 0 and {self.size}"
210
  local_loc = self.local_loc
211
  if local_loc == self.size:
212
  # move to the left
213
- local_loc -= keys.shape[1]
214
- assert local_loc >= 0, f"the cache update {keys.shape[1]} was larger than the cache {self.size}, that's not supported for now."
215
  assert local_loc % self.toks_per_frame == 0, f"the number of elements in the cache {local_loc} must be a multiple of the number of tokens per frame {self.toks_per_frame}"
216
- self.keys[layer_idx, :, :local_loc] = self.keys[layer_idx, :, self.toks_per_frame:local_loc+self.toks_per_frame].clone()
217
- self.values[layer_idx, :, :local_loc] = self.values[layer_idx, :, self.toks_per_frame:local_loc+self.toks_per_frame].clone()
218
- #self.keys[layer_idx, :, self.toks_per_frame:local_loc+self.toks_per_frame] = self.keys[layer_idx, :, -local_loc:].clone()
219
- #self.values[layer_idx, :, self.toks_per_frame:local_loc+self.toks_per_frame] = self.values[layer_idx, :, -local_loc:].clone()
220
-
221
- assert local_loc + keys.shape[1] <= self.size, f"{local_loc + keys.shape[1]} out of bounds {self.size}"
222
- self.keys[layer_idx, :, local_loc:local_loc + keys.shape[1]] = keys
223
- self.values[layer_idx, :, local_loc:local_loc + keys.shape[1]] = values
224
  self.curr_layer = (self.curr_layer + 1) % self.n_layers
225
 
226
- def update_global_location(self, n_frames):
227
- self.global_loc += n_frames * self.toks_per_frame
228
  if self.local_loc < self.size:
229
- self.local_loc += n_frames * self.toks_per_frame
230
  assert self.local_loc <= self.size, f"the local loc {self.local_loc} should never be bigger than {self.size}, something went wrong."
231
 
232
  def reset(self):
@@ -256,10 +205,13 @@ class KVCacheMine(nn.Module): # this does not work because it destroys the cache
256
  class AttentionEinOps(nn.Module):
257
  IGNORE: Float[Tensor, ""]
258
 
259
- def __init__(self, d_model, n_heads, rope=None):
260
  super().__init__()
261
  assert d_model % n_heads == 0, f"{d_model} must be divisble by {n_heads}"
262
  self.d_head = d_model // n_heads
 
 
 
263
  d_head = self.d_head
264
  self.W_Q = nn.Parameter(t.empty((n_heads, d_model, d_head)))
265
  self.W_K = nn.Parameter(t.empty((n_heads, d_model, d_head)))
@@ -289,7 +241,6 @@ class AttentionEinOps(nn.Module):
289
  offset: int = 0
290
  ) -> Float[Tensor, "batch posq d_model"]:
291
  assert (k_cache is None and v_cache is None) or (k_cache is not None and v_cache is not None), "k_cache and v_cache go together."
292
- d_head = self.d_head
293
  if k_cache is not None and v_cache is not None:
294
  q = einops.einsum(x_q, self.W_Q, 'b s d, n d h -> b s n h') + self.b_Q
295
  k_new = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h') + self.b_K
@@ -297,177 +248,47 @@ class AttentionEinOps(nn.Module):
297
 
298
  k = t.cat([k_cache, k_new], dim=1)
299
  v = t.cat([v_cache, v_new], dim=1)
 
 
 
300
 
301
  if self.rope is not None:
302
- q = self.rope(q, offset=k_cache.shape[1])
303
  k = self.rope(k, offset=0)
304
- q = self.ln1(q) # this should be before rope
305
- k = self.ln2(k)
 
 
306
  mask = None
307
  else:
308
  q = einops.einsum(x_q, self.W_Q, 'b s d, n d h -> b s n h') + self.b_Q
309
  k = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h') + self.b_K
310
  v = einops.einsum(x_kv, self.W_V, 'b s d, n d h -> b s n h') + self.b_V
 
 
 
 
311
  if self.rope is not None:
312
  q = self.rope(q)
313
  k = self.rope(k)
314
- q = self.ln1(q)
315
- k = self.ln2(k) # this leanrs much faster using layernorm here
 
 
316
  k_new = k
317
  v_new = v
318
 
319
  attention = einops.einsum(q, k, 'b sq n h, b sk n h -> b n sq sk')
320
  if mask is not None and k_cache is not None:
321
- attention = t.where(mask[k_cache.shape[1]:k_cache.shape[1]+q.shape[1], :k.shape[1]], self.IGNORE, attention)
322
  elif mask is not None:
323
  if attention.shape[-1] != mask.shape[-1] or attention.shape[-2] != mask.shape[-2]:
324
- #print(f"Warning: attention shape {attention.shape} does not match mask shape {mask.shape}")
325
  mask = mask[:attention.shape[-1], :attention.shape[-2]]
326
- attention = t.where(mask, self.IGNORE, attention)
327
  probas = attention.softmax(dim=3)
328
- #plt.imshow(probas[0, 0].cpu().numpy())
329
- #plt.show()
330
  z = einops.einsum(probas, v, 'b n sq sk, b sk n h -> b sq n h')
331
  out = einops.einsum(z, self.W_O, 'b s n h, n h d -> b s n d')
332
  out = out.sum(dim=2) + self.b_O
333
  return out, k_new, v_new
334
 
335
 
336
- class Attention(nn.Module):
337
- IGNORE: Float[Tensor, ""]
338
-
339
- def __init__(self, d_model, n_heads, rope=None, use_flex_attention=False):
340
- raise NotImplementedError("Attention is not implemented yet")
341
- super().__init__()
342
- assert d_model % n_heads == 0, f"{d_model} must be divisble by {n_heads}"
343
- self.d_head = d_model // n_heads
344
- d_head = self.d_head
345
- self.W_Q = nn.Parameter(t.empty((n_heads, d_model, d_head)))
346
- self.W_K = nn.Parameter(t.empty((n_heads, d_model, d_head)))
347
- self.W_V = nn.Parameter(t.empty((n_heads, d_model, d_head)))
348
- self.W_O = nn.Parameter(t.empty((n_heads, d_head, d_model)))
349
- #self.b_Q = nn.Parameter(t.zeros((n_heads, d_head)))
350
- #self.b_K = nn.Parameter(t.zeros((n_heads, d_head)))
351
- #self.b_V = nn.Parameter(t.zeros((n_heads, d_head)))
352
- #self.b_O = nn.Parameter(t.zeros((d_model)))
353
- nn.init.normal_(self.W_Q, 1/d_model**0.5)
354
- nn.init.normal_(self.W_K, 1/d_model**0.5)
355
- nn.init.normal_(self.W_V, 1/d_model**0.5)
356
- nn.init.normal_(self.W_O, 1/d_head**0.5)
357
- self.register_buffer("IGNORE", t.tensor(float('-inf'), dtype=t.float32))
358
- self.rope = rope
359
- self.use_flex_attention = use_flex_attention
360
- self.ln1 = nn.LayerNorm(d_head)
361
- self.ln2 = nn.LayerNorm(d_head)
362
-
363
-
364
- def forward(
365
- self,
366
- x_q: Float[Tensor, "batch posq d_model"],
367
- x_kv: Float[Tensor, "batch posk d_model"],
368
- mask: Bool[Tensor, "posq posk"] = None, # the 1s are removed
369
- k_cache: Optional[Float[Tensor, "batch posk n_head d_head"]] = None,
370
- v_cache: Optional[Float[Tensor, "batch posk n_head d_head"]] = None,
371
- ) -> Float[Tensor, "batch posq d_model"]:
372
- assert (k_cache is None and v_cache is None) or (k_cache is not None and v_cache is not None), "k_cache and v_cache go together."
373
- d_head = self.d_head
374
- if k_cache is not None and v_cache is not None:
375
- raise NotImplementedError("kv cache not implemented yet")
376
- q = einops.einsum(x, self.W_Q, 'b s d, n d h -> b s n h')
377
- k_new = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h')
378
- v_new = einops.einsum(x_kv, self.W_V, 'b s d, n d h -> b s n h')
379
- k = t.cat([k_cache, k_new], dim=1)
380
- v = t.cat([v_cache, v_new], dim=1)
381
- else:
382
- q = einops.einsum(x_q, self.W_Q, 'b s d, n d h -> b s n h')
383
- k = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h')
384
- v = einops.einsum(x_kv, self.W_V, 'b s d, n d h -> b s n h')
385
-
386
- q = self.ln1(q)
387
- k = self.ln2(k)
388
- if self.rope is not None:
389
- q = self.rope(q)
390
- k = self.rope(k)
391
-
392
- # Convert to (batch, num_heads, seq_len, head_dim) format for flex_attention
393
- q_perm = q.permute(0, 2, 1, 3) # (batch, n_heads, posq, d_head)
394
- k_perm = k.permute(0, 2, 1, 3) # (batch, n_heads, posk, d_head)
395
- v_perm = v.permute(0, 2, 1, 3) # (batch, n_heads, posk, d_head)
396
-
397
- # Ensure tensors are contiguous to avoid flex_attention indexing bugs
398
- q_perm = q_perm.contiguous()
399
- k_perm = k_perm.contiguous()
400
- v_perm = v_perm.contiguous()
401
-
402
- if self.use_flex_attention:
403
- # Handle mask using score_mod if needed
404
- if mask is not None:
405
- # Store mask and IGNORE for use in score_mod closure
406
- mask_tensor = mask # (posq, posk)
407
- ignore_val = self.IGNORE
408
- def score_mod(score, b, h, q_idx, kv_idx):
409
- # score_mod operates on individual scalar scores
410
- # Apply mask: where mask is True, set to -inf
411
- # Use torch ops that work in compiled context
412
- mask_val = mask_tensor[q_idx, kv_idx]
413
- return t.where(mask_val, ignore_val, score)
414
- z = flex_attention(q_perm, k_perm, v_perm, score_mod=score_mod)
415
- else:
416
- z = flex_attention(q_perm, k_perm, v_perm)
417
- else:
418
- condi = mask is None and not self.dtype == t.float32
419
- with t.backends.cuda.sdp_kernel(
420
- enable_flash=condi,
421
- enable_math=not condi,
422
- enable_mem_efficient=not condi
423
- ):
424
- z = F.scaled_dot_product_attention(
425
- q_perm, k_perm, v_perm,
426
- attn_mask = mask.logical_not() if mask is not None else None,
427
- dropout_p = 0.0,
428
- is_causal = False,
429
- scale = 1.0
430
- )
431
- z = z.permute(0, 2, 1, 3) # Back to (batch, posq, n_heads, d_head)
432
- out = einops.einsum(z, self.W_O, 'b s n h, n h d -> b s n d')
433
- out = out.sum(dim=2)
434
- #print(f"out {out.shape}, attention {probas.shape}, q {q.shape}, k {k.shape}, v {v.shape}")
435
- return out, z, None
436
-
437
- @property
438
- def dtype(self):
439
- return self.parameters().__next__().dtype
440
-
441
- @property
442
- def device(self):
443
- return self.parameters().__next__().device
444
-
445
-
446
- if __name__ == "__main__":
447
- from .pe import RoPE
448
- import inspect
449
- rope = RoPE(256//8, 10000)
450
- dtype = t.float32
451
- rope = rope.to(dtype)
452
- attn_slow = AttentionSlow(d_model=256, n_heads=8, rope=rope)
453
- attn = Attention(d_model=256, n_heads=8, rope=rope)
454
- attn.load_state_dict(attn_slow.state_dict(), strict=False)
455
- attn.to(dtype)
456
- attn_slow.to(dtype)
457
- x = t.randn(1, 1000, 256, dtype=dtype)*10
458
- xkv = t.randn(1, 1000, 256, dtype=dtype)*10
459
- mask = t.randint(0, 2, (1000, 1000), dtype=t.bool)
460
- y, z, _ = attn(x, xkv, mask=mask)
461
- y_slow, z_slow, _ = attn_slow(x, xkv, mask=mask)
462
- #assert t.allclose(z, z_slow, atol=1e-5), f"Attention and AttentionSlow should be the same: {(z - z_slow).abs().max()}"
463
- #assert t.allclose(y, y_slow, atol=1e-5), f"Attention and AttentionSlow should be the same: {(y - y_slow).abs().max()}"
464
- print("Attention and AttentionSlow are the same")
465
-
466
- loss = t.nn.functional.mse_loss(y, y_slow)
467
- loss.backward()
468
- print("-"*100)
469
- for n, p in attn.named_parameters():
470
- print(n, p.grad.shape, p.grad.max(), p.grad.min())
471
- print("-"*100)
472
- for n, p in attn_slow.named_parameters():
473
- print(n, p.grad.shape, p.grad.max(), p.grad.min())
 
6
  from torch import Tensor
7
  from typing import Optional
8
  from torch.nn.attention.flex_attention import flex_attention
9
+ from matplotlib import pyplot as plt
10
 
11
+ from pdb import set_trace
12
 
13
  class KVCache(nn.Module):
14
  """
 
21
  Call `extend(layer_idx, k, v)` once per layer for the *same* frame.
22
  Call `update_global_location(n_frames)` once after all layers to commit the frame(s).
23
  """
24
+ def __init__(self, batch_size, n_layers, n_heads, d_head, toks_per_frame, n_window, *, dtype=None, device=None):
25
  super().__init__()
26
  self.batch_size = batch_size
27
  self.n_layers = n_layers
 
29
  self.d_head = d_head
30
  self.toks_per_frame = toks_per_frame
31
  self.n_window = n_window
32
+ self.size = toks_per_frame * (n_window-1) #toks_per_frame # (toks_per_frame * n_window)
33
 
34
  # Pointers / counters
 
35
  self.global_loc = 0 # total tokens ever committed
36
  self.local_loc = 0 # valid tokens in buffer (<= size)
37
  self._write_ptr = 0 # ring-buffer write pointer (index of next commit position)
 
41
  self.register_buffer('keys', t.zeros(n_layers, batch_size, self.size, n_heads, d_head, dtype=dtype, device=device))
42
  self.register_buffer('values', t.zeros(n_layers, batch_size, self.size, n_heads, d_head, dtype=dtype, device=device))
43
 
 
 
44
 
45
+ def get(self):
 
46
  """Return (K, V) for given layer in chronological order: shape (B, L, H, D) where L = local_loc."""
 
47
  if self.local_loc == 0:
48
  # return empty views
49
+ empty = self.keys[:, :, :0]
50
  return empty, empty
51
 
52
  start = (self._write_ptr - self.local_loc) % self.size
53
  if start + self.local_loc <= self.size:
54
  # contiguous slice
55
+ k = self.keys[:, :, start:start + self.local_loc]
56
+ v = self.values[:, :, start:start + self.local_loc]
57
  else:
58
  # wrap: concatenate two slices to maintain chronological order
59
  first = self.size - start
60
  k = t.cat([
61
+ self.keys[:, :, start:self.size],
62
+ self.keys[:, :, 0:(self.local_loc - first)]
63
+ ], dim=2)
64
  v = t.cat([
65
+ self.values[:, :, start:self.size],
66
+ self.values[:, :, 0:(self.local_loc - first)]
67
+ ], dim=2)
68
  return k, v
69
 
70
  @t.no_grad()
71
+ def extend(self, keys, values):
72
  """
73
  Stage (but do not commit) tokens for the current frame for the given layer.
74
  Call update_global_location(n_frames) to commit after all layers wrote.
75
  """
76
  assert keys.shape == values.shape, f"keys and values shapes must match, got {keys.shape} vs {values.shape}"
 
77
 
78
+ L, B, T, H, D = keys.shape
79
+ assert L == self.n_layers, f"nlayers mismatch: expected {self.n_layers}, got {L}"
80
  assert B == self.batch_size, f"batch mismatch: expected {self.batch_size}, got {B}"
81
  assert H == self.n_heads and D == self.d_head, f"heads/d_head mismatch: expected {(self.n_heads, self.d_head)}, got {(H, D)}"
82
  assert T > 0 and T <= self.size, f"T must be in 1..{self.size}, got {T}"
 
 
83
 
 
84
  if keys.dtype != self.keys.dtype or keys.device != self.keys.device:
85
  keys = keys.to(dtype=self.keys.dtype, device=self.keys.device)
86
  if values.dtype != self.values.dtype or values.device != self.values.device:
87
  values = values.to(dtype=self.values.dtype, device=self.values.device)
88
 
 
89
  i0 = self._write_ptr
90
  i1 = (self._write_ptr + T) % self.size
91
  if i0 < i1:
92
+ self.keys[:, :, i0:i1] = keys
93
+ self.values[:, :, i0:i1] = values
94
  else:
95
+ # wrap
96
  split = self.size - i0
97
+ self.keys[:, :, i0:self.size] = keys[:, :, :split]
98
+ self.values[:, :, i0:self.size] = values[:, :, :split]
99
+ self.keys[:, :, 0:i1] = keys[:, :, split:]
100
+ self.values[:, :, 0:i1] = values[:, :, split:]
101
 
102
+ self.global_loc += keys.shape[2]
103
+ self.local_loc = min(self.size, self.local_loc + keys.shape[2])
104
+ self._write_ptr = (self._write_ptr + keys.shape[2]) % self.size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  @t.no_grad()
107
  def reset(self, zero_memory: bool = True):
 
113
  self.keys.zero_()
114
  self.values.zero_()
115
 
 
116
  @property
117
  def local_location(self):
118
  return self.local_loc
 
129
  def dtype(self):
130
  return self.keys.dtype
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
+ class KVCacheNaive(nn.Module):
135
+ def __init__(self, batch_size, n_layers, n_heads, d_head, toks_per_frame, n_window, dtype=t.float32, device='cuda'):
136
  """
137
  This is a rolling KVCache
138
  """
 
142
  self.d_head = d_head
143
  self.toks_per_frame = toks_per_frame
144
  self.n_window = n_window
145
+ self.size = toks_per_frame * (n_window - 1)
146
  self.n_layers = n_layers
 
147
  self.global_loc = 0
148
  self.local_loc = 0
149
+
150
+ self.register_buffer('keys', t.zeros(n_layers, batch_size, self.size, n_heads, d_head, dtype=dtype, device=device))
151
+ self.register_buffer('values', t.zeros(n_layers, batch_size, self.size, n_heads, d_head, dtype=dtype, device=device))
152
 
153
+ def get(self):
154
+ return self.keys[:, :, :self.local_loc], self.values[:, :, :self.local_loc]
 
155
 
156
+ def extend(self, keys, values):
157
+ """
158
+ this should only be called on the last denoising step respectively.
159
+ """
160
  assert keys.shape == values.shape, f"keys and values shapes must match {self.keys.shape} != {self.values.shape}"
 
161
  assert self.local_loc <= self.size, f"the cache size should be between 0 and {self.size}"
162
  local_loc = self.local_loc
163
  if local_loc == self.size:
164
  # move to the left
165
+ local_loc -= keys.shape[2]
166
+ assert local_loc >= 0, f"the cache update {keys.shape[2]} was larger than the cache {self.size}, that's not supported for now."
167
  assert local_loc % self.toks_per_frame == 0, f"the number of elements in the cache {local_loc} must be a multiple of the number of tokens per frame {self.toks_per_frame}"
168
+ self.keys[:, :, :local_loc] = self.keys[:, :, self.toks_per_frame:local_loc+self.toks_per_frame].clone()
169
+ self.values[:, :, :local_loc] = self.values[:, :, self.toks_per_frame:local_loc+self.toks_per_frame].clone()
170
+
171
+ assert local_loc + keys.shape[2] <= self.size, f"{local_loc + keys.shape[2]} out of bounds {self.size}"
172
+ self.keys[:, :, local_loc:local_loc + keys.shape[2]] = keys
173
+ self.values[:, :, local_loc:local_loc + keys.shape[2]] = values
 
 
174
  self.curr_layer = (self.curr_layer + 1) % self.n_layers
175
 
176
+ self.global_loc += keys.shape[2]
 
177
  if self.local_loc < self.size:
178
+ self.local_loc += keys.shape[2]
179
  assert self.local_loc <= self.size, f"the local loc {self.local_loc} should never be bigger than {self.size}, something went wrong."
180
 
181
  def reset(self):
 
205
  class AttentionEinOps(nn.Module):
206
  IGNORE: Float[Tensor, ""]
207
 
208
+ def __init__(self, d_model, n_heads, rope=None, ln_first=False):
209
  super().__init__()
210
  assert d_model % n_heads == 0, f"{d_model} must be divisble by {n_heads}"
211
  self.d_head = d_model // n_heads
212
+ self.d_model = d_model
213
+ self.n_heads = n_heads
214
+ self.ln_first = ln_first
215
  d_head = self.d_head
216
  self.W_Q = nn.Parameter(t.empty((n_heads, d_model, d_head)))
217
  self.W_K = nn.Parameter(t.empty((n_heads, d_model, d_head)))
 
241
  offset: int = 0
242
  ) -> Float[Tensor, "batch posq d_model"]:
243
  assert (k_cache is None and v_cache is None) or (k_cache is not None and v_cache is not None), "k_cache and v_cache go together."
 
244
  if k_cache is not None and v_cache is not None:
245
  q = einops.einsum(x_q, self.W_Q, 'b s d, n d h -> b s n h') + self.b_Q
246
  k_new = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h') + self.b_K
 
248
 
249
  k = t.cat([k_cache, k_new], dim=1)
250
  v = t.cat([v_cache, v_new], dim=1)
251
+ if self.ln_first:
252
+ q = self.ln1(q)
253
+ k = self.ln2(k)
254
 
255
  if self.rope is not None:
256
+ q = self.rope(q, offset=k_cache.shape[1])
257
  k = self.rope(k, offset=0)
258
+
259
+ if not self.ln_first:
260
+ q = self.ln1(q) # ppl usually do this before rope but our best checkpoint has it after rope, so this is for bwd compatibility; but in quick test on singleframe this did not make a big difference
261
+ k = self.ln2(k)
262
  mask = None
263
  else:
264
  q = einops.einsum(x_q, self.W_Q, 'b s d, n d h -> b s n h') + self.b_Q
265
  k = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h') + self.b_K
266
  v = einops.einsum(x_kv, self.W_V, 'b s d, n d h -> b s n h') + self.b_V
267
+ if self.ln_first:
268
+ q = self.ln1(q)
269
+ k = self.ln2(k)
270
+
271
  if self.rope is not None:
272
  q = self.rope(q)
273
  k = self.rope(k)
274
+
275
+ if not self.ln_first:
276
+ q = self.ln1(q)
277
+ k = self.ln2(k)
278
  k_new = k
279
  v_new = v
280
 
281
  attention = einops.einsum(q, k, 'b sq n h, b sk n h -> b n sq sk')
282
  if mask is not None and k_cache is not None:
283
+ attention = t.where(mask[k_cache.shape[1]:k_cache.shape[1]+q.shape[1], :k.shape[1]], attention, self.IGNORE)
284
  elif mask is not None:
285
  if attention.shape[-1] != mask.shape[-1] or attention.shape[-2] != mask.shape[-2]:
 
286
  mask = mask[:attention.shape[-1], :attention.shape[-2]]
287
+ attention = t.where(mask, attention, self.IGNORE)
288
  probas = attention.softmax(dim=3)
 
 
289
  z = einops.einsum(probas, v, 'b n sq sk, b sk n h -> b sq n h')
290
  out = einops.einsum(z, self.W_O, 'b s n h, n h d -> b s n d')
291
  out = out.sum(dim=2) + self.b_O
292
  return out, k_new, v_new
293
 
294