File size: 5,241 Bytes
231be4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
---
license: mit
tags:
- diffusion
- text-to-image
- quickdraw
- pytorch
- clip
- ddpm
language:
- en
datasets:
- Xenova/quickdraw-small
---
# Text-Conditional QuickDraw Diffusion Model
A text-conditional diffusion model for generating Google QuickDraw-style sketches from text prompts. This model uses DDPM (Denoising Diffusion Probabilistic Models) with CLIP text encoding and classifier-free guidance to generate 64x64 grayscale sketches.
## Model Description
This is a U-Net based diffusion model that generates sketches conditioned on text prompts. It uses:
- **CLIP text encoder** (`openai/clip-vit-base-patch32`) for text conditioning
- **DDPM** for the diffusion process (1000 timesteps)
- **Classifier-free guidance** for improved text-image alignment
- Trained on **Google QuickDraw** dataset
## Model Details
- **Model Type**: Text-conditional DDPM diffusion model
- **Architecture**: U-Net with cross-attention for text conditioning
- **Image Size**: 64x64 grayscale
- **Base Channels**: 256
- **Text Encoder**: CLIP ViT-B/32 (frozen)
- **Training Steps**: 100 epochs
- **Diffusion Timesteps**: 1000
- **Guidance Scale**: 5.0 (default)
### Training Configuration
- **Dataset**: Xenova/quickdraw-small (5 classes)
- **Batch Size**: 128 (32 per GPU Γ 4 GPUs)
- **Learning Rate**: 1e-4
- **CFG Drop Probability**: 0.15
- **Optimizer**: Adam
## Usage
### Installation
```bash
pip install torch torchvision transformers diffusers datasets matplotlib pillow tqdm
```
### Generate Images
```python
import torch
from model import TextConditionedUNet
from scheduler import SimpleDDPMScheduler
from text_encoder import CLIPTextEncoder
from generate import generate_samples
# Load checkpoint
checkpoint_path = "text_diffusion_final_epoch_100.pt"
checkpoint = torch.load(checkpoint_path)
# Initialize model
model = TextConditionedUNet(text_dim=512).cuda()
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Initialize text encoder
text_encoder = CLIPTextEncoder(model_name="openai/clip-vit-base-patch32", freeze=True).cuda()
text_encoder.eval()
# Generate samples
scheduler = SimpleDDPMScheduler(1000)
prompt = "a drawing of a cat"
num_samples = 4
guidance_scale = 5.0
with torch.no_grad():
text_embedding = text_encoder(prompt)
text_embeddings = text_embedding.repeat(num_samples, 1)
shape = (num_samples, 1, 64, 64)
samples = scheduler.sample_text(model, shape, text_embeddings, 'cuda', guidance_scale)
```
### Command Line Usage
```bash
# Generate samples
python generate.py --checkpoint text_diffusion_final_epoch_100.pt \
--prompt "a drawing of a fire truck" \
--num-samples 4 \
--guidance-scale 5.0
# Visualize denoising process
python visualize_generation.py --checkpoint text_diffusion_final_epoch_100.pt \
--prompt "a drawing of a cat" \
--num-steps 10
```
## Example Prompts
Try these prompts for best results:
- "a drawing of a cat"
- "a drawing of a fire truck"
- "a drawing of an airplane"
- "a drawing of a house"
- "a drawing of a tree"
**Note**: The model is trained on a limited set of QuickDraw classes, so it works best with simple object descriptions in the format "a drawing of a [object]".
## Classifier-Free Guidance
The model supports classifier-free guidance to improve text-image alignment:
- `guidance_scale = 1.0`: No guidance (pure conditional generation)
- `guidance_scale = 3.0-7.0`: Recommended range (default: 5.0)
- Higher values: Stronger adherence to text prompt (may reduce diversity)
## Model Architecture
### U-Net Structure
```
Input: (batch, 1, 64, 64)
βββ Down Block 1: 1 β 256 channels
βββ Down Block 2: 256 β 512 channels
βββ Down Block 3: 512 β 512 channels
βββ Middle Block: 512 channels
βββ Up Block 3: 1024 β 512 channels (with skip connections)
βββ Up Block 2: 768 β 256 channels (with skip connections)
βββ Up Block 1: 512 β 1 channel (with skip connections)
Output: (batch, 1, 64, 64) - predicted noise
```
### Text Conditioning
- Text prompts encoded via CLIP ViT-B/32
- 512-dimensional text embeddings
- Injected into U-Net via cross-attention
- Classifier-free guidance with 15% dropout during training
## Training Details
- **Framework**: PyTorch 2.0+
- **Hardware**: 4x NVIDIA GPUs
- **Training Time**: ~100 epochs
- **Dataset**: Google QuickDraw sketches (5 classes)
- **Noise Schedule**: Linear (Ξ² from 0.0001 to 0.02)
## Limitations
- Limited to 64x64 resolution
- Grayscale output only
- Best performance on simple objects from QuickDraw classes
- May not generalize well to complex or out-of-distribution prompts
## Citation
If you use this model, please cite:
```bibtex
@misc{quickdraw-text-diffusion,
title={Text-Conditional QuickDraw Diffusion Model},
author={Your Name},
year={2024},
howpublished={\url{https://huggingface.co/YOUR_USERNAME/quickdraw-text-diffusion}}
}
```
## License
MIT License
## Acknowledgments
- Google QuickDraw dataset
- OpenAI CLIP
- DDPM paper: "Denoising Diffusion Probabilistic Models" (Ho et al., 2020)
- Classifier-free guidance: "Classifier-Free Diffusion Guidance" (Ho & Salimans, 2022)
|