Add Neural Pong application files
Browse files- DEPLOYMENT.md +68 -0
- Dockerfile +22 -0
- QUICKSTART.md +97 -0
- README.md +37 -5
- RUN_SETUP.md +58 -0
- SETUP_GUIDE.md +244 -0
- SETUP_STEPS.md +201 -0
- SOURCE_FILES.md +63 -0
- START_HERE.md +78 -0
- TROUBLESHOOTING.md +90 -0
- app.py +480 -0
- checkpoints/ckpt-step=053700-metric=0.00092727.pt +3 -0
- cleanup.sh +40 -0
- configs/inference.yaml +50 -0
- push-and-cleanup.sh +69 -0
- push.sh +64 -0
- requirements.txt +25 -0
- setup.sh +55 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-311.pyc +0 -0
- src/config.py +59 -0
- src/datasets/__init__.py +2 -0
- src/datasets/__pycache__/__init__.cpython-311.pyc +0 -0
- src/datasets/__pycache__/pong1m.cpython-311.pyc +0 -0
- src/datasets/pong1m.py +62 -0
- src/inference/__init__.py +1 -0
- src/inference/__pycache__/__init__.cpython-311.pyc +0 -0
- src/inference/__pycache__/sampling.cpython-311.pyc +0 -0
- src/inference/sampling.py +23 -0
- src/models/__init__.py +0 -0
- src/models/dit_dforce.py +206 -0
- src/nn/__init__.py +0 -0
- src/nn/attn.py +473 -0
- src/nn/geglu.py +20 -0
- src/nn/patch.py +80 -0
- src/nn/pe.py +77 -0
- src/utils/__init__.py +2 -0
- src/utils/checkpoint.py +283 -0
- static/index.html +162 -0
DEPLOYMENT.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Space Setup
|
| 2 |
+
|
| 3 |
+
This folder contains everything needed to deploy the Neural Pong demo to Hugging Face Spaces.
|
| 4 |
+
|
| 5 |
+
## Files Structure
|
| 6 |
+
|
| 7 |
+
- `app.py` - Main Flask application (modified from play_pong.py, removed single-user limitation)
|
| 8 |
+
- `Dockerfile` - Docker configuration for HF Spaces
|
| 9 |
+
- `requirements.txt` - Python dependencies
|
| 10 |
+
- `README.md` - Space description and metadata
|
| 11 |
+
- `static/index.html` - Frontend web interface
|
| 12 |
+
- `configs/inference.yaml` - Model configuration
|
| 13 |
+
- `src/` - Source code for model loading and inference
|
| 14 |
+
|
| 15 |
+
## Important Notes
|
| 16 |
+
|
| 17 |
+
### Dependencies Fixed
|
| 18 |
+
|
| 19 |
+
✅ **No external git dependencies**: The app now imports `sample` directly from `src.inference.sampling` instead of going through training code, avoiding the `muon-optimizer` git dependency.
|
| 20 |
+
|
| 21 |
+
✅ **No data files needed**: The app uses `fixed2frame` directly instead of calling `get_loader`, so it doesn't need the training data files (`frames.npy`, `actions.npy`).
|
| 22 |
+
|
| 23 |
+
✅ **Minimal codebase**: Only inference-related code is included. All training scripts and utilities have been removed:
|
| 24 |
+
- Removed: `src/trainers/`, `src/main.py`, `src/main_dmd.py`
|
| 25 |
+
- Removed: Unused dataset files, alternative models, custom norm
|
| 26 |
+
- Removed: Matplotlib dependencies (not needed for inference)
|
| 27 |
+
- **Total: 15 Python files** (down from 25+)
|
| 28 |
+
|
| 29 |
+
See `SOURCE_FILES.md` for a complete list of included files.
|
| 30 |
+
|
| 31 |
+
### Checkpoint Path
|
| 32 |
+
|
| 33 |
+
The `configs/inference.yaml` file currently references a local checkpoint path:
|
| 34 |
+
```yaml
|
| 35 |
+
checkpoint: "experiments/radiant-forest-398/ckpt-step=053700-metric=0.00092727.pt"
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
**Before deploying**, you need to either:
|
| 39 |
+
|
| 40 |
+
1. **Upload checkpoint to Hugging Face Hub** and update the path to load from Hub
|
| 41 |
+
2. **Include the checkpoint file** in this directory and update the path
|
| 42 |
+
3. **Use HF Spaces storage/secrets** to store the checkpoint
|
| 43 |
+
|
| 44 |
+
### Changes Made
|
| 45 |
+
|
| 46 |
+
- Removed single-user limitation (all users can connect simultaneously)
|
| 47 |
+
- Simplified frontend to remove busy state handling
|
| 48 |
+
- Updated port to use environment variable (defaults to 7860 for HF Spaces)
|
| 49 |
+
- Created Dockerfile for containerized deployment
|
| 50 |
+
|
| 51 |
+
## Deployment Steps
|
| 52 |
+
|
| 53 |
+
1. Upload this folder to a Hugging Face Space repository
|
| 54 |
+
2. Update the checkpoint path in `configs/inference.yaml` to point to your model
|
| 55 |
+
3. Ensure the Space has GPU access enabled
|
| 56 |
+
4. The Space will automatically build and deploy
|
| 57 |
+
|
| 58 |
+
## Testing Locally
|
| 59 |
+
|
| 60 |
+
To test locally with Docker:
|
| 61 |
+
|
| 62 |
+
```bash
|
| 63 |
+
docker build -t neural-pong .
|
| 64 |
+
docker run -p 7860:7860 neural-pong
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
Then visit http://localhost:7860
|
| 68 |
+
|
Dockerfile
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
build-essential \
|
| 8 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
# Copy requirements and install Python dependencies
|
| 11 |
+
COPY requirements.txt .
|
| 12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
+
|
| 14 |
+
# Copy application code
|
| 15 |
+
COPY . .
|
| 16 |
+
|
| 17 |
+
# Expose port (HF Spaces will map this)
|
| 18 |
+
EXPOSE 7860
|
| 19 |
+
|
| 20 |
+
# Run the Flask app
|
| 21 |
+
CMD python app.py
|
| 22 |
+
|
QUICKSTART.md
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚀 Quick Setup Guide for Hugging Face Space
|
| 2 |
+
|
| 3 |
+
Your Neural Pong demo is ready to deploy! Follow these steps:
|
| 4 |
+
|
| 5 |
+
## Step 1: Create Your Hugging Face Space
|
| 6 |
+
|
| 7 |
+
1. **Go to Hugging Face Spaces:** https://huggingface.co/spaces
|
| 8 |
+
2. **Click "Create new Space"**
|
| 9 |
+
3. **Fill in the details:**
|
| 10 |
+
- **Space name:** `neural-pong` (or your preferred name)
|
| 11 |
+
- **SDK:** Select **"Docker"** ⚠️ Important!
|
| 12 |
+
- **Hardware:** Select **"GPU"** → **"T4 small"** (or larger)
|
| 13 |
+
- **Visibility:** Public or Private
|
| 14 |
+
4. **Click "Create Space"**
|
| 15 |
+
|
| 16 |
+
## Step 2: Upload Files Using Git
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
cd /share/u/wendler/code/toy-wm-hf-space
|
| 20 |
+
|
| 21 |
+
# Initialize git (if not already done)
|
| 22 |
+
git init
|
| 23 |
+
|
| 24 |
+
# Add all files
|
| 25 |
+
git add .
|
| 26 |
+
|
| 27 |
+
# Commit
|
| 28 |
+
git commit -m "Initial commit: Neural Pong demo"
|
| 29 |
+
|
| 30 |
+
# Add your Space as remote (replace YOUR_USERNAME and SPACE_NAME)
|
| 31 |
+
git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/SPACE_NAME
|
| 32 |
+
|
| 33 |
+
# Push everything
|
| 34 |
+
git push -u origin main
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
**Note:** The checkpoint file is 225MB, so Git is recommended over web upload.
|
| 38 |
+
|
| 39 |
+
## Step 3: Wait for Build
|
| 40 |
+
|
| 41 |
+
1. After pushing, Hugging Face will automatically start building
|
| 42 |
+
2. Go to your Space page → **"Logs"** tab to watch progress
|
| 43 |
+
3. Build time: **5-15 minutes** (installing PyTorch, etc.)
|
| 44 |
+
|
| 45 |
+
## Step 4: Test Your Space
|
| 46 |
+
|
| 47 |
+
1. Once build completes, visit your Space URL
|
| 48 |
+
2. You should see the Pong interface
|
| 49 |
+
3. Wait for model to load (loading spinner)
|
| 50 |
+
4. Click **"Start Stream"**
|
| 51 |
+
5. Use **Arrow Keys** or **WASD** to play!
|
| 52 |
+
|
| 53 |
+
## Quick Commands
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
# Run the setup script
|
| 57 |
+
./setup.sh
|
| 58 |
+
|
| 59 |
+
# Check files are ready
|
| 60 |
+
ls -la
|
| 61 |
+
|
| 62 |
+
# Test Docker build locally (optional)
|
| 63 |
+
docker build -t neural-pong .
|
| 64 |
+
docker run -p 7860:7860 neural-pong
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
## Troubleshooting
|
| 68 |
+
|
| 69 |
+
### Build Fails?
|
| 70 |
+
- Check **"Logs"** tab for errors
|
| 71 |
+
- Verify checkpoint path in `configs/inference.yaml`
|
| 72 |
+
- Ensure GPU is selected in Space settings
|
| 73 |
+
|
| 74 |
+
### Model Won't Load?
|
| 75 |
+
- Verify checkpoint exists: `checkpoints/ckpt-step=053700-metric=0.00092727.pt`
|
| 76 |
+
- Check the path in `configs/inference.yaml`
|
| 77 |
+
- Look for errors in the Logs tab
|
| 78 |
+
|
| 79 |
+
## What's Included
|
| 80 |
+
|
| 81 |
+
✅ **app.py** - Flask application (no single-user limitation)
|
| 82 |
+
✅ **checkpoints/** - Model checkpoint (225MB)
|
| 83 |
+
✅ **src/** - All necessary source code (15 Python files)
|
| 84 |
+
✅ **static/index.html** - Frontend interface
|
| 85 |
+
✅ **configs/inference.yaml** - Model configuration
|
| 86 |
+
✅ **Dockerfile** - Container configuration
|
| 87 |
+
✅ **requirements.txt** - Python dependencies
|
| 88 |
+
|
| 89 |
+
## Need More Help?
|
| 90 |
+
|
| 91 |
+
- See `SETUP_GUIDE.md` for detailed instructions
|
| 92 |
+
- See `DEPLOYMENT.md` for technical details
|
| 93 |
+
- Check Hugging Face Spaces docs: https://huggingface.co/docs/hub/spaces
|
| 94 |
+
|
| 95 |
+
---
|
| 96 |
+
|
| 97 |
+
**Ready?** Run `./setup.sh` to get started! 🚀
|
README.md
CHANGED
|
@@ -1,10 +1,42 @@
|
|
| 1 |
---
|
| 2 |
-
title: Pong
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Neural Pong
|
| 3 |
+
emoji: 🎮
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# Neural Pong
|
| 12 |
+
|
| 13 |
+
A real-time Pong game where frames are generated by a diffusion model trained with rectified flow matching. Control the blue paddle using arrow keys or WASD to play!
|
| 14 |
+
|
| 15 |
+
## Features
|
| 16 |
+
|
| 17 |
+
- **Real-time frame generation**: Uses a frame-autoregressive transformer with diffusion sampling
|
| 18 |
+
- **Interactive gameplay**: Control the paddle with keyboard inputs
|
| 19 |
+
- **Configurable parameters**: Adjust FPS and diffusion steps
|
| 20 |
+
- **Low-latency streaming**: Achieves ~20 FPS with 4 diffusion steps
|
| 21 |
+
|
| 22 |
+
## How to Play
|
| 23 |
+
|
| 24 |
+
1. Wait for the model to load (you'll see a loading spinner)
|
| 25 |
+
2. Click "Start Stream" to begin generating frames
|
| 26 |
+
3. Use **Arrow Keys** or **WASD** to control the blue paddle:
|
| 27 |
+
- **Up/W**: Move paddle up
|
| 28 |
+
- **Down/S**: Move paddle down
|
| 29 |
+
4. Adjust the FPS and diffusion steps using the controls
|
| 30 |
+
5. Click "Stop Stream" when done
|
| 31 |
+
|
| 32 |
+
## Technical Details
|
| 33 |
+
|
| 34 |
+
This demo uses a small transformer model trained with rectified flow matching to simulate Pong game frames conditioned on user inputs. The model generates 24×24 pixel frames in real-time using diffusion sampling with configurable steps.
|
| 35 |
+
|
| 36 |
+
## Model Architecture
|
| 37 |
+
|
| 38 |
+
- Frame-autoregressive transformer
|
| 39 |
+
- Rectified flow matching training
|
| 40 |
+
- Caching for efficient inference
|
| 41 |
+
- GPU-accelerated generation
|
| 42 |
+
|
RUN_SETUP.md
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Git Setup Complete - Next Steps
|
| 2 |
+
|
| 3 |
+
I've created a setup script for you. Here's what to do:
|
| 4 |
+
|
| 5 |
+
## Run the Setup Script
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
cd /share/u/wendler/code/toy-wm-hf-space
|
| 9 |
+
chmod +x setup-git.sh
|
| 10 |
+
./setup-git.sh
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
This script will:
|
| 14 |
+
1. ✅ Initialize git (if needed)
|
| 15 |
+
2. ✅ Remove old SSH remote
|
| 16 |
+
3. ✅ Add HTTPS remote: `https://huggingface.co/spaces/wendlerc/pong`
|
| 17 |
+
4. ✅ Stage all files
|
| 18 |
+
5. ✅ Create initial commit (if needed)
|
| 19 |
+
6. ✅ Ensure branch is named `main`
|
| 20 |
+
7. ✅ Show you the push command
|
| 21 |
+
|
| 22 |
+
## After Running the Script
|
| 23 |
+
|
| 24 |
+
The script will show you the exact command to push. It will be:
|
| 25 |
+
```bash
|
| 26 |
+
git push -u origin main
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Before Pushing - Important!
|
| 30 |
+
|
| 31 |
+
**Make sure your Space exists:**
|
| 32 |
+
1. Go to: https://huggingface.co/spaces/wendlerc/pong
|
| 33 |
+
2. If it doesn't exist, create it:
|
| 34 |
+
- Go to: https://huggingface.co/spaces
|
| 35 |
+
- Click "Create new Space"
|
| 36 |
+
- Name: `pong`
|
| 37 |
+
- SDK: **Docker**
|
| 38 |
+
- Hardware: **GPU (T4 small)**
|
| 39 |
+
- Click "Create Space"
|
| 40 |
+
|
| 41 |
+
## Then Push
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
git push -u origin main
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
You'll be prompted for your Hugging Face credentials (username and access token).
|
| 48 |
+
|
| 49 |
+
## If You Need an Access Token
|
| 50 |
+
|
| 51 |
+
1. Go to: https://huggingface.co/settings/tokens
|
| 52 |
+
2. Create a new token with "write" permissions
|
| 53 |
+
3. Use it as your password when pushing
|
| 54 |
+
|
| 55 |
+
---
|
| 56 |
+
|
| 57 |
+
**The setup script is ready!** Just run `./setup-git.sh` and follow the instructions.
|
| 58 |
+
|
SETUP_GUIDE.md
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Step-by-Step Setup Guide for Hugging Face Space
|
| 2 |
+
|
| 3 |
+
This guide will walk you through deploying your Neural Pong demo to Hugging Face Spaces.
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
- A Hugging Face account (sign up at https://huggingface.co/join)
|
| 8 |
+
- The model checkpoint file (`ckpt-step=053700-metric=0.00092727.pt`)
|
| 9 |
+
- Git installed on your machine (for uploading files)
|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## Step 1: Prepare Your Checkpoint File
|
| 14 |
+
|
| 15 |
+
First, you need to decide how to handle the model checkpoint. You have two main options:
|
| 16 |
+
|
| 17 |
+
### Option A: Include Checkpoint in Repository (Simplest)
|
| 18 |
+
|
| 19 |
+
1. **Locate your checkpoint file:**
|
| 20 |
+
```bash
|
| 21 |
+
# Check if the file exists
|
| 22 |
+
ls /share/u/wendler/code/toy-wm/experiments/radiant-forest-398/ckpt-step=053700-metric=0.00092727.pt
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
2. **Copy it to the hf-space directory:**
|
| 26 |
+
```bash
|
| 27 |
+
mkdir -p /share/u/wendler/code/toy-wm/hf-space/checkpoints
|
| 28 |
+
cp /share/u/wendler/code/toy-wm/experiments/radiant-forest-398/ckpt-step=053700-metric=0.00092727.pt \
|
| 29 |
+
/share/u/wendler/code/toy-wm/hf-space/checkpoints/
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
3. **Update the config file** to point to the new location:
|
| 33 |
+
```yaml
|
| 34 |
+
checkpoint: "checkpoints/ckpt-step=053700-metric=0.00092727.pt"
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
### Option B: Upload to Hugging Face Hub (Better for Large Files)
|
| 38 |
+
|
| 39 |
+
1. **Install Hugging Face Hub:**
|
| 40 |
+
```bash
|
| 41 |
+
pip install huggingface-hub
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
2. **Login to Hugging Face:**
|
| 45 |
+
```bash
|
| 46 |
+
huggingface-cli login
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
3. **Create a model repository and upload:**
|
| 50 |
+
```bash
|
| 51 |
+
# Create a repository (replace YOUR_USERNAME with your HF username)
|
| 52 |
+
huggingface-cli repo create YOUR_USERNAME/neural-pong-checkpoint --type model
|
| 53 |
+
|
| 54 |
+
# Upload the checkpoint
|
| 55 |
+
huggingface-cli upload YOUR_USERNAME/neural-pong-checkpoint \
|
| 56 |
+
/share/u/wendler/code/toy-wm/experiments/radiant-forest-398/ckpt-step=053700-metric=0.00092727.pt \
|
| 57 |
+
ckpt-step=053700-metric=0.00092727.pt
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
4. **Modify the checkpoint loading code** to download from Hub (we'll do this in Step 2)
|
| 61 |
+
|
| 62 |
+
---
|
| 63 |
+
|
| 64 |
+
## Step 2: Update Configuration Files
|
| 65 |
+
|
| 66 |
+
### If using Option A (checkpoint in repo):
|
| 67 |
+
|
| 68 |
+
Update `configs/inference.yaml`:
|
| 69 |
+
```yaml
|
| 70 |
+
checkpoint: "checkpoints/ckpt-step=053700-metric=0.00092727.pt"
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
### If using Option B (HF Hub):
|
| 74 |
+
|
| 75 |
+
We'll need to modify the app.py to download the checkpoint. Let me know if you want to go this route.
|
| 76 |
+
|
| 77 |
+
---
|
| 78 |
+
|
| 79 |
+
## Step 3: Create a Hugging Face Space
|
| 80 |
+
|
| 81 |
+
1. **Go to Hugging Face Spaces:** https://huggingface.co/spaces
|
| 82 |
+
|
| 83 |
+
2. **Click "Create new Space"**
|
| 84 |
+
|
| 85 |
+
3. **Fill in the details:**
|
| 86 |
+
- **Space name:** `neural-pong` (or your preferred name)
|
| 87 |
+
- **SDK:** Select **Docker**
|
| 88 |
+
- **Hardware:** Select **GPU** (T4 small or larger)
|
| 89 |
+
- **Visibility:** Public or Private (your choice)
|
| 90 |
+
|
| 91 |
+
4. **Click "Create Space"**
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
## Step 4: Upload Files to the Space
|
| 96 |
+
|
| 97 |
+
You have two options:
|
| 98 |
+
|
| 99 |
+
### Option A: Using Git (Recommended)
|
| 100 |
+
|
| 101 |
+
1. **Initialize git in your hf-space directory:**
|
| 102 |
+
```bash
|
| 103 |
+
cd /share/u/wendler/code/toy-wm/hf-space
|
| 104 |
+
git init
|
| 105 |
+
git add .
|
| 106 |
+
git commit -m "Initial commit"
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
2. **Add the Hugging Face remote:**
|
| 110 |
+
```bash
|
| 111 |
+
# Replace YOUR_USERNAME and SPACE_NAME with your values
|
| 112 |
+
git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/SPACE_NAME
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
3. **Push to Hugging Face:**
|
| 116 |
+
```bash
|
| 117 |
+
git push -u origin main
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
### Option B: Using Web Interface
|
| 121 |
+
|
| 122 |
+
1. **Go to your Space page** on Hugging Face
|
| 123 |
+
2. **Click "Files" tab**
|
| 124 |
+
3. **Click "Add file" → "Upload files"**
|
| 125 |
+
4. **Drag and drop all files** from the `hf-space` directory
|
| 126 |
+
5. **Click "Commit changes"**
|
| 127 |
+
|
| 128 |
+
**Note:** For large checkpoint files, Git is recommended as the web interface has size limits.
|
| 129 |
+
|
| 130 |
+
---
|
| 131 |
+
|
| 132 |
+
## Step 5: Configure the Space
|
| 133 |
+
|
| 134 |
+
1. **Go to your Space settings** (click the gear icon)
|
| 135 |
+
|
| 136 |
+
2. **Important settings:**
|
| 137 |
+
- **Hardware:** Ensure GPU is selected (T4 small minimum)
|
| 138 |
+
- **Environment variables:** None needed for basic setup
|
| 139 |
+
- **Storage:** If using Option B, you might want persistent storage
|
| 140 |
+
|
| 141 |
+
3. **Save settings**
|
| 142 |
+
|
| 143 |
+
---
|
| 144 |
+
|
| 145 |
+
## Step 6: Wait for Build and Deployment
|
| 146 |
+
|
| 147 |
+
1. **After pushing files**, Hugging Face will automatically:
|
| 148 |
+
- Build the Docker image
|
| 149 |
+
- Install dependencies
|
| 150 |
+
- Start your application
|
| 151 |
+
|
| 152 |
+
2. **Monitor the build:**
|
| 153 |
+
- Go to your Space page
|
| 154 |
+
- Click "Logs" tab to see build progress
|
| 155 |
+
- Look for any errors
|
| 156 |
+
|
| 157 |
+
3. **Expected build time:** 5-15 minutes depending on dependencies
|
| 158 |
+
|
| 159 |
+
---
|
| 160 |
+
|
| 161 |
+
## Step 7: Test Your Space
|
| 162 |
+
|
| 163 |
+
1. **Once the build completes**, your Space will be live
|
| 164 |
+
2. **Visit your Space URL:** `https://huggingface.co/spaces/YOUR_USERNAME/SPACE_NAME`
|
| 165 |
+
3. **Test the application:**
|
| 166 |
+
- Wait for model to load (loading spinner)
|
| 167 |
+
- Click "Start Stream"
|
| 168 |
+
- Use arrow keys or WASD to control paddle
|
| 169 |
+
- Verify frames are generating correctly
|
| 170 |
+
|
| 171 |
+
---
|
| 172 |
+
|
| 173 |
+
## Troubleshooting
|
| 174 |
+
|
| 175 |
+
### Build Fails
|
| 176 |
+
|
| 177 |
+
- **Check logs** in the Space's "Logs" tab
|
| 178 |
+
- **Common issues:**
|
| 179 |
+
- Missing dependencies in `requirements.txt`
|
| 180 |
+
- Dockerfile syntax errors
|
| 181 |
+
- Checkpoint file not found (check path in `inference.yaml`)
|
| 182 |
+
|
| 183 |
+
### Model Won't Load
|
| 184 |
+
|
| 185 |
+
- **Check checkpoint path** in `configs/inference.yaml`
|
| 186 |
+
- **Verify checkpoint file exists** in the repository
|
| 187 |
+
- **Check GPU availability** in Space settings
|
| 188 |
+
|
| 189 |
+
### Port Issues
|
| 190 |
+
|
| 191 |
+
- The app uses port 7860 (HF Spaces default)
|
| 192 |
+
- If you see port errors, check the `PORT` environment variable
|
| 193 |
+
|
| 194 |
+
### Out of Memory
|
| 195 |
+
|
| 196 |
+
- **Reduce batch size** or model size
|
| 197 |
+
- **Upgrade to larger GPU** in Space settings
|
| 198 |
+
- **Check if checkpoint is too large** (consider Option B)
|
| 199 |
+
|
| 200 |
+
---
|
| 201 |
+
|
| 202 |
+
## Quick Reference Commands
|
| 203 |
+
|
| 204 |
+
```bash
|
| 205 |
+
# Navigate to hf-space directory
|
| 206 |
+
cd /share/u/wendler/code/toy-wm/hf-space
|
| 207 |
+
|
| 208 |
+
# Check files are ready
|
| 209 |
+
ls -la
|
| 210 |
+
|
| 211 |
+
# Test Docker build locally (optional)
|
| 212 |
+
docker build -t neural-pong .
|
| 213 |
+
docker run -p 7860:7860 neural-pong
|
| 214 |
+
|
| 215 |
+
# Git setup (if using Git)
|
| 216 |
+
git init
|
| 217 |
+
git add .
|
| 218 |
+
git commit -m "Initial commit"
|
| 219 |
+
git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/SPACE_NAME
|
| 220 |
+
git push -u origin main
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
---
|
| 224 |
+
|
| 225 |
+
## Next Steps
|
| 226 |
+
|
| 227 |
+
After successful deployment:
|
| 228 |
+
|
| 229 |
+
1. **Share your Space** with others
|
| 230 |
+
2. **Monitor usage** in the Space analytics
|
| 231 |
+
3. **Update as needed** by pushing new commits
|
| 232 |
+
4. **Consider adding:**
|
| 233 |
+
- Better error handling
|
| 234 |
+
- More configuration options
|
| 235 |
+
- Performance optimizations
|
| 236 |
+
|
| 237 |
+
---
|
| 238 |
+
|
| 239 |
+
## Need Help?
|
| 240 |
+
|
| 241 |
+
- Check Hugging Face Spaces docs: https://huggingface.co/docs/hub/spaces
|
| 242 |
+
- Review your Space logs for errors
|
| 243 |
+
- Test locally with Docker first to catch issues early
|
| 244 |
+
|
SETUP_STEPS.md
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Step-by-Step Setup for Hugging Face Space
|
| 2 |
+
|
| 3 |
+
Follow these steps to deploy your Neural Pong demo to Hugging Face Spaces.
|
| 4 |
+
|
| 5 |
+
## ✅ Pre-flight Check
|
| 6 |
+
|
| 7 |
+
Your directory structure looks good! Here's what you have:
|
| 8 |
+
|
| 9 |
+
```
|
| 10 |
+
toy-wm-hf-space/
|
| 11 |
+
├── app.py ✅ Main Flask application
|
| 12 |
+
├── Dockerfile ✅ Docker configuration
|
| 13 |
+
├── requirements.txt ✅ Python dependencies
|
| 14 |
+
├── README.md ✅ Space metadata
|
| 15 |
+
├── checkpoints/ ✅ Model checkpoint (225MB)
|
| 16 |
+
│ └── ckpt-step=053700-metric=0.00092727.pt
|
| 17 |
+
├── configs/
|
| 18 |
+
│ └── inference.yaml ✅ Model config (checkpoint path correct)
|
| 19 |
+
├── static/
|
| 20 |
+
│ └── index.html ✅ Frontend
|
| 21 |
+
└── src/ ✅ All source code (15 files)
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
## Step 1: Create Your Hugging Face Space
|
| 25 |
+
|
| 26 |
+
1. **Go to:** https://huggingface.co/spaces
|
| 27 |
+
2. **Click:** "Create new Space" button
|
| 28 |
+
3. **Fill in:**
|
| 29 |
+
- **Space name:** `neural-pong` (or your choice)
|
| 30 |
+
- **SDK:** **Docker** ⚠️ Must be Docker!
|
| 31 |
+
- **Hardware:** **GPU** → **T4 small** (minimum)
|
| 32 |
+
- **Visibility:** Public or Private
|
| 33 |
+
4. **Click:** "Create Space"
|
| 34 |
+
|
| 35 |
+
You'll get a URL like: `https://huggingface.co/spaces/YOUR_USERNAME/neural-pong`
|
| 36 |
+
|
| 37 |
+
## Step 2: Initialize Git Repository
|
| 38 |
+
|
| 39 |
+
```bash
|
| 40 |
+
cd /share/u/wendler/code/toy-wm-hf-space
|
| 41 |
+
|
| 42 |
+
# Initialize git (if not already done)
|
| 43 |
+
git init
|
| 44 |
+
|
| 45 |
+
# Add all files
|
| 46 |
+
git add .
|
| 47 |
+
|
| 48 |
+
# Make initial commit
|
| 49 |
+
git commit -m "Initial commit: Neural Pong demo"
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## Step 3: Connect to Your Hugging Face Space
|
| 53 |
+
|
| 54 |
+
Replace `YOUR_USERNAME` and `SPACE_NAME` with your actual values:
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
# Add your Space as remote
|
| 58 |
+
git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/SPACE_NAME
|
| 59 |
+
|
| 60 |
+
# Push everything
|
| 61 |
+
git push -u origin main
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
**Example:**
|
| 65 |
+
```bash
|
| 66 |
+
git remote add origin https://huggingface.co/spaces/johndoe/neural-pong
|
| 67 |
+
git push -u origin main
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
**Note:** The checkpoint is 225MB, so this may take a few minutes to upload.
|
| 71 |
+
|
| 72 |
+
## Step 4: Monitor the Build
|
| 73 |
+
|
| 74 |
+
1. **Go to your Space page** on Hugging Face
|
| 75 |
+
2. **Click the "Logs" tab** to watch the build progress
|
| 76 |
+
3. **Wait 5-15 minutes** for:
|
| 77 |
+
- Docker image build
|
| 78 |
+
- Dependency installation (PyTorch, Flask, etc.)
|
| 79 |
+
- Model loading
|
| 80 |
+
|
| 81 |
+
## Step 5: Test Your Space
|
| 82 |
+
|
| 83 |
+
Once the build completes:
|
| 84 |
+
|
| 85 |
+
1. **Visit your Space URL** (e.g., `https://huggingface.co/spaces/YOUR_USERNAME/neural-pong`)
|
| 86 |
+
2. **You should see:**
|
| 87 |
+
- Loading spinner while model loads
|
| 88 |
+
- Controls for FPS and diffusion steps
|
| 89 |
+
- "Start Stream" button
|
| 90 |
+
3. **Test the game:**
|
| 91 |
+
- Click "Start Stream"
|
| 92 |
+
- Use Arrow Keys or WASD to control paddle
|
| 93 |
+
- Verify frames are generating
|
| 94 |
+
|
| 95 |
+
## Troubleshooting
|
| 96 |
+
|
| 97 |
+
### Build Fails?
|
| 98 |
+
|
| 99 |
+
**Check the Logs tab for errors:**
|
| 100 |
+
|
| 101 |
+
- **Missing dependencies?** → Check `requirements.txt`
|
| 102 |
+
- **Checkpoint not found?** → Verify path in `configs/inference.yaml`
|
| 103 |
+
- **GPU errors?** → Ensure GPU is enabled in Space settings
|
| 104 |
+
- **Port errors?** → Should use port 7860 automatically
|
| 105 |
+
|
| 106 |
+
### Model Won't Load?
|
| 107 |
+
|
| 108 |
+
1. **Verify checkpoint path** in `configs/inference.yaml`:
|
| 109 |
+
```yaml
|
| 110 |
+
checkpoint: "checkpoints/ckpt-step=053700-metric=0.00092727.pt"
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
2. **Check checkpoint exists:**
|
| 114 |
+
```bash
|
| 115 |
+
ls -lh checkpoints/ckpt-step=053700-metric=0.00092727.pt
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
3. **Look for errors** in the Logs tab
|
| 119 |
+
|
| 120 |
+
### Out of Memory?
|
| 121 |
+
|
| 122 |
+
- **Upgrade GPU** in Space settings (T4 medium or larger)
|
| 123 |
+
- **Reduce batch size** if applicable
|
| 124 |
+
- **Check checkpoint size** (225MB is reasonable)
|
| 125 |
+
|
| 126 |
+
## Testing Locally (Optional)
|
| 127 |
+
|
| 128 |
+
Before deploying, you can test locally with Docker:
|
| 129 |
+
|
| 130 |
+
```bash
|
| 131 |
+
cd /share/u/wendler/code/toy-wm-hf-space
|
| 132 |
+
|
| 133 |
+
# Build Docker image
|
| 134 |
+
docker build -t neural-pong .
|
| 135 |
+
|
| 136 |
+
# Run container
|
| 137 |
+
docker run -p 7860:7860 --gpus all neural-pong
|
| 138 |
+
|
| 139 |
+
# Visit http://localhost:7860
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
**Note:** Requires Docker and NVIDIA Docker runtime for GPU support.
|
| 143 |
+
|
| 144 |
+
## Updating Your Space
|
| 145 |
+
|
| 146 |
+
After making changes:
|
| 147 |
+
|
| 148 |
+
```bash
|
| 149 |
+
git add .
|
| 150 |
+
git commit -m "Your update message"
|
| 151 |
+
git push origin main
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
Hugging Face will automatically rebuild and redeploy.
|
| 155 |
+
|
| 156 |
+
## File Checklist
|
| 157 |
+
|
| 158 |
+
Before pushing, verify:
|
| 159 |
+
|
| 160 |
+
- ✅ `app.py` exists and is executable
|
| 161 |
+
- ✅ `Dockerfile` exists
|
| 162 |
+
- ✅ `requirements.txt` has all dependencies
|
| 163 |
+
- ✅ `checkpoints/ckpt-step=053700-metric=0.00092727.pt` exists (225MB)
|
| 164 |
+
- ✅ `configs/inference.yaml` has correct checkpoint path
|
| 165 |
+
- ✅ `static/index.html` exists
|
| 166 |
+
- ✅ `src/` directory has all necessary files
|
| 167 |
+
|
| 168 |
+
## Quick Reference
|
| 169 |
+
|
| 170 |
+
**Your Space URL format:**
|
| 171 |
+
```
|
| 172 |
+
https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
**Git remote format:**
|
| 176 |
+
```bash
|
| 177 |
+
git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
**Key files:**
|
| 181 |
+
- `app.py` - Main application (port 7860)
|
| 182 |
+
- `Dockerfile` - Container config
|
| 183 |
+
- `requirements.txt` - Dependencies
|
| 184 |
+
- `configs/inference.yaml` - Model config
|
| 185 |
+
|
| 186 |
+
## Next Steps After Deployment
|
| 187 |
+
|
| 188 |
+
1. **Share your Space** with others
|
| 189 |
+
2. **Monitor usage** in Space analytics
|
| 190 |
+
3. **Update as needed** by pushing new commits
|
| 191 |
+
4. **Consider adding:**
|
| 192 |
+
- Better error handling
|
| 193 |
+
- Performance metrics
|
| 194 |
+
- More configuration options
|
| 195 |
+
|
| 196 |
+
---
|
| 197 |
+
|
| 198 |
+
**Ready to deploy?** Follow Step 1 above! 🚀
|
| 199 |
+
|
| 200 |
+
For more details, see `QUICKSTART.md` or `DEPLOYMENT.md`.
|
| 201 |
+
|
SOURCE_FILES.md
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Source Files Included in Hugging Face Space
|
| 2 |
+
|
| 3 |
+
This document lists all the source files included in the deployment. Only inference-related code is included - all training code has been removed.
|
| 4 |
+
|
| 5 |
+
## File Structure
|
| 6 |
+
|
| 7 |
+
```
|
| 8 |
+
src/
|
| 9 |
+
├── __init__.py # Package init
|
| 10 |
+
├── config.py # Configuration classes (Config, TransformerConfig, etc.)
|
| 11 |
+
├── datasets/
|
| 12 |
+
│ ├── __init__.py # Datasets package init
|
| 13 |
+
│ └── pong1m.py # Dataset utilities (only fixed2frame used)
|
| 14 |
+
├── inference/
|
| 15 |
+
│ ├── __init__.py # Inference package init
|
| 16 |
+
│ └── sampling.py # Diffusion sampling function
|
| 17 |
+
├── models/
|
| 18 |
+
│ ├── __init__.py # Models package init
|
| 19 |
+
│ └── dit_dforce.py # CausalDit model and get_model function
|
| 20 |
+
├── nn/
|
| 21 |
+
│ ├── __init__.py # NN package init
|
| 22 |
+
│ ├── attn.py # Attention mechanisms and KVCache
|
| 23 |
+
│ ├── geglu.py # GEGLU activation
|
| 24 |
+
│ ├── patch.py # Patch/UnPatch for image tokens
|
| 25 |
+
│ └── pe.py # Positional encodings (RoPE, FrameRoPE, etc.)
|
| 26 |
+
└── utils/
|
| 27 |
+
├── __init__.py # Utils package init
|
| 28 |
+
└── checkpoint.py # Model loading utilities
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Total: 15 Python files
|
| 32 |
+
|
| 33 |
+
## Files Removed (Training Code)
|
| 34 |
+
|
| 35 |
+
- ❌ `src/main.py` - Training script
|
| 36 |
+
- ❌ `src/main_dmd.py` - Training script
|
| 37 |
+
- ❌ `src/trainers/` - All training code (5 files)
|
| 38 |
+
- ❌ `src/datasets/pong1m_embedding.py` - Not used
|
| 39 |
+
- ❌ `src/datasets/pong1m_gpt.py` - Not used
|
| 40 |
+
- ❌ `src/models/dit.py` - Alternative model (not used)
|
| 41 |
+
- ❌ `src/nn/norm.py` - Custom norm (not used, PyTorch LayerNorm used instead)
|
| 42 |
+
- ❌ `src/utils/logging.py` - Logging utilities (not needed for inference)
|
| 43 |
+
- ❌ `src/config/` - Empty directory
|
| 44 |
+
|
| 45 |
+
## Dependencies Removed
|
| 46 |
+
|
| 47 |
+
- ✅ Removed `matplotlib` imports (not needed for inference)
|
| 48 |
+
- ✅ Removed `muon-optimizer` dependency (only used in training)
|
| 49 |
+
- ✅ Removed training data file dependencies
|
| 50 |
+
|
| 51 |
+
## Verification
|
| 52 |
+
|
| 53 |
+
All necessary classes and functions are included:
|
| 54 |
+
- ✅ `load_model_from_config` - Model loading
|
| 55 |
+
- ✅ `sample` - Diffusion sampling
|
| 56 |
+
- ✅ `fixed2frame` - Frame conversion
|
| 57 |
+
- ✅ `Config` - Configuration parsing
|
| 58 |
+
- ✅ `CausalDit` - Model class
|
| 59 |
+
- ✅ `KVCache` - KV caching for inference
|
| 60 |
+
- ✅ All NN components (Attention, GEGLU, Patch, Positional Encodings)
|
| 61 |
+
|
| 62 |
+
The codebase is now minimal and contains only what's needed for inference.
|
| 63 |
+
|
START_HERE.md
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚀 Your Hugging Face Space is Ready!
|
| 2 |
+
|
| 3 |
+
Everything is set up and ready to deploy. Here's what you need to do:
|
| 4 |
+
|
| 5 |
+
## Quick Start (3 Steps)
|
| 6 |
+
|
| 7 |
+
### 1. Create Your Space
|
| 8 |
+
- Go to: https://huggingface.co/spaces
|
| 9 |
+
- Click "Create new Space"
|
| 10 |
+
- Name: `neural-pong`
|
| 11 |
+
- SDK: **Docker** ⚠️
|
| 12 |
+
- Hardware: **GPU (T4 small)**
|
| 13 |
+
|
| 14 |
+
### 2. Push Your Code
|
| 15 |
+
```bash
|
| 16 |
+
cd /share/u/wendler/code/toy-wm-hf-space
|
| 17 |
+
|
| 18 |
+
git init
|
| 19 |
+
git add .
|
| 20 |
+
git commit -m "Initial commit"
|
| 21 |
+
|
| 22 |
+
# Replace with your actual Space URL
|
| 23 |
+
git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
|
| 24 |
+
git push -u origin main
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
### 3. Wait & Test
|
| 28 |
+
- Watch build in "Logs" tab (5-15 min)
|
| 29 |
+
- Visit your Space URL when done
|
| 30 |
+
- Click "Start Stream" and play!
|
| 31 |
+
|
| 32 |
+
## ✅ What's Ready
|
| 33 |
+
|
| 34 |
+
- ✅ **app.py** - Flask app (port 7860, no user limits)
|
| 35 |
+
- ✅ **Dockerfile** - Container config
|
| 36 |
+
- ✅ **requirements.txt** - All dependencies
|
| 37 |
+
- ✅ **checkpoints/** - Model file (225MB)
|
| 38 |
+
- ✅ **configs/inference.yaml** - Config (checkpoint path correct)
|
| 39 |
+
- ✅ **static/index.html** - Frontend
|
| 40 |
+
- ✅ **src/** - All source code (15 files, cleaned)
|
| 41 |
+
|
| 42 |
+
## 📚 Documentation
|
| 43 |
+
|
| 44 |
+
- **SETUP_STEPS.md** - Detailed step-by-step guide
|
| 45 |
+
- **QUICKSTART.md** - Quick reference
|
| 46 |
+
- **DEPLOYMENT.md** - Technical details
|
| 47 |
+
- **README.md** - Space description
|
| 48 |
+
|
| 49 |
+
## 🔍 Verify Before Pushing
|
| 50 |
+
|
| 51 |
+
```bash
|
| 52 |
+
# Check checkpoint exists
|
| 53 |
+
ls -lh checkpoints/ckpt-step=053700-metric=0.00092727.pt
|
| 54 |
+
|
| 55 |
+
# Check config path
|
| 56 |
+
grep checkpoint configs/inference.yaml
|
| 57 |
+
|
| 58 |
+
# Check main files
|
| 59 |
+
ls app.py Dockerfile requirements.txt
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
## 💡 Tips
|
| 63 |
+
|
| 64 |
+
- **Large file upload:** The checkpoint is 225MB, Git is recommended
|
| 65 |
+
- **Build time:** 5-15 minutes (PyTorch installation)
|
| 66 |
+
- **GPU required:** Make sure GPU is enabled in Space settings
|
| 67 |
+
- **Port:** Automatically uses 7860 (HF Spaces default)
|
| 68 |
+
|
| 69 |
+
## 🆘 Need Help?
|
| 70 |
+
|
| 71 |
+
1. Check **SETUP_STEPS.md** for detailed instructions
|
| 72 |
+
2. Check Space **Logs** tab for build errors
|
| 73 |
+
3. Verify all files are present (see checklist above)
|
| 74 |
+
|
| 75 |
+
---
|
| 76 |
+
|
| 77 |
+
**Ready?** Start with Step 1 above! 🎮
|
| 78 |
+
|
TROUBLESHOOTING.md
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Troubleshooting: Repository Not Found
|
| 2 |
+
|
| 3 |
+
The error "Repository not found" usually means one of these issues:
|
| 4 |
+
|
| 5 |
+
## Issue 1: Space Doesn't Exist Yet
|
| 6 |
+
|
| 7 |
+
**You need to create the Space on Hugging Face first!**
|
| 8 |
+
|
| 9 |
+
1. Go to: https://huggingface.co/spaces
|
| 10 |
+
2. Click "Create new Space"
|
| 11 |
+
3. Fill in:
|
| 12 |
+
- **Space name:** `pong` (or your choice)
|
| 13 |
+
- **SDK:** Docker
|
| 14 |
+
- **Hardware:** GPU (T4 small)
|
| 15 |
+
4. Click "Create Space"
|
| 16 |
+
|
| 17 |
+
**Then** come back and push your code.
|
| 18 |
+
|
| 19 |
+
## Issue 2: Wrong Remote URL
|
| 20 |
+
|
| 21 |
+
Your remote is currently set to SSH: `git@hf.co:spaces/wendlerc/pong`
|
| 22 |
+
|
| 23 |
+
**For Hugging Face Spaces, use HTTPS instead:**
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
cd /share/u/wendler/code/toy-wm-hf-space
|
| 27 |
+
|
| 28 |
+
# Remove the old remote
|
| 29 |
+
git remote remove origin
|
| 30 |
+
|
| 31 |
+
# Add the correct HTTPS remote
|
| 32 |
+
git remote add origin https://huggingface.co/spaces/wendlerc/pong
|
| 33 |
+
|
| 34 |
+
# Now try pushing
|
| 35 |
+
git push -u origin main
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Issue 3: Wrong Space Name
|
| 39 |
+
|
| 40 |
+
Make sure the Space name matches exactly. If you created it with a different name, update the URL:
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
# Check what remote you have
|
| 44 |
+
git remote -v
|
| 45 |
+
|
| 46 |
+
# Update to correct URL (replace with your actual Space name)
|
| 47 |
+
git remote set-url origin https://huggingface.co/spaces/wendlerc/YOUR_SPACE_NAME
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## Quick Fix Commands
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
cd /share/u/wendler/code/toy-wm-hf-space
|
| 54 |
+
|
| 55 |
+
# 1. Check current remote
|
| 56 |
+
git remote -v
|
| 57 |
+
|
| 58 |
+
# 2. Remove old remote
|
| 59 |
+
git remote remove origin
|
| 60 |
+
|
| 61 |
+
# 3. Add HTTPS remote (replace wendlerc/pong with your actual Space)
|
| 62 |
+
git remote add origin https://huggingface.co/spaces/wendlerc/pong
|
| 63 |
+
|
| 64 |
+
# 4. Verify remote
|
| 65 |
+
git remote -v
|
| 66 |
+
|
| 67 |
+
# 5. Push
|
| 68 |
+
git push -u origin main
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
## Verify Your Space Exists
|
| 72 |
+
|
| 73 |
+
1. Go to: https://huggingface.co/spaces/wendlerc
|
| 74 |
+
2. Check if `pong` Space exists
|
| 75 |
+
3. If not, create it first!
|
| 76 |
+
|
| 77 |
+
## Alternative: Use Hugging Face CLI
|
| 78 |
+
|
| 79 |
+
If you have `huggingface-cli` installed:
|
| 80 |
+
|
| 81 |
+
```bash
|
| 82 |
+
# Login
|
| 83 |
+
huggingface-cli login
|
| 84 |
+
|
| 85 |
+
# Create Space (if it doesn't exist)
|
| 86 |
+
huggingface-cli repo create wendlerc/pong --type space --sdk docker
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
Then push your code.
|
| 90 |
+
|
app.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Pong backend (GPU, eager) for Hugging Face Spaces.
|
| 4 |
+
Broadcasts readiness via Socket.IO so the frontend can auto-hide a loading overlay once the model is ready.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
# Eventlet must be imported first and monkey-patched before other imports
|
| 8 |
+
import eventlet
|
| 9 |
+
eventlet.monkey_patch()
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
import os
|
| 13 |
+
import time
|
| 14 |
+
import threading
|
| 15 |
+
import base64
|
| 16 |
+
import traceback
|
| 17 |
+
from contextlib import contextmanager
|
| 18 |
+
from io import BytesIO
|
| 19 |
+
|
| 20 |
+
import torch as t
|
| 21 |
+
import torch._dynamo as _dynamo
|
| 22 |
+
import numpy as np
|
| 23 |
+
from PIL import Image
|
| 24 |
+
from flask import Flask, request, jsonify, send_from_directory
|
| 25 |
+
from flask_cors import CORS
|
| 26 |
+
from flask_socketio import SocketIO, emit
|
| 27 |
+
|
| 28 |
+
# --------------------------
|
| 29 |
+
# Project imports
|
| 30 |
+
# --------------------------
|
| 31 |
+
project_root = os.path.dirname(os.path.abspath(__file__))
|
| 32 |
+
if project_root not in sys.path:
|
| 33 |
+
sys.path.insert(0, project_root)
|
| 34 |
+
|
| 35 |
+
from src.utils.checkpoint import load_model_from_config
|
| 36 |
+
from src.inference.sampling import sample
|
| 37 |
+
from src.datasets.pong1m import fixed2frame
|
| 38 |
+
from src.config import Config
|
| 39 |
+
|
| 40 |
+
# --------------------------
|
| 41 |
+
# App setup
|
| 42 |
+
# --------------------------
|
| 43 |
+
app = Flask(__name__, static_folder='static')
|
| 44 |
+
CORS(app)
|
| 45 |
+
# Configure SocketIO - use eventlet for proper WebSocket support
|
| 46 |
+
socketio = SocketIO(
|
| 47 |
+
app,
|
| 48 |
+
cors_allowed_origins="*",
|
| 49 |
+
async_mode='eventlet',
|
| 50 |
+
logger=False,
|
| 51 |
+
engineio_logger=False,
|
| 52 |
+
ping_timeout=60,
|
| 53 |
+
ping_interval=25,
|
| 54 |
+
max_http_buffer_size=1e8 # Allow larger messages
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# --------------------------
|
| 58 |
+
# Globals
|
| 59 |
+
# --------------------------
|
| 60 |
+
model = None
|
| 61 |
+
pred2frame = None
|
| 62 |
+
device = None
|
| 63 |
+
|
| 64 |
+
server_ready = False # <--- readiness flag
|
| 65 |
+
|
| 66 |
+
stream_lock = threading.Lock()
|
| 67 |
+
stream_thread = None
|
| 68 |
+
stream_running = False
|
| 69 |
+
latest_action = 1 # 0=init, 1=nothing, 2=up, 3=down
|
| 70 |
+
target_fps = 30
|
| 71 |
+
frame_index = 0
|
| 72 |
+
|
| 73 |
+
noise_buf = None # (1,1,3,24,24) on GPU
|
| 74 |
+
action_buf = None # (1,1) long on GPU
|
| 75 |
+
cpu_png_buffer = None # BytesIO; reused
|
| 76 |
+
|
| 77 |
+
step_once = None
|
| 78 |
+
|
| 79 |
+
# --------------------------
|
| 80 |
+
# Perf (new API)
|
| 81 |
+
# --------------------------
|
| 82 |
+
t.backends.cudnn.benchmark = True
|
| 83 |
+
t.backends.cudnn.conv.fp32_precision = "tf32"
|
| 84 |
+
t.backends.cuda.matmul.fp32_precision = "high"
|
| 85 |
+
|
| 86 |
+
# --------------------------
|
| 87 |
+
# Debug helpers
|
| 88 |
+
# --------------------------
|
| 89 |
+
def _shape(x):
|
| 90 |
+
try:
|
| 91 |
+
return f"{tuple(x.shape)} | {x.dtype} | {x.device}"
|
| 92 |
+
except Exception:
|
| 93 |
+
return "<?>"
|
| 94 |
+
|
| 95 |
+
def _shape_attr(obj, name):
|
| 96 |
+
try:
|
| 97 |
+
ten = getattr(obj, name, None)
|
| 98 |
+
return None if ten is None else _shape(ten)
|
| 99 |
+
except Exception:
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
def _fail(msg, extra=None):
|
| 103 |
+
lines = [f"[GEN ERROR] {msg}"]
|
| 104 |
+
if extra:
|
| 105 |
+
for k, v in extra.items():
|
| 106 |
+
lines.append(f" - {k}: {v}")
|
| 107 |
+
raise RuntimeError("\n".join(lines))
|
| 108 |
+
|
| 109 |
+
@contextmanager
|
| 110 |
+
def log_step_debug(action_tensor=None, noise_tensor=None):
|
| 111 |
+
try:
|
| 112 |
+
yield
|
| 113 |
+
except Exception as e:
|
| 114 |
+
tb = traceback.format_exc(limit=6)
|
| 115 |
+
_fail("Step failed",
|
| 116 |
+
extra={
|
| 117 |
+
"action": _shape(action_tensor),
|
| 118 |
+
"noise": _shape(noise_tensor),
|
| 119 |
+
"model.device": str(device),
|
| 120 |
+
"cache.keys": _shape_attr(getattr(model, "cache", None), "keys"),
|
| 121 |
+
"cache.values": _shape_attr(getattr(model, "cache", None), "values"),
|
| 122 |
+
"frame_index": str(frame_index),
|
| 123 |
+
"exception": f"{type(e).__name__}: {e}",
|
| 124 |
+
"trace": tb.strip()
|
| 125 |
+
})
|
| 126 |
+
|
| 127 |
+
# --------------------------
|
| 128 |
+
# Utilities
|
| 129 |
+
# --------------------------
|
| 130 |
+
def _ensure_cuda():
|
| 131 |
+
if not t.cuda.is_available():
|
| 132 |
+
raise RuntimeError("CUDA GPU required; torch.cuda.is_available() is False.")
|
| 133 |
+
return t.device("cuda:0")
|
| 134 |
+
|
| 135 |
+
def _png_base64_from_uint8(frame_uint8) -> str:
|
| 136 |
+
global cpu_png_buffer
|
| 137 |
+
if cpu_png_buffer is None:
|
| 138 |
+
cpu_png_buffer = BytesIO()
|
| 139 |
+
else:
|
| 140 |
+
cpu_png_buffer.seek(0)
|
| 141 |
+
cpu_png_buffer.truncate(0)
|
| 142 |
+
Image.fromarray(frame_uint8).save(cpu_png_buffer, format="PNG")
|
| 143 |
+
return base64.b64encode(cpu_png_buffer.getvalue()).decode()
|
| 144 |
+
|
| 145 |
+
def _reset_cache_fresh():
|
| 146 |
+
model.cache.reset()
|
| 147 |
+
|
| 148 |
+
def _broadcast_ready():
|
| 149 |
+
"""Tell all clients whether the server is ready."""
|
| 150 |
+
socketio.emit('server_status', {'ready': server_ready, 'busy': False})
|
| 151 |
+
|
| 152 |
+
# --------------------------
|
| 153 |
+
# Model init (pure eager) & warmup
|
| 154 |
+
# --------------------------
|
| 155 |
+
def initialize_model():
|
| 156 |
+
global model, pred2frame, device
|
| 157 |
+
global noise_buf, action_buf, step_once, server_ready
|
| 158 |
+
|
| 159 |
+
t_start = time.time()
|
| 160 |
+
print("Loading model and preparing GPU runtime...")
|
| 161 |
+
device = _ensure_cuda()
|
| 162 |
+
|
| 163 |
+
config_path = os.path.join(project_root, "configs/inference.yaml")
|
| 164 |
+
|
| 165 |
+
cfg = Config.from_yaml(config_path)
|
| 166 |
+
checkpoint_path = cfg.model.checkpoint
|
| 167 |
+
|
| 168 |
+
model = load_model_from_config(config_path, checkpoint_path=checkpoint_path, strict=False)
|
| 169 |
+
model.to(device) # Move model to GPU before activating cache
|
| 170 |
+
model.eval()
|
| 171 |
+
|
| 172 |
+
model.activate_caching(1, 300) # Cache will now be created on the same device as model
|
| 173 |
+
|
| 174 |
+
# Use fixed2frame directly instead of get_loader to avoid loading data files
|
| 175 |
+
globals()["pred2frame"] = fixed2frame
|
| 176 |
+
|
| 177 |
+
H = W = 24
|
| 178 |
+
noise_buf = t.empty((1, 1, 3, H, W), device=device)
|
| 179 |
+
action_buf = t.empty((1, 1), dtype=t.long, device=device)
|
| 180 |
+
|
| 181 |
+
@_dynamo.disable
|
| 182 |
+
def _step(model_, action_scalar_long: int, n_steps: int, cfg: float, clamp: bool):
|
| 183 |
+
# Match the notebook logic exactly: create fresh noise each time
|
| 184 |
+
noise = t.randn(1, 1, 3, 24, 24, device=device)
|
| 185 |
+
action_buf.fill_(int(action_scalar_long))
|
| 186 |
+
|
| 187 |
+
assert action_buf.shape == (1, 1) and action_buf.dtype == t.long and action_buf.device == device, \
|
| 188 |
+
f"action_buf wrong: { _shape(action_buf) }"
|
| 189 |
+
assert noise.shape == (1, 1, 3, 24, 24) and noise.device == device, \
|
| 190 |
+
f"noise wrong: { _shape(noise) }"
|
| 191 |
+
|
| 192 |
+
# Debug: Check cache state before sampling
|
| 193 |
+
if model_.cache is not None:
|
| 194 |
+
cache_loc = model_.cache.local_location
|
| 195 |
+
if cache_loc == 0:
|
| 196 |
+
# Cache is empty, this should be fine for the first frame
|
| 197 |
+
pass
|
| 198 |
+
elif cache_loc > 0:
|
| 199 |
+
# Check if cache has valid data
|
| 200 |
+
k_test, v_test = model_.cache.get(0)
|
| 201 |
+
if k_test.shape[1] == 0:
|
| 202 |
+
print(f"Warning: Cache returned empty tensors at frame {frame_index}, resetting...")
|
| 203 |
+
_reset_cache_fresh()
|
| 204 |
+
|
| 205 |
+
# Sample with the fresh noise (matching notebook: sample(model, noise, actions[:, aidx:aidx+1], ...))
|
| 206 |
+
z = sample(model_, noise, action_buf, num_steps=n_steps, cfg=cfg, negative_actions=None)
|
| 207 |
+
|
| 208 |
+
# Update cache location after sample (matching notebook: model.cache.update_global_location(1))
|
| 209 |
+
model_.cache.update_global_location(1)
|
| 210 |
+
|
| 211 |
+
if clamp:
|
| 212 |
+
z = t.clamp(z, -1, 1)
|
| 213 |
+
return z
|
| 214 |
+
|
| 215 |
+
globals()["step_once"] = _step
|
| 216 |
+
print("Mode: eager (no torch.compile)")
|
| 217 |
+
|
| 218 |
+
# Warmup
|
| 219 |
+
_reset_cache_fresh()
|
| 220 |
+
with t.inference_mode(), t.autocast(device_type="cuda", dtype=t.bfloat16):
|
| 221 |
+
for _ in range(4):
|
| 222 |
+
with log_step_debug(action_tensor=action_buf, noise_tensor=noise_buf):
|
| 223 |
+
_ = step_once(model, action_scalar_long=1, n_steps=4, cfg=0.0, clamp=True)
|
| 224 |
+
|
| 225 |
+
server_ready = True
|
| 226 |
+
print(f"Model ready on {device}")
|
| 227 |
+
_broadcast_ready()
|
| 228 |
+
return model, pred2frame
|
| 229 |
+
|
| 230 |
+
# --------------------------
|
| 231 |
+
# Fixed-FPS streaming worker
|
| 232 |
+
# --------------------------
|
| 233 |
+
class FrameScheduler(threading.Thread):
|
| 234 |
+
def __init__(self, fps=30, n_steps=8, cfg=0.0, clamp=True):
|
| 235 |
+
super().__init__(daemon=True)
|
| 236 |
+
self.frame_period = 1.0 / max(1, int(fps))
|
| 237 |
+
self.n_steps = int(n_steps)
|
| 238 |
+
self.cfg = float(cfg)
|
| 239 |
+
self.clamp = bool(clamp)
|
| 240 |
+
self._stop = threading.Event()
|
| 241 |
+
# FPS tracking
|
| 242 |
+
self.frame_times = []
|
| 243 |
+
self.last_frame_time = None
|
| 244 |
+
|
| 245 |
+
def stop(self):
|
| 246 |
+
self._stop.set()
|
| 247 |
+
|
| 248 |
+
def run(self):
|
| 249 |
+
global frame_index, latest_action
|
| 250 |
+
next_tick = time.perf_counter()
|
| 251 |
+
while not self._stop.is_set():
|
| 252 |
+
start = time.perf_counter()
|
| 253 |
+
if start - next_tick > self.frame_period * 0.75:
|
| 254 |
+
next_tick = start + self.frame_period
|
| 255 |
+
continue
|
| 256 |
+
try:
|
| 257 |
+
with stream_lock:
|
| 258 |
+
action = int(latest_action)
|
| 259 |
+
with t.inference_mode(), t.autocast(device_type="cuda", dtype=t.bfloat16):
|
| 260 |
+
with log_step_debug(action_tensor=action_buf, noise_tensor=noise_buf):
|
| 261 |
+
z = step_once(model, action_scalar_long=action,
|
| 262 |
+
n_steps=self.n_steps, cfg=self.cfg, clamp=self.clamp)
|
| 263 |
+
frames_btchw = pred2frame(z)
|
| 264 |
+
# Debug: check what pred2frame returns
|
| 265 |
+
if frame_index < 3:
|
| 266 |
+
print(f"Frame {frame_index}: z range [{z.min().item():.3f}, {z.max().item():.3f}], "
|
| 267 |
+
f"frames_btchw dtype={frames_btchw.dtype}, range [{frames_btchw.min().item()}, {frames_btchw.max().item()}]")
|
| 268 |
+
|
| 269 |
+
frame_arr = frames_btchw[0, 0].permute(1, 2, 0).contiguous()
|
| 270 |
+
if isinstance(frame_arr, t.Tensor):
|
| 271 |
+
frame_np = frame_arr.to("cpu", non_blocking=True).numpy()
|
| 272 |
+
else:
|
| 273 |
+
frame_np = frame_arr.astype(np.uint8, copy=False)
|
| 274 |
+
img_b64 = _png_base64_from_uint8(frame_np)
|
| 275 |
+
|
| 276 |
+
# Calculate achieved FPS
|
| 277 |
+
current_time = time.perf_counter()
|
| 278 |
+
if self.last_frame_time is not None:
|
| 279 |
+
frame_delta = current_time - self.last_frame_time
|
| 280 |
+
self.frame_times.append(frame_delta)
|
| 281 |
+
# Keep only last 30 frames for moving average
|
| 282 |
+
if len(self.frame_times) > 30:
|
| 283 |
+
self.frame_times.pop(0)
|
| 284 |
+
avg_frame_time = sum(self.frame_times) / len(self.frame_times)
|
| 285 |
+
achieved_fps = 1.0 / avg_frame_time if avg_frame_time > 0 else 0
|
| 286 |
+
else:
|
| 287 |
+
achieved_fps = 0
|
| 288 |
+
self.last_frame_time = current_time
|
| 289 |
+
|
| 290 |
+
socketio.emit('frame', {'frame': img_b64,
|
| 291 |
+
'frame_index': frame_index,
|
| 292 |
+
'action': action,
|
| 293 |
+
'fps': achieved_fps})
|
| 294 |
+
frame_index += 1
|
| 295 |
+
except Exception as e:
|
| 296 |
+
print("Generation error:", repr(e))
|
| 297 |
+
socketio.emit('error', {'message': str(e)})
|
| 298 |
+
next_tick += self.frame_period
|
| 299 |
+
now = time.perf_counter()
|
| 300 |
+
sleep_for = next_tick - now
|
| 301 |
+
if sleep_for > 0:
|
| 302 |
+
time.sleep(sleep_for)
|
| 303 |
+
|
| 304 |
+
# --------------------------
|
| 305 |
+
# Routes
|
| 306 |
+
# --------------------------
|
| 307 |
+
@app.route('/')
|
| 308 |
+
def index():
|
| 309 |
+
return send_from_directory('static', 'index.html')
|
| 310 |
+
|
| 311 |
+
@app.errorhandler(500)
|
| 312 |
+
def handle_500(e):
|
| 313 |
+
"""Handle WSGI errors gracefully"""
|
| 314 |
+
import traceback
|
| 315 |
+
print(f"Flask error handler caught: {e}")
|
| 316 |
+
traceback.print_exc()
|
| 317 |
+
return jsonify({'error': 'Internal server error'}), 500
|
| 318 |
+
|
| 319 |
+
@app.route('/api/health', methods=['GET'])
|
| 320 |
+
def health():
|
| 321 |
+
return jsonify({
|
| 322 |
+
'status': 'ok',
|
| 323 |
+
'ready': server_ready,
|
| 324 |
+
'model_loaded': model is not None,
|
| 325 |
+
'device': str(device) if device else None,
|
| 326 |
+
'stream_running': stream_running,
|
| 327 |
+
'target_fps': target_fps
|
| 328 |
+
})
|
| 329 |
+
|
| 330 |
+
@app.route('/api/generate', methods=['POST'])
|
| 331 |
+
def generate_frames():
|
| 332 |
+
try:
|
| 333 |
+
if not server_ready:
|
| 334 |
+
return jsonify({'success': False, 'error': 'Server not ready'}), 503
|
| 335 |
+
|
| 336 |
+
data = request.json or {}
|
| 337 |
+
actions_list = data.get('actions', [1])
|
| 338 |
+
n_steps = int(data.get('n_steps', 8))
|
| 339 |
+
cfg = float(data.get('cfg', 0))
|
| 340 |
+
clamp = bool(data.get('clamp', True))
|
| 341 |
+
|
| 342 |
+
if len(actions_list) == 0 or actions_list[0] != 0:
|
| 343 |
+
actions_list = [0] + actions_list
|
| 344 |
+
|
| 345 |
+
_reset_cache_fresh()
|
| 346 |
+
|
| 347 |
+
frames_png = []
|
| 348 |
+
with t.inference_mode(), t.autocast(device_type="cuda", dtype=t.bfloat16):
|
| 349 |
+
for a in actions_list:
|
| 350 |
+
with log_step_debug(action_tensor=action_buf, noise_tensor=noise_buf):
|
| 351 |
+
z = step_once(model, action_scalar_long=int(a), n_steps=n_steps, cfg=cfg, clamp=clamp)
|
| 352 |
+
f_btchw = pred2frame(z)
|
| 353 |
+
f_arr = f_btchw[0, 0].permute(1, 2, 0).contiguous()
|
| 354 |
+
if isinstance(f_arr, t.Tensor):
|
| 355 |
+
if f_arr.dtype != t.uint8:
|
| 356 |
+
f_arr = f_arr.to(t.uint8)
|
| 357 |
+
f_np = f_arr.to("cpu", non_blocking=True).numpy()
|
| 358 |
+
else:
|
| 359 |
+
f_np = f_arr.astype(np.uint8, copy=False)
|
| 360 |
+
frames_png.append(_png_base64_from_uint8(f_np))
|
| 361 |
+
|
| 362 |
+
return jsonify({'success': True, 'frames': frames_png, 'num_frames': len(frames_png)})
|
| 363 |
+
|
| 364 |
+
except Exception as e:
|
| 365 |
+
print("Batch generation error:", repr(e))
|
| 366 |
+
return jsonify({'success': False, 'error': str(e)}), 500
|
| 367 |
+
|
| 368 |
+
# --------------------------
|
| 369 |
+
# Socket events & helpers
|
| 370 |
+
# --------------------------
|
| 371 |
+
def start_stream(n_steps=8, cfg=0.0, fps=30, clamp=True):
|
| 372 |
+
global stream_thread, stream_running, frame_index, target_fps, latest_action
|
| 373 |
+
if not server_ready:
|
| 374 |
+
_broadcast_ready()
|
| 375 |
+
raise RuntimeError("Server not ready")
|
| 376 |
+
with stream_lock:
|
| 377 |
+
stop_stream()
|
| 378 |
+
target_fps = int(fps)
|
| 379 |
+
frame_index = 0
|
| 380 |
+
_reset_cache_fresh()
|
| 381 |
+
latest_action = 0 # first action = 0 (init)
|
| 382 |
+
stream_thread = FrameScheduler(fps=target_fps, n_steps=n_steps, cfg=cfg, clamp=clamp)
|
| 383 |
+
stream_running = True
|
| 384 |
+
stream_thread.start()
|
| 385 |
+
|
| 386 |
+
def stop_stream():
|
| 387 |
+
global stream_thread, stream_running
|
| 388 |
+
if stream_thread is not None:
|
| 389 |
+
stream_thread.stop()
|
| 390 |
+
stream_thread.join(timeout=1.0)
|
| 391 |
+
stream_thread = None
|
| 392 |
+
stream_running = False
|
| 393 |
+
|
| 394 |
+
@socketio.on_error_default
|
| 395 |
+
def default_error_handler(e):
|
| 396 |
+
print(f"SocketIO error: {e}")
|
| 397 |
+
import traceback
|
| 398 |
+
traceback.print_exc()
|
| 399 |
+
|
| 400 |
+
@socketio.on('connect')
|
| 401 |
+
def handle_connect():
|
| 402 |
+
try:
|
| 403 |
+
sid = request.sid
|
| 404 |
+
print(f'Client connected: {sid}')
|
| 405 |
+
|
| 406 |
+
# Immediately tell the new client current readiness
|
| 407 |
+
emit('server_status', {
|
| 408 |
+
'ready': server_ready,
|
| 409 |
+
'busy': False
|
| 410 |
+
})
|
| 411 |
+
emit('connected', {
|
| 412 |
+
'status': 'connected',
|
| 413 |
+
'model_loaded': model is not None,
|
| 414 |
+
'ready': server_ready
|
| 415 |
+
})
|
| 416 |
+
except Exception as e:
|
| 417 |
+
print(f"Error in handle_connect: {e}")
|
| 418 |
+
import traceback
|
| 419 |
+
traceback.print_exc()
|
| 420 |
+
|
| 421 |
+
@socketio.on('disconnect')
|
| 422 |
+
def handle_disconnect(*args):
|
| 423 |
+
sid = request.sid
|
| 424 |
+
print(f'Client disconnected: {sid}')
|
| 425 |
+
# Note: We don't stop the stream on disconnect since multiple users can be connected
|
| 426 |
+
|
| 427 |
+
@socketio.on('start_stream')
|
| 428 |
+
def handle_start_stream(data):
|
| 429 |
+
try:
|
| 430 |
+
if not server_ready:
|
| 431 |
+
# Tell client to keep showing spinner
|
| 432 |
+
emit('server_status', {'ready': server_ready, 'busy': False})
|
| 433 |
+
return
|
| 434 |
+
|
| 435 |
+
n_steps = int(data.get('n_steps', 8))
|
| 436 |
+
cfg = float(data.get('cfg', 0))
|
| 437 |
+
fps = int(data.get('fps', 30))
|
| 438 |
+
clamp = bool(data.get('clamp', True))
|
| 439 |
+
print(f"Starting stream @ {fps} FPS (n_steps={n_steps}, cfg={cfg}, clamp={clamp})")
|
| 440 |
+
try:
|
| 441 |
+
start_stream(n_steps=n_steps, cfg=cfg, fps=fps, clamp=clamp)
|
| 442 |
+
emit('stream_started', {'status': 'ok'})
|
| 443 |
+
except Exception as e:
|
| 444 |
+
print(f"Error starting stream: {e}")
|
| 445 |
+
import traceback
|
| 446 |
+
traceback.print_exc()
|
| 447 |
+
emit('error', {'message': str(e)})
|
| 448 |
+
except Exception as e:
|
| 449 |
+
print(f"Error in handle_start_stream: {e}")
|
| 450 |
+
import traceback
|
| 451 |
+
traceback.print_exc()
|
| 452 |
+
emit('error', {'message': f'Failed to start stream: {str(e)}'})
|
| 453 |
+
|
| 454 |
+
@socketio.on('action')
|
| 455 |
+
def handle_action(data):
|
| 456 |
+
global latest_action
|
| 457 |
+
action = int(data.get('action', 1))
|
| 458 |
+
with stream_lock:
|
| 459 |
+
latest_action = action
|
| 460 |
+
emit('action_ack', {'received': action, 'will_apply_to_frame_index': frame_index})
|
| 461 |
+
|
| 462 |
+
@socketio.on('stop_stream')
|
| 463 |
+
def handle_stop_stream():
|
| 464 |
+
print('Stopping stream')
|
| 465 |
+
stop_stream()
|
| 466 |
+
|
| 467 |
+
# --------------------------
|
| 468 |
+
# Entrypoint
|
| 469 |
+
# --------------------------
|
| 470 |
+
if __name__ == '__main__':
|
| 471 |
+
# Start model initialization in background thread so server starts immediately
|
| 472 |
+
init_thread = threading.Thread(target=initialize_model, daemon=True)
|
| 473 |
+
init_thread.start()
|
| 474 |
+
|
| 475 |
+
# Use PORT environment variable for Hugging Face Spaces, default to 7860
|
| 476 |
+
port = int(os.environ.get('PORT', 7860))
|
| 477 |
+
print(f"Starting Flask server on http://0.0.0.0:{port}")
|
| 478 |
+
print("Model will load in background...")
|
| 479 |
+
socketio.run(app, host='0.0.0.0', port=port, debug=False, allow_unsafe_werkzeug=True, use_reloader=False)
|
| 480 |
+
|
checkpoints/ckpt-step=053700-metric=0.00092727.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f3813cf639d5370bb90be4bc3974de5b6858a9cb4216458f757c0d415537d0d6
|
| 3 |
+
size 235359093
|
cleanup.sh
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Cleanup script - removes temporary files and scripts from toy-wm-hf-space
|
| 3 |
+
|
| 4 |
+
set -e
|
| 5 |
+
|
| 6 |
+
echo "🧹 Cleaning up temporary files..."
|
| 7 |
+
echo ""
|
| 8 |
+
|
| 9 |
+
CLEANUP_DIR="/share/u/wendler/code/toy-wm-hf-space"
|
| 10 |
+
|
| 11 |
+
if [ ! -d "$CLEANUP_DIR" ]; then
|
| 12 |
+
echo "⚠️ Directory $CLEANUP_DIR doesn't exist, skipping cleanup"
|
| 13 |
+
exit 0
|
| 14 |
+
fi
|
| 15 |
+
|
| 16 |
+
cd "$CLEANUP_DIR"
|
| 17 |
+
|
| 18 |
+
echo "Removing temporary scripts..."
|
| 19 |
+
rm -f push.sh push-now.sh push-force.sh setup-git.sh setup.sh fix-remote.sh fix-and-push.sh push-to-hf.py 2>/dev/null || true
|
| 20 |
+
|
| 21 |
+
echo "Removing temporary documentation files..."
|
| 22 |
+
rm -f RUN_SETUP.md TROUBLESHOOTING.md SETUP_STEPS.md SETUP_GUIDE.md QUICKSTART.md START_HERE.md 2>/dev/null || true
|
| 23 |
+
|
| 24 |
+
echo "Keeping essential files:"
|
| 25 |
+
echo " ✅ app.py"
|
| 26 |
+
echo " ✅ Dockerfile"
|
| 27 |
+
echo " ✅ requirements.txt"
|
| 28 |
+
echo " ✅ README.md"
|
| 29 |
+
echo " ✅ DEPLOYMENT.md"
|
| 30 |
+
echo " ✅ SOURCE_FILES.md"
|
| 31 |
+
echo " ✅ .gitignore"
|
| 32 |
+
echo " ✅ All source code and checkpoints"
|
| 33 |
+
|
| 34 |
+
echo ""
|
| 35 |
+
echo "✅ Cleanup complete!"
|
| 36 |
+
echo ""
|
| 37 |
+
echo "The toy-wm-hf-space directory still contains your files"
|
| 38 |
+
echo "but temporary scripts have been removed."
|
| 39 |
+
|
| 40 |
+
|
configs/inference.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
model_id: "dit_dforce"
|
| 3 |
+
width: 24
|
| 4 |
+
height: 24
|
| 5 |
+
T: 1000
|
| 6 |
+
in_channels: 3
|
| 7 |
+
n_window: 30
|
| 8 |
+
patch_size: 3
|
| 9 |
+
n_heads: 12
|
| 10 |
+
d_model: 384
|
| 11 |
+
n_blocks: 8
|
| 12 |
+
C: 5000
|
| 13 |
+
bidirectional: false
|
| 14 |
+
nocompile: false
|
| 15 |
+
checkpoint: "checkpoints/ckpt-step=053700-metric=0.00092727.pt"
|
| 16 |
+
# "experiments/dulcet-disco-547/ckpt-step=000800-metric=0.00384521.pt"
|
| 17 |
+
#"experiments/radiant-forest-398/ckpt-step=053700-metric=0.00092727.pt"
|
| 18 |
+
# "experiments/polished-paper-531/model.pt"
|
| 19 |
+
#"experiments/polished-paper-531/ckpt-step=000200-metric=0.00251065.pt"
|
| 20 |
+
#"experiments/polished-paper-531/ckpt-step=000800-metric=0.00450636.pt"
|
| 21 |
+
#checkpoint: "experiments/radiant-forest-398/ckpt-step=053700-metric=0.00092727.pt"
|
| 22 |
+
# few frame 2-step distilled model experiments/smart-waterfall-528/model.pt
|
| 23 |
+
# few frame 1-step distilled model experiments/blooming-flower-530/model.pt
|
| 24 |
+
#"experiments/rich-meadow-488/model.pt"
|
| 25 |
+
# "experiments/rich-meadow-488/ckpt-step=003700-metric=0.00309512.pt"
|
| 26 |
+
#checkpoint: "experiments/radiant-forest-398/ckpt-step=053700-metric=0.00092727.pt"
|
| 27 |
+
#checkpoint: "experiments/glad-water-486/model.pt"
|
| 28 |
+
#checkpoint: "experiments/dutiful-river-427/ckpt-step=011600-metric=0.00229805.pt"
|
| 29 |
+
#checkpoint: "experiments/iconic-paper-421/ckpt-step=001600-metric=0.00471355.pt"
|
| 30 |
+
#"experiments/radiant-forest-398/ckpt-step=053700-metric=0.00092727.pt"
|
| 31 |
+
#checkpoint: "experiments/frosty-sunset-395/ckpt-step=002100-metric=0.00160125.pt"
|
| 32 |
+
#checkpoint: "experiments/vivid-sea-390/ckpt-step=000700-metric=0.01958773.pt"
|
| 33 |
+
|
| 34 |
+
train:
|
| 35 |
+
lr1: 0.0002
|
| 36 |
+
lr2: 1.5e-6
|
| 37 |
+
betas: [0.9, 0.95]
|
| 38 |
+
weight_decay: 1.0e-5
|
| 39 |
+
max_steps: 20000
|
| 40 |
+
batch_size: 16
|
| 41 |
+
noclip: false
|
| 42 |
+
duration: 1
|
| 43 |
+
fps: 31
|
| 44 |
+
debug: false
|
| 45 |
+
p_pretrain: 0.95
|
| 46 |
+
|
| 47 |
+
wandb:
|
| 48 |
+
name: null
|
| 49 |
+
project: "toy-wm"
|
| 50 |
+
run_name: "causal-layers8-heads12-d384"
|
push-and-cleanup.sh
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Complete script: Push from pong directory and clean up
|
| 3 |
+
|
| 4 |
+
set -e
|
| 5 |
+
|
| 6 |
+
echo "🚀 Neural Pong - Complete Push and Cleanup"
|
| 7 |
+
echo "==========================================="
|
| 8 |
+
echo ""
|
| 9 |
+
|
| 10 |
+
# Step 1: Push from pong directory
|
| 11 |
+
echo "Step 1: Pushing from /share/u/wendler/code/pong"
|
| 12 |
+
echo "-----------------------------------------------"
|
| 13 |
+
cd /share/u/wendler/code/pong
|
| 14 |
+
|
| 15 |
+
if [ ! -d ".git" ]; then
|
| 16 |
+
echo "❌ Error: Not a git repository in pong directory"
|
| 17 |
+
exit 1
|
| 18 |
+
fi
|
| 19 |
+
|
| 20 |
+
# Stage and commit
|
| 21 |
+
git add .
|
| 22 |
+
if ! git diff --cached --quiet; then
|
| 23 |
+
git commit -m "Add Neural Pong application files" || git commit --amend --no-edit
|
| 24 |
+
fi
|
| 25 |
+
|
| 26 |
+
# Push
|
| 27 |
+
BRANCH=$(git branch --show-current 2>/dev/null || echo "main")
|
| 28 |
+
echo "Pushing to origin/$BRANCH..."
|
| 29 |
+
if ! git push -u origin $BRANCH 2>&1; then
|
| 30 |
+
echo "Trying force push..."
|
| 31 |
+
git push -u origin $BRANCH --force
|
| 32 |
+
fi
|
| 33 |
+
|
| 34 |
+
echo ""
|
| 35 |
+
echo "✅ Successfully pushed to Hugging Face Spaces!"
|
| 36 |
+
echo ""
|
| 37 |
+
|
| 38 |
+
# Step 2: Cleanup temporary files
|
| 39 |
+
echo "Step 2: Cleaning up temporary files"
|
| 40 |
+
echo "-----------------------------------"
|
| 41 |
+
CLEANUP_DIR="/share/u/wendler/code/toy-wm-hf-space"
|
| 42 |
+
|
| 43 |
+
if [ -d "$CLEANUP_DIR" ]; then
|
| 44 |
+
cd "$CLEANUP_DIR"
|
| 45 |
+
|
| 46 |
+
echo "Removing temporary scripts..."
|
| 47 |
+
rm -f push.sh push-now.sh push-force.sh setup-git.sh setup.sh \
|
| 48 |
+
fix-remote.sh fix-and-push.sh push-to-hf.py 2>/dev/null || true
|
| 49 |
+
|
| 50 |
+
echo "Removing temporary documentation..."
|
| 51 |
+
rm -f RUN_SETUP.md TROUBLESHOOTING.md SETUP_STEPS.md \
|
| 52 |
+
SETUP_GUIDE.md QUICKSTART.md START_HERE.md 2>/dev/null || true
|
| 53 |
+
|
| 54 |
+
echo "✅ Cleanup complete!"
|
| 55 |
+
else
|
| 56 |
+
echo "⚠️ Cleanup directory not found, skipping"
|
| 57 |
+
fi
|
| 58 |
+
|
| 59 |
+
echo ""
|
| 60 |
+
echo "==========================================="
|
| 61 |
+
echo "✅ All done!"
|
| 62 |
+
echo "==========================================="
|
| 63 |
+
echo ""
|
| 64 |
+
echo "🌐 Your Space: https://huggingface.co/spaces/wendlerc/pong"
|
| 65 |
+
echo "📁 Working directory: /share/u/wendler/code/pong"
|
| 66 |
+
echo ""
|
| 67 |
+
echo "The build should start automatically. Check the Logs tab for progress."
|
| 68 |
+
|
| 69 |
+
|
push.sh
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Push script for /share/u/wendler/code/pong
|
| 3 |
+
# This will push all files to Hugging Face Spaces
|
| 4 |
+
|
| 5 |
+
set -e
|
| 6 |
+
|
| 7 |
+
cd /share/u/wendler/code/pong
|
| 8 |
+
|
| 9 |
+
echo "🚀 Pushing Neural Pong to Hugging Face Spaces..."
|
| 10 |
+
echo ""
|
| 11 |
+
|
| 12 |
+
# Check if we're in a git repo
|
| 13 |
+
if [ ! -d ".git" ]; then
|
| 14 |
+
echo "❌ Error: Not a git repository"
|
| 15 |
+
exit 1
|
| 16 |
+
fi
|
| 17 |
+
|
| 18 |
+
# Stage all files
|
| 19 |
+
echo "📁 Staging files..."
|
| 20 |
+
git add .
|
| 21 |
+
|
| 22 |
+
# Check status
|
| 23 |
+
echo ""
|
| 24 |
+
echo "📋 Files to be committed:"
|
| 25 |
+
git status --short | head -20
|
| 26 |
+
|
| 27 |
+
# Commit changes
|
| 28 |
+
echo ""
|
| 29 |
+
if git diff --cached --quiet; then
|
| 30 |
+
echo "✅ No changes to commit"
|
| 31 |
+
else
|
| 32 |
+
echo "💾 Committing changes..."
|
| 33 |
+
git commit -m "Add Neural Pong application files" || git commit --amend --no-edit
|
| 34 |
+
echo "✅ Changes committed"
|
| 35 |
+
fi
|
| 36 |
+
|
| 37 |
+
# Check remote
|
| 38 |
+
echo ""
|
| 39 |
+
echo "🔗 Checking remote..."
|
| 40 |
+
git remote -v
|
| 41 |
+
|
| 42 |
+
# Check branch
|
| 43 |
+
BRANCH=$(git branch --show-current 2>/dev/null || echo "main")
|
| 44 |
+
echo "🌿 Current branch: $BRANCH"
|
| 45 |
+
|
| 46 |
+
# Push
|
| 47 |
+
echo ""
|
| 48 |
+
echo "📤 Pushing to Hugging Face Spaces..."
|
| 49 |
+
if git push -u origin $BRANCH 2>&1; then
|
| 50 |
+
echo ""
|
| 51 |
+
echo "✅ Successfully pushed!"
|
| 52 |
+
else
|
| 53 |
+
echo ""
|
| 54 |
+
echo "⚠️ Push failed, trying force push..."
|
| 55 |
+
git push -u origin $BRANCH --force
|
| 56 |
+
echo ""
|
| 57 |
+
echo "✅ Force pushed successfully!"
|
| 58 |
+
fi
|
| 59 |
+
|
| 60 |
+
echo ""
|
| 61 |
+
echo "🌐 Your Space is available at:"
|
| 62 |
+
echo " https://huggingface.co/spaces/wendlerc/pong"
|
| 63 |
+
echo ""
|
| 64 |
+
echo "The build should start automatically. Check the Logs tab for progress."
|
requirements.txt
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core ML framework
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
torchvision>=0.15.0
|
| 4 |
+
|
| 5 |
+
# Web framework
|
| 6 |
+
flask>=3.1.0
|
| 7 |
+
flask-cors>=6.0.0
|
| 8 |
+
flask-socketio>=5.5.0
|
| 9 |
+
eventlet>=0.40.0
|
| 10 |
+
|
| 11 |
+
# Data processing
|
| 12 |
+
numpy>=1.24.0
|
| 13 |
+
pillow>=10.0.0
|
| 14 |
+
einops>=0.7.0
|
| 15 |
+
|
| 16 |
+
# Configuration
|
| 17 |
+
pyyaml>=6.0
|
| 18 |
+
omegaconf>=2.3.0
|
| 19 |
+
|
| 20 |
+
# Type hints
|
| 21 |
+
jaxtyping>=0.2.0
|
| 22 |
+
|
| 23 |
+
# Hugging Face Hub (for model loading if needed)
|
| 24 |
+
huggingface-hub>=0.20.0
|
| 25 |
+
|
setup.sh
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Quick setup script for Hugging Face Space deployment
|
| 3 |
+
|
| 4 |
+
set -e
|
| 5 |
+
|
| 6 |
+
echo "🚀 Neural Pong - Hugging Face Space Setup"
|
| 7 |
+
echo "=========================================="
|
| 8 |
+
echo ""
|
| 9 |
+
|
| 10 |
+
# Check if we're in the right directory
|
| 11 |
+
if [ ! -f "app.py" ] || [ ! -f "Dockerfile" ]; then
|
| 12 |
+
echo "❌ Error: Please run this script from the toy-wm-hf-space directory"
|
| 13 |
+
exit 1
|
| 14 |
+
fi
|
| 15 |
+
|
| 16 |
+
# Check if checkpoint exists
|
| 17 |
+
if [ ! -f "checkpoints/ckpt-step=053700-metric=0.00092727.pt" ]; then
|
| 18 |
+
echo "❌ Error: Checkpoint file not found!"
|
| 19 |
+
echo " Expected: checkpoints/ckpt-step=053700-metric=0.00092727.pt"
|
| 20 |
+
exit 1
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
echo "✅ Checkpoint file found"
|
| 24 |
+
echo "✅ All required files present"
|
| 25 |
+
echo ""
|
| 26 |
+
|
| 27 |
+
# Check if git is initialized
|
| 28 |
+
if [ ! -d ".git" ]; then
|
| 29 |
+
echo "📦 Initializing git repository..."
|
| 30 |
+
git init
|
| 31 |
+
echo "✅ Git initialized"
|
| 32 |
+
else
|
| 33 |
+
echo "✅ Git repository already initialized"
|
| 34 |
+
fi
|
| 35 |
+
|
| 36 |
+
echo ""
|
| 37 |
+
echo "📋 Next steps:"
|
| 38 |
+
echo ""
|
| 39 |
+
echo "1. Create a Hugging Face Space:"
|
| 40 |
+
echo " - Go to https://huggingface.co/spaces"
|
| 41 |
+
echo " - Click 'Create new Space'"
|
| 42 |
+
echo " - Name: neural-pong (or your choice)"
|
| 43 |
+
echo " - SDK: Docker"
|
| 44 |
+
echo " - Hardware: GPU (T4 small or larger)"
|
| 45 |
+
echo ""
|
| 46 |
+
echo "2. Add the remote and push:"
|
| 47 |
+
echo " git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME"
|
| 48 |
+
echo " git add ."
|
| 49 |
+
echo " git commit -m 'Initial commit'"
|
| 50 |
+
echo " git push -u origin main"
|
| 51 |
+
echo ""
|
| 52 |
+
echo "3. Wait for build (5-15 minutes)"
|
| 53 |
+
echo ""
|
| 54 |
+
echo "📖 For detailed instructions, see SETUP_GUIDE.md"
|
| 55 |
+
echo ""
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (162 Bytes). View file
|
|
|
src/config.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
import yaml
|
| 4 |
+
from omegaconf import OmegaConf
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class TransformerConfig:
|
| 8 |
+
model_id : str = None
|
| 9 |
+
width : int = 24
|
| 10 |
+
height : int = 24
|
| 11 |
+
T : int = 1000
|
| 12 |
+
in_channels : int = 3
|
| 13 |
+
n_window : int = 7
|
| 14 |
+
patch_size : int = 2
|
| 15 |
+
n_heads : int = 4
|
| 16 |
+
d_model : int = 64
|
| 17 |
+
n_blocks : int = 12
|
| 18 |
+
n_heads : int = 12
|
| 19 |
+
d_model : int = 384
|
| 20 |
+
patch_size : int = 1
|
| 21 |
+
bidirectional : bool = True
|
| 22 |
+
nocompile : bool = False
|
| 23 |
+
checkpoint : str = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class TrainingConfig:
|
| 28 |
+
lr1 : float = 0.002
|
| 29 |
+
lr2 : float = 3e-5
|
| 30 |
+
betas : tuple = (0.9, 0.95)
|
| 31 |
+
weight_decay : float = 1e-5
|
| 32 |
+
max_steps : int = 26000
|
| 33 |
+
batch_size : int = 32
|
| 34 |
+
noclip : bool = False
|
| 35 |
+
duration : int = 1
|
| 36 |
+
fps : int = 7
|
| 37 |
+
in_channels : int = 3
|
| 38 |
+
debug : bool = False
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class WANDBConfig:
|
| 43 |
+
name : str = "toy-wm"
|
| 44 |
+
project : str = None
|
| 45 |
+
run_name : str = None
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class Config:
|
| 49 |
+
model: TransformerConfig
|
| 50 |
+
train: TrainingConfig
|
| 51 |
+
wandb: WANDBConfig
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def from_yaml(cls, path):
|
| 55 |
+
with open(path) as f:
|
| 56 |
+
raw_cfg = yaml.safe_load(f)
|
| 57 |
+
|
| 58 |
+
cfg = OmegaConf.create(raw_cfg)
|
| 59 |
+
return OmegaConf.structured(cls(**cfg))
|
src/datasets/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Datasets module
|
| 2 |
+
|
src/datasets/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (171 Bytes). View file
|
|
|
src/datasets/__pycache__/pong1m.cpython-311.pyc
ADDED
|
Binary file (4.54 kB). View file
|
|
|
src/datasets/pong1m.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import TensorDataset, DataLoader
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch as t
|
| 4 |
+
import numpy as np
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
|
| 7 |
+
mean = t.tensor([[[[[0.0352]],
|
| 8 |
+
[[0.1046]],
|
| 9 |
+
[[0.1046]]]]])
|
| 10 |
+
std = t.tensor([[[[[0.1066]],
|
| 11 |
+
[[0.0995]],
|
| 12 |
+
[[0.0995]]]]])
|
| 13 |
+
|
| 14 |
+
def fixed2frame(y, lam=1e-6):
|
| 15 |
+
y = y.clamp(-1, 1) * 0.5 + 0.5
|
| 16 |
+
frames = (y * 255.0).round().byte()
|
| 17 |
+
return frames
|
| 18 |
+
|
| 19 |
+
def z2frame(y, lam=1e-6, mean=mean, std=std):
|
| 20 |
+
y = y*std.to(y.dtype).to(y.device) + mean.to(y.dtype).to(y.device)
|
| 21 |
+
frames = (y.clamp(0, 1) * 255.0).round().byte()
|
| 22 |
+
return frames
|
| 23 |
+
|
| 24 |
+
def get_loader(batch_size=64, fps=30, duration=5, shuffle=True, debug=False, mode="-1,1", mean=mean, std=std, drop_duration=False):
|
| 25 |
+
frames = t.from_numpy(np.load("./datasets/pong1M/frames.npy"))
|
| 26 |
+
actions = t.from_numpy(np.load("./datasets/pong1M/actions.npy"))
|
| 27 |
+
height, width, channels = frames.shape[-3:]
|
| 28 |
+
n = frames.shape[0]//(fps*duration)
|
| 29 |
+
frames = frames[:n*fps*duration]
|
| 30 |
+
frames = frames.reshape(n, fps*duration, height, width, channels)
|
| 31 |
+
frames = frames.permute(0, 1, 4, 2, 3)
|
| 32 |
+
actions = actions[:n*fps*duration]
|
| 33 |
+
actions = actions.reshape(-1, fps*duration)
|
| 34 |
+
b, dur, c, h, w = frames.shape
|
| 35 |
+
if mode == "-1,1":
|
| 36 |
+
z = rearrange(frames, "b dur c h w -> (b dur h w) c")
|
| 37 |
+
mask = (z == t.tensor([6, 24, 24], dtype=z.dtype)).all(dim=1)
|
| 38 |
+
z = (z.float()/255.0 - 0.5)*2
|
| 39 |
+
z[mask] = 0
|
| 40 |
+
z = rearrange(z, "(b dur h w) c -> b dur c h w", b=b, dur=dur, c=c, h=h, w=w)
|
| 41 |
+
frames = z
|
| 42 |
+
pred2frame = fixed2frame
|
| 43 |
+
elif mode == "z":
|
| 44 |
+
frames = frames.float()/255.0
|
| 45 |
+
frames = (frames - mean) / (std + 1e-6)
|
| 46 |
+
pred2frame = z2frame
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError(f"Invalid mode: {mode}")
|
| 49 |
+
|
| 50 |
+
firstf = frames[0]
|
| 51 |
+
firsta = actions[0]
|
| 52 |
+
if debug:
|
| 53 |
+
frames = 0*frames + firstf[None]
|
| 54 |
+
actions = 0*actions + firsta[None]
|
| 55 |
+
frames = 0*frames + frames[:,0].unsqueeze(1)
|
| 56 |
+
if drop_duration:
|
| 57 |
+
dataset = TensorDataset(frames[:, 0], actions[:,0]*0)
|
| 58 |
+
else:
|
| 59 |
+
dataset = TensorDataset(frames, actions)
|
| 60 |
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
| 61 |
+
print(f"{frames.shape[0]//batch_size} batches")
|
| 62 |
+
return loader, pred2frame
|
src/inference/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .sampling import sample, sample_with_grad
|
src/inference/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (258 Bytes). View file
|
|
|
src/inference/__pycache__/sampling.cpython-311.pyc
ADDED
|
Binary file (2.13 kB). View file
|
|
|
src/inference/sampling.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch as t
|
| 2 |
+
|
| 3 |
+
@t.no_grad()
|
| 4 |
+
def sample(v, z, actions, num_steps=10, cfg=0, negative_actions=None):
|
| 5 |
+
return sample_with_grad(v, z, actions, num_steps, cfg, negative_actions)
|
| 6 |
+
|
| 7 |
+
def sample_with_grad(v, z, actions, num_steps=10, cfg=0, negative_actions=None):
|
| 8 |
+
device = v.device
|
| 9 |
+
ts = 1 - t.linspace(0, 1, num_steps+1, device=device)
|
| 10 |
+
ts = 3*ts/(2*ts + 1)
|
| 11 |
+
z_prev = z.clone()
|
| 12 |
+
z_prev = z_prev.to(device)
|
| 13 |
+
for i in range(len(ts)-1):
|
| 14 |
+
t_cond = ts[i].repeat(z_prev.shape[0], 1)
|
| 15 |
+
v_pred = v(z_prev.to(device), actions.to(device), t_cond.to(device))
|
| 16 |
+
if cfg > 0:
|
| 17 |
+
if negative_actions is not None:
|
| 18 |
+
v_neg = v(z_prev.to(device), negative_actions.to(device), t_cond.to(device))
|
| 19 |
+
else:
|
| 20 |
+
v_neg = v(z_prev.to(device), t.zeros_like(actions, dtype=t.long, device=device), t_cond.to(device))
|
| 21 |
+
v_pred = v_neg + cfg * (v_pred - v_neg)
|
| 22 |
+
z_prev = z_prev + (ts[i] - ts[i+1])*v_pred
|
| 23 |
+
return z_prev
|
src/models/__init__.py
ADDED
|
File without changes
|
src/models/dit_dforce.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch as t
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from ..nn.attn import Attention, AttentionEinOps, KVCache
|
| 6 |
+
from ..nn.patch import Patch, UnPatch
|
| 7 |
+
from ..nn.geglu import GEGLU
|
| 8 |
+
from ..nn.pe import FrameRoPE, NumericEncoding, RoPE
|
| 9 |
+
from jaxtyping import Float, Bool, Int
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
import math
|
| 14 |
+
|
| 15 |
+
def modulate(x, shift, scale):
|
| 16 |
+
return x * (1 + scale) + shift
|
| 17 |
+
|
| 18 |
+
class CausalBlock(nn.Module):
|
| 19 |
+
def __init__(self, layer_idx, d_model, expansion, n_heads, rope=None):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.layer_idx = layer_idx
|
| 22 |
+
self.d_model = d_model
|
| 23 |
+
self.expansion = expansion
|
| 24 |
+
self.n_heads = n_heads
|
| 25 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 26 |
+
if t.backends.mps.is_available():
|
| 27 |
+
self.selfattn = AttentionEinOps(d_model, n_heads, rope=rope)
|
| 28 |
+
else:
|
| 29 |
+
self.selfattn = AttentionEinOps(d_model, n_heads, rope=rope) # there is a problem with flexattn i think
|
| 30 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 31 |
+
self.geglu = GEGLU(d_model, expansion*d_model, d_model)
|
| 32 |
+
|
| 33 |
+
self.modulation = nn.Sequential(
|
| 34 |
+
nn.SiLU(),
|
| 35 |
+
nn.Linear(d_model, 6 * d_model, bias=True),
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def forward(self, z, cond, mask_self, cache: Optional[KVCache] = None):
|
| 39 |
+
# batch durseq1 d
|
| 40 |
+
# batch durseq2 d
|
| 41 |
+
mu1, sigma1, c1, mu2, sigma2, c2 = self.modulation(cond).chunk(6, dim=-1)
|
| 42 |
+
residual = z
|
| 43 |
+
z = modulate(self.norm1(z), mu1, sigma1)
|
| 44 |
+
if cache is not None:
|
| 45 |
+
k, v = cache.get(self.layer_idx)
|
| 46 |
+
offset = cache.global_location # this enables to include rope and ln into the cache
|
| 47 |
+
offset = 0 # this is for reapplying rope again and again to stay more similar to training
|
| 48 |
+
z, k_new, v_new = self.selfattn(z, z, mask=mask_self, k_cache=k, v_cache=v, offset=offset)
|
| 49 |
+
cache.extend(self.layer_idx, k_new, v_new)
|
| 50 |
+
else:
|
| 51 |
+
z, _, _ = self.selfattn(z, z, mask=mask_self)
|
| 52 |
+
|
| 53 |
+
z = residual + c1*z
|
| 54 |
+
|
| 55 |
+
residual = z
|
| 56 |
+
z = modulate(self.norm2(z), mu2, sigma2)
|
| 57 |
+
z = self.geglu(z)
|
| 58 |
+
z = residual + c2*z
|
| 59 |
+
return z
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class CausalDit(nn.Module):
|
| 63 |
+
def __init__(self, height, width, n_window, d_model, T=1000, in_channels=3,
|
| 64 |
+
patch_size=2, n_heads=8, expansion=4, n_blocks=6,
|
| 65 |
+
n_registers=1, n_actions=4, bidirectional=False,
|
| 66 |
+
debug=False,
|
| 67 |
+
legacy=False,
|
| 68 |
+
frame_rope=False,
|
| 69 |
+
rope_C=10000,
|
| 70 |
+
rope_tmax=None):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.height = height
|
| 73 |
+
self.width = width
|
| 74 |
+
self.n_window = n_window
|
| 75 |
+
self.d_model = d_model
|
| 76 |
+
self.n_heads = n_heads
|
| 77 |
+
self.d_head = self.d_model // self.n_heads
|
| 78 |
+
self.n_blocks = n_blocks
|
| 79 |
+
self.expansion = expansion
|
| 80 |
+
self.n_registers = n_registers
|
| 81 |
+
self.T = T
|
| 82 |
+
self.patch_size = patch_size
|
| 83 |
+
self.debug = debug
|
| 84 |
+
self.legacy = legacy
|
| 85 |
+
self.bidirectional = bidirectional
|
| 86 |
+
self.frame_rope = frame_rope
|
| 87 |
+
self.toks_per_frame = (height//patch_size)*(width//patch_size) + n_registers
|
| 88 |
+
self.rope_C = rope_C
|
| 89 |
+
if frame_rope:
|
| 90 |
+
print("Using frame rope")
|
| 91 |
+
print(self.toks_per_frame)
|
| 92 |
+
self.rope_seq = FrameRoPE(d_model//n_heads, self.n_window, self.toks_per_frame, C=rope_C)
|
| 93 |
+
self.grid_pe = nn.Parameter(t.randn(self.toks_per_frame - n_registers, d_model) * 1/d_model**0.5)
|
| 94 |
+
else:
|
| 95 |
+
if rope_tmax is None:
|
| 96 |
+
rope_tmax = self.n_window*self.toks_per_frame
|
| 97 |
+
self.rope_seq = RoPE(d_model//n_heads, rope_tmax, C=rope_C)
|
| 98 |
+
self.grid_pe = None
|
| 99 |
+
self.rope_tmax = rope_tmax
|
| 100 |
+
|
| 101 |
+
self.blocks = nn.ModuleList([CausalBlock(lidx, d_model, expansion, n_heads, rope=self.rope_seq) for lidx in range(n_blocks)])
|
| 102 |
+
self.patch = Patch(in_channels=in_channels, out_channels=d_model, patch_size=patch_size)
|
| 103 |
+
self.norm = nn.LayerNorm(d_model)
|
| 104 |
+
self.unpatch = UnPatch(height, width, in_channels=d_model, out_channels=in_channels, patch_size=patch_size)
|
| 105 |
+
self.action_emb = nn.Embedding(n_actions, d_model)
|
| 106 |
+
self.registers = nn.Parameter(t.randn(n_registers, d_model) * 1/d_model**0.5)
|
| 107 |
+
self.time_emb = NumericEncoding(dim=d_model, n_max=T)
|
| 108 |
+
self.time_emb_mixer = nn.Linear(d_model, d_model)
|
| 109 |
+
self.modulation = nn.Sequential(
|
| 110 |
+
nn.SiLU(),
|
| 111 |
+
nn.Linear(d_model, 2 * d_model, bias=True),
|
| 112 |
+
)
|
| 113 |
+
self.cache = None
|
| 114 |
+
|
| 115 |
+
def activate_caching(self, batch_size, max_frames=None, cache_rope=False):
|
| 116 |
+
self.cache = KVCache(batch_size, self.n_blocks, self.n_heads, self.d_head, self.toks_per_frame, self.n_window, dtype=self.dtype, device=self.device)
|
| 117 |
+
if max_frames is not None:
|
| 118 |
+
self.rope_seq = RoPE(self.d_head, max_frames*self.toks_per_frame, C=self.rope_C)
|
| 119 |
+
print(self.rope_seq.sins.shape, self.rope_seq.coss.shape)
|
| 120 |
+
self.rope_seq.to(self.device)
|
| 121 |
+
self.rope_seq.to(self.dtype)
|
| 122 |
+
for idx, block in enumerate(self.blocks):
|
| 123 |
+
print("updating rope for block", idx)
|
| 124 |
+
print(self.blocks[idx].selfattn.rope.sins.shape, self.blocks[idx].selfattn.rope.coss.shape)
|
| 125 |
+
self.blocks[idx].selfattn.rope = self.rope_seq
|
| 126 |
+
print(self.blocks[idx].selfattn.rope.sins.shape, self.blocks[idx].selfattn.rope.coss.shape)
|
| 127 |
+
def deactivate_caching(self):
|
| 128 |
+
self.cache = None
|
| 129 |
+
|
| 130 |
+
def forward(self,
|
| 131 |
+
z: Float[Tensor, "batch dur channels height width"],
|
| 132 |
+
actions: Float[Tensor, "batch dur"],
|
| 133 |
+
ts: Int[Tensor, "batch dur"]):
|
| 134 |
+
|
| 135 |
+
if ts.shape[1] == 1:
|
| 136 |
+
ts = ts.repeat(1, z.shape[1])
|
| 137 |
+
|
| 138 |
+
a = self.action_emb(actions) # batch dur d
|
| 139 |
+
ts_scaled = (ts * self.T).clamp(0, self.T - 1).long()
|
| 140 |
+
cond = self.time_emb_mixer(self.time_emb(ts_scaled)) + a
|
| 141 |
+
#print(ts_scaled.shape, a.shape, cond.shape, actions.shape)
|
| 142 |
+
cond = cond.repeat_interleave(self.toks_per_frame, dim=1)
|
| 143 |
+
z = self.patch(z) # batch dur seq d
|
| 144 |
+
if self.grid_pe is not None:
|
| 145 |
+
z = z + self.grid_pe[None, None]
|
| 146 |
+
# self.registers is in 1x
|
| 147 |
+
zr = t.cat((z, self.registers[None, None].repeat([z.shape[0], z.shape[1], 1, 1])), dim=2)# z plus registers
|
| 148 |
+
if self.bidirectional:
|
| 149 |
+
mask_self = None
|
| 150 |
+
else:
|
| 151 |
+
mask_self = self.causal_mask
|
| 152 |
+
batch, durzr, seqzr, d = zr.shape
|
| 153 |
+
zr = zr.reshape(batch, -1, d) # batch durseq d
|
| 154 |
+
|
| 155 |
+
for block in self.blocks:
|
| 156 |
+
zr = block(zr, cond, mask_self, cache=self.cache)
|
| 157 |
+
mu, sigma = self.modulation(cond).chunk(2, dim=-1)
|
| 158 |
+
zr = modulate(self.norm(zr), mu, sigma)
|
| 159 |
+
zr = zr.reshape(batch, durzr, seqzr, d)
|
| 160 |
+
out = self.unpatch(zr[:, :, :-self.n_registers])
|
| 161 |
+
return out # batch dur channels height width
|
| 162 |
+
|
| 163 |
+
@property
|
| 164 |
+
def causal_mask(self):
|
| 165 |
+
size = self.n_window
|
| 166 |
+
m_self = t.tril(t.ones((size, size), dtype=t.int8, device=self.device)) #- t.tril(t.ones((size, size), dtype=t.int8, device=self.device), diagonal=-self.n_window)
|
| 167 |
+
m_self = t.kron(m_self, t.ones((self.toks_per_frame, self.toks_per_frame), dtype=t.int8, device=self.device))
|
| 168 |
+
m_self = m_self.to(bool)
|
| 169 |
+
return ~ m_self # we want to mask out the ones
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def device(self):
|
| 173 |
+
return self.parameters().__next__().device
|
| 174 |
+
|
| 175 |
+
@property
|
| 176 |
+
def dtype(self):
|
| 177 |
+
return self.parameters().__next__().dtype
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def get_model(height, width, n_window=5, d_model=64, T=100, n_blocks=2, patch_size=2, n_heads=8, bidirectional=False, in_channels=3, frame_rope=False, C=10000):
|
| 181 |
+
return CausalDit(height, width, n_window, d_model, T, in_channels=in_channels, n_blocks=n_blocks, patch_size=patch_size, n_heads=n_heads, bidirectional=bidirectional, frame_rope=frame_rope, rope_C=C)
|
| 182 |
+
|
| 183 |
+
if __name__ == "__main__":
|
| 184 |
+
print("running w/o cache")
|
| 185 |
+
dit = CausalDit(20, 20, 100, 64, 5, n_blocks=2)
|
| 186 |
+
z = t.rand((2, 6, 3, 20, 20))
|
| 187 |
+
actions = t.randint(4, (2, 6))
|
| 188 |
+
ts = t.rand((2, 6))
|
| 189 |
+
out = dit(z, actions, ts)
|
| 190 |
+
print(z.shape)
|
| 191 |
+
print(out.shape)
|
| 192 |
+
|
| 193 |
+
print("running w cache")
|
| 194 |
+
dit = CausalDit(20, 20, 10, 64, 5, n_blocks=2)
|
| 195 |
+
dit.activate_caching(2)
|
| 196 |
+
print(dit.cache.toks_per_frame)
|
| 197 |
+
print(dit.cache.size)
|
| 198 |
+
for i in range(30):
|
| 199 |
+
print(dit.cache.local_loc)
|
| 200 |
+
print(dit.cache.global_loc)
|
| 201 |
+
z = t.rand((2, 1, 3, 20, 20))
|
| 202 |
+
actions = t.randint(4, (2, 1))
|
| 203 |
+
ts = t.rand((2, 1))
|
| 204 |
+
out = dit(z, actions, ts)
|
| 205 |
+
print(i, z.shape)
|
| 206 |
+
print(i, out.shape)
|
src/nn/__init__.py
ADDED
|
File without changes
|
src/nn/attn.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
from torch.nn import functional as F
|
| 3 |
+
import torch as t
|
| 4 |
+
import einops
|
| 5 |
+
from jaxtyping import Float, Bool
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
from typing import Optional
|
| 8 |
+
from torch.nn.attention.flex_attention import flex_attention
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class KVCache(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Rolling KV cache implemented as a ring buffer.
|
| 14 |
+
- Shapes:
|
| 15 |
+
keys/values per extend(): (batch_size, T, n_heads, d_head)
|
| 16 |
+
- Internal storage:
|
| 17 |
+
(n_layers, batch_size, size, n_heads, d_head) where size = toks_per_frame * n_window
|
| 18 |
+
- Semantics:
|
| 19 |
+
Call `extend(layer_idx, k, v)` once per layer for the *same* frame.
|
| 20 |
+
Call `update_global_location(n_frames)` once after all layers to commit the frame(s).
|
| 21 |
+
"""
|
| 22 |
+
def __init__(self, batch_size, n_layers, n_heads, d_head, toks_per_frame, n_window, *, dtype=None, device=None, enforce_layer_order=True):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.batch_size = batch_size
|
| 25 |
+
self.n_layers = n_layers
|
| 26 |
+
self.n_heads = n_heads
|
| 27 |
+
self.d_head = d_head
|
| 28 |
+
self.toks_per_frame = toks_per_frame
|
| 29 |
+
self.n_window = n_window
|
| 30 |
+
self.size = (toks_per_frame * n_window) #toks_per_frame # (toks_per_frame * n_window)
|
| 31 |
+
|
| 32 |
+
# Pointers / counters
|
| 33 |
+
self.curr_layer = 0 # which layer are we writing for this frame
|
| 34 |
+
self.global_loc = 0 # total tokens ever committed
|
| 35 |
+
self.local_loc = 0 # valid tokens in buffer (<= size)
|
| 36 |
+
self._write_ptr = 0 # ring-buffer write pointer (index of next commit position)
|
| 37 |
+
|
| 38 |
+
# Storage
|
| 39 |
+
dtype = dtype if dtype is not None else t.float32
|
| 40 |
+
self.register_buffer('keys', t.zeros(n_layers, batch_size, self.size, n_heads, d_head, dtype=dtype, device=device))
|
| 41 |
+
self.register_buffer('values', t.zeros(n_layers, batch_size, self.size, n_heads, d_head, dtype=dtype, device=device))
|
| 42 |
+
|
| 43 |
+
# Misc
|
| 44 |
+
self.enforce_layer_order = enforce_layer_order
|
| 45 |
+
|
| 46 |
+
# -------------- Public API --------------
|
| 47 |
+
def get(self, layer_idx):
|
| 48 |
+
"""Return (K, V) for given layer in chronological order: shape (B, L, H, D) where L = local_loc."""
|
| 49 |
+
self._check_layer(layer_idx)
|
| 50 |
+
if self.local_loc == 0:
|
| 51 |
+
# return empty views
|
| 52 |
+
empty = self.keys[layer_idx, :, :0]
|
| 53 |
+
return empty, empty
|
| 54 |
+
|
| 55 |
+
start = (self._write_ptr - self.local_loc) % self.size
|
| 56 |
+
if start + self.local_loc <= self.size:
|
| 57 |
+
# contiguous slice
|
| 58 |
+
k = self.keys[layer_idx, :, start:start + self.local_loc]
|
| 59 |
+
v = self.values[layer_idx, :, start:start + self.local_loc]
|
| 60 |
+
else:
|
| 61 |
+
# wrap: concatenate two slices to maintain chronological order
|
| 62 |
+
first = self.size - start
|
| 63 |
+
k = t.cat([
|
| 64 |
+
self.keys[layer_idx, :, start:self.size],
|
| 65 |
+
self.keys[layer_idx, :, 0:(self.local_loc - first)]
|
| 66 |
+
], dim=1)
|
| 67 |
+
v = t.cat([
|
| 68 |
+
self.values[layer_idx, :, start:self.size],
|
| 69 |
+
self.values[layer_idx, :, 0:(self.local_loc - first)]
|
| 70 |
+
], dim=1)
|
| 71 |
+
return k, v
|
| 72 |
+
|
| 73 |
+
@t.no_grad()
|
| 74 |
+
def extend(self, layer_idx, keys, values):
|
| 75 |
+
"""
|
| 76 |
+
Stage (but do not commit) tokens for the current frame for the given layer.
|
| 77 |
+
Call update_global_location(n_frames) to commit after all layers wrote.
|
| 78 |
+
"""
|
| 79 |
+
assert keys.shape == values.shape, f"keys and values shapes must match, got {keys.shape} vs {values.shape}"
|
| 80 |
+
self._check_layer(layer_idx)
|
| 81 |
+
|
| 82 |
+
# Expected shape: (B, T, H, D)
|
| 83 |
+
B, T, H, D = keys.shape
|
| 84 |
+
assert B == self.batch_size, f"batch mismatch: expected {self.batch_size}, got {B}"
|
| 85 |
+
assert H == self.n_heads and D == self.d_head, f"heads/d_head mismatch: expected {(self.n_heads, self.d_head)}, got {(H, D)}"
|
| 86 |
+
assert T > 0 and T <= self.size, f"T must be in 1..{self.size}, got {T}"
|
| 87 |
+
# Optional: if you only ever append whole frames:
|
| 88 |
+
# assert T == self.toks_per_frame, f"T must equal toks_per_frame ({self.toks_per_frame}), got {T}"
|
| 89 |
+
|
| 90 |
+
# Cast to buffer dtype/device if needed
|
| 91 |
+
if keys.dtype != self.keys.dtype or keys.device != self.keys.device:
|
| 92 |
+
keys = keys.to(dtype=self.keys.dtype, device=self.keys.device)
|
| 93 |
+
if values.dtype != self.values.dtype or values.device != self.values.device:
|
| 94 |
+
values = values.to(dtype=self.values.dtype, device=self.values.device)
|
| 95 |
+
|
| 96 |
+
# Write into the ring at the *current* write_ptr (uncommitted until update_global_location)
|
| 97 |
+
i0 = self._write_ptr
|
| 98 |
+
i1 = (self._write_ptr + T) % self.size
|
| 99 |
+
if i0 < i1:
|
| 100 |
+
self.keys[layer_idx, :, i0:i1] = keys
|
| 101 |
+
self.values[layer_idx, :, i0:i1] = values
|
| 102 |
+
else:
|
| 103 |
+
# wraps: split write
|
| 104 |
+
split = self.size - i0
|
| 105 |
+
self.keys[layer_idx, :, i0:self.size] = keys[:, :split]
|
| 106 |
+
self.values[layer_idx, :, i0:self.size] = values[:, :split]
|
| 107 |
+
self.keys[layer_idx, :, 0:i1] = keys[:, split:]
|
| 108 |
+
self.values[layer_idx, :, 0:i1] = values[:, split:]
|
| 109 |
+
|
| 110 |
+
# Advance expected layer (but do *not* advance write_ptr/local_len here)
|
| 111 |
+
self.curr_layer = (self.curr_layer + 1) % self.n_layers
|
| 112 |
+
|
| 113 |
+
@t.no_grad()
|
| 114 |
+
def update_global_location(self, n_frames):
|
| 115 |
+
"""
|
| 116 |
+
Commit staged writes for n_frames (advances the ring write pointer once per frame).
|
| 117 |
+
Keep calling extend(layer_idx, ...) for each layer before you call this.
|
| 118 |
+
"""
|
| 119 |
+
assert n_frames >= 0, f"n_frames must be >= 0, got {n_frames}"
|
| 120 |
+
tokens = n_frames * self.toks_per_frame
|
| 121 |
+
if tokens == 0:
|
| 122 |
+
return
|
| 123 |
+
assert tokens <= self.size, f"Cannot commit {tokens} tokens (> buffer size {self.size})."
|
| 124 |
+
|
| 125 |
+
self.global_loc += tokens
|
| 126 |
+
# Update valid length (never exceeds capacity)
|
| 127 |
+
self.local_loc = min(self.size, self.local_loc + tokens)
|
| 128 |
+
# Advance write pointer
|
| 129 |
+
self._write_ptr = (self._write_ptr + tokens) % self.size
|
| 130 |
+
|
| 131 |
+
@t.no_grad()
|
| 132 |
+
def reset(self, zero_memory: bool = True):
|
| 133 |
+
self.global_loc = 0
|
| 134 |
+
self.local_loc = 0
|
| 135 |
+
self.curr_layer = 0
|
| 136 |
+
self._write_ptr = 0
|
| 137 |
+
if zero_memory:
|
| 138 |
+
self.keys.zero_()
|
| 139 |
+
self.values.zero_()
|
| 140 |
+
|
| 141 |
+
# -------------- Convenience / Introspection --------------
|
| 142 |
+
@property
|
| 143 |
+
def local_location(self):
|
| 144 |
+
return self.local_loc
|
| 145 |
+
|
| 146 |
+
@property
|
| 147 |
+
def global_location(self):
|
| 148 |
+
return self.global_loc
|
| 149 |
+
|
| 150 |
+
@property
|
| 151 |
+
def device(self):
|
| 152 |
+
return self.keys.device
|
| 153 |
+
|
| 154 |
+
@property
|
| 155 |
+
def dtype(self):
|
| 156 |
+
return self.keys.dtype
|
| 157 |
+
|
| 158 |
+
def get_recent(self, layer_idx, last_T):
|
| 159 |
+
"""Return the most recent last_T tokens for a layer (chronological)."""
|
| 160 |
+
self._check_layer(layer_idx, allow_any=True)
|
| 161 |
+
last_T = min(last_T, self.local_loc)
|
| 162 |
+
if last_T == 0:
|
| 163 |
+
empty = self.keys[layer_idx, :, :0]
|
| 164 |
+
return empty, empty
|
| 165 |
+
start = (self._write_ptr - last_T) % self.size
|
| 166 |
+
if start + last_T <= self.size:
|
| 167 |
+
k = self.keys[layer_idx, :, start:start + last_T]
|
| 168 |
+
v = self.values[layer_idx, :, start:start + last_T]
|
| 169 |
+
else:
|
| 170 |
+
first = self.size - start
|
| 171 |
+
k = t.cat([self.keys[layer_idx, :, start:self.size], self.keys[layer_idx, :, 0:(last_T - first)]], dim=1)
|
| 172 |
+
v = t.cat([self.values[layer_idx, :, start:self.size], self.values[layer_idx, :, 0:(last_T - first)]], dim=1)
|
| 173 |
+
return k, v
|
| 174 |
+
|
| 175 |
+
# -------------- Internal checks --------------
|
| 176 |
+
def _check_layer(self, layer_idx, allow_any=False):
|
| 177 |
+
assert 0 <= layer_idx < self.n_layers, f"layer_idx out of range: 0..{self.n_layers-1}, got {layer_idx}"
|
| 178 |
+
if self.enforce_layer_order and not allow_any:
|
| 179 |
+
assert layer_idx == (self.curr_layer % self.n_layers), \
|
| 180 |
+
f"Layer order mismatch: expected {self.curr_layer % self.n_layers}, got {layer_idx}"
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class KVCacheMine(nn.Module): # this does not work because it destroys the cache of later timesteps when the earlier ones overflow and move to the left. --> fix as an exercise.
|
| 184 |
+
def __init__(self, batch_size, n_layers, n_heads, d_head, toks_per_frame, n_window):
|
| 185 |
+
"""
|
| 186 |
+
This is a rolling KVCache
|
| 187 |
+
"""
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.batch_size = batch_size
|
| 190 |
+
self.n_heads = n_heads
|
| 191 |
+
self.d_head = d_head
|
| 192 |
+
self.toks_per_frame = toks_per_frame
|
| 193 |
+
self.n_window = n_window
|
| 194 |
+
self.size = toks_per_frame * n_window#5*n_window#(n_window + 1)
|
| 195 |
+
self.n_layers = n_layers
|
| 196 |
+
self.curr_layer = 0
|
| 197 |
+
self.global_loc = 0
|
| 198 |
+
self.local_loc = 0
|
| 199 |
+
self.register_buffer('keys', t.zeros(n_layers, batch_size, self.size, n_heads, d_head))
|
| 200 |
+
self.register_buffer('values', t.zeros(n_layers, batch_size, self.size, n_heads, d_head))
|
| 201 |
+
|
| 202 |
+
def get(self, layer_idx):
|
| 203 |
+
assert layer_idx == self.curr_layer, f"layer idx should be the same as our internal counter but we got {layer_idx} and internal is {self.curr_layer}."
|
| 204 |
+
return self.keys[layer_idx, :, :self.local_loc], self.values[layer_idx, :, :self.local_loc]
|
| 205 |
+
|
| 206 |
+
def extend(self, layer_idx, keys, values):
|
| 207 |
+
assert keys.shape == values.shape, f"keys and values shapes must match {self.keys.shape} != {self.values.shape}"
|
| 208 |
+
assert layer_idx == self.curr_layer, f"layer idx should be the same as our internal counter but we got {layer_idx} and internal is {self.curr_layer}."
|
| 209 |
+
assert self.local_loc <= self.size, f"the cache size should be between 0 and {self.size}"
|
| 210 |
+
local_loc = self.local_loc
|
| 211 |
+
if local_loc == self.size:
|
| 212 |
+
# move to the left
|
| 213 |
+
local_loc -= keys.shape[1]
|
| 214 |
+
assert local_loc >= 0, f"the cache update {keys.shape[1]} was larger than the cache {self.size}, that's not supported for now."
|
| 215 |
+
assert local_loc % self.toks_per_frame == 0, f"the number of elements in the cache {local_loc} must be a multiple of the number of tokens per frame {self.toks_per_frame}"
|
| 216 |
+
self.keys[layer_idx, :, :local_loc] = self.keys[layer_idx, :, self.toks_per_frame:local_loc+self.toks_per_frame].clone()
|
| 217 |
+
self.values[layer_idx, :, :local_loc] = self.values[layer_idx, :, self.toks_per_frame:local_loc+self.toks_per_frame].clone()
|
| 218 |
+
#self.keys[layer_idx, :, self.toks_per_frame:local_loc+self.toks_per_frame] = self.keys[layer_idx, :, -local_loc:].clone()
|
| 219 |
+
#self.values[layer_idx, :, self.toks_per_frame:local_loc+self.toks_per_frame] = self.values[layer_idx, :, -local_loc:].clone()
|
| 220 |
+
|
| 221 |
+
assert local_loc + keys.shape[1] <= self.size, f"{local_loc + keys.shape[1]} out of bounds {self.size}"
|
| 222 |
+
self.keys[layer_idx, :, local_loc:local_loc + keys.shape[1]] = keys
|
| 223 |
+
self.values[layer_idx, :, local_loc:local_loc + keys.shape[1]] = values
|
| 224 |
+
self.curr_layer = (self.curr_layer + 1) % self.n_layers
|
| 225 |
+
|
| 226 |
+
def update_global_location(self, n_frames):
|
| 227 |
+
self.global_loc += n_frames * self.toks_per_frame
|
| 228 |
+
if self.local_loc < self.size:
|
| 229 |
+
self.local_loc += n_frames * self.toks_per_frame
|
| 230 |
+
assert self.local_loc <= self.size, f"the local loc {self.local_loc} should never be bigger than {self.size}, something went wrong."
|
| 231 |
+
|
| 232 |
+
def reset(self):
|
| 233 |
+
self.global_loc = 0
|
| 234 |
+
self.local_loc = 0
|
| 235 |
+
self.curr_layer = 0
|
| 236 |
+
self.keys.zero_()
|
| 237 |
+
self.values.zero_()
|
| 238 |
+
|
| 239 |
+
@property
|
| 240 |
+
def local_location(self):
|
| 241 |
+
return self.local_loc
|
| 242 |
+
|
| 243 |
+
@property
|
| 244 |
+
def global_location(self):
|
| 245 |
+
return self.global_loc
|
| 246 |
+
|
| 247 |
+
@property
|
| 248 |
+
def device(self):
|
| 249 |
+
return self.keys.device
|
| 250 |
+
|
| 251 |
+
@property
|
| 252 |
+
def dtype(self):
|
| 253 |
+
return self.keys.dtype
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class AttentionEinOps(nn.Module):
|
| 257 |
+
IGNORE: Float[Tensor, ""]
|
| 258 |
+
|
| 259 |
+
def __init__(self, d_model, n_heads, rope=None):
|
| 260 |
+
super().__init__()
|
| 261 |
+
assert d_model % n_heads == 0, f"{d_model} must be divisble by {n_heads}"
|
| 262 |
+
self.d_head = d_model // n_heads
|
| 263 |
+
d_head = self.d_head
|
| 264 |
+
self.W_Q = nn.Parameter(t.empty((n_heads, d_model, d_head)))
|
| 265 |
+
self.W_K = nn.Parameter(t.empty((n_heads, d_model, d_head)))
|
| 266 |
+
self.W_V = nn.Parameter(t.empty((n_heads, d_model, d_head)))
|
| 267 |
+
self.W_O = nn.Parameter(t.empty((n_heads, d_head, d_model)))
|
| 268 |
+
self.b_Q = nn.Parameter(t.zeros((n_heads, d_head)))
|
| 269 |
+
self.b_K = nn.Parameter(t.zeros((n_heads, d_head)))
|
| 270 |
+
self.b_V = nn.Parameter(t.zeros((n_heads, d_head)))
|
| 271 |
+
self.b_O = nn.Parameter(t.zeros((d_model)))
|
| 272 |
+
nn.init.normal_(self.W_Q, 1/d_model**0.5)
|
| 273 |
+
nn.init.normal_(self.W_K, 1/d_model**0.5)
|
| 274 |
+
nn.init.normal_(self.W_V, 1/d_model**0.5)
|
| 275 |
+
nn.init.normal_(self.W_O, 1/d_head**0.5)
|
| 276 |
+
self.register_buffer("IGNORE", t.tensor(float('-inf'), dtype=t.float32))
|
| 277 |
+
self.rope = rope
|
| 278 |
+
self.ln1 = nn.LayerNorm(d_head)
|
| 279 |
+
self.ln2 = nn.LayerNorm(d_head)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def forward(
|
| 283 |
+
self,
|
| 284 |
+
x_q: Float[Tensor, "batch posq d_model"],
|
| 285 |
+
x_kv: Float[Tensor, "batch posk d_model"],
|
| 286 |
+
mask: Bool[Tensor, "posq posk"] = None, # the 1s are removed
|
| 287 |
+
k_cache: Optional[Float[Tensor, "batch posk n_head d_head"]] = None,
|
| 288 |
+
v_cache: Optional[Float[Tensor, "batch posk n_head d_head"]] = None,
|
| 289 |
+
offset: int = 0
|
| 290 |
+
) -> Float[Tensor, "batch posq d_model"]:
|
| 291 |
+
assert (k_cache is None and v_cache is None) or (k_cache is not None and v_cache is not None), "k_cache and v_cache go together."
|
| 292 |
+
d_head = self.d_head
|
| 293 |
+
if k_cache is not None and v_cache is not None:
|
| 294 |
+
q = einops.einsum(x_q, self.W_Q, 'b s d, n d h -> b s n h') + self.b_Q
|
| 295 |
+
k_new = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h') + self.b_K
|
| 296 |
+
v_new = einops.einsum(x_kv, self.W_V, 'b s d, n d h -> b s n h') + self.b_V
|
| 297 |
+
|
| 298 |
+
k = t.cat([k_cache, k_new], dim=1)
|
| 299 |
+
v = t.cat([v_cache, v_new], dim=1)
|
| 300 |
+
|
| 301 |
+
if self.rope is not None:
|
| 302 |
+
q = self.rope(q, offset=k_cache.shape[1])
|
| 303 |
+
k = self.rope(k, offset=0)
|
| 304 |
+
q = self.ln1(q) # this should be before rope
|
| 305 |
+
k = self.ln2(k)
|
| 306 |
+
mask = None
|
| 307 |
+
else:
|
| 308 |
+
q = einops.einsum(x_q, self.W_Q, 'b s d, n d h -> b s n h') + self.b_Q
|
| 309 |
+
k = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h') + self.b_K
|
| 310 |
+
v = einops.einsum(x_kv, self.W_V, 'b s d, n d h -> b s n h') + self.b_V
|
| 311 |
+
if self.rope is not None:
|
| 312 |
+
q = self.rope(q)
|
| 313 |
+
k = self.rope(k)
|
| 314 |
+
q = self.ln1(q)
|
| 315 |
+
k = self.ln2(k) # this leanrs much faster using layernorm here
|
| 316 |
+
k_new = k
|
| 317 |
+
v_new = v
|
| 318 |
+
|
| 319 |
+
attention = einops.einsum(q, k, 'b sq n h, b sk n h -> b n sq sk')
|
| 320 |
+
if mask is not None and k_cache is not None:
|
| 321 |
+
attention = t.where(mask[k_cache.shape[1]:k_cache.shape[1]+q.shape[1], :k.shape[1]], self.IGNORE, attention)
|
| 322 |
+
elif mask is not None:
|
| 323 |
+
if attention.shape[-1] != mask.shape[-1] or attention.shape[-2] != mask.shape[-2]:
|
| 324 |
+
#print(f"Warning: attention shape {attention.shape} does not match mask shape {mask.shape}")
|
| 325 |
+
mask = mask[:attention.shape[-1], :attention.shape[-2]]
|
| 326 |
+
attention = t.where(mask, self.IGNORE, attention)
|
| 327 |
+
probas = attention.softmax(dim=3)
|
| 328 |
+
#plt.imshow(probas[0, 0].cpu().numpy())
|
| 329 |
+
#plt.show()
|
| 330 |
+
z = einops.einsum(probas, v, 'b n sq sk, b sk n h -> b sq n h')
|
| 331 |
+
out = einops.einsum(z, self.W_O, 'b s n h, n h d -> b s n d')
|
| 332 |
+
out = out.sum(dim=2) + self.b_O
|
| 333 |
+
return out, k_new, v_new
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class Attention(nn.Module):
|
| 337 |
+
IGNORE: Float[Tensor, ""]
|
| 338 |
+
|
| 339 |
+
def __init__(self, d_model, n_heads, rope=None, use_flex_attention=False):
|
| 340 |
+
raise NotImplementedError("Attention is not implemented yet")
|
| 341 |
+
super().__init__()
|
| 342 |
+
assert d_model % n_heads == 0, f"{d_model} must be divisble by {n_heads}"
|
| 343 |
+
self.d_head = d_model // n_heads
|
| 344 |
+
d_head = self.d_head
|
| 345 |
+
self.W_Q = nn.Parameter(t.empty((n_heads, d_model, d_head)))
|
| 346 |
+
self.W_K = nn.Parameter(t.empty((n_heads, d_model, d_head)))
|
| 347 |
+
self.W_V = nn.Parameter(t.empty((n_heads, d_model, d_head)))
|
| 348 |
+
self.W_O = nn.Parameter(t.empty((n_heads, d_head, d_model)))
|
| 349 |
+
#self.b_Q = nn.Parameter(t.zeros((n_heads, d_head)))
|
| 350 |
+
#self.b_K = nn.Parameter(t.zeros((n_heads, d_head)))
|
| 351 |
+
#self.b_V = nn.Parameter(t.zeros((n_heads, d_head)))
|
| 352 |
+
#self.b_O = nn.Parameter(t.zeros((d_model)))
|
| 353 |
+
nn.init.normal_(self.W_Q, 1/d_model**0.5)
|
| 354 |
+
nn.init.normal_(self.W_K, 1/d_model**0.5)
|
| 355 |
+
nn.init.normal_(self.W_V, 1/d_model**0.5)
|
| 356 |
+
nn.init.normal_(self.W_O, 1/d_head**0.5)
|
| 357 |
+
self.register_buffer("IGNORE", t.tensor(float('-inf'), dtype=t.float32))
|
| 358 |
+
self.rope = rope
|
| 359 |
+
self.use_flex_attention = use_flex_attention
|
| 360 |
+
self.ln1 = nn.LayerNorm(d_head)
|
| 361 |
+
self.ln2 = nn.LayerNorm(d_head)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def forward(
|
| 365 |
+
self,
|
| 366 |
+
x_q: Float[Tensor, "batch posq d_model"],
|
| 367 |
+
x_kv: Float[Tensor, "batch posk d_model"],
|
| 368 |
+
mask: Bool[Tensor, "posq posk"] = None, # the 1s are removed
|
| 369 |
+
k_cache: Optional[Float[Tensor, "batch posk n_head d_head"]] = None,
|
| 370 |
+
v_cache: Optional[Float[Tensor, "batch posk n_head d_head"]] = None,
|
| 371 |
+
) -> Float[Tensor, "batch posq d_model"]:
|
| 372 |
+
assert (k_cache is None and v_cache is None) or (k_cache is not None and v_cache is not None), "k_cache and v_cache go together."
|
| 373 |
+
d_head = self.d_head
|
| 374 |
+
if k_cache is not None and v_cache is not None:
|
| 375 |
+
raise NotImplementedError("kv cache not implemented yet")
|
| 376 |
+
q = einops.einsum(x, self.W_Q, 'b s d, n d h -> b s n h')
|
| 377 |
+
k_new = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h')
|
| 378 |
+
v_new = einops.einsum(x_kv, self.W_V, 'b s d, n d h -> b s n h')
|
| 379 |
+
k = t.cat([k_cache, k_new], dim=1)
|
| 380 |
+
v = t.cat([v_cache, v_new], dim=1)
|
| 381 |
+
else:
|
| 382 |
+
q = einops.einsum(x_q, self.W_Q, 'b s d, n d h -> b s n h')
|
| 383 |
+
k = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h')
|
| 384 |
+
v = einops.einsum(x_kv, self.W_V, 'b s d, n d h -> b s n h')
|
| 385 |
+
|
| 386 |
+
q = self.ln1(q)
|
| 387 |
+
k = self.ln2(k)
|
| 388 |
+
if self.rope is not None:
|
| 389 |
+
q = self.rope(q)
|
| 390 |
+
k = self.rope(k)
|
| 391 |
+
|
| 392 |
+
# Convert to (batch, num_heads, seq_len, head_dim) format for flex_attention
|
| 393 |
+
q_perm = q.permute(0, 2, 1, 3) # (batch, n_heads, posq, d_head)
|
| 394 |
+
k_perm = k.permute(0, 2, 1, 3) # (batch, n_heads, posk, d_head)
|
| 395 |
+
v_perm = v.permute(0, 2, 1, 3) # (batch, n_heads, posk, d_head)
|
| 396 |
+
|
| 397 |
+
# Ensure tensors are contiguous to avoid flex_attention indexing bugs
|
| 398 |
+
q_perm = q_perm.contiguous()
|
| 399 |
+
k_perm = k_perm.contiguous()
|
| 400 |
+
v_perm = v_perm.contiguous()
|
| 401 |
+
|
| 402 |
+
if self.use_flex_attention:
|
| 403 |
+
# Handle mask using score_mod if needed
|
| 404 |
+
if mask is not None:
|
| 405 |
+
# Store mask and IGNORE for use in score_mod closure
|
| 406 |
+
mask_tensor = mask # (posq, posk)
|
| 407 |
+
ignore_val = self.IGNORE
|
| 408 |
+
def score_mod(score, b, h, q_idx, kv_idx):
|
| 409 |
+
# score_mod operates on individual scalar scores
|
| 410 |
+
# Apply mask: where mask is True, set to -inf
|
| 411 |
+
# Use torch ops that work in compiled context
|
| 412 |
+
mask_val = mask_tensor[q_idx, kv_idx]
|
| 413 |
+
return t.where(mask_val, ignore_val, score)
|
| 414 |
+
z = flex_attention(q_perm, k_perm, v_perm, score_mod=score_mod)
|
| 415 |
+
else:
|
| 416 |
+
z = flex_attention(q_perm, k_perm, v_perm)
|
| 417 |
+
else:
|
| 418 |
+
condi = mask is None and not self.dtype == t.float32
|
| 419 |
+
with t.backends.cuda.sdp_kernel(
|
| 420 |
+
enable_flash=condi,
|
| 421 |
+
enable_math=not condi,
|
| 422 |
+
enable_mem_efficient=not condi
|
| 423 |
+
):
|
| 424 |
+
z = F.scaled_dot_product_attention(
|
| 425 |
+
q_perm, k_perm, v_perm,
|
| 426 |
+
attn_mask = mask.logical_not() if mask is not None else None,
|
| 427 |
+
dropout_p = 0.0,
|
| 428 |
+
is_causal = False,
|
| 429 |
+
scale = 1.0
|
| 430 |
+
)
|
| 431 |
+
z = z.permute(0, 2, 1, 3) # Back to (batch, posq, n_heads, d_head)
|
| 432 |
+
out = einops.einsum(z, self.W_O, 'b s n h, n h d -> b s n d')
|
| 433 |
+
out = out.sum(dim=2)
|
| 434 |
+
#print(f"out {out.shape}, attention {probas.shape}, q {q.shape}, k {k.shape}, v {v.shape}")
|
| 435 |
+
return out, z, None
|
| 436 |
+
|
| 437 |
+
@property
|
| 438 |
+
def dtype(self):
|
| 439 |
+
return self.parameters().__next__().dtype
|
| 440 |
+
|
| 441 |
+
@property
|
| 442 |
+
def device(self):
|
| 443 |
+
return self.parameters().__next__().device
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
if __name__ == "__main__":
|
| 447 |
+
from .pe import RoPE
|
| 448 |
+
import inspect
|
| 449 |
+
rope = RoPE(256//8, 10000)
|
| 450 |
+
dtype = t.float32
|
| 451 |
+
rope = rope.to(dtype)
|
| 452 |
+
attn_slow = AttentionSlow(d_model=256, n_heads=8, rope=rope)
|
| 453 |
+
attn = Attention(d_model=256, n_heads=8, rope=rope)
|
| 454 |
+
attn.load_state_dict(attn_slow.state_dict(), strict=False)
|
| 455 |
+
attn.to(dtype)
|
| 456 |
+
attn_slow.to(dtype)
|
| 457 |
+
x = t.randn(1, 1000, 256, dtype=dtype)*10
|
| 458 |
+
xkv = t.randn(1, 1000, 256, dtype=dtype)*10
|
| 459 |
+
mask = t.randint(0, 2, (1000, 1000), dtype=t.bool)
|
| 460 |
+
y, z, _ = attn(x, xkv, mask=mask)
|
| 461 |
+
y_slow, z_slow, _ = attn_slow(x, xkv, mask=mask)
|
| 462 |
+
#assert t.allclose(z, z_slow, atol=1e-5), f"Attention and AttentionSlow should be the same: {(z - z_slow).abs().max()}"
|
| 463 |
+
#assert t.allclose(y, y_slow, atol=1e-5), f"Attention and AttentionSlow should be the same: {(y - y_slow).abs().max()}"
|
| 464 |
+
print("Attention and AttentionSlow are the same")
|
| 465 |
+
|
| 466 |
+
loss = t.nn.functional.mse_loss(y, y_slow)
|
| 467 |
+
loss.backward()
|
| 468 |
+
print("-"*100)
|
| 469 |
+
for n, p in attn.named_parameters():
|
| 470 |
+
print(n, p.grad.shape, p.grad.max(), p.grad.min())
|
| 471 |
+
print("-"*100)
|
| 472 |
+
for n, p in attn_slow.named_parameters():
|
| 473 |
+
print(n, p.grad.shape, p.grad.max(), p.grad.min())
|
src/nn/geglu.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
class GEGLU(nn.Module):
|
| 4 |
+
def __init__(self, d_in, d_mid, d_out):
|
| 5 |
+
super().__init__()
|
| 6 |
+
self.d_in = d_in
|
| 7 |
+
self.d_mid = d_mid
|
| 8 |
+
self.d_out = d_out
|
| 9 |
+
self.up_proj = nn.Linear(d_in, d_mid, bias=True)
|
| 10 |
+
self.up_proj.bias.data.zero_()
|
| 11 |
+
self.up_gate = nn.Linear(d_in, d_mid, bias=True)
|
| 12 |
+
self.up_gate.bias.data.zero_()
|
| 13 |
+
self.down = nn.Linear(d_mid, d_out, bias=True)
|
| 14 |
+
self.down.bias.data.zero_()
|
| 15 |
+
self.nonlin = nn.SiLU()
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
x = self.up_proj(x) * self.nonlin(self.up_gate(x))
|
| 19 |
+
x = self.down(x)
|
| 20 |
+
return x
|
src/nn/patch.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
import torch as t
|
| 4 |
+
|
| 5 |
+
class Patch(nn.Module): # adapted from https://github.com/cloneofsimo/minRF
|
| 6 |
+
def __init__(self, in_channels=3, out_channels=64, patch_size=2):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.patch_size = patch_size
|
| 9 |
+
self.in_channels = in_channels
|
| 10 |
+
self.out_channels = out_channels
|
| 11 |
+
dim = out_channels
|
| 12 |
+
if dim % 32 == 0 and dim > 32:
|
| 13 |
+
self.init_conv_seq = nn.Sequential(
|
| 14 |
+
nn.Conv2d(in_channels, dim // 2, kernel_size=5, padding=2, stride=1),
|
| 15 |
+
nn.SiLU(),
|
| 16 |
+
nn.GroupNorm(32, dim // 2),
|
| 17 |
+
nn.Conv2d(dim // 2, dim // 2, kernel_size=5, padding=2, stride=1),
|
| 18 |
+
nn.SiLU(),
|
| 19 |
+
nn.GroupNorm(32, dim // 2),
|
| 20 |
+
)
|
| 21 |
+
else:
|
| 22 |
+
self.init_conv_seq = nn.Sequential(
|
| 23 |
+
nn.Conv2d(in_channels, dim // 2, kernel_size=5, padding=2, stride=1),
|
| 24 |
+
nn.SiLU(),
|
| 25 |
+
nn.Conv2d(dim // 2, dim // 2, kernel_size=5, padding=2, stride=1),
|
| 26 |
+
nn.SiLU(),
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
self.x_embedder = nn.Linear(patch_size * patch_size * dim // 2, dim, bias=True)
|
| 30 |
+
nn.init.constant_(self.x_embedder.bias, 0)
|
| 31 |
+
|
| 32 |
+
def forward(self, x):
|
| 33 |
+
batch, dur, c, h, w = x.shape
|
| 34 |
+
x = x.reshape(-1, c, h, w)
|
| 35 |
+
x = self.init_conv_seq(x)
|
| 36 |
+
x = self.patchify(x)
|
| 37 |
+
x = self.x_embedder(x)
|
| 38 |
+
x = x.reshape(batch, dur, -1, self.out_channels)
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
def patchify(self, x):
|
| 42 |
+
B, C, H, W = x.size()
|
| 43 |
+
x = x.view(
|
| 44 |
+
B,
|
| 45 |
+
C,
|
| 46 |
+
H // self.patch_size,
|
| 47 |
+
self.patch_size,
|
| 48 |
+
W // self.patch_size,
|
| 49 |
+
self.patch_size,
|
| 50 |
+
)
|
| 51 |
+
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
class UnPatch(nn.Module):
|
| 55 |
+
def __init__(self, height, width, in_channels=64, out_channels=3, patch_size=2):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.width = width
|
| 58 |
+
self.height = height
|
| 59 |
+
self.patch_size = patch_size
|
| 60 |
+
self.in_channels = in_channels
|
| 61 |
+
self.out_channels = out_channels
|
| 62 |
+
self.unpatch = nn.Linear(in_channels, out_channels*patch_size**2)
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
x = self.unpatch(x)
|
| 66 |
+
batch, dur, seq, d = x.shape
|
| 67 |
+
x = x.reshape(-1, seq, d)
|
| 68 |
+
x = self.unpatchify(x)
|
| 69 |
+
x = x.reshape(batch, dur, self.out_channels, self.height, self.width)
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
def unpatchify(self, x):
|
| 73 |
+
c = self.out_channels
|
| 74 |
+
p = self.patch_size
|
| 75 |
+
h = self.height // p
|
| 76 |
+
w = self.width // p
|
| 77 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 78 |
+
x = t.einsum("nhwpqc->nchpwq", x)
|
| 79 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
| 80 |
+
return imgs
|
src/nn/pe.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch as t
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
from jaxtyping import Float, Bool, Int
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
class NumericEncoding(nn.Module):
|
| 10 |
+
def __init__(self, C = 1e4, dim = 64, n_max = 10000):
|
| 11 |
+
super().__init__()
|
| 12 |
+
args = t.exp(- math.log(C) * t.arange(0, dim, 2)/dim)
|
| 13 |
+
args = t.arange(n_max)[:, None] * args[None, :]
|
| 14 |
+
sins = t.sin(args)
|
| 15 |
+
coss = t.cos(args)
|
| 16 |
+
pe = t.empty((n_max, dim))
|
| 17 |
+
pe[:,::2] = sins
|
| 18 |
+
pe[:,1::2] = coss
|
| 19 |
+
self.register_buffer("pe", pe)
|
| 20 |
+
|
| 21 |
+
def forward(self, num):
|
| 22 |
+
"""
|
| 23 |
+
expects integers between 0 and n_max
|
| 24 |
+
"""
|
| 25 |
+
assert num.dtype == t.int32 or num.dtype == t.int64, f"wrong dtype {num.dtype}"
|
| 26 |
+
return self.pe[num]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class RoPE(nn.Module):
|
| 30 |
+
def __init__(self, d_head, n_ctx, C=10000):
|
| 31 |
+
super().__init__()
|
| 32 |
+
thetas = t.exp(-math.log(C)*t.arange(0,d_head,2)/d_head)
|
| 33 |
+
thetas = thetas.repeat([2,1]).T.flatten()
|
| 34 |
+
positions = t.arange(n_ctx)
|
| 35 |
+
all_thetas = positions.unsqueeze(1)*thetas.unsqueeze(0)
|
| 36 |
+
sins = t.sin(all_thetas)
|
| 37 |
+
coss = t.cos(all_thetas)
|
| 38 |
+
self.register_buffer('sins', sins.unsqueeze(0).unsqueeze(2))
|
| 39 |
+
self.register_buffer('coss', coss.unsqueeze(0).unsqueeze(2))
|
| 40 |
+
|
| 41 |
+
def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
|
| 42 |
+
offset: int = 0):
|
| 43 |
+
x = key_or_query
|
| 44 |
+
# start with doing it for just a single position m
|
| 45 |
+
x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
|
| 46 |
+
even = t.arange(0, x.shape[-1], 2)
|
| 47 |
+
odd = t.arange(1, x.shape[-1],2)
|
| 48 |
+
x_perm[:, :, :, even] = -x[:, :, :, odd]
|
| 49 |
+
x_perm[:, :, :, odd] = x[:, :, :, even]
|
| 50 |
+
assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
|
| 51 |
+
return self.coss[:,offset:offset+x.shape[1]]*x + self.sins[:,offset:offset+x.shape[1]]*x_perm
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class FrameRoPE(nn.Module):
|
| 55 |
+
def __init__(self, d_head, n_ctx, toks_per_frame, C=10000):
|
| 56 |
+
super().__init__()
|
| 57 |
+
thetas = t.exp(-math.log(C)*t.arange(0,d_head,2)/d_head)
|
| 58 |
+
thetas = thetas.repeat([2,1]).T.flatten()
|
| 59 |
+
positions = t.arange(n_ctx)
|
| 60 |
+
all_thetas = positions.unsqueeze(1)*thetas.unsqueeze(0)
|
| 61 |
+
sins = t.sin(all_thetas)
|
| 62 |
+
coss = t.cos(all_thetas)
|
| 63 |
+
self.register_buffer('sins', sins.unsqueeze(0).unsqueeze(2))
|
| 64 |
+
self.register_buffer('coss', coss.unsqueeze(0).unsqueeze(2))
|
| 65 |
+
self.toks_per_frame = toks_per_frame
|
| 66 |
+
|
| 67 |
+
def forward(self, key_or_query: Float[Tensor, "batch dur*seq n_head d_head"]):
|
| 68 |
+
x = key_or_query
|
| 69 |
+
# start with doing it for just a single position m
|
| 70 |
+
x_perm = t.empty(x.shape, dtype=x.dtype, device=x.device) # batch sequence n_head d_head, we perm the last axis
|
| 71 |
+
even = t.arange(0, x.shape[-1], 2)
|
| 72 |
+
odd = t.arange(1, x.shape[-1], 2)
|
| 73 |
+
x_perm[:, :, :, even] = -x[:, :, :, odd]
|
| 74 |
+
x_perm[:, :, :, odd] = x[:, :, :, even]
|
| 75 |
+
idcs = t.arange(0, x.shape[1]//self.toks_per_frame, device=x.device)
|
| 76 |
+
idcs = idcs[:, None].repeat(1, self.toks_per_frame).flatten()
|
| 77 |
+
return self.coss[:,idcs]*x + self.sins[:,idcs]*x_perm
|
src/utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .logging import log_video
|
| 2 |
+
from .checkpoint import load_model_from_config
|
src/utils/checkpoint.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
import shutil
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from tempfile import NamedTemporaryFile
|
| 8 |
+
from typing import Optional, Dict, Any, List
|
| 9 |
+
|
| 10 |
+
import torch as t
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
from ..models.dit_dforce import get_model
|
| 14 |
+
from ..config import Config
|
| 15 |
+
|
| 16 |
+
import yaml
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_model_from_config(config_path: str, checkpoint_path: str = None, strict: bool = True) -> nn.Module:
|
| 20 |
+
print(f"loading {config_path}")
|
| 21 |
+
cmodel = Config.from_yaml(config_path).model
|
| 22 |
+
model = get_model(cmodel.height, cmodel.width,
|
| 23 |
+
n_window=cmodel.n_window,
|
| 24 |
+
patch_size=cmodel.patch_size,
|
| 25 |
+
n_heads=cmodel.n_heads,d_model=cmodel.d_model,
|
| 26 |
+
n_blocks=cmodel.n_blocks,
|
| 27 |
+
T=cmodel.T,
|
| 28 |
+
in_channels=cmodel.in_channels,
|
| 29 |
+
bidirectional=cmodel.bidirectional)
|
| 30 |
+
if checkpoint_path is None and cmodel.checkpoint is not None:
|
| 31 |
+
checkpoint_path = cmodel.checkpoint
|
| 32 |
+
if checkpoint_path is not None:
|
| 33 |
+
state_dict = t.load(checkpoint_path, weights_only=False)
|
| 34 |
+
if "model" in state_dict:
|
| 35 |
+
state_dict = state_dict["model"]
|
| 36 |
+
if "_orig_mod." in list(state_dict.keys())[0]:
|
| 37 |
+
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items() if k.startswith("_orig_mod.")}
|
| 38 |
+
model.load_state_dict(state_dict, strict=strict)
|
| 39 |
+
print('loaded state dict')
|
| 40 |
+
return model
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class CheckpointManager:
|
| 45 |
+
"""
|
| 46 |
+
Manage top-K checkpoints by a metric. On each save:
|
| 47 |
+
- Write a new checkpoint atomically
|
| 48 |
+
- Keep only the top-K files by metric (max or min)
|
| 49 |
+
- Delete files not in top-K
|
| 50 |
+
- Maintain a small JSON index for quick reloads
|
| 51 |
+
Also scans the directory on init to reconstruct state.
|
| 52 |
+
|
| 53 |
+
Filenames are of the form: ckpt-step=<step>-metric=<metric>.pt
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
CKPT_PATTERN = re.compile(
|
| 57 |
+
r"^ckpt-step=(?P<step>\d+)-metric=(?P<metric>[+-]?\d+(?:\.\d+)?(?:e[+-]?\d+)?)\.pt$"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
dirpath: str | Path,
|
| 63 |
+
k: int = 5,
|
| 64 |
+
mode: str = "max", # or "min"
|
| 65 |
+
metric_name: str = "score",
|
| 66 |
+
is_main_process: bool = True,
|
| 67 |
+
index_filename: str = "ckpt_index.json",
|
| 68 |
+
):
|
| 69 |
+
self.dir = Path(dirpath)
|
| 70 |
+
self.dir.mkdir(parents=True, exist_ok=True)
|
| 71 |
+
assert mode in {"max", "min"}
|
| 72 |
+
self.k = int(k)
|
| 73 |
+
self.mode = mode
|
| 74 |
+
self.metric_name = metric_name
|
| 75 |
+
self.is_main = bool(is_main_process)
|
| 76 |
+
self.index_path = self.dir / index_filename
|
| 77 |
+
|
| 78 |
+
# entries: list of {path(str), step(int), metric(float), ts(float)}
|
| 79 |
+
self.entries: List[Dict[str, Any]] = []
|
| 80 |
+
|
| 81 |
+
self._load_index()
|
| 82 |
+
self._scan_and_merge()
|
| 83 |
+
self._prune_and_persist()
|
| 84 |
+
|
| 85 |
+
# ---------- Public API ----------
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
def best(self) -> Optional[Dict[str, Any]]:
|
| 89 |
+
return self.entries[0] if self.entries else None
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def paths(self) -> List[str]:
|
| 93 |
+
return [e["path"] for e in self.entries]
|
| 94 |
+
|
| 95 |
+
@property
|
| 96 |
+
def should_save(self) -> bool:
|
| 97 |
+
"""Use inside DDP loops to gate saving to rank-0 only."""
|
| 98 |
+
return self.is_main
|
| 99 |
+
|
| 100 |
+
def save(
|
| 101 |
+
self,
|
| 102 |
+
*,
|
| 103 |
+
metric: float,
|
| 104 |
+
step: int,
|
| 105 |
+
model: Optional[nn.Module] = None,
|
| 106 |
+
optimizer: Optional[t.optim.Optimizer] = None,
|
| 107 |
+
scheduler: Optional[Any] = None,
|
| 108 |
+
extra: Optional[Dict[str, Any]] = None,
|
| 109 |
+
state_dict: Optional[Dict[str, Any]] = None,
|
| 110 |
+
) -> Dict[str, Any]:
|
| 111 |
+
"""
|
| 112 |
+
Save a checkpoint and keep only top-K by metric.
|
| 113 |
+
|
| 114 |
+
Provide either `state_dict` or a `model` (optionally optimizer/scheduler).
|
| 115 |
+
The saved file always contains:
|
| 116 |
+
- 'model', 'optimizer', 'scheduler' (if provided)
|
| 117 |
+
- 'step', metric_name, 'timestamp', 'manager'
|
| 118 |
+
Returns info about the saved file and whether it made the top-K.
|
| 119 |
+
"""
|
| 120 |
+
if not self.should_save:
|
| 121 |
+
return {"saved": False, "kept": False, "reason": "not main process"}
|
| 122 |
+
|
| 123 |
+
if state_dict is None:
|
| 124 |
+
state_dict = {}
|
| 125 |
+
if model is not None:
|
| 126 |
+
state_dict["model"] = model.state_dict()
|
| 127 |
+
if optimizer is not None:
|
| 128 |
+
state_dict["optimizer"] = optimizer.state_dict()
|
| 129 |
+
if scheduler is not None:
|
| 130 |
+
# Some schedulers (e.g., OneCycleLR) have state_dict
|
| 131 |
+
try:
|
| 132 |
+
state_dict["scheduler"] = scheduler.state_dict()
|
| 133 |
+
except Exception:
|
| 134 |
+
pass
|
| 135 |
+
|
| 136 |
+
ts = time.time()
|
| 137 |
+
filename = f"ckpt-step={int(step):06d}-metric={float(metric):.8f}.pt"
|
| 138 |
+
fpath = self.dir / filename
|
| 139 |
+
|
| 140 |
+
# Attach metadata for convenience
|
| 141 |
+
payload = {
|
| 142 |
+
**state_dict,
|
| 143 |
+
"step": int(step),
|
| 144 |
+
self.metric_name: float(metric),
|
| 145 |
+
"timestamp": ts,
|
| 146 |
+
"manager": {
|
| 147 |
+
"mode": self.mode,
|
| 148 |
+
"k": self.k,
|
| 149 |
+
"metric_name": self.metric_name,
|
| 150 |
+
"filename": filename,
|
| 151 |
+
},
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
# Atomic write
|
| 155 |
+
with NamedTemporaryFile(dir=self.dir, delete=False) as tmp:
|
| 156 |
+
tmp_path = Path(tmp.name)
|
| 157 |
+
try:
|
| 158 |
+
t.save(payload, tmp_path)
|
| 159 |
+
os.replace(tmp_path, fpath) # atomic on POSIX
|
| 160 |
+
finally:
|
| 161 |
+
if tmp_path.exists():
|
| 162 |
+
try:
|
| 163 |
+
tmp_path.unlink()
|
| 164 |
+
except Exception:
|
| 165 |
+
pass
|
| 166 |
+
|
| 167 |
+
# Update entries and prune
|
| 168 |
+
new_entry = {
|
| 169 |
+
"path": str(fpath),
|
| 170 |
+
"step": int(step),
|
| 171 |
+
"metric": float(metric),
|
| 172 |
+
"ts": ts,
|
| 173 |
+
}
|
| 174 |
+
self.entries.append(new_entry)
|
| 175 |
+
kept = self._prune_and_persist() # returns True if new file in top-K
|
| 176 |
+
|
| 177 |
+
return {"saved": True, "kept": kept, "path": str(fpath), "best": self.best}
|
| 178 |
+
|
| 179 |
+
# ---------- Internal helpers ----------
|
| 180 |
+
|
| 181 |
+
def _sort_key(self, e: Dict[str, Any]):
|
| 182 |
+
# For MAX: better first => sort by (-metric, step)
|
| 183 |
+
# For MIN: better first => sort by (metric, step)
|
| 184 |
+
return ((-e["metric"], e["step"]) if self.mode == "max" else (e["metric"], e["step"]))
|
| 185 |
+
|
| 186 |
+
def _load_index(self):
|
| 187 |
+
if not self.index_path.exists():
|
| 188 |
+
self.entries = []
|
| 189 |
+
return
|
| 190 |
+
try:
|
| 191 |
+
data = json.loads(self.index_path.read_text())
|
| 192 |
+
entries = data.get("entries", [])
|
| 193 |
+
# Drop missing files
|
| 194 |
+
self.entries = [e for e in entries if Path(e["path"]).exists()]
|
| 195 |
+
# Normalize types
|
| 196 |
+
for e in self.entries:
|
| 197 |
+
e["metric"] = float(e["metric"])
|
| 198 |
+
e["step"] = int(e["step"])
|
| 199 |
+
e["ts"] = float(e.get("ts", time.time()))
|
| 200 |
+
except Exception:
|
| 201 |
+
# If index is corrupted, fall back to empty and rescan
|
| 202 |
+
self.entries = []
|
| 203 |
+
|
| 204 |
+
def _scan_and_merge(self):
|
| 205 |
+
"""Scan directory for checkpoint files and merge with current entries."""
|
| 206 |
+
seen = {Path(e["path"]).name for e in self.entries}
|
| 207 |
+
for p in self.dir.glob("ckpt-step=*-metric=*.pt"):
|
| 208 |
+
name = p.name
|
| 209 |
+
if name in seen:
|
| 210 |
+
continue
|
| 211 |
+
m = self.CKPT_PATTERN.match(name)
|
| 212 |
+
if not m:
|
| 213 |
+
continue
|
| 214 |
+
step = int(m.group("step"))
|
| 215 |
+
try:
|
| 216 |
+
metric = float(m.group("metric"))
|
| 217 |
+
except ValueError:
|
| 218 |
+
continue
|
| 219 |
+
self.entries.append(
|
| 220 |
+
{"path": str(p), "step": step, "metric": metric, "ts": p.stat().st_mtime}
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
def _prune_and_persist(self) -> bool:
|
| 224 |
+
"""Sort by metric, keep top-K, delete the rest. Return True if newest file is kept."""
|
| 225 |
+
if not self.entries:
|
| 226 |
+
self._persist_index()
|
| 227 |
+
return False
|
| 228 |
+
|
| 229 |
+
# Sort best-first
|
| 230 |
+
self.entries.sort(key=self._sort_key)
|
| 231 |
+
|
| 232 |
+
# Determine which to keep and which to delete
|
| 233 |
+
keep = self.entries[: self.k]
|
| 234 |
+
drop = self.entries[self.k :]
|
| 235 |
+
|
| 236 |
+
keep_paths = {e["path"] for e in keep}
|
| 237 |
+
newest_path = max(self.entries, key=lambda e: e["ts"])["path"]
|
| 238 |
+
newest_kept = newest_path in keep_paths
|
| 239 |
+
|
| 240 |
+
# Delete files not in top-K
|
| 241 |
+
for e in drop:
|
| 242 |
+
try:
|
| 243 |
+
Path(e["path"]).unlink(missing_ok=True)
|
| 244 |
+
except Exception:
|
| 245 |
+
pass
|
| 246 |
+
|
| 247 |
+
# Commit the top-K
|
| 248 |
+
self.entries = keep
|
| 249 |
+
self._persist_index()
|
| 250 |
+
return newest_kept
|
| 251 |
+
|
| 252 |
+
def _persist_index(self):
|
| 253 |
+
data = {
|
| 254 |
+
"k": self.k,
|
| 255 |
+
"mode": self.mode,
|
| 256 |
+
"metric_name": self.metric_name,
|
| 257 |
+
"entries": self.entries,
|
| 258 |
+
"updated_at": time.time(),
|
| 259 |
+
}
|
| 260 |
+
tmp = self.index_path.with_suffix(".json.tmp")
|
| 261 |
+
tmp.write_text(json.dumps(data, indent=2))
|
| 262 |
+
os.replace(tmp, self.index_path)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# ---------------------- Example usage ----------------------
|
| 266 |
+
if __name__ == "__main__":
|
| 267 |
+
# Example (single process). In DDP, construct with is_main_process=(rank==0).
|
| 268 |
+
mgr = CheckpointManager("checkpoints", k=5, mode="max", metric_name="val_acc")
|
| 269 |
+
|
| 270 |
+
model = nn.Linear(10, 2)
|
| 271 |
+
opt = t.optim.AdamW(model.parameters(), lr=1e-3)
|
| 272 |
+
|
| 273 |
+
# Fake loop
|
| 274 |
+
for epoch in range(10):
|
| 275 |
+
metric = 0.5 + 0.1 * t.rand(1).item() # pretend validation accuracy
|
| 276 |
+
info = mgr.save(metric=metric, step=epoch, model=model, optimizer=opt)
|
| 277 |
+
print(
|
| 278 |
+
f"epoch {epoch:02d} metric={metric:.4f} saved={info['saved']} kept={info['kept']} "
|
| 279 |
+
f"best_metric={mgr.best['metric'] if mgr.best else None:.4f}"
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
print("Top-K paths:", mgr.paths)
|
| 283 |
+
print("Best:", mgr.best)
|
static/index.html
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!doctype html>
|
| 2 |
+
<html>
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="utf-8" />
|
| 5 |
+
<title>Pong</title>
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
| 7 |
+
<!-- Socket.IO client library (CDN) -->
|
| 8 |
+
<script src="https://cdn.socket.io/4.5.4/socket.io.min.js"></script>
|
| 9 |
+
<style>
|
| 10 |
+
html, body { margin:0; height:100%; background:#111; color:#eee; font-family: system-ui, sans-serif; }
|
| 11 |
+
#overlay {
|
| 12 |
+
position: fixed; inset: 0; display: flex; align-items: center; justify-content: center;
|
| 13 |
+
background: rgba(0,0,0,0.8); z-index: 9999; transition: opacity 200ms ease;
|
| 14 |
+
}
|
| 15 |
+
#overlay.hidden { opacity: 0; pointer-events: none; }
|
| 16 |
+
.spinner {
|
| 17 |
+
width: 64px; height: 64px; border: 6px solid #444; border-top-color: #09f; border-radius: 50%;
|
| 18 |
+
animation: spin 0.9s linear infinite;
|
| 19 |
+
}
|
| 20 |
+
@keyframes spin { to { transform: rotate(360deg); } }
|
| 21 |
+
#statusText { margin-top: 12px; color: #aaa; text-align: center; font-size: 14px; white-space: pre-line; }
|
| 22 |
+
#app { padding: 16px; }
|
| 23 |
+
button { padding: 8px 12px; background:#09f; color:#fff; border:none; border-radius:8px; cursor:pointer; }
|
| 24 |
+
button:disabled { opacity: .5; cursor: not-allowed; }
|
| 25 |
+
img#frame { image-rendering: pixelated; width: 240px; height: 240px; background:#222; display:block; margin-top:12px; }
|
| 26 |
+
</style>
|
| 27 |
+
</head>
|
| 28 |
+
<body>
|
| 29 |
+
<div id="overlay">
|
| 30 |
+
<div>
|
| 31 |
+
<div class="spinner"></div>
|
| 32 |
+
<div id="statusText">Loading model…</div>
|
| 33 |
+
</div>
|
| 34 |
+
</div>
|
| 35 |
+
|
| 36 |
+
<div id="app">
|
| 37 |
+
<h1>Pong</h1>
|
| 38 |
+
<div style="margin-bottom: 12px;">
|
| 39 |
+
<label style="display: block; margin-bottom: 8px;">
|
| 40 |
+
FPS: <input type="number" id="fpsInput" value="20" min="1" max="30" step="1" style="width: 60px; padding: 4px; margin-left: 8px;" />
|
| 41 |
+
<span style="color: #aaa; font-size: 12px; margin-left: 8px;">frames per second</span>
|
| 42 |
+
</label>
|
| 43 |
+
<label style="display: block; margin-bottom: 8px;">
|
| 44 |
+
Steps: <input type="number" id="stepsInput" value="4" min="1" max="10" step="1" style="width: 60px; padding: 4px; margin-left: 8px;" />
|
| 45 |
+
<span style="color: #aaa; font-size: 12px; margin-left: 8px;">diffusion steps</span>
|
| 46 |
+
</label>
|
| 47 |
+
</div>
|
| 48 |
+
<div>
|
| 49 |
+
<button id="startBtn" disabled>Start Stream</button>
|
| 50 |
+
<button id="stopBtn" disabled>Stop Stream</button>
|
| 51 |
+
</div>
|
| 52 |
+
<img id="frame" alt="Latest frame" />
|
| 53 |
+
<div id="actionDisplay" style="margin-top: 12px; font-size: 16px; font-family: monospace;">
|
| 54 |
+
Action: <span id="actionValue">-</span>
|
| 55 |
+
</div>
|
| 56 |
+
<div id="fpsDisplay" style="margin-top: 8px; font-size: 16px; font-family: monospace;">
|
| 57 |
+
Achieved FPS: <span id="fpsValue">-</span>
|
| 58 |
+
</div>
|
| 59 |
+
<div>
|
| 60 |
+
This is the output of a small frame-autoregressive transformer trained with rectified flow matching to simulate pong frames conditioned on user inputs for the blue paddle. It should reach 20 FPS when using 4 steps for generation unless something else is running on my machine.
|
| 61 |
+
</div>
|
| 62 |
+
</div>
|
| 63 |
+
|
| 64 |
+
<script>
|
| 65 |
+
// If you serve socket.io client at /socket.io/socket.io.js you can use global io():
|
| 66 |
+
const socket = io({ transports: ['websocket', 'polling'] });
|
| 67 |
+
|
| 68 |
+
const overlay = document.getElementById('overlay');
|
| 69 |
+
const statusText = document.getElementById('statusText');
|
| 70 |
+
const startBtn = document.getElementById('startBtn');
|
| 71 |
+
const stopBtn = document.getElementById('stopBtn');
|
| 72 |
+
const frameImg = document.getElementById('frame');
|
| 73 |
+
|
| 74 |
+
function setStatus(isReady) {
|
| 75 |
+
if (!isReady) {
|
| 76 |
+
// Model is still loading
|
| 77 |
+
overlay.classList.remove('hidden');
|
| 78 |
+
startBtn.disabled = true;
|
| 79 |
+
stopBtn.disabled = true;
|
| 80 |
+
statusText.textContent = 'Loading model…';
|
| 81 |
+
} else {
|
| 82 |
+
// Server is ready and available
|
| 83 |
+
overlay.classList.add('hidden');
|
| 84 |
+
startBtn.disabled = false;
|
| 85 |
+
stopBtn.disabled = false;
|
| 86 |
+
statusText.textContent = 'Ready';
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
// Initial state: assume not ready (show spinner)
|
| 91 |
+
setStatus(false);
|
| 92 |
+
|
| 93 |
+
socket.on('connect', () => {
|
| 94 |
+
// server will immediately emit 'server_status' with current readiness
|
| 95 |
+
console.log('connected');
|
| 96 |
+
});
|
| 97 |
+
|
| 98 |
+
// Backend broadcasts readiness changes
|
| 99 |
+
socket.on('server_status', (payload) => {
|
| 100 |
+
const ready = !!(payload && payload.ready);
|
| 101 |
+
console.log('Server status:', { ready });
|
| 102 |
+
setStatus(ready);
|
| 103 |
+
});
|
| 104 |
+
|
| 105 |
+
// Start/stop controls
|
| 106 |
+
startBtn.addEventListener('click', () => {
|
| 107 |
+
const fps = parseInt(document.getElementById('fpsInput').value) || 12;
|
| 108 |
+
const n_steps = parseInt(document.getElementById('stepsInput').value) || 1;
|
| 109 |
+
socket.emit('start_stream', { n_steps: n_steps, cfg: 0.0, fps: fps, clamp: true });
|
| 110 |
+
});
|
| 111 |
+
stopBtn.addEventListener('click', () => {
|
| 112 |
+
socket.emit('stop_stream');
|
| 113 |
+
});
|
| 114 |
+
|
| 115 |
+
const actionValue = document.getElementById('actionValue');
|
| 116 |
+
const fpsValue = document.getElementById('fpsValue');
|
| 117 |
+
|
| 118 |
+
// Incoming frames
|
| 119 |
+
socket.on('frame', ({ frame, frame_index, action, fps }) => {
|
| 120 |
+
frameImg.src = `data:image/png;base64,${frame}`;
|
| 121 |
+
// Display action: 0=NOOP, 1=UP, 2=DOWN
|
| 122 |
+
const actionLabels = ['START','NOOP', 'UP', 'DOWN'];
|
| 123 |
+
actionValue.textContent = `${action} (${actionLabels[action] || 'UNKNOWN'})`;
|
| 124 |
+
// Display achieved FPS
|
| 125 |
+
if (fps !== undefined) {
|
| 126 |
+
fpsValue.textContent = fps.toFixed(1);
|
| 127 |
+
}
|
| 128 |
+
});
|
| 129 |
+
|
| 130 |
+
socket.on('error', (e) => {
|
| 131 |
+
console.warn('server error', e);
|
| 132 |
+
// The server_status event will handle showing the appropriate overlay
|
| 133 |
+
// Just log the error for now
|
| 134 |
+
if (e && e.message) {
|
| 135 |
+
console.error('Server error message:', e.message);
|
| 136 |
+
}
|
| 137 |
+
});
|
| 138 |
+
|
| 139 |
+
// Keyboard controls for paddle
|
| 140 |
+
// Actions: 0=NOOP, 1=UP, 2=DOWN
|
| 141 |
+
document.addEventListener('keydown', (e) => {
|
| 142 |
+
let action = null;
|
| 143 |
+
if (e.key === 'ArrowUp' || e.key === 'w' || e.key === 'W') {
|
| 144 |
+
action = 2; // UP
|
| 145 |
+
} else if (e.key === 'ArrowDown' || e.key === 's' || e.key === 'S') {
|
| 146 |
+
action = 3; // DOWN
|
| 147 |
+
}
|
| 148 |
+
if (action !== null) {
|
| 149 |
+
socket.emit('action', { action });
|
| 150 |
+
e.preventDefault();
|
| 151 |
+
}
|
| 152 |
+
});
|
| 153 |
+
|
| 154 |
+
document.addEventListener('keyup', (e) => {
|
| 155 |
+
if (['ArrowUp', 'ArrowDown', 'w', 'W', 's', 'S'].includes(e.key)) {
|
| 156 |
+
socket.emit('action', { action: 1 }); // NOOP when key released
|
| 157 |
+
e.preventDefault();
|
| 158 |
+
}
|
| 159 |
+
});
|
| 160 |
+
</script>
|
| 161 |
+
</body>
|
| 162 |
+
</html>
|