|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- diffusion |
|
|
- text-to-image |
|
|
- quickdraw |
|
|
- pytorch |
|
|
- clip |
|
|
- ddpm |
|
|
language: |
|
|
- en |
|
|
datasets: |
|
|
- Xenova/quickdraw-small |
|
|
--- |
|
|
|
|
|
# Text-Conditional QuickDraw Diffusion Model |
|
|
|
|
|
A text-conditional diffusion model for generating Google QuickDraw-style sketches from text prompts. This model uses DDPM (Denoising Diffusion Probabilistic Models) with CLIP text encoding and classifier-free guidance to generate 64x64 grayscale sketches. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This is a U-Net based diffusion model that generates sketches conditioned on text prompts. It uses: |
|
|
- **CLIP text encoder** (`openai/clip-vit-base-patch32`) for text conditioning |
|
|
- **DDPM** for the diffusion process (1000 timesteps) |
|
|
- **Classifier-free guidance** for improved text-image alignment |
|
|
- Trained on **Google QuickDraw** dataset |
|
|
|
|
|
## Model Details |
|
|
|
|
|
- **Model Type**: Text-conditional DDPM diffusion model |
|
|
- **Architecture**: U-Net with cross-attention for text conditioning |
|
|
- **Image Size**: 64x64 grayscale |
|
|
- **Base Channels**: 256 |
|
|
- **Text Encoder**: CLIP ViT-B/32 (frozen) |
|
|
- **Training Steps**: 100 epochs |
|
|
- **Diffusion Timesteps**: 1000 |
|
|
- **Guidance Scale**: 5.0 (default) |
|
|
|
|
|
### Training Configuration |
|
|
|
|
|
- **Dataset**: Xenova/quickdraw-small (5 classes) |
|
|
- **Batch Size**: 128 (32 per GPU Γ 4 GPUs) |
|
|
- **Learning Rate**: 1e-4 |
|
|
- **CFG Drop Probability**: 0.15 |
|
|
- **Optimizer**: Adam |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Installation |
|
|
|
|
|
```bash |
|
|
pip install torch torchvision transformers diffusers datasets matplotlib pillow tqdm |
|
|
``` |
|
|
|
|
|
### Generate Images |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from model import TextConditionedUNet |
|
|
from scheduler import SimpleDDPMScheduler |
|
|
from text_encoder import CLIPTextEncoder |
|
|
from generate import generate_samples |
|
|
|
|
|
# Load checkpoint |
|
|
checkpoint_path = "text_diffusion_final_epoch_100.pt" |
|
|
checkpoint = torch.load(checkpoint_path) |
|
|
|
|
|
# Initialize model |
|
|
model = TextConditionedUNet(text_dim=512).cuda() |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
# Initialize text encoder |
|
|
text_encoder = CLIPTextEncoder(model_name="openai/clip-vit-base-patch32", freeze=True).cuda() |
|
|
text_encoder.eval() |
|
|
|
|
|
# Generate samples |
|
|
scheduler = SimpleDDPMScheduler(1000) |
|
|
prompt = "a drawing of a cat" |
|
|
num_samples = 4 |
|
|
guidance_scale = 5.0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
text_embedding = text_encoder(prompt) |
|
|
text_embeddings = text_embedding.repeat(num_samples, 1) |
|
|
|
|
|
shape = (num_samples, 1, 64, 64) |
|
|
samples = scheduler.sample_text(model, shape, text_embeddings, 'cuda', guidance_scale) |
|
|
``` |
|
|
|
|
|
### Command Line Usage |
|
|
|
|
|
```bash |
|
|
# Generate samples |
|
|
python generate.py --checkpoint text_diffusion_final_epoch_100.pt \ |
|
|
--prompt "a drawing of a fire truck" \ |
|
|
--num-samples 4 \ |
|
|
--guidance-scale 5.0 |
|
|
|
|
|
# Visualize denoising process |
|
|
python visualize_generation.py --checkpoint text_diffusion_final_epoch_100.pt \ |
|
|
--prompt "a drawing of a cat" \ |
|
|
--num-steps 10 |
|
|
``` |
|
|
|
|
|
## Example Prompts |
|
|
|
|
|
Try these prompts for best results: |
|
|
- "a drawing of a cat" |
|
|
- "a drawing of a fire truck" |
|
|
- "a drawing of an airplane" |
|
|
- "a drawing of a house" |
|
|
- "a drawing of a tree" |
|
|
|
|
|
**Note**: The model is trained on a limited set of QuickDraw classes, so it works best with simple object descriptions in the format "a drawing of a [object]". |
|
|
|
|
|
## Classifier-Free Guidance |
|
|
|
|
|
The model supports classifier-free guidance to improve text-image alignment: |
|
|
- `guidance_scale = 1.0`: No guidance (pure conditional generation) |
|
|
- `guidance_scale = 3.0-7.0`: Recommended range (default: 5.0) |
|
|
- Higher values: Stronger adherence to text prompt (may reduce diversity) |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
### U-Net Structure |
|
|
``` |
|
|
Input: (batch, 1, 64, 64) |
|
|
βββ Down Block 1: 1 β 256 channels |
|
|
βββ Down Block 2: 256 β 512 channels |
|
|
βββ Down Block 3: 512 β 512 channels |
|
|
βββ Middle Block: 512 channels |
|
|
βββ Up Block 3: 1024 β 512 channels (with skip connections) |
|
|
βββ Up Block 2: 768 β 256 channels (with skip connections) |
|
|
βββ Up Block 1: 512 β 1 channel (with skip connections) |
|
|
Output: (batch, 1, 64, 64) - predicted noise |
|
|
``` |
|
|
|
|
|
### Text Conditioning |
|
|
- Text prompts encoded via CLIP ViT-B/32 |
|
|
- 512-dimensional text embeddings |
|
|
- Injected into U-Net via cross-attention |
|
|
- Classifier-free guidance with 15% dropout during training |
|
|
|
|
|
## Training Details |
|
|
|
|
|
- **Framework**: PyTorch 2.0+ |
|
|
- **Hardware**: 4x NVIDIA GPUs |
|
|
- **Training Time**: ~100 epochs |
|
|
- **Dataset**: Google QuickDraw sketches (5 classes) |
|
|
- **Noise Schedule**: Linear (Ξ² from 0.0001 to 0.02) |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- Limited to 64x64 resolution |
|
|
- Grayscale output only |
|
|
- Best performance on simple objects from QuickDraw classes |
|
|
- May not generalize well to complex or out-of-distribution prompts |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite: |
|
|
|
|
|
```bibtex |
|
|
@misc{quickdraw-text-diffusion, |
|
|
title={Text-Conditional QuickDraw Diffusion Model}, |
|
|
author={Your Name}, |
|
|
year={2024}, |
|
|
howpublished={\url{https://huggingface.co/YOUR_USERNAME/quickdraw-text-diffusion}} |
|
|
} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
MIT License |
|
|
|
|
|
## Acknowledgments |
|
|
|
|
|
- Google QuickDraw dataset |
|
|
- OpenAI CLIP |
|
|
- DDPM paper: "Denoising Diffusion Probabilistic Models" (Ho et al., 2020) |
|
|
- Classifier-free guidance: "Classifier-Free Diffusion Guidance" (Ho & Salimans, 2022) |
|
|
|