Update src/nn/pe.py
Browse files- src/nn/pe.py +103 -28
src/nn/pe.py
CHANGED
|
@@ -5,6 +5,16 @@ import math
|
|
| 5 |
from jaxtyping import Float, Bool, Int
|
| 6 |
from torch import Tensor
|
| 7 |
from typing import Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
class NumericEncoding(nn.Module):
|
| 10 |
def __init__(self, C = 1e4, dim = 64, n_max = 10000):
|
|
@@ -29,19 +39,16 @@ class NumericEncoding(nn.Module):
|
|
| 29 |
class RoPE(nn.Module):
|
| 30 |
def __init__(self, d_head, n_ctx, C=10000):
|
| 31 |
super().__init__()
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
sins = t.sin(all_thetas)
|
| 37 |
-
coss = t.cos(all_thetas)
|
| 38 |
self.register_buffer('sins', sins.unsqueeze(0).unsqueeze(2))
|
| 39 |
self.register_buffer('coss', coss.unsqueeze(0).unsqueeze(2))
|
| 40 |
|
| 41 |
def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
|
| 42 |
offset: int = 0):
|
| 43 |
-
x = key_or_query
|
| 44 |
-
# start with doing it for just a single position m
|
| 45 |
x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
|
| 46 |
even = t.arange(0, x.shape[-1], 2)
|
| 47 |
odd = t.arange(1, x.shape[-1],2)
|
|
@@ -50,28 +57,96 @@ class RoPE(nn.Module):
|
|
| 50 |
assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
|
| 51 |
return self.coss[:,offset:offset+x.shape[1]]*x + self.sins[:,offset:offset+x.shape[1]]*x_perm
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def __init__(self, d_head, n_ctx, toks_per_frame, C=10000):
|
| 56 |
super().__init__()
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
sins =
|
| 62 |
-
coss =
|
| 63 |
-
self.register_buffer('sins', sins.unsqueeze(0).unsqueeze(2))
|
| 64 |
-
self.register_buffer('coss', coss.unsqueeze(0).unsqueeze(2))
|
| 65 |
-
self.toks_per_frame = toks_per_frame
|
| 66 |
|
| 67 |
-
def forward(self, key_or_query: Float[Tensor, "batch
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
x_perm = t.empty(x.shape,
|
| 71 |
even = t.arange(0, x.shape[-1], 2)
|
| 72 |
-
odd = t.arange(1, x.shape[-1],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
x_perm[:, :, :, even] = -x[:, :, :, odd]
|
| 74 |
x_perm[:, :, :, odd] = x[:, :, :, even]
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from jaxtyping import Float, Bool, Int
|
| 6 |
from torch import Tensor
|
| 7 |
from typing import Optional
|
| 8 |
+
from pdb import set_trace
|
| 9 |
+
|
| 10 |
+
def compute_trig(d_head, n_ctx, C):
|
| 11 |
+
thetas = t.exp(-math.log(C)*t.arange(0,d_head,2)/d_head)
|
| 12 |
+
thetas = thetas.repeat([2,1]).T.flatten()
|
| 13 |
+
positions = t.arange(n_ctx)
|
| 14 |
+
all_thetas = positions.unsqueeze(1)*thetas.unsqueeze(0)
|
| 15 |
+
sins = t.sin(all_thetas)
|
| 16 |
+
coss = t.cos(all_thetas)
|
| 17 |
+
return sins, coss
|
| 18 |
|
| 19 |
class NumericEncoding(nn.Module):
|
| 20 |
def __init__(self, C = 1e4, dim = 64, n_max = 10000):
|
|
|
|
| 39 |
class RoPE(nn.Module):
|
| 40 |
def __init__(self, d_head, n_ctx, C=10000):
|
| 41 |
super().__init__()
|
| 42 |
+
self.d_head = d_head
|
| 43 |
+
self.n_ctx = n_ctx
|
| 44 |
+
self.C = C
|
| 45 |
+
sins, coss = compute_trig(d_head, n_ctx, C)
|
|
|
|
|
|
|
| 46 |
self.register_buffer('sins', sins.unsqueeze(0).unsqueeze(2))
|
| 47 |
self.register_buffer('coss', coss.unsqueeze(0).unsqueeze(2))
|
| 48 |
|
| 49 |
def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
|
| 50 |
offset: int = 0):
|
| 51 |
+
x = key_or_query
|
|
|
|
| 52 |
x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
|
| 53 |
even = t.arange(0, x.shape[-1], 2)
|
| 54 |
odd = t.arange(1, x.shape[-1],2)
|
|
|
|
| 57 |
assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
|
| 58 |
return self.coss[:,offset:offset+x.shape[1]]*x + self.sins[:,offset:offset+x.shape[1]]*x_perm
|
| 59 |
|
| 60 |
+
class LearnRoPE(nn.Module):
|
| 61 |
+
def __init__(self, d_head, n_ctx, C=10000):
|
|
|
|
| 62 |
super().__init__()
|
| 63 |
+
self.d_head = d_head
|
| 64 |
+
self.n_ctx = n_ctx
|
| 65 |
+
self.C = C
|
| 66 |
+
sins, coss = compute_trig(d_head, n_ctx, C)
|
| 67 |
+
self.sins = nn.Parameter(sins.unsqueeze(0).unsqueeze(2))
|
| 68 |
+
self.coss = nn.Parameter(coss.unsqueeze(0).unsqueeze(2))
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
|
| 71 |
+
offset: int = 0):
|
| 72 |
+
x = key_or_query
|
| 73 |
+
x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
|
| 74 |
even = t.arange(0, x.shape[-1], 2)
|
| 75 |
+
odd = t.arange(1, x.shape[-1],2)
|
| 76 |
+
x_perm[:, :, :, even] = -x[:, :, :, odd]
|
| 77 |
+
x_perm[:, :, :, odd] = x[:, :, :, even]
|
| 78 |
+
assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
|
| 79 |
+
return self.coss[:,offset:offset+x.shape[1]]*x + self.sins[:,offset:offset+x.shape[1]]*x_perm
|
| 80 |
+
|
| 81 |
+
class VidRoPE(nn.Module):
|
| 82 |
+
def __init__(self, d_head,
|
| 83 |
+
d_x,
|
| 84 |
+
d_y,
|
| 85 |
+
d_t,
|
| 86 |
+
ctx_x,
|
| 87 |
+
ctx_y,
|
| 88 |
+
ctx_t,
|
| 89 |
+
C_x,
|
| 90 |
+
C_y,
|
| 91 |
+
C_t,
|
| 92 |
+
toks_per_frame,
|
| 93 |
+
n_registers):
|
| 94 |
+
super().__init__()
|
| 95 |
+
assert d_x + d_y + d_t <= d_head, f"dx + dy + dt > d_head"
|
| 96 |
+
self.d_head = d_head
|
| 97 |
+
self.d_x = d_x
|
| 98 |
+
self.d_y = d_y
|
| 99 |
+
self.d_t = d_t
|
| 100 |
+
self.ctx_x = ctx_x
|
| 101 |
+
self.ctx_y = ctx_y
|
| 102 |
+
self.ctx_t = ctx_t
|
| 103 |
+
self.C_x = C_x
|
| 104 |
+
self.C_y = C_y
|
| 105 |
+
self.C_t = C_t
|
| 106 |
+
self.toks_per_frame = toks_per_frame
|
| 107 |
+
self.n_registers = n_registers
|
| 108 |
+
sins_x, coss_x = compute_trig(d_x, ctx_x+1, C_x) # +1 for the register
|
| 109 |
+
self.register_buffer("sins_x", sins_x.unsqueeze(0).unsqueeze(2))
|
| 110 |
+
self.register_buffer("coss_x", coss_x.unsqueeze(0).unsqueeze(2))
|
| 111 |
+
sins_y, coss_y = compute_trig(d_y, ctx_y+1, C_y) # +1 for the register
|
| 112 |
+
self.register_buffer("sins_y", sins_y.unsqueeze(0).unsqueeze(2))
|
| 113 |
+
self.register_buffer("coss_y", coss_y.unsqueeze(0).unsqueeze(2))
|
| 114 |
+
sins_t, coss_t = compute_trig(d_t, ctx_t, C_t)
|
| 115 |
+
self.register_buffer("sins_t", sins_t.unsqueeze(0).unsqueeze(2))
|
| 116 |
+
self.register_buffer("coss_t", coss_t.unsqueeze(0).unsqueeze(2))
|
| 117 |
+
n_frames = ctx_t
|
| 118 |
+
# ctx_x should be equal to width
|
| 119 |
+
# ctx_y should be equal to height
|
| 120 |
+
pos_x = t.arange(self.ctx_x).repeat(self.ctx_y) # w cols with h entries each
|
| 121 |
+
pos_x = t.cat([pos_x, t.tensor([self.ctx_x], dtype=t.int32)]) # deal with register
|
| 122 |
+
pos_x = pos_x.repeat(n_frames)
|
| 123 |
+
pos_y = t.arange(self.ctx_y).repeat_interleave(self.ctx_x) # h rows with w entries each
|
| 124 |
+
pos_y = t.cat([pos_y, t.tensor([self.ctx_y], dtype=t.int32)]) # deal with register
|
| 125 |
+
pos_y = pos_y.repeat(n_frames)
|
| 126 |
+
pos_t = t.arange(n_frames).repeat_interleave(self.toks_per_frame)
|
| 127 |
+
self.register_buffer("pos_x", pos_x)
|
| 128 |
+
self.register_buffer("pos_y", pos_y)
|
| 129 |
+
self.register_buffer("pos_t", pos_t)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def rotate(self, x, pos_idcs, coss, sins):
|
| 133 |
+
x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
|
| 134 |
+
even = t.arange(0, x.shape[-1], 2, device=x.device)
|
| 135 |
+
odd = t.arange(1, x.shape[-1], 2, device=x.device)
|
| 136 |
x_perm[:, :, :, even] = -x[:, :, :, odd]
|
| 137 |
x_perm[:, :, :, odd] = x[:, :, :, even]
|
| 138 |
+
assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
|
| 139 |
+
assert pos_idcs.shape[0] == x.shape[1], f"pos_idcs length {pos_idcs.shape[0]} must match x.shape[1] {x.shape[1]}"
|
| 140 |
+
|
| 141 |
+
return coss[:,pos_idcs]*x + sins[:,pos_idcs]*x_perm
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
|
| 145 |
+
offset: int = 0):
|
| 146 |
+
x = key_or_query
|
| 147 |
+
x[:, :, :, :self.d_x] = self.rotate(x[:, :, :, :self.d_x], self.pos_x, self.coss_x, self.sins_x)
|
| 148 |
+
x[:, :, :, self.d_x:self.d_x+self.d_y] = self.rotate(x[:, :, :, self.d_x:self.d_x+self.d_y], self.pos_y, self.coss_y, self.sins_y)
|
| 149 |
+
x[:, :, :, self.d_x+self.d_y:self.d_x+self.d_y+self.d_t] = self.rotate(x[:, : , :, self.d_x+self.d_y:self.d_x+self.d_y+self.d_t], self.pos_t+(offset//self.toks_per_frame), self.coss_t, self.sins_t)
|
| 150 |
+
return x
|
| 151 |
+
|
| 152 |
+
|