|
|
import torch as t |
|
|
import torch.nn as nn |
|
|
import math |
|
|
|
|
|
from jaxtyping import Float, Bool, Int |
|
|
from torch import Tensor |
|
|
from typing import Optional |
|
|
from pdb import set_trace |
|
|
|
|
|
def compute_trig(d_head, n_ctx, C): |
|
|
thetas = t.exp(-math.log(C)*t.arange(0,d_head,2)/d_head) |
|
|
thetas = thetas.repeat([2,1]).T.flatten() |
|
|
positions = t.arange(n_ctx) |
|
|
all_thetas = positions.unsqueeze(1)*thetas.unsqueeze(0) |
|
|
sins = t.sin(all_thetas) |
|
|
coss = t.cos(all_thetas) |
|
|
return sins, coss |
|
|
|
|
|
class NumericEncoding(nn.Module): |
|
|
def __init__(self, C = 1e4, dim = 64, n_max = 10000): |
|
|
super().__init__() |
|
|
args = t.exp(- math.log(C) * t.arange(0, dim, 2)/dim) |
|
|
args = t.arange(n_max)[:, None] * args[None, :] |
|
|
sins = t.sin(args) |
|
|
coss = t.cos(args) |
|
|
pe = t.empty((n_max, dim)) |
|
|
pe[:,::2] = sins |
|
|
pe[:,1::2] = coss |
|
|
self.register_buffer("pe", pe) |
|
|
|
|
|
def forward(self, num): |
|
|
""" |
|
|
expects integers between 0 and n_max |
|
|
""" |
|
|
assert num.dtype == t.int32 or num.dtype == t.int64, f"wrong dtype {num.dtype}" |
|
|
return self.pe[num] |
|
|
|
|
|
|
|
|
class RoPE(nn.Module): |
|
|
def __init__(self, d_head, n_ctx, C=10000): |
|
|
super().__init__() |
|
|
self.d_head = d_head |
|
|
self.n_ctx = n_ctx |
|
|
self.C = C |
|
|
sins, coss = compute_trig(d_head, n_ctx, C) |
|
|
self.register_buffer('sins', sins.unsqueeze(0).unsqueeze(2)) |
|
|
self.register_buffer('coss', coss.unsqueeze(0).unsqueeze(2)) |
|
|
|
|
|
def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"], |
|
|
offset: int = 0): |
|
|
x = key_or_query |
|
|
x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) |
|
|
even = t.arange(0, x.shape[-1], 2) |
|
|
odd = t.arange(1, x.shape[-1],2) |
|
|
x_perm[:, :, :, even] = -x[:, :, :, odd] |
|
|
x_perm[:, :, :, odd] = x[:, :, :, even] |
|
|
assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}" |
|
|
return self.coss[:,offset:offset+x.shape[1]]*x + self.sins[:,offset:offset+x.shape[1]]*x_perm |
|
|
|
|
|
class LearnRoPE(nn.Module): |
|
|
def __init__(self, d_head, n_ctx, C=10000): |
|
|
super().__init__() |
|
|
self.d_head = d_head |
|
|
self.n_ctx = n_ctx |
|
|
self.C = C |
|
|
sins, coss = compute_trig(d_head, n_ctx, C) |
|
|
self.sins = nn.Parameter(sins.unsqueeze(0).unsqueeze(2)) |
|
|
self.coss = nn.Parameter(coss.unsqueeze(0).unsqueeze(2)) |
|
|
|
|
|
def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"], |
|
|
offset: int = 0): |
|
|
x = key_or_query |
|
|
x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) |
|
|
even = t.arange(0, x.shape[-1], 2) |
|
|
odd = t.arange(1, x.shape[-1],2) |
|
|
x_perm[:, :, :, even] = -x[:, :, :, odd] |
|
|
x_perm[:, :, :, odd] = x[:, :, :, even] |
|
|
assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}" |
|
|
return self.coss[:,offset:offset+x.shape[1]]*x + self.sins[:,offset:offset+x.shape[1]]*x_perm |
|
|
|
|
|
class VidRoPE(nn.Module): |
|
|
def __init__(self, d_head, |
|
|
d_x, |
|
|
d_y, |
|
|
d_t, |
|
|
ctx_x, |
|
|
ctx_y, |
|
|
ctx_t, |
|
|
C_x, |
|
|
C_y, |
|
|
C_t, |
|
|
toks_per_frame, |
|
|
n_registers): |
|
|
super().__init__() |
|
|
assert d_x + d_y + d_t <= d_head, f"dx + dy + dt > d_head" |
|
|
self.d_head = d_head |
|
|
self.d_x = d_x |
|
|
self.d_y = d_y |
|
|
self.d_t = d_t |
|
|
self.ctx_x = ctx_x |
|
|
self.ctx_y = ctx_y |
|
|
self.ctx_t = ctx_t |
|
|
self.C_x = C_x |
|
|
self.C_y = C_y |
|
|
self.C_t = C_t |
|
|
self.toks_per_frame = toks_per_frame |
|
|
self.n_registers = n_registers |
|
|
sins_x, coss_x = compute_trig(d_x, ctx_x+1, C_x) |
|
|
self.register_buffer("sins_x", sins_x.unsqueeze(0).unsqueeze(2)) |
|
|
self.register_buffer("coss_x", coss_x.unsqueeze(0).unsqueeze(2)) |
|
|
sins_y, coss_y = compute_trig(d_y, ctx_y+1, C_y) |
|
|
self.register_buffer("sins_y", sins_y.unsqueeze(0).unsqueeze(2)) |
|
|
self.register_buffer("coss_y", coss_y.unsqueeze(0).unsqueeze(2)) |
|
|
sins_t, coss_t = compute_trig(d_t, ctx_t, C_t) |
|
|
self.register_buffer("sins_t", sins_t.unsqueeze(0).unsqueeze(2)) |
|
|
self.register_buffer("coss_t", coss_t.unsqueeze(0).unsqueeze(2)) |
|
|
n_frames = ctx_t |
|
|
|
|
|
|
|
|
pos_x = t.arange(self.ctx_x).repeat(self.ctx_y) |
|
|
pos_x = t.cat([pos_x, t.tensor([self.ctx_x], dtype=t.int32)]) |
|
|
pos_x = pos_x.repeat(n_frames) |
|
|
pos_y = t.arange(self.ctx_y).repeat_interleave(self.ctx_x) |
|
|
pos_y = t.cat([pos_y, t.tensor([self.ctx_y], dtype=t.int32)]) |
|
|
pos_y = pos_y.repeat(n_frames) |
|
|
pos_t = t.arange(n_frames).repeat_interleave(self.toks_per_frame) |
|
|
self.register_buffer("pos_x", pos_x) |
|
|
self.register_buffer("pos_y", pos_y) |
|
|
self.register_buffer("pos_t", pos_t) |
|
|
|
|
|
|
|
|
def rotate(self, x, pos_idcs, coss, sins): |
|
|
x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) |
|
|
even = t.arange(0, x.shape[-1], 2, device=x.device) |
|
|
odd = t.arange(1, x.shape[-1], 2, device=x.device) |
|
|
x_perm[:, :, :, even] = -x[:, :, :, odd] |
|
|
x_perm[:, :, :, odd] = x[:, :, :, even] |
|
|
assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}" |
|
|
assert pos_idcs.shape[0] == x.shape[1], f"pos_idcs length {pos_idcs.shape[0]} must match x.shape[1] {x.shape[1]}" |
|
|
|
|
|
return coss[:,pos_idcs]*x + sins[:,pos_idcs]*x_perm |
|
|
|
|
|
|
|
|
def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"], |
|
|
offset: int = 0): |
|
|
x = key_or_query |
|
|
x[:, :, :, :self.d_x] = self.rotate(x[:, :, :, :self.d_x], self.pos_x, self.coss_x, self.sins_x) |
|
|
x[:, :, :, self.d_x:self.d_x+self.d_y] = self.rotate(x[:, :, :, self.d_x:self.d_x+self.d_y], self.pos_y, self.coss_y, self.sins_y) |
|
|
x[:, :, :, self.d_x+self.d_y:self.d_x+self.d_y+self.d_t] = self.rotate(x[:, : , :, self.d_x+self.d_y:self.d_x+self.d_y+self.d_t], self.pos_t+(offset//self.toks_per_frame), self.coss_t, self.sins_t) |
|
|
return x |
|
|
|
|
|
|