File size: 6,588 Bytes
8746765 a9440e0 8746765 a9440e0 8746765 a9440e0 8746765 a9440e0 8746765 a9440e0 8746765 a9440e0 8746765 a9440e0 8746765 a9440e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch as t
import torch.nn as nn
import math
from jaxtyping import Float, Bool, Int
from torch import Tensor
from typing import Optional
from pdb import set_trace
def compute_trig(d_head, n_ctx, C):
thetas = t.exp(-math.log(C)*t.arange(0,d_head,2)/d_head)
thetas = thetas.repeat([2,1]).T.flatten()
positions = t.arange(n_ctx)
all_thetas = positions.unsqueeze(1)*thetas.unsqueeze(0)
sins = t.sin(all_thetas)
coss = t.cos(all_thetas)
return sins, coss
class NumericEncoding(nn.Module):
def __init__(self, C = 1e4, dim = 64, n_max = 10000):
super().__init__()
args = t.exp(- math.log(C) * t.arange(0, dim, 2)/dim)
args = t.arange(n_max)[:, None] * args[None, :]
sins = t.sin(args)
coss = t.cos(args)
pe = t.empty((n_max, dim))
pe[:,::2] = sins
pe[:,1::2] = coss
self.register_buffer("pe", pe)
def forward(self, num):
"""
expects integers between 0 and n_max
"""
assert num.dtype == t.int32 or num.dtype == t.int64, f"wrong dtype {num.dtype}"
return self.pe[num]
class RoPE(nn.Module):
def __init__(self, d_head, n_ctx, C=10000):
super().__init__()
self.d_head = d_head
self.n_ctx = n_ctx
self.C = C
sins, coss = compute_trig(d_head, n_ctx, C)
self.register_buffer('sins', sins.unsqueeze(0).unsqueeze(2))
self.register_buffer('coss', coss.unsqueeze(0).unsqueeze(2))
def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
offset: int = 0):
x = key_or_query
x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
even = t.arange(0, x.shape[-1], 2)
odd = t.arange(1, x.shape[-1],2)
x_perm[:, :, :, even] = -x[:, :, :, odd]
x_perm[:, :, :, odd] = x[:, :, :, even]
assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
return self.coss[:,offset:offset+x.shape[1]]*x + self.sins[:,offset:offset+x.shape[1]]*x_perm
class LearnRoPE(nn.Module):
def __init__(self, d_head, n_ctx, C=10000):
super().__init__()
self.d_head = d_head
self.n_ctx = n_ctx
self.C = C
sins, coss = compute_trig(d_head, n_ctx, C)
self.sins = nn.Parameter(sins.unsqueeze(0).unsqueeze(2))
self.coss = nn.Parameter(coss.unsqueeze(0).unsqueeze(2))
def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
offset: int = 0):
x = key_or_query
x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
even = t.arange(0, x.shape[-1], 2)
odd = t.arange(1, x.shape[-1],2)
x_perm[:, :, :, even] = -x[:, :, :, odd]
x_perm[:, :, :, odd] = x[:, :, :, even]
assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
return self.coss[:,offset:offset+x.shape[1]]*x + self.sins[:,offset:offset+x.shape[1]]*x_perm
class VidRoPE(nn.Module):
def __init__(self, d_head,
d_x,
d_y,
d_t,
ctx_x,
ctx_y,
ctx_t,
C_x,
C_y,
C_t,
toks_per_frame,
n_registers):
super().__init__()
assert d_x + d_y + d_t <= d_head, f"dx + dy + dt > d_head"
self.d_head = d_head
self.d_x = d_x
self.d_y = d_y
self.d_t = d_t
self.ctx_x = ctx_x
self.ctx_y = ctx_y
self.ctx_t = ctx_t
self.C_x = C_x
self.C_y = C_y
self.C_t = C_t
self.toks_per_frame = toks_per_frame
self.n_registers = n_registers
sins_x, coss_x = compute_trig(d_x, ctx_x+1, C_x) # +1 for the register
self.register_buffer("sins_x", sins_x.unsqueeze(0).unsqueeze(2))
self.register_buffer("coss_x", coss_x.unsqueeze(0).unsqueeze(2))
sins_y, coss_y = compute_trig(d_y, ctx_y+1, C_y) # +1 for the register
self.register_buffer("sins_y", sins_y.unsqueeze(0).unsqueeze(2))
self.register_buffer("coss_y", coss_y.unsqueeze(0).unsqueeze(2))
sins_t, coss_t = compute_trig(d_t, ctx_t, C_t)
self.register_buffer("sins_t", sins_t.unsqueeze(0).unsqueeze(2))
self.register_buffer("coss_t", coss_t.unsqueeze(0).unsqueeze(2))
n_frames = ctx_t
# ctx_x should be equal to width
# ctx_y should be equal to height
pos_x = t.arange(self.ctx_x).repeat(self.ctx_y) # w cols with h entries each
pos_x = t.cat([pos_x, t.tensor([self.ctx_x], dtype=t.int32)]) # deal with register
pos_x = pos_x.repeat(n_frames)
pos_y = t.arange(self.ctx_y).repeat_interleave(self.ctx_x) # h rows with w entries each
pos_y = t.cat([pos_y, t.tensor([self.ctx_y], dtype=t.int32)]) # deal with register
pos_y = pos_y.repeat(n_frames)
pos_t = t.arange(n_frames).repeat_interleave(self.toks_per_frame)
self.register_buffer("pos_x", pos_x)
self.register_buffer("pos_y", pos_y)
self.register_buffer("pos_t", pos_t)
def rotate(self, x, pos_idcs, coss, sins):
x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
even = t.arange(0, x.shape[-1], 2, device=x.device)
odd = t.arange(1, x.shape[-1], 2, device=x.device)
x_perm[:, :, :, even] = -x[:, :, :, odd]
x_perm[:, :, :, odd] = x[:, :, :, even]
assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
assert pos_idcs.shape[0] == x.shape[1], f"pos_idcs length {pos_idcs.shape[0]} must match x.shape[1] {x.shape[1]}"
return coss[:,pos_idcs]*x + sins[:,pos_idcs]*x_perm
def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
offset: int = 0):
x = key_or_query
x[:, :, :, :self.d_x] = self.rotate(x[:, :, :, :self.d_x], self.pos_x, self.coss_x, self.sins_x)
x[:, :, :, self.d_x:self.d_x+self.d_y] = self.rotate(x[:, :, :, self.d_x:self.d_x+self.d_y], self.pos_y, self.coss_y, self.sins_y)
x[:, :, :, self.d_x+self.d_y:self.d_x+self.d_y+self.d_t] = self.rotate(x[:, : , :, self.d_x+self.d_y:self.d_x+self.d_y+self.d_t], self.pos_t+(offset//self.toks_per_frame), self.coss_t, self.sins_t)
return x
|