jamesaasher commited on
Commit
369d21b
·
verified ·
1 Parent(s): 35bd2c2

Upload scheduler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scheduler.py +123 -0
scheduler.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+
5
+
6
+ """
7
+ This scheduler has 3 main responsibilities:
8
+
9
+ 1. Setup (init) - Pre-compute noise schedule
10
+ 2. Training (q_sample) - Add noise to images
11
+ 3. Generation (p_sample_text + sample_text) - Remove noise
12
+ step-by-step
13
+
14
+ """
15
+
16
+
17
+ class SimpleDDPMScheduler:
18
+ def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02):
19
+ self.num_timesteps = num_timesteps
20
+
21
+ # Linear beta schedule - can replace with cosine
22
+ self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
23
+ self.alphas = 1.0 - self.betas
24
+ self.alphas_cumprod = torch.cumprod(
25
+ self.alphas, dim=0
26
+ ) # cumulative product - lets us jump to any timestep immediately.
27
+ self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
28
+
29
+ # Calculations for diffusion q(x_t | x_{t-1}) and others (pre-compute for efficiency)
30
+ self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
31
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
32
+
33
+ # Calculations for posterior q(x_{t-1} | x_t, x_0)
34
+ # This tells us how much randomness is appropriate at this step.
35
+ # Removing this would lead to mode-seeking behavior (and poor sample quality).
36
+ self.posterior_variance = (
37
+ self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
38
+ )
39
+
40
+ def q_sample(self, x_start, t, noise=None):
41
+ """Add noise to the clean images according to the noise schedule.
42
+
43
+ So we can have examples at any timestep in the forward process."""
44
+ # Generate original noise
45
+ if noise is None:
46
+ noise = torch.randn_like(x_start)
47
+
48
+ sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
49
+ sqrt_one_minus_alphas_cumprod_t = extract(
50
+ self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
51
+ )
52
+
53
+ return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
54
+
55
+ def p_sample_text(self, model, x, t, text_embeddings, guidance_scale=1.0):
56
+ """Sample x_{t-1} from x_t using the model with text conditioning and CFG.
57
+
58
+ Args:
59
+ model: The diffusion model
60
+ x: Current noisy image
61
+ t: Current timestep
62
+ text_embeddings: Text embeddings for conditioning
63
+ guidance_scale: Classifier-free guidance scale (1.0 = no guidance, higher = stronger)
64
+ """
65
+ # Get model prediction with text conditioning
66
+ predicted_noise = model(x, t, text_embeddings)
67
+
68
+ # Apply classifier-free guidance if scale > 1.0
69
+ if guidance_scale > 1.0:
70
+ # Also get unconditional prediction (with zero text embeddings)
71
+ uncond_embeddings = torch.zeros_like(text_embeddings)
72
+ uncond_noise = model(x, t, uncond_embeddings)
73
+
74
+ # Amplify the difference between conditional and unconditional
75
+ predicted_noise = uncond_noise + guidance_scale * (predicted_noise - uncond_noise)
76
+
77
+ # Get coefficients
78
+ betas_t = extract(self.betas, t, x.shape)
79
+ sqrt_one_minus_alphas_cumprod_t = extract(
80
+ self.sqrt_one_minus_alphas_cumprod, t, x.shape
81
+ )
82
+ sqrt_recip_alphas_t = extract(1.0 / torch.sqrt(self.alphas), t, x.shape)
83
+
84
+ # Compute x_{t-1}
85
+ model_mean = sqrt_recip_alphas_t * (
86
+ x - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t
87
+ )
88
+
89
+ if t[0] == 0:
90
+ return model_mean
91
+ else:
92
+ posterior_variance_t = extract(self.posterior_variance, t, x.shape)
93
+ noise = torch.randn_like(x)
94
+ return model_mean + torch.sqrt(posterior_variance_t) * noise
95
+
96
+ def sample_text(self, model, shape, text_embeddings, device="cuda", guidance_scale=1.0):
97
+ """Generate samples using DDPM sampling with text conditioning and CFG.
98
+
99
+ Args:
100
+ model: The diffusion model
101
+ shape: Output shape (B, C, H, W)
102
+ text_embeddings: Text embeddings for conditioning
103
+ device: Device to use
104
+ guidance_scale: Classifier-free guidance scale (1.0 = no guidance, 3.0-7.0 typical)
105
+ """
106
+ b = shape[0]
107
+ img = torch.randn(shape, device=device)
108
+
109
+ for i in reversed(range(0, self.num_timesteps)):
110
+ t = torch.full((b,), i, device=device, dtype=torch.long)
111
+ img = self.p_sample_text(model, img, t, text_embeddings, guidance_scale)
112
+
113
+ # Clamp to prevent explosion
114
+ img = torch.clamp(img, -2.0, 2.0)
115
+
116
+ return img
117
+
118
+
119
+ def extract(a, t, x_shape):
120
+ """Extract coefficients from a based on t and reshape to match x_shape."""
121
+ batch_size = t.shape[0]
122
+ out = a.gather(-1, t.cpu())
123
+ return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)