File size: 4,884 Bytes
369d21b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import torch
import torch.nn.functional as F
import math
"""
This scheduler has 3 main responsibilities:
1. Setup (init) - Pre-compute noise schedule
2. Training (q_sample) - Add noise to images
3. Generation (p_sample_text + sample_text) - Remove noise
step-by-step
"""
class SimpleDDPMScheduler:
def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02):
self.num_timesteps = num_timesteps
# Linear beta schedule - can replace with cosine
self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(
self.alphas, dim=0
) # cumulative product - lets us jump to any timestep immediately.
self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
# Calculations for diffusion q(x_t | x_{t-1}) and others (pre-compute for efficiency)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
# Calculations for posterior q(x_{t-1} | x_t, x_0)
# This tells us how much randomness is appropriate at this step.
# Removing this would lead to mode-seeking behavior (and poor sample quality).
self.posterior_variance = (
self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
)
def q_sample(self, x_start, t, noise=None):
"""Add noise to the clean images according to the noise schedule.
So we can have examples at any timestep in the forward process."""
# Generate original noise
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
def p_sample_text(self, model, x, t, text_embeddings, guidance_scale=1.0):
"""Sample x_{t-1} from x_t using the model with text conditioning and CFG.
Args:
model: The diffusion model
x: Current noisy image
t: Current timestep
text_embeddings: Text embeddings for conditioning
guidance_scale: Classifier-free guidance scale (1.0 = no guidance, higher = stronger)
"""
# Get model prediction with text conditioning
predicted_noise = model(x, t, text_embeddings)
# Apply classifier-free guidance if scale > 1.0
if guidance_scale > 1.0:
# Also get unconditional prediction (with zero text embeddings)
uncond_embeddings = torch.zeros_like(text_embeddings)
uncond_noise = model(x, t, uncond_embeddings)
# Amplify the difference between conditional and unconditional
predicted_noise = uncond_noise + guidance_scale * (predicted_noise - uncond_noise)
# Get coefficients
betas_t = extract(self.betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = extract(1.0 / torch.sqrt(self.alphas), t, x.shape)
# Compute x_{t-1}
model_mean = sqrt_recip_alphas_t * (
x - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t
)
if t[0] == 0:
return model_mean
else:
posterior_variance_t = extract(self.posterior_variance, t, x.shape)
noise = torch.randn_like(x)
return model_mean + torch.sqrt(posterior_variance_t) * noise
def sample_text(self, model, shape, text_embeddings, device="cuda", guidance_scale=1.0):
"""Generate samples using DDPM sampling with text conditioning and CFG.
Args:
model: The diffusion model
shape: Output shape (B, C, H, W)
text_embeddings: Text embeddings for conditioning
device: Device to use
guidance_scale: Classifier-free guidance scale (1.0 = no guidance, 3.0-7.0 typical)
"""
b = shape[0]
img = torch.randn(shape, device=device)
for i in reversed(range(0, self.num_timesteps)):
t = torch.full((b,), i, device=device, dtype=torch.long)
img = self.p_sample_text(model, img, t, text_embeddings, guidance_scale)
# Clamp to prevent explosion
img = torch.clamp(img, -2.0, 2.0)
return img
def extract(a, t, x_shape):
"""Extract coefficients from a based on t and reshape to match x_shape."""
batch_size = t.shape[0]
out = a.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
|