multimodalart HF Staff commited on
Commit
e3acd98
·
verified ·
1 Parent(s): dc748ce

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +357 -0
app.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import os
4
+ import math
5
+ import pickle
6
+ import shutil
7
+ import subprocess
8
+ import sys
9
+ import textwrap
10
+ import time
11
+ from dataclasses import dataclass
12
+ from typing import Optional
13
+ import spaces
14
+
15
+ import gradio as gr
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+
21
+ # --- One-Time Setup Function ---
22
+
23
+ def setup_data():
24
+ """
25
+ Checks for dataset metadata and prepares it if missing.
26
+ This involves cloning a repo, running a script, and cleaning up.
27
+ """
28
+ data_dir = 'shakespeare_char'
29
+ meta_path = os.path.join(data_dir, 'meta.pkl')
30
+
31
+ if os.path.exists(meta_path):
32
+ print("Dataset metadata found. Skipping setup.")
33
+ return
34
+
35
+ print("Dataset metadata not found. Starting one-time setup...")
36
+ print("This may take a minute...")
37
+
38
+ repo_url = "https://github.com/karpathy/nanoGPT"
39
+ repo_dir = "nanoGPT"
40
+
41
+ try:
42
+ # 1. Clone the repository
43
+ print(f"Cloning {repo_url}...")
44
+ subprocess.run(["git", "clone", repo_url], check=True, capture_output=True)
45
+
46
+ # 2. Copy the data directory
47
+ source_data_dir = os.path.join(repo_dir, 'data', 'shakespeare_char')
48
+ print(f"Copying data from {source_data_dir} to {data_dir}...")
49
+ shutil.copytree(source_data_dir, data_dir)
50
+
51
+ # 3. Run the preparation script
52
+ prepare_script_path = os.path.join(data_dir, 'prepare.py')
53
+ print(f"Running {prepare_script_path} to generate metadata...")
54
+ # Use the same python executable that is running this script
55
+ subprocess.run([sys.executable, prepare_script_path], check=True, capture_output=True)
56
+
57
+ print("Setup successful. 'meta.pkl' has been created.")
58
+
59
+ except subprocess.CalledProcessError as e:
60
+ print(f"An error occurred during setup: {e}", file=sys.stderr)
61
+ print(f"Stdout: {e.stdout.decode()}", file=sys.stderr)
62
+ print(f"Stderr: {e.stderr.decode()}", file=sys.stderr)
63
+ sys.exit("Setup failed. Please check your git installation and internet connection.")
64
+ except Exception as e:
65
+ print(f"An unexpected error occurred: {e}", file=sys.stderr)
66
+ sys.exit("Setup failed.")
67
+ finally:
68
+ # 4. Clean up the cloned repository
69
+ if os.path.exists(repo_dir):
70
+ print(f"Cleaning up by removing '{repo_dir}' directory...")
71
+ shutil.rmtree(repo_dir)
72
+
73
+ # --- Run Setup and Load Data ---
74
+ setup_data()
75
+
76
+ # Load metadata for character mappings
77
+ data_dir = './shakespeare_char/'
78
+ meta_path = os.path.join(data_dir, 'meta.pkl')
79
+ with open(meta_path, 'rb') as f:
80
+ meta = pickle.load(f)
81
+
82
+ itos = meta['itos']
83
+ stoi = meta['stoi']
84
+ vocab_size = meta['vocab_size']
85
+ CONTEXT_LENGTH = 256
86
+
87
+ def decode(indices_tensor: torch.Tensor):
88
+ '''Decodes a 1D tensor of indices to text'''
89
+ if indices_tensor.dim() == 2:
90
+ indices_tensor = indices_tensor[0]
91
+ indices = indices_tensor.cpu().numpy()
92
+ return ''.join([itos[i] for i in indices])
93
+
94
+ def wrap_text(long_text, width=80):
95
+ """Wraps text to a maximum line width, preserving paragraph breaks."""
96
+ paragraphs = long_text.splitlines()
97
+ wrapped = [textwrap.fill(p, width=width) if p else '' for p in paragraphs]
98
+ return "\n".join(wrapped)
99
+
100
+
101
+ # --- Model Architecture (Copied from the notebook) ---
102
+
103
+ class MLP(nn.Module):
104
+ def __init__(self, config):
105
+ super().__init__()
106
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
107
+ self.gelu = nn.GELU()
108
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
109
+ self.dropout = nn.Dropout(config.dropout)
110
+ def forward(self, x):
111
+ return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))
112
+
113
+ class SelfAttention(nn.Module):
114
+ def __init__(self, config):
115
+ super().__init__()
116
+ assert config.n_embd % config.n_head == 0
117
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
118
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
119
+ self.attn_dropout = nn.Dropout(config.dropout)
120
+ self.resid_dropout = nn.Dropout(config.dropout)
121
+ self.n_head = config.n_head
122
+ self.n_embd = config.n_embd
123
+ self.dropout = config.dropout
124
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
125
+ def forward(self, x):
126
+ B, T, C = x.size()
127
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
128
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
129
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
130
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
131
+ if self.flash:
132
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=False)
133
+ else:
134
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
135
+ att = F.softmax(att, dim=-1)
136
+ att = self.attn_dropout(att)
137
+ y = att @ v
138
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
139
+ return self.resid_dropout(self.c_proj(y))
140
+
141
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
142
+ return x * (1 + scale) + shift
143
+
144
+ def bias_add_scale(x: torch.Tensor, bias: Optional[torch.Tensor], scale: torch.Tensor, residual: Optional[torch.Tensor]) -> torch.Tensor:
145
+ out = scale * (x + bias) if bias is not None else scale * x
146
+ return residual + out if residual is not None else out
147
+
148
+ class DDiTBlock(nn.Module):
149
+ def __init__(self, config):
150
+ super().__init__()
151
+ self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias)
152
+ self.attn = SelfAttention(config)
153
+ self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
154
+ self.mlp = MLP(config)
155
+ self.adaLN_modulation = nn.Linear(config.cond_dim, 6 * config.n_embd, bias=True)
156
+ self.adaLN_modulation.weight.data.zero_()
157
+ self.adaLN_modulation.bias.data.zero_()
158
+ def forward(self, x, c):
159
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)
160
+ x_skip = x
161
+ x = modulate(self.ln_1(x), shift_msa, scale_msa)
162
+ x = self.attn(x)
163
+ x = bias_add_scale(x, None, gate_msa, x_skip)
164
+ x = bias_add_scale(self.mlp(modulate(self.ln_2(x), shift_mlp, scale_mlp)), None, gate_mlp, x)
165
+ return x
166
+
167
+ class DDitFinalLayer(nn.Module):
168
+ def __init__(self, config):
169
+ super().__init__()
170
+ self.norm_final = nn.LayerNorm(config.n_embd, bias=config.bias)
171
+ self.linear = nn.Linear(config.n_embd, config.vocab_size)
172
+ self.linear.weight.data.zero_()
173
+ self.linear.bias.data.zero_()
174
+ self.adaLN_modulation = nn.Linear(config.cond_dim, 2 * config.n_embd)
175
+ self.adaLN_modulation.weight.data.zero_()
176
+ self.adaLN_modulation.bias.data.zero_()
177
+ def forward(self, x, c):
178
+ shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
179
+ x = modulate(self.norm_final(x), shift, scale)
180
+ return self.linear(x)
181
+
182
+ class TimestepEmbedder(nn.Module):
183
+ def __init__(self, hidden_size, frequency_embedding_size=256):
184
+ super().__init__()
185
+ self.mlp = nn.Sequential(
186
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
187
+ nn.SiLU(),
188
+ nn.Linear(hidden_size, hidden_size, bias=True),
189
+ )
190
+ self.frequency_embedding_size = frequency_embedding_size
191
+ @staticmethod
192
+ def timestep_embedding(t, dim, max_period=10000):
193
+ half = dim // 2
194
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
195
+ args = t[:, None].float() * freqs[None]
196
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
197
+ if dim % 2:
198
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
199
+ return embedding
200
+ def forward(self, t):
201
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
202
+ return self.mlp(t_freq)
203
+
204
+ class GPT(nn.Module):
205
+ def __init__(self, config):
206
+ super().__init__()
207
+ self.config = config
208
+ self.sigma_map = TimestepEmbedder(config.cond_dim)
209
+ self.transformer = nn.ModuleDict(dict(
210
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
211
+ wpe = nn.Embedding(config.block_size, config.n_embd),
212
+ drop = nn.Dropout(config.dropout),
213
+ h = nn.ModuleList([DDiTBlock(config) for _ in range(config.n_layer)]),
214
+ ))
215
+ self.lm_head = DDitFinalLayer(config)
216
+ self.apply(self._init_weights)
217
+ def _init_weights(self, module):
218
+ if isinstance(module, nn.Linear):
219
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
220
+ if module.bias is not None:
221
+ torch.nn.init.zeros_(module.bias)
222
+ elif isinstance(module, nn.Embedding):
223
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
224
+ def forward(self, idx, sigma):
225
+ sigma = sigma.reshape(-1)
226
+ b, t = idx.size()
227
+ c = F.silu(self.sigma_map(sigma))
228
+ pos = torch.arange(0, t, dtype=torch.long, device=idx.device)
229
+ tok_emb = self.transformer.wte(idx)
230
+ pos_emb = self.transformer.wpe(pos)
231
+ x = self.transformer.drop(tok_emb + pos_emb)
232
+ for block in self.transformer.h:
233
+ x = block(x, c)
234
+ x = self.lm_head(x, c)
235
+ return torch.scatter(x, -1, idx[..., None], torch.zeros_like(x[..., :1]))
236
+
237
+ @dataclass
238
+ class GPTConfig:
239
+ block_size: int = 1024
240
+ vocab_size: int = 50304
241
+ n_layer: int = 12
242
+ n_head: int = 12
243
+ n_embd: int = 768
244
+ cond_dim: int = 64
245
+ dropout: float = 0.0
246
+ bias: bool = False
247
+
248
+ # --- Noise Schedule & Sampling Logic ---
249
+
250
+ class GeometricNoise:
251
+ def __init__(self, sigma_min=1e-4, sigma_max=20):
252
+ self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max])
253
+ def rate_noise(self, t):
254
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * (self.sigmas[1].log() - self.sigmas[0].log())
255
+ def total_noise(self, t):
256
+ return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t
257
+ def __call__(self, t):
258
+ return self.total_noise(t), self.rate_noise(t)
259
+
260
+ def transition(x_t: torch.Tensor, delta_sigma: torch.Tensor) -> torch.Tensor:
261
+ base_prob = (1 - torch.exp(-delta_sigma[..., None])) / vocab_size
262
+ trans = torch.ones(*x_t.shape, vocab_size, device=x_t.device) * base_prob
263
+ trans = trans.scatter(-1, x_t[..., None], torch.zeros_like(trans))
264
+ diag_fill = 1 - trans.sum(dim=-1, keepdim=True)
265
+ return trans.scatter(-1, x_t[..., None], diag_fill)
266
+
267
+ def staggered_score(score, delta_sigma):
268
+ exp_factor = torch.exp(-delta_sigma)[..., None]
269
+ correction = ((exp_factor - 1) / (vocab_size * exp_factor)) * score.sum(dim=-1, keepdim=True)
270
+ return correction + score / exp_factor
271
+
272
+ def sample_categorical(probs: torch.Tensor) -> torch.Tensor:
273
+ eps = 1e-10
274
+ gumbel_noise = -torch.log(-torch.log(torch.rand_like(probs) + eps) + eps)
275
+ return torch.argmax(torch.log(probs + eps) + gumbel_noise, dim=-1)
276
+
277
+ # --- Global Model Loading ---
278
+
279
+ print("Setting up model and device...")
280
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
281
+ model_args = dict(n_layer=6, n_head=6, n_embd=384, cond_dim=64,
282
+ bias=False, vocab_size=vocab_size, block_size=CONTEXT_LENGTH, dropout=0.2)
283
+ config = GPTConfig(**model_args)
284
+ model = GPT(config)
285
+
286
+ print("Loading pre-trained model weights...")
287
+ model.load_state_dict(
288
+ torch.hub.load_state_dict_from_url(
289
+ 'https://raw.githubusercontent.com/ash80/diffusion-gpt/master/pretrained_model/model_epoch_25.pth',
290
+ map_location=DEVICE
291
+ )
292
+ )
293
+ model.to(DEVICE)
294
+ model.eval()
295
+
296
+ NOISE = GeometricNoise(sigma_min=1e-4, sigma_max=20)
297
+ print("Model setup complete. Launching Gradio demo...")
298
+
299
+ # --- Gradio Generation Function ---
300
+
301
+ @spaces.GPU
302
+ def generate_text(steps):
303
+ """Generator function that yields denoised text at each step."""
304
+ steps = int(steps)
305
+ eps = 1e-5
306
+ timesteps = torch.linspace(1, eps, steps + 1, device=DEVICE)
307
+ step_size = (1 - eps) / steps
308
+
309
+ # Start with a fresh random sample
310
+ x = torch.randint(0, vocab_size, (1, CONTEXT_LENGTH), device=DEVICE)
311
+
312
+ # Initial random text
313
+ initial_text = decode(x)
314
+ yield f"Step 0/{steps} (Initial Noise):\n\n{wrap_text(initial_text)}"
315
+ time.sleep(0.5)
316
+
317
+ with torch.no_grad():
318
+ for i in range(steps):
319
+ progress(i / steps, desc=f"Denoising Step {i+1}/{steps}")
320
+
321
+ t = timesteps[i] * torch.ones(x.shape[0], 1, device=DEVICE)
322
+ curr_sigma_bar = NOISE(t)[0]
323
+
324
+ next_sigma_bar = NOISE(t - step_size)[0]
325
+ delta_sigma = curr_sigma_bar - next_sigma_bar
326
+
327
+ log_score = model(x, curr_sigma_bar)
328
+ score = torch.exp(log_score)
329
+
330
+ stag_score = staggered_score(score, delta_sigma)
331
+ probs = stag_score * transition(x, delta_sigma)
332
+ x = sample_categorical(probs)
333
+
334
+ # Yield the decoded text and step info
335
+ decoded_text = decode(x)
336
+ yield f"Step {i+1}/{steps}:\n\n{wrap_text(decoded_text)}"
337
+
338
+ # Final result
339
+ final_text = decode(x)
340
+ yield f"Final Result (Step {steps}/{steps}):\n\n{wrap_text(final_text)}"
341
+
342
+ # --- Gradio Interface ---
343
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
344
+ gr.Markdown(
345
+ """
346
+ # The Annotated Discrete Diffusion Model: Live Demo
347
+ This demo visualizes the denoising process of a character-level discrete diffusion model.
348
+ Start with pure random noise and watch as coherent text, in the style of Shakespeare, emerges over several steps.
349
+ """
350
+ )
351
+ with gr.Row():
352
+ steps_slider = gr.Slider(
353
+ minimum=10,
354
+ maximum=200,
355
+ value=128,
356
+ step=1,
357
+ label="Number of Denoising Steps",