jamesaasher commited on
Commit
5d34f66
·
verified ·
1 Parent(s): c997f38

Upload visualize_generation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. visualize_generation.py +111 -0
visualize_generation.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visualize the diffusion generation process - capture images at each timestep."""
2
+ import torch
3
+ import argparse
4
+ import os
5
+ import matplotlib.pyplot as plt
6
+
7
+ import config
8
+ from model import TextConditionedUNet
9
+ from scheduler import SimpleDDPMScheduler
10
+ from text_encoder import CLIPTextEncoder
11
+ from generate import tensor_to_image
12
+
13
+
14
+ def sample_with_snapshots(scheduler, model, shape, text_embeddings, device='cuda',
15
+ guidance_scale=1.0, snapshot_steps=None):
16
+ """Modified sampling that captures snapshots at specific timesteps."""
17
+ b = shape[0]
18
+ img = torch.randn(shape, device=device)
19
+
20
+ # Default: capture 10 evenly spaced steps
21
+ if snapshot_steps is None:
22
+ interval = scheduler.num_timesteps // 10
23
+ snapshot_steps = list(range(scheduler.num_timesteps - 1, -1, -interval))
24
+ if 0 not in snapshot_steps:
25
+ snapshot_steps.append(0)
26
+
27
+ snapshots = {}
28
+
29
+ for i in reversed(range(0, scheduler.num_timesteps)):
30
+ t = torch.full((b,), i, device=device, dtype=torch.long)
31
+ img = scheduler.p_sample_text(model, img, t, text_embeddings, guidance_scale)
32
+ img = torch.clamp(img, -2.0, 2.0)
33
+
34
+ if i in snapshot_steps:
35
+ snapshots[i] = img.clone().detach()
36
+
37
+ return img, snapshots
38
+
39
+
40
+ def plot_denoising_process(snapshots, prompt, output_path, sample_idx=0):
41
+ """Plot snapshots side by side showing noise -> final image."""
42
+ timesteps = sorted(snapshots.keys(), reverse=True) # noise to clean
43
+ num_steps = len(timesteps)
44
+
45
+ fig, axes = plt.subplots(1, num_steps, figsize=(2.5 * num_steps, 3))
46
+ if num_steps == 1:
47
+ axes = [axes]
48
+
49
+ fig.suptitle(f'Denoising Process: "{prompt}"', fontsize=12, fontweight='bold')
50
+
51
+ for idx, t in enumerate(timesteps):
52
+ img_tensor = snapshots[t][sample_idx]
53
+ img = tensor_to_image(img_tensor)
54
+
55
+ axes[idx].imshow(img, cmap='gray')
56
+ axes[idx].axis('off')
57
+ axes[idx].set_title(f't={t}' if t > 0 else 'Final', fontsize=10)
58
+
59
+ plt.tight_layout()
60
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
61
+ plt.close()
62
+
63
+
64
+ def main():
65
+ parser = argparse.ArgumentParser(description='Visualize denoising process')
66
+ parser.add_argument('--checkpoint', type=str, required=True)
67
+ parser.add_argument('--prompt', type=str, default="a drawing of a cat")
68
+ parser.add_argument('--guidance-scale', type=float, default=config.CFG_GUIDANCE_SCALE)
69
+ parser.add_argument('--num-steps', type=int, default=10,
70
+ help='Number of snapshots to capture')
71
+ parser.add_argument('--device', type=str, default='cuda')
72
+ args = parser.parse_args()
73
+
74
+ if args.device == 'cuda' and not torch.cuda.is_available():
75
+ args.device = 'cpu'
76
+
77
+ # Load model
78
+ checkpoint = torch.load(args.checkpoint, map_location=args.device)
79
+ ckpt_config = checkpoint.get('config', {})
80
+
81
+ model = TextConditionedUNet(text_dim=ckpt_config.get('text_dim', config.TEXT_DIM)).to(args.device)
82
+ model.load_state_dict(checkpoint['model_state_dict'])
83
+ model.eval()
84
+
85
+ text_encoder = CLIPTextEncoder(
86
+ model_name=ckpt_config.get('clip_model', config.CLIP_MODEL), freeze=True
87
+ ).to(args.device)
88
+ text_encoder.eval()
89
+
90
+ scheduler = SimpleDDPMScheduler(config.TIMESTEPS)
91
+
92
+ # Generate with snapshots
93
+ with torch.no_grad():
94
+ text_embedding = text_encoder(args.prompt)
95
+ shape = (1, 1, config.IMAGE_SIZE, config.IMAGE_SIZE)
96
+
97
+ _, snapshots = sample_with_snapshots(
98
+ scheduler, model, shape, text_embedding, args.device, args.guidance_scale
99
+ )
100
+
101
+ # Save visualization
102
+ os.makedirs("outputs", exist_ok=True)
103
+ safe_prompt = "".join(c if c.isalnum() or c in " _" else "" for c in args.prompt)[:50]
104
+ output_path = f"outputs/denoising_{safe_prompt}.png"
105
+
106
+ plot_denoising_process(snapshots, args.prompt, output_path)
107
+ print(f"✅ Saved visualization: {output_path}")
108
+
109
+
110
+ if __name__ == "__main__":
111
+ main()