chrisxx commited on
Commit
8ef5ca9
·
verified ·
1 Parent(s): f266945

Update src/models/dit_dforce.py

Browse files
Files changed (1) hide show
  1. src/models/dit_dforce.py +108 -76
src/models/dit_dforce.py CHANGED
@@ -2,31 +2,45 @@ import torch as t
2
  from torch import nn
3
  import torch.nn.functional as F
4
 
5
- from ..nn.attn import Attention, AttentionEinOps, KVCache
6
  from ..nn.patch import Patch, UnPatch
7
  from ..nn.geglu import GEGLU
8
- from ..nn.pe import FrameRoPE, NumericEncoding, RoPE
9
  from jaxtyping import Float, Bool, Int
10
  from torch import Tensor
11
- from typing import Optional
12
 
 
13
  import math
14
 
15
  def modulate(x, shift, scale):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  return x * (1 + scale) + shift
17
 
18
  class CausalBlock(nn.Module):
19
- def __init__(self, layer_idx, d_model, expansion, n_heads, rope=None):
20
  super().__init__()
21
  self.layer_idx = layer_idx
22
  self.d_model = d_model
23
  self.expansion = expansion
24
  self.n_heads = n_heads
25
  self.norm1 = nn.LayerNorm(d_model)
26
- if t.backends.mps.is_available():
27
- self.selfattn = AttentionEinOps(d_model, n_heads, rope=rope)
28
- else:
29
- self.selfattn = AttentionEinOps(d_model, n_heads, rope=rope) # there is a problem with flexattn i think
30
  self.norm2 = nn.LayerNorm(d_model)
31
  self.geglu = GEGLU(d_model, expansion*d_model, d_model)
32
 
@@ -35,28 +49,20 @@ class CausalBlock(nn.Module):
35
  nn.Linear(d_model, 6 * d_model, bias=True),
36
  )
37
 
38
- def forward(self, z, cond, mask_self, cache: Optional[KVCache] = None):
39
  # batch durseq1 d
40
  # batch durseq2 d
41
  mu1, sigma1, c1, mu2, sigma2, c2 = self.modulation(cond).chunk(6, dim=-1)
42
  residual = z
43
  z = modulate(self.norm1(z), mu1, sigma1)
44
- if cache is not None:
45
- k, v = cache.get(self.layer_idx)
46
- offset = cache.global_location # this enables to include rope and ln into the cache
47
- offset = 0 # this is for reapplying rope again and again to stay more similar to training
48
- z, k_new, v_new = self.selfattn(z, z, mask=mask_self, k_cache=k, v_cache=v, offset=offset)
49
- cache.extend(self.layer_idx, k_new, v_new)
50
- else:
51
- z, _, _ = self.selfattn(z, z, mask=mask_self)
52
-
53
- z = residual + c1*z
54
 
55
  residual = z
56
  z = modulate(self.norm2(z), mu2, sigma2)
57
  z = self.geglu(z)
58
- z = residual + c2*z
59
- return z
60
 
61
 
62
  class CausalDit(nn.Module):
@@ -64,10 +70,10 @@ class CausalDit(nn.Module):
64
  patch_size=2, n_heads=8, expansion=4, n_blocks=6,
65
  n_registers=1, n_actions=4, bidirectional=False,
66
  debug=False,
67
- legacy=False,
68
- frame_rope=False,
69
  rope_C=10000,
70
- rope_tmax=None):
 
 
71
  super().__init__()
72
  self.height = height
73
  self.width = width
@@ -81,24 +87,39 @@ class CausalDit(nn.Module):
81
  self.T = T
82
  self.patch_size = patch_size
83
  self.debug = debug
84
- self.legacy = legacy
85
  self.bidirectional = bidirectional
86
- self.frame_rope = frame_rope
87
  self.toks_per_frame = (height//patch_size)*(width//patch_size) + n_registers
88
  self.rope_C = rope_C
89
- if frame_rope:
90
- print("Using frame rope")
91
- print(self.toks_per_frame)
92
- self.rope_seq = FrameRoPE(d_model//n_heads, self.n_window, self.toks_per_frame, C=rope_C)
93
- self.grid_pe = nn.Parameter(t.randn(self.toks_per_frame - n_registers, d_model) * 1/d_model**0.5)
94
- else:
95
- if rope_tmax is None:
96
- rope_tmax = self.n_window*self.toks_per_frame
97
  self.rope_seq = RoPE(d_model//n_heads, rope_tmax, C=rope_C)
98
- self.grid_pe = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  self.rope_tmax = rope_tmax
100
 
101
- self.blocks = nn.ModuleList([CausalBlock(lidx, d_model, expansion, n_heads, rope=self.rope_seq) for lidx in range(n_blocks)])
102
  self.patch = Patch(in_channels=in_channels, out_channels=d_model, patch_size=patch_size)
103
  self.norm = nn.LayerNorm(d_model)
104
  self.unpatch = UnPatch(height, width, in_channels=d_model, out_channels=in_channels, patch_size=patch_size)
@@ -112,25 +133,18 @@ class CausalDit(nn.Module):
112
  )
113
  self.cache = None
114
 
115
- def activate_caching(self, batch_size, max_frames=None, cache_rope=False):
116
- self.cache = KVCache(batch_size, self.n_blocks, self.n_heads, self.d_head, self.toks_per_frame, self.n_window, dtype=self.dtype, device=self.device)
117
- if max_frames is not None:
118
- self.rope_seq = RoPE(self.d_head, max_frames*self.toks_per_frame, C=self.rope_C)
119
- print(self.rope_seq.sins.shape, self.rope_seq.coss.shape)
120
- self.rope_seq.to(self.device)
121
- self.rope_seq.to(self.dtype)
122
- for idx, block in enumerate(self.blocks):
123
- print("updating rope for block", idx)
124
- print(self.blocks[idx].selfattn.rope.sins.shape, self.blocks[idx].selfattn.rope.coss.shape)
125
- self.blocks[idx].selfattn.rope = self.rope_seq
126
- print(self.blocks[idx].selfattn.rope.sins.shape, self.blocks[idx].selfattn.rope.coss.shape)
127
- def deactivate_caching(self):
128
- self.cache = None
129
 
130
  def forward(self,
131
  z: Float[Tensor, "batch dur channels height width"],
132
  actions: Float[Tensor, "batch dur"],
133
- ts: Int[Tensor, "batch dur"]):
 
 
134
 
135
  if ts.shape[1] == 1:
136
  ts = ts.repeat(1, z.shape[1])
@@ -138,11 +152,11 @@ class CausalDit(nn.Module):
138
  a = self.action_emb(actions) # batch dur d
139
  ts_scaled = (ts * self.T).clamp(0, self.T - 1).long()
140
  cond = self.time_emb_mixer(self.time_emb(ts_scaled)) + a
141
- #print(ts_scaled.shape, a.shape, cond.shape, actions.shape)
142
- cond = cond.repeat_interleave(self.toks_per_frame, dim=1)
143
  z = self.patch(z) # batch dur seq d
144
  if self.grid_pe is not None:
145
  z = z + self.grid_pe[None, None]
 
146
  # self.registers is in 1x
147
  zr = t.cat((z, self.registers[None, None].repeat([z.shape[0], z.shape[1], 1, 1])), dim=2)# z plus registers
148
  if self.bidirectional:
@@ -152,18 +166,29 @@ class CausalDit(nn.Module):
152
  batch, durzr, seqzr, d = zr.shape
153
  zr = zr.reshape(batch, -1, d) # batch durseq d
154
 
155
- for block in self.blocks:
156
- zr = block(zr, cond, mask_self, cache=self.cache)
 
 
 
 
 
 
 
 
 
 
 
157
  mu, sigma = self.modulation(cond).chunk(2, dim=-1)
158
  zr = modulate(self.norm(zr), mu, sigma)
159
  zr = zr.reshape(batch, durzr, seqzr, d)
160
  out = self.unpatch(zr[:, :, :-self.n_registers])
161
- return out # batch dur channels height width
162
 
163
  @property
164
  def causal_mask(self):
165
  size = self.n_window
166
- m_self = t.tril(t.ones((size, size), dtype=t.int8, device=self.device)) #- t.tril(t.ones((size, size), dtype=t.int8, device=self.device), diagonal=-self.n_window)
167
  m_self = t.kron(m_self, t.ones((self.toks_per_frame, self.toks_per_frame), dtype=t.int8, device=self.device))
168
  m_self = m_self.to(bool)
169
  return ~ m_self # we want to mask out the ones
@@ -177,8 +202,30 @@ class CausalDit(nn.Module):
177
  return self.parameters().__next__().dtype
178
 
179
 
180
- def get_model(height, width, n_window=5, d_model=64, T=100, n_blocks=2, patch_size=2, n_heads=8, bidirectional=False, in_channels=3, frame_rope=False, C=10000):
181
- return CausalDit(height, width, n_window, d_model, T, in_channels=in_channels, n_blocks=n_blocks, patch_size=patch_size, n_heads=n_heads, bidirectional=bidirectional, frame_rope=frame_rope, rope_C=C)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  if __name__ == "__main__":
184
  print("running w/o cache")
@@ -186,21 +233,6 @@ if __name__ == "__main__":
186
  z = t.rand((2, 6, 3, 20, 20))
187
  actions = t.randint(4, (2, 6))
188
  ts = t.rand((2, 6))
189
- out = dit(z, actions, ts)
190
  print(z.shape)
191
- print(out.shape)
192
-
193
- print("running w cache")
194
- dit = CausalDit(20, 20, 10, 64, 5, n_blocks=2)
195
- dit.activate_caching(2)
196
- print(dit.cache.toks_per_frame)
197
- print(dit.cache.size)
198
- for i in range(30):
199
- print(dit.cache.local_loc)
200
- print(dit.cache.global_loc)
201
- z = t.rand((2, 1, 3, 20, 20))
202
- actions = t.randint(4, (2, 1))
203
- ts = t.rand((2, 1))
204
- out = dit(z, actions, ts)
205
- print(i, z.shape)
206
- print(i, out.shape)
 
2
  from torch import nn
3
  import torch.nn.functional as F
4
 
5
+ from ..nn.attn import AttentionEinOps, KVCache, KVCacheNaive
6
  from ..nn.patch import Patch, UnPatch
7
  from ..nn.geglu import GEGLU
8
+ from ..nn.pe import NumericEncoding, RoPE, LearnRoPE, VidRoPE
9
  from jaxtyping import Float, Bool, Int
10
  from torch import Tensor
11
+ from typing import Optional, Literal
12
 
13
+ import matplotlib.pyplot as plt
14
  import math
15
 
16
  def modulate(x, shift, scale):
17
+ b, s, d = x.shape
18
+ toks_per_frame = s // shift.shape[1]
19
+ x = x.reshape(b, -1, toks_per_frame, d)
20
+ x = x * (1 + scale[:, :, None, :]) + shift[:, :, None, :]
21
+ x = x.reshape(b, s, d)
22
+ return x
23
+
24
+ def gate(x, gate):
25
+ b, s, d = x.shape
26
+ toks_per_frame = s // gate.shape[1]
27
+ x = x.reshape(b, -1, toks_per_frame, d)
28
+ x = x * gate[:, :, None, :]
29
+ x = x.reshape(b, s, d)
30
+ return x
31
+
32
+ def modulate_deprecated(x, shift, scale):
33
  return x * (1 + scale) + shift
34
 
35
  class CausalBlock(nn.Module):
36
+ def __init__(self, layer_idx, d_model, expansion, n_heads, rope=None, ln_first = False):
37
  super().__init__()
38
  self.layer_idx = layer_idx
39
  self.d_model = d_model
40
  self.expansion = expansion
41
  self.n_heads = n_heads
42
  self.norm1 = nn.LayerNorm(d_model)
43
+ self.selfattn = AttentionEinOps(d_model, n_heads, rope=rope, ln_first=ln_first)
 
 
 
44
  self.norm2 = nn.LayerNorm(d_model)
45
  self.geglu = GEGLU(d_model, expansion*d_model, d_model)
46
 
 
49
  nn.Linear(d_model, 6 * d_model, bias=True),
50
  )
51
 
52
+ def forward(self, z, cond, mask_self, cached_k=None, cached_v=None):
53
  # batch durseq1 d
54
  # batch durseq2 d
55
  mu1, sigma1, c1, mu2, sigma2, c2 = self.modulation(cond).chunk(6, dim=-1)
56
  residual = z
57
  z = modulate(self.norm1(z), mu1, sigma1)
58
+ z, k_new, v_new = self.selfattn(z, z, mask=mask_self, k_cache=cached_k, v_cache=cached_v)
59
+ z = residual + gate(z, c1)
 
 
 
 
 
 
 
 
60
 
61
  residual = z
62
  z = modulate(self.norm2(z), mu2, sigma2)
63
  z = self.geglu(z)
64
+ z = residual + gate(z, c2)
65
+ return z, k_new, v_new
66
 
67
 
68
  class CausalDit(nn.Module):
 
70
  patch_size=2, n_heads=8, expansion=4, n_blocks=6,
71
  n_registers=1, n_actions=4, bidirectional=False,
72
  debug=False,
 
 
73
  rope_C=10000,
74
+ rope_tmax=None,
75
+ rope_type: Literal["rope", "learn", "vid"] = "rope",
76
+ ln_first: bool = False):
77
  super().__init__()
78
  self.height = height
79
  self.width = width
 
87
  self.T = T
88
  self.patch_size = patch_size
89
  self.debug = debug
 
90
  self.bidirectional = bidirectional
 
91
  self.toks_per_frame = (height//patch_size)*(width//patch_size) + n_registers
92
  self.rope_C = rope_C
93
+ if rope_tmax is None:
94
+ rope_tmax = self.n_window*self.toks_per_frame
95
+ if rope_type == "rope":
 
 
 
 
 
96
  self.rope_seq = RoPE(d_model//n_heads, rope_tmax, C=rope_C)
97
+ elif rope_type == "learn":
98
+ self.rope_seq = LearnRoPE(d_model//n_heads, rope_tmax, C=rope_C)
99
+ elif rope_type == "vid":
100
+ d_head = d_model//n_heads
101
+ d_x = d_y = d_t = d_head // 3
102
+ C_x = C_y = C_t = rope_C // 3
103
+ ctx_x = width // patch_size
104
+ ctx_y = height // patch_size
105
+ ctx_t = self.n_window
106
+ self.rope_seq = VidRoPE(d_head,
107
+ d_x,
108
+ d_y,
109
+ d_t,
110
+ ctx_x,
111
+ ctx_y,
112
+ ctx_t,
113
+ C_x,
114
+ C_y,
115
+ C_t,
116
+ self.toks_per_frame,
117
+ n_registers)
118
+
119
+ self.grid_pe = None
120
  self.rope_tmax = rope_tmax
121
 
122
+ self.blocks = nn.ModuleList([CausalBlock(lidx, d_model, expansion, n_heads, rope=self.rope_seq, ln_first=ln_first) for lidx in range(n_blocks)])
123
  self.patch = Patch(in_channels=in_channels, out_channels=d_model, patch_size=patch_size)
124
  self.norm = nn.LayerNorm(d_model)
125
  self.unpatch = UnPatch(height, width, in_channels=d_model, out_channels=in_channels, patch_size=patch_size)
 
133
  )
134
  self.cache = None
135
 
136
+ def create_cache(self, batch_size):
137
+ return KVCache(batch_size, self.n_blocks, self.n_heads, self.d_head, self.toks_per_frame, self.n_window, dtype=self.dtype, device=self.device)
138
+
139
+ def create_cache2(self, batch_size):
140
+ return KVCacheNaive(batch_size, self.n_blocks, self.n_heads, self.d_head, self.toks_per_frame, self.n_window, dtype=self.dtype, device=self.device)
 
 
 
 
 
 
 
 
 
141
 
142
  def forward(self,
143
  z: Float[Tensor, "batch dur channels height width"],
144
  actions: Float[Tensor, "batch dur"],
145
+ ts: Int[Tensor, "batch dur"],
146
+ cached_k: Optional[Float[Tensor, "layer batch dur seq d"]] = None,
147
+ cached_v: Optional[Float[Tensor, "layer batch dur seq d"]] = None):
148
 
149
  if ts.shape[1] == 1:
150
  ts = ts.repeat(1, z.shape[1])
 
152
  a = self.action_emb(actions) # batch dur d
153
  ts_scaled = (ts * self.T).clamp(0, self.T - 1).long()
154
  cond = self.time_emb_mixer(self.time_emb(ts_scaled)) + a
155
+
 
156
  z = self.patch(z) # batch dur seq d
157
  if self.grid_pe is not None:
158
  z = z + self.grid_pe[None, None]
159
+
160
  # self.registers is in 1x
161
  zr = t.cat((z, self.registers[None, None].repeat([z.shape[0], z.shape[1], 1, 1])), dim=2)# z plus registers
162
  if self.bidirectional:
 
166
  batch, durzr, seqzr, d = zr.shape
167
  zr = zr.reshape(batch, -1, d) # batch durseq d
168
 
169
+ k_update = []
170
+ v_update = []
171
+ for bidx, block in enumerate(self.blocks):
172
+ ks = cached_k[bidx] if cached_k is not None else None
173
+ vs = cached_v[bidx] if cached_v is not None else None
174
+ zr, k_new, v_new = block(zr, cond, mask_self, cached_k=ks, cached_v=vs)
175
+ if k_new is not None:
176
+ k_update.append(k_new.unsqueeze(0))
177
+ v_update.append(v_new.unsqueeze(0))
178
+ if len(k_update) > 0:
179
+ k_update = t.cat(k_update, dim=0)
180
+ v_update = t.cat(v_update, dim=0)
181
+
182
  mu, sigma = self.modulation(cond).chunk(2, dim=-1)
183
  zr = modulate(self.norm(zr), mu, sigma)
184
  zr = zr.reshape(batch, durzr, seqzr, d)
185
  out = self.unpatch(zr[:, :, :-self.n_registers])
186
+ return out, k_update, v_update
187
 
188
  @property
189
  def causal_mask(self):
190
  size = self.n_window
191
+ m_self = t.tril(t.ones((size, size), dtype=t.int8, device=self.device)) # - t.tril(t.ones((size, size), dtype=t.int8, device=self.device), diagonal=-self.n_window) # this would be useful if we go bigger than windowxwindow
192
  m_self = t.kron(m_self, t.ones((self.toks_per_frame, self.toks_per_frame), dtype=t.int8, device=self.device))
193
  m_self = m_self.to(bool)
194
  return ~ m_self # we want to mask out the ones
 
202
  return self.parameters().__next__().dtype
203
 
204
 
205
+ def get_model(height, width,
206
+ n_window=5,
207
+ d_model=64,
208
+ T=100,
209
+ n_blocks=2,
210
+ patch_size=2,
211
+ n_heads=8,
212
+ bidirectional=False,
213
+ in_channels=3,
214
+ C=10000,
215
+ rope_type: Literal["rope", "learn", "vid"] = "rope",
216
+ ln_first=False):
217
+ return CausalDit(height, width,
218
+ n_window,
219
+ d_model,
220
+ T,
221
+ in_channels=in_channels,
222
+ n_blocks=n_blocks,
223
+ patch_size=patch_size,
224
+ n_heads=n_heads,
225
+ bidirectional=bidirectional,
226
+ rope_C=C,
227
+ rope_type=rope_type,
228
+ ln_first=ln_first)
229
 
230
  if __name__ == "__main__":
231
  print("running w/o cache")
 
233
  z = t.rand((2, 6, 3, 20, 20))
234
  actions = t.randint(4, (2, 6))
235
  ts = t.rand((2, 6))
236
+ out, _, _ = dit(z, actions, ts)
237
  print(z.shape)
238
+ print(out.shape)