File size: 5,241 Bytes
231be4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
---
license: mit
tags:
- diffusion
- text-to-image
- quickdraw
- pytorch
- clip
- ddpm
language:
- en
datasets:
- Xenova/quickdraw-small
---

# Text-Conditional QuickDraw Diffusion Model

A text-conditional diffusion model for generating Google QuickDraw-style sketches from text prompts. This model uses DDPM (Denoising Diffusion Probabilistic Models) with CLIP text encoding and classifier-free guidance to generate 64x64 grayscale sketches.

## Model Description

This is a U-Net based diffusion model that generates sketches conditioned on text prompts. It uses:
- **CLIP text encoder** (`openai/clip-vit-base-patch32`) for text conditioning
- **DDPM** for the diffusion process (1000 timesteps)
- **Classifier-free guidance** for improved text-image alignment
- Trained on **Google QuickDraw** dataset

## Model Details

- **Model Type**: Text-conditional DDPM diffusion model
- **Architecture**: U-Net with cross-attention for text conditioning
- **Image Size**: 64x64 grayscale
- **Base Channels**: 256
- **Text Encoder**: CLIP ViT-B/32 (frozen)
- **Training Steps**: 100 epochs
- **Diffusion Timesteps**: 1000
- **Guidance Scale**: 5.0 (default)

### Training Configuration

- **Dataset**: Xenova/quickdraw-small (5 classes)
- **Batch Size**: 128 (32 per GPU Γ— 4 GPUs)
- **Learning Rate**: 1e-4
- **CFG Drop Probability**: 0.15
- **Optimizer**: Adam

## Usage

### Installation

```bash
pip install torch torchvision transformers diffusers datasets matplotlib pillow tqdm
```

### Generate Images

```python
import torch
from model import TextConditionedUNet
from scheduler import SimpleDDPMScheduler
from text_encoder import CLIPTextEncoder
from generate import generate_samples

# Load checkpoint
checkpoint_path = "text_diffusion_final_epoch_100.pt"
checkpoint = torch.load(checkpoint_path)

# Initialize model
model = TextConditionedUNet(text_dim=512).cuda()
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Initialize text encoder
text_encoder = CLIPTextEncoder(model_name="openai/clip-vit-base-patch32", freeze=True).cuda()
text_encoder.eval()

# Generate samples
scheduler = SimpleDDPMScheduler(1000)
prompt = "a drawing of a cat"
num_samples = 4
guidance_scale = 5.0

with torch.no_grad():
    text_embedding = text_encoder(prompt)
    text_embeddings = text_embedding.repeat(num_samples, 1)

    shape = (num_samples, 1, 64, 64)
    samples = scheduler.sample_text(model, shape, text_embeddings, 'cuda', guidance_scale)
```

### Command Line Usage

```bash
# Generate samples
python generate.py --checkpoint text_diffusion_final_epoch_100.pt \
                   --prompt "a drawing of a fire truck" \
                   --num-samples 4 \
                   --guidance-scale 5.0

# Visualize denoising process
python visualize_generation.py --checkpoint text_diffusion_final_epoch_100.pt \
                               --prompt "a drawing of a cat" \
                               --num-steps 10
```

## Example Prompts

Try these prompts for best results:
- "a drawing of a cat"
- "a drawing of a fire truck"
- "a drawing of an airplane"
- "a drawing of a house"
- "a drawing of a tree"

**Note**: The model is trained on a limited set of QuickDraw classes, so it works best with simple object descriptions in the format "a drawing of a [object]".

## Classifier-Free Guidance

The model supports classifier-free guidance to improve text-image alignment:
- `guidance_scale = 1.0`: No guidance (pure conditional generation)
- `guidance_scale = 3.0-7.0`: Recommended range (default: 5.0)
- Higher values: Stronger adherence to text prompt (may reduce diversity)

## Model Architecture

### U-Net Structure
```
Input: (batch, 1, 64, 64)
β”œβ”€β”€ Down Block 1: 1 β†’ 256 channels
β”œβ”€β”€ Down Block 2: 256 β†’ 512 channels
β”œβ”€β”€ Down Block 3: 512 β†’ 512 channels
β”œβ”€β”€ Middle Block: 512 channels
β”œβ”€β”€ Up Block 3: 1024 β†’ 512 channels (with skip connections)
β”œβ”€β”€ Up Block 2: 768 β†’ 256 channels (with skip connections)
└── Up Block 1: 512 β†’ 1 channel (with skip connections)
Output: (batch, 1, 64, 64) - predicted noise
```

### Text Conditioning
- Text prompts encoded via CLIP ViT-B/32
- 512-dimensional text embeddings
- Injected into U-Net via cross-attention
- Classifier-free guidance with 15% dropout during training

## Training Details

- **Framework**: PyTorch 2.0+
- **Hardware**: 4x NVIDIA GPUs
- **Training Time**: ~100 epochs
- **Dataset**: Google QuickDraw sketches (5 classes)
- **Noise Schedule**: Linear (Ξ² from 0.0001 to 0.02)

## Limitations

- Limited to 64x64 resolution
- Grayscale output only
- Best performance on simple objects from QuickDraw classes
- May not generalize well to complex or out-of-distribution prompts

## Citation

If you use this model, please cite:

```bibtex
@misc{quickdraw-text-diffusion,
  title={Text-Conditional QuickDraw Diffusion Model},
  author={Your Name},
  year={2024},
  howpublished={\url{https://huggingface.co/YOUR_USERNAME/quickdraw-text-diffusion}}
}
```

## License

MIT License

## Acknowledgments

- Google QuickDraw dataset
- OpenAI CLIP
- DDPM paper: "Denoising Diffusion Probabilistic Models" (Ho et al., 2020)
- Classifier-free guidance: "Classifier-Free Diffusion Guidance" (Ho & Salimans, 2022)