File size: 6,588 Bytes
8746765
 
 
 
 
 
 
a9440e0
 
 
 
 
 
 
 
 
 
8746765
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9440e0
 
 
 
8746765
 
 
 
 
a9440e0
8746765
 
 
 
 
 
 
 
a9440e0
 
8746765
a9440e0
 
 
 
 
 
8746765
a9440e0
 
 
 
8746765
a9440e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8746765
 
a9440e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch as t
import torch.nn as nn
import math

from jaxtyping import Float, Bool, Int
from torch import Tensor
from typing import Optional
from pdb import set_trace 

def compute_trig(d_head, n_ctx, C):
    thetas = t.exp(-math.log(C)*t.arange(0,d_head,2)/d_head)
    thetas = thetas.repeat([2,1]).T.flatten()
    positions = t.arange(n_ctx)
    all_thetas = positions.unsqueeze(1)*thetas.unsqueeze(0)
    sins = t.sin(all_thetas)
    coss = t.cos(all_thetas)
    return sins, coss

class NumericEncoding(nn.Module):
    def __init__(self, C = 1e4, dim = 64, n_max = 10000):
        super().__init__()
        args = t.exp(- math.log(C) * t.arange(0, dim, 2)/dim)
        args = t.arange(n_max)[:, None] * args[None, :]
        sins = t.sin(args)
        coss = t.cos(args)
        pe = t.empty((n_max, dim))
        pe[:,::2] = sins
        pe[:,1::2] = coss
        self.register_buffer("pe", pe)
    
    def forward(self, num):
        """
        expects integers between 0 and n_max
        """
        assert num.dtype == t.int32 or num.dtype == t.int64, f"wrong dtype {num.dtype}"
        return self.pe[num]


class RoPE(nn.Module):
    def __init__(self, d_head, n_ctx, C=10000):
        super().__init__()
        self.d_head = d_head
        self.n_ctx = n_ctx
        self.C = C
        sins, coss = compute_trig(d_head, n_ctx, C)
        self.register_buffer('sins', sins.unsqueeze(0).unsqueeze(2))
        self.register_buffer('coss', coss.unsqueeze(0).unsqueeze(2))
    
    def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
                      offset: int = 0):
        x = key_or_query 
        x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
        even = t.arange(0, x.shape[-1], 2)
        odd = t.arange(1, x.shape[-1],2)
        x_perm[:, :, :, even] = -x[:, :, :, odd]
        x_perm[:, :, :, odd] = x[:, :, :, even]
        assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
        return self.coss[:,offset:offset+x.shape[1]]*x + self.sins[:,offset:offset+x.shape[1]]*x_perm

class LearnRoPE(nn.Module):
    def __init__(self, d_head, n_ctx, C=10000):
        super().__init__()
        self.d_head = d_head
        self.n_ctx = n_ctx
        self.C = C
        sins, coss = compute_trig(d_head, n_ctx, C)
        self.sins = nn.Parameter(sins.unsqueeze(0).unsqueeze(2))
        self.coss = nn.Parameter(coss.unsqueeze(0).unsqueeze(2))
    
    def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
                      offset: int = 0):
        x = key_or_query 
        x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
        even = t.arange(0, x.shape[-1], 2)
        odd = t.arange(1, x.shape[-1],2)
        x_perm[:, :, :, even] = -x[:, :, :, odd]
        x_perm[:, :, :, odd] = x[:, :, :, even]
        assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
        return self.coss[:,offset:offset+x.shape[1]]*x + self.sins[:,offset:offset+x.shape[1]]*x_perm

class VidRoPE(nn.Module):
    def __init__(self, d_head, 
                    d_x, 
                    d_y,
                    d_t,
                    ctx_x,
                    ctx_y,
                    ctx_t,
                    C_x,
                    C_y,
                    C_t,
                    toks_per_frame,
                    n_registers):
        super().__init__()
        assert d_x + d_y + d_t <= d_head, f"dx + dy + dt > d_head"
        self.d_head = d_head
        self.d_x = d_x
        self.d_y = d_y
        self.d_t = d_t
        self.ctx_x = ctx_x 
        self.ctx_y = ctx_y
        self.ctx_t = ctx_t 
        self.C_x = C_x
        self.C_y = C_y
        self.C_t = C_t
        self.toks_per_frame = toks_per_frame 
        self.n_registers = n_registers 
        sins_x, coss_x = compute_trig(d_x, ctx_x+1, C_x) # +1 for the register
        self.register_buffer("sins_x", sins_x.unsqueeze(0).unsqueeze(2))
        self.register_buffer("coss_x", coss_x.unsqueeze(0).unsqueeze(2))
        sins_y, coss_y = compute_trig(d_y, ctx_y+1, C_y) # +1 for the register
        self.register_buffer("sins_y", sins_y.unsqueeze(0).unsqueeze(2))
        self.register_buffer("coss_y", coss_y.unsqueeze(0).unsqueeze(2))
        sins_t, coss_t = compute_trig(d_t, ctx_t, C_t)
        self.register_buffer("sins_t", sins_t.unsqueeze(0).unsqueeze(2))
        self.register_buffer("coss_t", coss_t.unsqueeze(0).unsqueeze(2)) 
        n_frames = ctx_t
        # ctx_x should be equal to width
        # ctx_y should be equal to height
        pos_x = t.arange(self.ctx_x).repeat(self.ctx_y) # w cols with h entries each
        pos_x = t.cat([pos_x, t.tensor([self.ctx_x], dtype=t.int32)]) # deal with register
        pos_x = pos_x.repeat(n_frames)
        pos_y = t.arange(self.ctx_y).repeat_interleave(self.ctx_x) # h rows with w entries each
        pos_y = t.cat([pos_y, t.tensor([self.ctx_y], dtype=t.int32)]) # deal with register
        pos_y = pos_y.repeat(n_frames)
        pos_t = t.arange(n_frames).repeat_interleave(self.toks_per_frame)
        self.register_buffer("pos_x", pos_x)
        self.register_buffer("pos_y", pos_y)
        self.register_buffer("pos_t", pos_t)

    
    def rotate(self, x, pos_idcs, coss, sins):
        x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
        even = t.arange(0, x.shape[-1], 2, device=x.device)
        odd = t.arange(1, x.shape[-1], 2, device=x.device)
        x_perm[:, :, :, even] = -x[:, :, :, odd]
        x_perm[:, :, :, odd] = x[:, :, :, even]
        assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
        assert pos_idcs.shape[0] == x.shape[1], f"pos_idcs length {pos_idcs.shape[0]} must match x.shape[1] {x.shape[1]}"

        return coss[:,pos_idcs]*x + sins[:,pos_idcs]*x_perm


    def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
                      offset: int = 0): 
        x = key_or_query 
        x[:, :, :, :self.d_x] = self.rotate(x[:, :, :, :self.d_x], self.pos_x, self.coss_x, self.sins_x) 
        x[:, :, :, self.d_x:self.d_x+self.d_y] = self.rotate(x[:, :, :, self.d_x:self.d_x+self.d_y], self.pos_y, self.coss_y, self.sins_y)
        x[:, :, :, self.d_x+self.d_y:self.d_x+self.d_y+self.d_t] = self.rotate(x[:, : , :, self.d_x+self.d_y:self.d_x+self.d_y+self.d_t], self.pos_t+(offset//self.toks_per_frame), self.coss_t, self.sins_t) 
        return x