chrisxx commited on
Commit
a9440e0
·
verified ·
1 Parent(s): dd18454

Update src/nn/pe.py

Browse files
Files changed (1) hide show
  1. src/nn/pe.py +103 -28
src/nn/pe.py CHANGED
@@ -5,6 +5,16 @@ import math
5
  from jaxtyping import Float, Bool, Int
6
  from torch import Tensor
7
  from typing import Optional
 
 
 
 
 
 
 
 
 
 
8
 
9
  class NumericEncoding(nn.Module):
10
  def __init__(self, C = 1e4, dim = 64, n_max = 10000):
@@ -29,19 +39,16 @@ class NumericEncoding(nn.Module):
29
  class RoPE(nn.Module):
30
  def __init__(self, d_head, n_ctx, C=10000):
31
  super().__init__()
32
- thetas = t.exp(-math.log(C)*t.arange(0,d_head,2)/d_head)
33
- thetas = thetas.repeat([2,1]).T.flatten()
34
- positions = t.arange(n_ctx)
35
- all_thetas = positions.unsqueeze(1)*thetas.unsqueeze(0)
36
- sins = t.sin(all_thetas)
37
- coss = t.cos(all_thetas)
38
  self.register_buffer('sins', sins.unsqueeze(0).unsqueeze(2))
39
  self.register_buffer('coss', coss.unsqueeze(0).unsqueeze(2))
40
 
41
  def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
42
  offset: int = 0):
43
- x = key_or_query
44
- # start with doing it for just a single position m
45
  x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
46
  even = t.arange(0, x.shape[-1], 2)
47
  odd = t.arange(1, x.shape[-1],2)
@@ -50,28 +57,96 @@ class RoPE(nn.Module):
50
  assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
51
  return self.coss[:,offset:offset+x.shape[1]]*x + self.sins[:,offset:offset+x.shape[1]]*x_perm
52
 
53
-
54
- class FrameRoPE(nn.Module):
55
- def __init__(self, d_head, n_ctx, toks_per_frame, C=10000):
56
  super().__init__()
57
- thetas = t.exp(-math.log(C)*t.arange(0,d_head,2)/d_head)
58
- thetas = thetas.repeat([2,1]).T.flatten()
59
- positions = t.arange(n_ctx)
60
- all_thetas = positions.unsqueeze(1)*thetas.unsqueeze(0)
61
- sins = t.sin(all_thetas)
62
- coss = t.cos(all_thetas)
63
- self.register_buffer('sins', sins.unsqueeze(0).unsqueeze(2))
64
- self.register_buffer('coss', coss.unsqueeze(0).unsqueeze(2))
65
- self.toks_per_frame = toks_per_frame
66
 
67
- def forward(self, key_or_query: Float[Tensor, "batch dur*seq n_head d_head"]):
68
- x = key_or_query
69
- # start with doing it for just a single position m
70
- x_perm = t.empty(x.shape, dtype=x.dtype, device=x.device) # batch sequence n_head d_head, we perm the last axis
71
  even = t.arange(0, x.shape[-1], 2)
72
- odd = t.arange(1, x.shape[-1], 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  x_perm[:, :, :, even] = -x[:, :, :, odd]
74
  x_perm[:, :, :, odd] = x[:, :, :, even]
75
- idcs = t.arange(0, x.shape[1]//self.toks_per_frame, device=x.device)
76
- idcs = idcs[:, None].repeat(1, self.toks_per_frame).flatten()
77
- return self.coss[:,idcs]*x + self.sins[:,idcs]*x_perm
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from jaxtyping import Float, Bool, Int
6
  from torch import Tensor
7
  from typing import Optional
8
+ from pdb import set_trace
9
+
10
+ def compute_trig(d_head, n_ctx, C):
11
+ thetas = t.exp(-math.log(C)*t.arange(0,d_head,2)/d_head)
12
+ thetas = thetas.repeat([2,1]).T.flatten()
13
+ positions = t.arange(n_ctx)
14
+ all_thetas = positions.unsqueeze(1)*thetas.unsqueeze(0)
15
+ sins = t.sin(all_thetas)
16
+ coss = t.cos(all_thetas)
17
+ return sins, coss
18
 
19
  class NumericEncoding(nn.Module):
20
  def __init__(self, C = 1e4, dim = 64, n_max = 10000):
 
39
  class RoPE(nn.Module):
40
  def __init__(self, d_head, n_ctx, C=10000):
41
  super().__init__()
42
+ self.d_head = d_head
43
+ self.n_ctx = n_ctx
44
+ self.C = C
45
+ sins, coss = compute_trig(d_head, n_ctx, C)
 
 
46
  self.register_buffer('sins', sins.unsqueeze(0).unsqueeze(2))
47
  self.register_buffer('coss', coss.unsqueeze(0).unsqueeze(2))
48
 
49
  def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
50
  offset: int = 0):
51
+ x = key_or_query
 
52
  x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
53
  even = t.arange(0, x.shape[-1], 2)
54
  odd = t.arange(1, x.shape[-1],2)
 
57
  assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
58
  return self.coss[:,offset:offset+x.shape[1]]*x + self.sins[:,offset:offset+x.shape[1]]*x_perm
59
 
60
+ class LearnRoPE(nn.Module):
61
+ def __init__(self, d_head, n_ctx, C=10000):
 
62
  super().__init__()
63
+ self.d_head = d_head
64
+ self.n_ctx = n_ctx
65
+ self.C = C
66
+ sins, coss = compute_trig(d_head, n_ctx, C)
67
+ self.sins = nn.Parameter(sins.unsqueeze(0).unsqueeze(2))
68
+ self.coss = nn.Parameter(coss.unsqueeze(0).unsqueeze(2))
 
 
 
69
 
70
+ def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
71
+ offset: int = 0):
72
+ x = key_or_query
73
+ x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
74
  even = t.arange(0, x.shape[-1], 2)
75
+ odd = t.arange(1, x.shape[-1],2)
76
+ x_perm[:, :, :, even] = -x[:, :, :, odd]
77
+ x_perm[:, :, :, odd] = x[:, :, :, even]
78
+ assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
79
+ return self.coss[:,offset:offset+x.shape[1]]*x + self.sins[:,offset:offset+x.shape[1]]*x_perm
80
+
81
+ class VidRoPE(nn.Module):
82
+ def __init__(self, d_head,
83
+ d_x,
84
+ d_y,
85
+ d_t,
86
+ ctx_x,
87
+ ctx_y,
88
+ ctx_t,
89
+ C_x,
90
+ C_y,
91
+ C_t,
92
+ toks_per_frame,
93
+ n_registers):
94
+ super().__init__()
95
+ assert d_x + d_y + d_t <= d_head, f"dx + dy + dt > d_head"
96
+ self.d_head = d_head
97
+ self.d_x = d_x
98
+ self.d_y = d_y
99
+ self.d_t = d_t
100
+ self.ctx_x = ctx_x
101
+ self.ctx_y = ctx_y
102
+ self.ctx_t = ctx_t
103
+ self.C_x = C_x
104
+ self.C_y = C_y
105
+ self.C_t = C_t
106
+ self.toks_per_frame = toks_per_frame
107
+ self.n_registers = n_registers
108
+ sins_x, coss_x = compute_trig(d_x, ctx_x+1, C_x) # +1 for the register
109
+ self.register_buffer("sins_x", sins_x.unsqueeze(0).unsqueeze(2))
110
+ self.register_buffer("coss_x", coss_x.unsqueeze(0).unsqueeze(2))
111
+ sins_y, coss_y = compute_trig(d_y, ctx_y+1, C_y) # +1 for the register
112
+ self.register_buffer("sins_y", sins_y.unsqueeze(0).unsqueeze(2))
113
+ self.register_buffer("coss_y", coss_y.unsqueeze(0).unsqueeze(2))
114
+ sins_t, coss_t = compute_trig(d_t, ctx_t, C_t)
115
+ self.register_buffer("sins_t", sins_t.unsqueeze(0).unsqueeze(2))
116
+ self.register_buffer("coss_t", coss_t.unsqueeze(0).unsqueeze(2))
117
+ n_frames = ctx_t
118
+ # ctx_x should be equal to width
119
+ # ctx_y should be equal to height
120
+ pos_x = t.arange(self.ctx_x).repeat(self.ctx_y) # w cols with h entries each
121
+ pos_x = t.cat([pos_x, t.tensor([self.ctx_x], dtype=t.int32)]) # deal with register
122
+ pos_x = pos_x.repeat(n_frames)
123
+ pos_y = t.arange(self.ctx_y).repeat_interleave(self.ctx_x) # h rows with w entries each
124
+ pos_y = t.cat([pos_y, t.tensor([self.ctx_y], dtype=t.int32)]) # deal with register
125
+ pos_y = pos_y.repeat(n_frames)
126
+ pos_t = t.arange(n_frames).repeat_interleave(self.toks_per_frame)
127
+ self.register_buffer("pos_x", pos_x)
128
+ self.register_buffer("pos_y", pos_y)
129
+ self.register_buffer("pos_t", pos_t)
130
+
131
+
132
+ def rotate(self, x, pos_idcs, coss, sins):
133
+ x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
134
+ even = t.arange(0, x.shape[-1], 2, device=x.device)
135
+ odd = t.arange(1, x.shape[-1], 2, device=x.device)
136
  x_perm[:, :, :, even] = -x[:, :, :, odd]
137
  x_perm[:, :, :, odd] = x[:, :, :, even]
138
+ assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
139
+ assert pos_idcs.shape[0] == x.shape[1], f"pos_idcs length {pos_idcs.shape[0]} must match x.shape[1] {x.shape[1]}"
140
+
141
+ return coss[:,pos_idcs]*x + sins[:,pos_idcs]*x_perm
142
+
143
+
144
+ def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
145
+ offset: int = 0):
146
+ x = key_or_query
147
+ x[:, :, :, :self.d_x] = self.rotate(x[:, :, :, :self.d_x], self.pos_x, self.coss_x, self.sins_x)
148
+ x[:, :, :, self.d_x:self.d_x+self.d_y] = self.rotate(x[:, :, :, self.d_x:self.d_x+self.d_y], self.pos_y, self.coss_y, self.sins_y)
149
+ x[:, :, :, self.d_x+self.d_y:self.d_x+self.d_y+self.d_t] = self.rotate(x[:, : , :, self.d_x+self.d_y:self.d_x+self.d_y+self.d_t], self.pos_t+(offset//self.toks_per_frame), self.coss_t, self.sins_t)
150
+ return x
151
+
152
+