chrisxx commited on
Commit
8746765
·
1 Parent(s): a8aa75d

Add Neural Pong application files

Browse files
DEPLOYMENT.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Space Setup
2
+
3
+ This folder contains everything needed to deploy the Neural Pong demo to Hugging Face Spaces.
4
+
5
+ ## Files Structure
6
+
7
+ - `app.py` - Main Flask application (modified from play_pong.py, removed single-user limitation)
8
+ - `Dockerfile` - Docker configuration for HF Spaces
9
+ - `requirements.txt` - Python dependencies
10
+ - `README.md` - Space description and metadata
11
+ - `static/index.html` - Frontend web interface
12
+ - `configs/inference.yaml` - Model configuration
13
+ - `src/` - Source code for model loading and inference
14
+
15
+ ## Important Notes
16
+
17
+ ### Dependencies Fixed
18
+
19
+ ✅ **No external git dependencies**: The app now imports `sample` directly from `src.inference.sampling` instead of going through training code, avoiding the `muon-optimizer` git dependency.
20
+
21
+ ✅ **No data files needed**: The app uses `fixed2frame` directly instead of calling `get_loader`, so it doesn't need the training data files (`frames.npy`, `actions.npy`).
22
+
23
+ ✅ **Minimal codebase**: Only inference-related code is included. All training scripts and utilities have been removed:
24
+ - Removed: `src/trainers/`, `src/main.py`, `src/main_dmd.py`
25
+ - Removed: Unused dataset files, alternative models, custom norm
26
+ - Removed: Matplotlib dependencies (not needed for inference)
27
+ - **Total: 15 Python files** (down from 25+)
28
+
29
+ See `SOURCE_FILES.md` for a complete list of included files.
30
+
31
+ ### Checkpoint Path
32
+
33
+ The `configs/inference.yaml` file currently references a local checkpoint path:
34
+ ```yaml
35
+ checkpoint: "experiments/radiant-forest-398/ckpt-step=053700-metric=0.00092727.pt"
36
+ ```
37
+
38
+ **Before deploying**, you need to either:
39
+
40
+ 1. **Upload checkpoint to Hugging Face Hub** and update the path to load from Hub
41
+ 2. **Include the checkpoint file** in this directory and update the path
42
+ 3. **Use HF Spaces storage/secrets** to store the checkpoint
43
+
44
+ ### Changes Made
45
+
46
+ - Removed single-user limitation (all users can connect simultaneously)
47
+ - Simplified frontend to remove busy state handling
48
+ - Updated port to use environment variable (defaults to 7860 for HF Spaces)
49
+ - Created Dockerfile for containerized deployment
50
+
51
+ ## Deployment Steps
52
+
53
+ 1. Upload this folder to a Hugging Face Space repository
54
+ 2. Update the checkpoint path in `configs/inference.yaml` to point to your model
55
+ 3. Ensure the Space has GPU access enabled
56
+ 4. The Space will automatically build and deploy
57
+
58
+ ## Testing Locally
59
+
60
+ To test locally with Docker:
61
+
62
+ ```bash
63
+ docker build -t neural-pong .
64
+ docker run -p 7860:7860 neural-pong
65
+ ```
66
+
67
+ Then visit http://localhost:7860
68
+
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ build-essential \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Copy requirements and install Python dependencies
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy application code
15
+ COPY . .
16
+
17
+ # Expose port (HF Spaces will map this)
18
+ EXPOSE 7860
19
+
20
+ # Run the Flask app
21
+ CMD python app.py
22
+
QUICKSTART.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 Quick Setup Guide for Hugging Face Space
2
+
3
+ Your Neural Pong demo is ready to deploy! Follow these steps:
4
+
5
+ ## Step 1: Create Your Hugging Face Space
6
+
7
+ 1. **Go to Hugging Face Spaces:** https://huggingface.co/spaces
8
+ 2. **Click "Create new Space"**
9
+ 3. **Fill in the details:**
10
+ - **Space name:** `neural-pong` (or your preferred name)
11
+ - **SDK:** Select **"Docker"** ⚠️ Important!
12
+ - **Hardware:** Select **"GPU"** → **"T4 small"** (or larger)
13
+ - **Visibility:** Public or Private
14
+ 4. **Click "Create Space"**
15
+
16
+ ## Step 2: Upload Files Using Git
17
+
18
+ ```bash
19
+ cd /share/u/wendler/code/toy-wm-hf-space
20
+
21
+ # Initialize git (if not already done)
22
+ git init
23
+
24
+ # Add all files
25
+ git add .
26
+
27
+ # Commit
28
+ git commit -m "Initial commit: Neural Pong demo"
29
+
30
+ # Add your Space as remote (replace YOUR_USERNAME and SPACE_NAME)
31
+ git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/SPACE_NAME
32
+
33
+ # Push everything
34
+ git push -u origin main
35
+ ```
36
+
37
+ **Note:** The checkpoint file is 225MB, so Git is recommended over web upload.
38
+
39
+ ## Step 3: Wait for Build
40
+
41
+ 1. After pushing, Hugging Face will automatically start building
42
+ 2. Go to your Space page → **"Logs"** tab to watch progress
43
+ 3. Build time: **5-15 minutes** (installing PyTorch, etc.)
44
+
45
+ ## Step 4: Test Your Space
46
+
47
+ 1. Once build completes, visit your Space URL
48
+ 2. You should see the Pong interface
49
+ 3. Wait for model to load (loading spinner)
50
+ 4. Click **"Start Stream"**
51
+ 5. Use **Arrow Keys** or **WASD** to play!
52
+
53
+ ## Quick Commands
54
+
55
+ ```bash
56
+ # Run the setup script
57
+ ./setup.sh
58
+
59
+ # Check files are ready
60
+ ls -la
61
+
62
+ # Test Docker build locally (optional)
63
+ docker build -t neural-pong .
64
+ docker run -p 7860:7860 neural-pong
65
+ ```
66
+
67
+ ## Troubleshooting
68
+
69
+ ### Build Fails?
70
+ - Check **"Logs"** tab for errors
71
+ - Verify checkpoint path in `configs/inference.yaml`
72
+ - Ensure GPU is selected in Space settings
73
+
74
+ ### Model Won't Load?
75
+ - Verify checkpoint exists: `checkpoints/ckpt-step=053700-metric=0.00092727.pt`
76
+ - Check the path in `configs/inference.yaml`
77
+ - Look for errors in the Logs tab
78
+
79
+ ## What's Included
80
+
81
+ ✅ **app.py** - Flask application (no single-user limitation)
82
+ ✅ **checkpoints/** - Model checkpoint (225MB)
83
+ ✅ **src/** - All necessary source code (15 Python files)
84
+ ✅ **static/index.html** - Frontend interface
85
+ ✅ **configs/inference.yaml** - Model configuration
86
+ ✅ **Dockerfile** - Container configuration
87
+ ✅ **requirements.txt** - Python dependencies
88
+
89
+ ## Need More Help?
90
+
91
+ - See `SETUP_GUIDE.md` for detailed instructions
92
+ - See `DEPLOYMENT.md` for technical details
93
+ - Check Hugging Face Spaces docs: https://huggingface.co/docs/hub/spaces
94
+
95
+ ---
96
+
97
+ **Ready?** Run `./setup.sh` to get started! 🚀
README.md CHANGED
@@ -1,10 +1,42 @@
1
  ---
2
- title: Pong
3
- emoji: 📊
4
- colorFrom: purple
5
- colorTo: gray
6
  sdk: docker
7
  pinned: false
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Neural Pong
3
+ emoji: 🎮
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
  pinned: false
8
+ license: mit
9
  ---
10
 
11
+ # Neural Pong
12
+
13
+ A real-time Pong game where frames are generated by a diffusion model trained with rectified flow matching. Control the blue paddle using arrow keys or WASD to play!
14
+
15
+ ## Features
16
+
17
+ - **Real-time frame generation**: Uses a frame-autoregressive transformer with diffusion sampling
18
+ - **Interactive gameplay**: Control the paddle with keyboard inputs
19
+ - **Configurable parameters**: Adjust FPS and diffusion steps
20
+ - **Low-latency streaming**: Achieves ~20 FPS with 4 diffusion steps
21
+
22
+ ## How to Play
23
+
24
+ 1. Wait for the model to load (you'll see a loading spinner)
25
+ 2. Click "Start Stream" to begin generating frames
26
+ 3. Use **Arrow Keys** or **WASD** to control the blue paddle:
27
+ - **Up/W**: Move paddle up
28
+ - **Down/S**: Move paddle down
29
+ 4. Adjust the FPS and diffusion steps using the controls
30
+ 5. Click "Stop Stream" when done
31
+
32
+ ## Technical Details
33
+
34
+ This demo uses a small transformer model trained with rectified flow matching to simulate Pong game frames conditioned on user inputs. The model generates 24×24 pixel frames in real-time using diffusion sampling with configurable steps.
35
+
36
+ ## Model Architecture
37
+
38
+ - Frame-autoregressive transformer
39
+ - Rectified flow matching training
40
+ - Caching for efficient inference
41
+ - GPU-accelerated generation
42
+
RUN_SETUP.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Git Setup Complete - Next Steps
2
+
3
+ I've created a setup script for you. Here's what to do:
4
+
5
+ ## Run the Setup Script
6
+
7
+ ```bash
8
+ cd /share/u/wendler/code/toy-wm-hf-space
9
+ chmod +x setup-git.sh
10
+ ./setup-git.sh
11
+ ```
12
+
13
+ This script will:
14
+ 1. ✅ Initialize git (if needed)
15
+ 2. ✅ Remove old SSH remote
16
+ 3. ✅ Add HTTPS remote: `https://huggingface.co/spaces/wendlerc/pong`
17
+ 4. ✅ Stage all files
18
+ 5. ✅ Create initial commit (if needed)
19
+ 6. ✅ Ensure branch is named `main`
20
+ 7. ✅ Show you the push command
21
+
22
+ ## After Running the Script
23
+
24
+ The script will show you the exact command to push. It will be:
25
+ ```bash
26
+ git push -u origin main
27
+ ```
28
+
29
+ ## Before Pushing - Important!
30
+
31
+ **Make sure your Space exists:**
32
+ 1. Go to: https://huggingface.co/spaces/wendlerc/pong
33
+ 2. If it doesn't exist, create it:
34
+ - Go to: https://huggingface.co/spaces
35
+ - Click "Create new Space"
36
+ - Name: `pong`
37
+ - SDK: **Docker**
38
+ - Hardware: **GPU (T4 small)**
39
+ - Click "Create Space"
40
+
41
+ ## Then Push
42
+
43
+ ```bash
44
+ git push -u origin main
45
+ ```
46
+
47
+ You'll be prompted for your Hugging Face credentials (username and access token).
48
+
49
+ ## If You Need an Access Token
50
+
51
+ 1. Go to: https://huggingface.co/settings/tokens
52
+ 2. Create a new token with "write" permissions
53
+ 3. Use it as your password when pushing
54
+
55
+ ---
56
+
57
+ **The setup script is ready!** Just run `./setup-git.sh` and follow the instructions.
58
+
SETUP_GUIDE.md ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Step-by-Step Setup Guide for Hugging Face Space
2
+
3
+ This guide will walk you through deploying your Neural Pong demo to Hugging Face Spaces.
4
+
5
+ ## Prerequisites
6
+
7
+ - A Hugging Face account (sign up at https://huggingface.co/join)
8
+ - The model checkpoint file (`ckpt-step=053700-metric=0.00092727.pt`)
9
+ - Git installed on your machine (for uploading files)
10
+
11
+ ---
12
+
13
+ ## Step 1: Prepare Your Checkpoint File
14
+
15
+ First, you need to decide how to handle the model checkpoint. You have two main options:
16
+
17
+ ### Option A: Include Checkpoint in Repository (Simplest)
18
+
19
+ 1. **Locate your checkpoint file:**
20
+ ```bash
21
+ # Check if the file exists
22
+ ls /share/u/wendler/code/toy-wm/experiments/radiant-forest-398/ckpt-step=053700-metric=0.00092727.pt
23
+ ```
24
+
25
+ 2. **Copy it to the hf-space directory:**
26
+ ```bash
27
+ mkdir -p /share/u/wendler/code/toy-wm/hf-space/checkpoints
28
+ cp /share/u/wendler/code/toy-wm/experiments/radiant-forest-398/ckpt-step=053700-metric=0.00092727.pt \
29
+ /share/u/wendler/code/toy-wm/hf-space/checkpoints/
30
+ ```
31
+
32
+ 3. **Update the config file** to point to the new location:
33
+ ```yaml
34
+ checkpoint: "checkpoints/ckpt-step=053700-metric=0.00092727.pt"
35
+ ```
36
+
37
+ ### Option B: Upload to Hugging Face Hub (Better for Large Files)
38
+
39
+ 1. **Install Hugging Face Hub:**
40
+ ```bash
41
+ pip install huggingface-hub
42
+ ```
43
+
44
+ 2. **Login to Hugging Face:**
45
+ ```bash
46
+ huggingface-cli login
47
+ ```
48
+
49
+ 3. **Create a model repository and upload:**
50
+ ```bash
51
+ # Create a repository (replace YOUR_USERNAME with your HF username)
52
+ huggingface-cli repo create YOUR_USERNAME/neural-pong-checkpoint --type model
53
+
54
+ # Upload the checkpoint
55
+ huggingface-cli upload YOUR_USERNAME/neural-pong-checkpoint \
56
+ /share/u/wendler/code/toy-wm/experiments/radiant-forest-398/ckpt-step=053700-metric=0.00092727.pt \
57
+ ckpt-step=053700-metric=0.00092727.pt
58
+ ```
59
+
60
+ 4. **Modify the checkpoint loading code** to download from Hub (we'll do this in Step 2)
61
+
62
+ ---
63
+
64
+ ## Step 2: Update Configuration Files
65
+
66
+ ### If using Option A (checkpoint in repo):
67
+
68
+ Update `configs/inference.yaml`:
69
+ ```yaml
70
+ checkpoint: "checkpoints/ckpt-step=053700-metric=0.00092727.pt"
71
+ ```
72
+
73
+ ### If using Option B (HF Hub):
74
+
75
+ We'll need to modify the app.py to download the checkpoint. Let me know if you want to go this route.
76
+
77
+ ---
78
+
79
+ ## Step 3: Create a Hugging Face Space
80
+
81
+ 1. **Go to Hugging Face Spaces:** https://huggingface.co/spaces
82
+
83
+ 2. **Click "Create new Space"**
84
+
85
+ 3. **Fill in the details:**
86
+ - **Space name:** `neural-pong` (or your preferred name)
87
+ - **SDK:** Select **Docker**
88
+ - **Hardware:** Select **GPU** (T4 small or larger)
89
+ - **Visibility:** Public or Private (your choice)
90
+
91
+ 4. **Click "Create Space"**
92
+
93
+ ---
94
+
95
+ ## Step 4: Upload Files to the Space
96
+
97
+ You have two options:
98
+
99
+ ### Option A: Using Git (Recommended)
100
+
101
+ 1. **Initialize git in your hf-space directory:**
102
+ ```bash
103
+ cd /share/u/wendler/code/toy-wm/hf-space
104
+ git init
105
+ git add .
106
+ git commit -m "Initial commit"
107
+ ```
108
+
109
+ 2. **Add the Hugging Face remote:**
110
+ ```bash
111
+ # Replace YOUR_USERNAME and SPACE_NAME with your values
112
+ git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/SPACE_NAME
113
+ ```
114
+
115
+ 3. **Push to Hugging Face:**
116
+ ```bash
117
+ git push -u origin main
118
+ ```
119
+
120
+ ### Option B: Using Web Interface
121
+
122
+ 1. **Go to your Space page** on Hugging Face
123
+ 2. **Click "Files" tab**
124
+ 3. **Click "Add file" → "Upload files"**
125
+ 4. **Drag and drop all files** from the `hf-space` directory
126
+ 5. **Click "Commit changes"**
127
+
128
+ **Note:** For large checkpoint files, Git is recommended as the web interface has size limits.
129
+
130
+ ---
131
+
132
+ ## Step 5: Configure the Space
133
+
134
+ 1. **Go to your Space settings** (click the gear icon)
135
+
136
+ 2. **Important settings:**
137
+ - **Hardware:** Ensure GPU is selected (T4 small minimum)
138
+ - **Environment variables:** None needed for basic setup
139
+ - **Storage:** If using Option B, you might want persistent storage
140
+
141
+ 3. **Save settings**
142
+
143
+ ---
144
+
145
+ ## Step 6: Wait for Build and Deployment
146
+
147
+ 1. **After pushing files**, Hugging Face will automatically:
148
+ - Build the Docker image
149
+ - Install dependencies
150
+ - Start your application
151
+
152
+ 2. **Monitor the build:**
153
+ - Go to your Space page
154
+ - Click "Logs" tab to see build progress
155
+ - Look for any errors
156
+
157
+ 3. **Expected build time:** 5-15 minutes depending on dependencies
158
+
159
+ ---
160
+
161
+ ## Step 7: Test Your Space
162
+
163
+ 1. **Once the build completes**, your Space will be live
164
+ 2. **Visit your Space URL:** `https://huggingface.co/spaces/YOUR_USERNAME/SPACE_NAME`
165
+ 3. **Test the application:**
166
+ - Wait for model to load (loading spinner)
167
+ - Click "Start Stream"
168
+ - Use arrow keys or WASD to control paddle
169
+ - Verify frames are generating correctly
170
+
171
+ ---
172
+
173
+ ## Troubleshooting
174
+
175
+ ### Build Fails
176
+
177
+ - **Check logs** in the Space's "Logs" tab
178
+ - **Common issues:**
179
+ - Missing dependencies in `requirements.txt`
180
+ - Dockerfile syntax errors
181
+ - Checkpoint file not found (check path in `inference.yaml`)
182
+
183
+ ### Model Won't Load
184
+
185
+ - **Check checkpoint path** in `configs/inference.yaml`
186
+ - **Verify checkpoint file exists** in the repository
187
+ - **Check GPU availability** in Space settings
188
+
189
+ ### Port Issues
190
+
191
+ - The app uses port 7860 (HF Spaces default)
192
+ - If you see port errors, check the `PORT` environment variable
193
+
194
+ ### Out of Memory
195
+
196
+ - **Reduce batch size** or model size
197
+ - **Upgrade to larger GPU** in Space settings
198
+ - **Check if checkpoint is too large** (consider Option B)
199
+
200
+ ---
201
+
202
+ ## Quick Reference Commands
203
+
204
+ ```bash
205
+ # Navigate to hf-space directory
206
+ cd /share/u/wendler/code/toy-wm/hf-space
207
+
208
+ # Check files are ready
209
+ ls -la
210
+
211
+ # Test Docker build locally (optional)
212
+ docker build -t neural-pong .
213
+ docker run -p 7860:7860 neural-pong
214
+
215
+ # Git setup (if using Git)
216
+ git init
217
+ git add .
218
+ git commit -m "Initial commit"
219
+ git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/SPACE_NAME
220
+ git push -u origin main
221
+ ```
222
+
223
+ ---
224
+
225
+ ## Next Steps
226
+
227
+ After successful deployment:
228
+
229
+ 1. **Share your Space** with others
230
+ 2. **Monitor usage** in the Space analytics
231
+ 3. **Update as needed** by pushing new commits
232
+ 4. **Consider adding:**
233
+ - Better error handling
234
+ - More configuration options
235
+ - Performance optimizations
236
+
237
+ ---
238
+
239
+ ## Need Help?
240
+
241
+ - Check Hugging Face Spaces docs: https://huggingface.co/docs/hub/spaces
242
+ - Review your Space logs for errors
243
+ - Test locally with Docker first to catch issues early
244
+
SETUP_STEPS.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Step-by-Step Setup for Hugging Face Space
2
+
3
+ Follow these steps to deploy your Neural Pong demo to Hugging Face Spaces.
4
+
5
+ ## ✅ Pre-flight Check
6
+
7
+ Your directory structure looks good! Here's what you have:
8
+
9
+ ```
10
+ toy-wm-hf-space/
11
+ ├── app.py ✅ Main Flask application
12
+ ├── Dockerfile ✅ Docker configuration
13
+ ├── requirements.txt ✅ Python dependencies
14
+ ├── README.md ✅ Space metadata
15
+ ├── checkpoints/ ✅ Model checkpoint (225MB)
16
+ │ └── ckpt-step=053700-metric=0.00092727.pt
17
+ ├── configs/
18
+ │ └── inference.yaml ✅ Model config (checkpoint path correct)
19
+ ├── static/
20
+ │ └── index.html ✅ Frontend
21
+ └── src/ ✅ All source code (15 files)
22
+ ```
23
+
24
+ ## Step 1: Create Your Hugging Face Space
25
+
26
+ 1. **Go to:** https://huggingface.co/spaces
27
+ 2. **Click:** "Create new Space" button
28
+ 3. **Fill in:**
29
+ - **Space name:** `neural-pong` (or your choice)
30
+ - **SDK:** **Docker** ⚠️ Must be Docker!
31
+ - **Hardware:** **GPU** → **T4 small** (minimum)
32
+ - **Visibility:** Public or Private
33
+ 4. **Click:** "Create Space"
34
+
35
+ You'll get a URL like: `https://huggingface.co/spaces/YOUR_USERNAME/neural-pong`
36
+
37
+ ## Step 2: Initialize Git Repository
38
+
39
+ ```bash
40
+ cd /share/u/wendler/code/toy-wm-hf-space
41
+
42
+ # Initialize git (if not already done)
43
+ git init
44
+
45
+ # Add all files
46
+ git add .
47
+
48
+ # Make initial commit
49
+ git commit -m "Initial commit: Neural Pong demo"
50
+ ```
51
+
52
+ ## Step 3: Connect to Your Hugging Face Space
53
+
54
+ Replace `YOUR_USERNAME` and `SPACE_NAME` with your actual values:
55
+
56
+ ```bash
57
+ # Add your Space as remote
58
+ git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/SPACE_NAME
59
+
60
+ # Push everything
61
+ git push -u origin main
62
+ ```
63
+
64
+ **Example:**
65
+ ```bash
66
+ git remote add origin https://huggingface.co/spaces/johndoe/neural-pong
67
+ git push -u origin main
68
+ ```
69
+
70
+ **Note:** The checkpoint is 225MB, so this may take a few minutes to upload.
71
+
72
+ ## Step 4: Monitor the Build
73
+
74
+ 1. **Go to your Space page** on Hugging Face
75
+ 2. **Click the "Logs" tab** to watch the build progress
76
+ 3. **Wait 5-15 minutes** for:
77
+ - Docker image build
78
+ - Dependency installation (PyTorch, Flask, etc.)
79
+ - Model loading
80
+
81
+ ## Step 5: Test Your Space
82
+
83
+ Once the build completes:
84
+
85
+ 1. **Visit your Space URL** (e.g., `https://huggingface.co/spaces/YOUR_USERNAME/neural-pong`)
86
+ 2. **You should see:**
87
+ - Loading spinner while model loads
88
+ - Controls for FPS and diffusion steps
89
+ - "Start Stream" button
90
+ 3. **Test the game:**
91
+ - Click "Start Stream"
92
+ - Use Arrow Keys or WASD to control paddle
93
+ - Verify frames are generating
94
+
95
+ ## Troubleshooting
96
+
97
+ ### Build Fails?
98
+
99
+ **Check the Logs tab for errors:**
100
+
101
+ - **Missing dependencies?** → Check `requirements.txt`
102
+ - **Checkpoint not found?** → Verify path in `configs/inference.yaml`
103
+ - **GPU errors?** → Ensure GPU is enabled in Space settings
104
+ - **Port errors?** → Should use port 7860 automatically
105
+
106
+ ### Model Won't Load?
107
+
108
+ 1. **Verify checkpoint path** in `configs/inference.yaml`:
109
+ ```yaml
110
+ checkpoint: "checkpoints/ckpt-step=053700-metric=0.00092727.pt"
111
+ ```
112
+
113
+ 2. **Check checkpoint exists:**
114
+ ```bash
115
+ ls -lh checkpoints/ckpt-step=053700-metric=0.00092727.pt
116
+ ```
117
+
118
+ 3. **Look for errors** in the Logs tab
119
+
120
+ ### Out of Memory?
121
+
122
+ - **Upgrade GPU** in Space settings (T4 medium or larger)
123
+ - **Reduce batch size** if applicable
124
+ - **Check checkpoint size** (225MB is reasonable)
125
+
126
+ ## Testing Locally (Optional)
127
+
128
+ Before deploying, you can test locally with Docker:
129
+
130
+ ```bash
131
+ cd /share/u/wendler/code/toy-wm-hf-space
132
+
133
+ # Build Docker image
134
+ docker build -t neural-pong .
135
+
136
+ # Run container
137
+ docker run -p 7860:7860 --gpus all neural-pong
138
+
139
+ # Visit http://localhost:7860
140
+ ```
141
+
142
+ **Note:** Requires Docker and NVIDIA Docker runtime for GPU support.
143
+
144
+ ## Updating Your Space
145
+
146
+ After making changes:
147
+
148
+ ```bash
149
+ git add .
150
+ git commit -m "Your update message"
151
+ git push origin main
152
+ ```
153
+
154
+ Hugging Face will automatically rebuild and redeploy.
155
+
156
+ ## File Checklist
157
+
158
+ Before pushing, verify:
159
+
160
+ - ✅ `app.py` exists and is executable
161
+ - ✅ `Dockerfile` exists
162
+ - ✅ `requirements.txt` has all dependencies
163
+ - ✅ `checkpoints/ckpt-step=053700-metric=0.00092727.pt` exists (225MB)
164
+ - ✅ `configs/inference.yaml` has correct checkpoint path
165
+ - ✅ `static/index.html` exists
166
+ - ✅ `src/` directory has all necessary files
167
+
168
+ ## Quick Reference
169
+
170
+ **Your Space URL format:**
171
+ ```
172
+ https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
173
+ ```
174
+
175
+ **Git remote format:**
176
+ ```bash
177
+ git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
178
+ ```
179
+
180
+ **Key files:**
181
+ - `app.py` - Main application (port 7860)
182
+ - `Dockerfile` - Container config
183
+ - `requirements.txt` - Dependencies
184
+ - `configs/inference.yaml` - Model config
185
+
186
+ ## Next Steps After Deployment
187
+
188
+ 1. **Share your Space** with others
189
+ 2. **Monitor usage** in Space analytics
190
+ 3. **Update as needed** by pushing new commits
191
+ 4. **Consider adding:**
192
+ - Better error handling
193
+ - Performance metrics
194
+ - More configuration options
195
+
196
+ ---
197
+
198
+ **Ready to deploy?** Follow Step 1 above! 🚀
199
+
200
+ For more details, see `QUICKSTART.md` or `DEPLOYMENT.md`.
201
+
SOURCE_FILES.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Source Files Included in Hugging Face Space
2
+
3
+ This document lists all the source files included in the deployment. Only inference-related code is included - all training code has been removed.
4
+
5
+ ## File Structure
6
+
7
+ ```
8
+ src/
9
+ ├── __init__.py # Package init
10
+ ├── config.py # Configuration classes (Config, TransformerConfig, etc.)
11
+ ├── datasets/
12
+ │ ├── __init__.py # Datasets package init
13
+ │ └── pong1m.py # Dataset utilities (only fixed2frame used)
14
+ ├── inference/
15
+ │ ├── __init__.py # Inference package init
16
+ │ └── sampling.py # Diffusion sampling function
17
+ ├── models/
18
+ │ ├── __init__.py # Models package init
19
+ │ └── dit_dforce.py # CausalDit model and get_model function
20
+ ├── nn/
21
+ │ ├── __init__.py # NN package init
22
+ │ ├── attn.py # Attention mechanisms and KVCache
23
+ │ ├── geglu.py # GEGLU activation
24
+ │ ├── patch.py # Patch/UnPatch for image tokens
25
+ │ └── pe.py # Positional encodings (RoPE, FrameRoPE, etc.)
26
+ └── utils/
27
+ ├── __init__.py # Utils package init
28
+ └── checkpoint.py # Model loading utilities
29
+ ```
30
+
31
+ ## Total: 15 Python files
32
+
33
+ ## Files Removed (Training Code)
34
+
35
+ - ❌ `src/main.py` - Training script
36
+ - ❌ `src/main_dmd.py` - Training script
37
+ - ❌ `src/trainers/` - All training code (5 files)
38
+ - ❌ `src/datasets/pong1m_embedding.py` - Not used
39
+ - ❌ `src/datasets/pong1m_gpt.py` - Not used
40
+ - ❌ `src/models/dit.py` - Alternative model (not used)
41
+ - ❌ `src/nn/norm.py` - Custom norm (not used, PyTorch LayerNorm used instead)
42
+ - ❌ `src/utils/logging.py` - Logging utilities (not needed for inference)
43
+ - ❌ `src/config/` - Empty directory
44
+
45
+ ## Dependencies Removed
46
+
47
+ - ✅ Removed `matplotlib` imports (not needed for inference)
48
+ - ✅ Removed `muon-optimizer` dependency (only used in training)
49
+ - ✅ Removed training data file dependencies
50
+
51
+ ## Verification
52
+
53
+ All necessary classes and functions are included:
54
+ - ✅ `load_model_from_config` - Model loading
55
+ - ✅ `sample` - Diffusion sampling
56
+ - ✅ `fixed2frame` - Frame conversion
57
+ - ✅ `Config` - Configuration parsing
58
+ - ✅ `CausalDit` - Model class
59
+ - ✅ `KVCache` - KV caching for inference
60
+ - ✅ All NN components (Attention, GEGLU, Patch, Positional Encodings)
61
+
62
+ The codebase is now minimal and contains only what's needed for inference.
63
+
START_HERE.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 Your Hugging Face Space is Ready!
2
+
3
+ Everything is set up and ready to deploy. Here's what you need to do:
4
+
5
+ ## Quick Start (3 Steps)
6
+
7
+ ### 1. Create Your Space
8
+ - Go to: https://huggingface.co/spaces
9
+ - Click "Create new Space"
10
+ - Name: `neural-pong`
11
+ - SDK: **Docker** ⚠️
12
+ - Hardware: **GPU (T4 small)**
13
+
14
+ ### 2. Push Your Code
15
+ ```bash
16
+ cd /share/u/wendler/code/toy-wm-hf-space
17
+
18
+ git init
19
+ git add .
20
+ git commit -m "Initial commit"
21
+
22
+ # Replace with your actual Space URL
23
+ git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME
24
+ git push -u origin main
25
+ ```
26
+
27
+ ### 3. Wait & Test
28
+ - Watch build in "Logs" tab (5-15 min)
29
+ - Visit your Space URL when done
30
+ - Click "Start Stream" and play!
31
+
32
+ ## ✅ What's Ready
33
+
34
+ - ✅ **app.py** - Flask app (port 7860, no user limits)
35
+ - ✅ **Dockerfile** - Container config
36
+ - ✅ **requirements.txt** - All dependencies
37
+ - ✅ **checkpoints/** - Model file (225MB)
38
+ - ✅ **configs/inference.yaml** - Config (checkpoint path correct)
39
+ - ✅ **static/index.html** - Frontend
40
+ - ✅ **src/** - All source code (15 files, cleaned)
41
+
42
+ ## 📚 Documentation
43
+
44
+ - **SETUP_STEPS.md** - Detailed step-by-step guide
45
+ - **QUICKSTART.md** - Quick reference
46
+ - **DEPLOYMENT.md** - Technical details
47
+ - **README.md** - Space description
48
+
49
+ ## 🔍 Verify Before Pushing
50
+
51
+ ```bash
52
+ # Check checkpoint exists
53
+ ls -lh checkpoints/ckpt-step=053700-metric=0.00092727.pt
54
+
55
+ # Check config path
56
+ grep checkpoint configs/inference.yaml
57
+
58
+ # Check main files
59
+ ls app.py Dockerfile requirements.txt
60
+ ```
61
+
62
+ ## 💡 Tips
63
+
64
+ - **Large file upload:** The checkpoint is 225MB, Git is recommended
65
+ - **Build time:** 5-15 minutes (PyTorch installation)
66
+ - **GPU required:** Make sure GPU is enabled in Space settings
67
+ - **Port:** Automatically uses 7860 (HF Spaces default)
68
+
69
+ ## 🆘 Need Help?
70
+
71
+ 1. Check **SETUP_STEPS.md** for detailed instructions
72
+ 2. Check Space **Logs** tab for build errors
73
+ 3. Verify all files are present (see checklist above)
74
+
75
+ ---
76
+
77
+ **Ready?** Start with Step 1 above! 🎮
78
+
TROUBLESHOOTING.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Troubleshooting: Repository Not Found
2
+
3
+ The error "Repository not found" usually means one of these issues:
4
+
5
+ ## Issue 1: Space Doesn't Exist Yet
6
+
7
+ **You need to create the Space on Hugging Face first!**
8
+
9
+ 1. Go to: https://huggingface.co/spaces
10
+ 2. Click "Create new Space"
11
+ 3. Fill in:
12
+ - **Space name:** `pong` (or your choice)
13
+ - **SDK:** Docker
14
+ - **Hardware:** GPU (T4 small)
15
+ 4. Click "Create Space"
16
+
17
+ **Then** come back and push your code.
18
+
19
+ ## Issue 2: Wrong Remote URL
20
+
21
+ Your remote is currently set to SSH: `git@hf.co:spaces/wendlerc/pong`
22
+
23
+ **For Hugging Face Spaces, use HTTPS instead:**
24
+
25
+ ```bash
26
+ cd /share/u/wendler/code/toy-wm-hf-space
27
+
28
+ # Remove the old remote
29
+ git remote remove origin
30
+
31
+ # Add the correct HTTPS remote
32
+ git remote add origin https://huggingface.co/spaces/wendlerc/pong
33
+
34
+ # Now try pushing
35
+ git push -u origin main
36
+ ```
37
+
38
+ ## Issue 3: Wrong Space Name
39
+
40
+ Make sure the Space name matches exactly. If you created it with a different name, update the URL:
41
+
42
+ ```bash
43
+ # Check what remote you have
44
+ git remote -v
45
+
46
+ # Update to correct URL (replace with your actual Space name)
47
+ git remote set-url origin https://huggingface.co/spaces/wendlerc/YOUR_SPACE_NAME
48
+ ```
49
+
50
+ ## Quick Fix Commands
51
+
52
+ ```bash
53
+ cd /share/u/wendler/code/toy-wm-hf-space
54
+
55
+ # 1. Check current remote
56
+ git remote -v
57
+
58
+ # 2. Remove old remote
59
+ git remote remove origin
60
+
61
+ # 3. Add HTTPS remote (replace wendlerc/pong with your actual Space)
62
+ git remote add origin https://huggingface.co/spaces/wendlerc/pong
63
+
64
+ # 4. Verify remote
65
+ git remote -v
66
+
67
+ # 5. Push
68
+ git push -u origin main
69
+ ```
70
+
71
+ ## Verify Your Space Exists
72
+
73
+ 1. Go to: https://huggingface.co/spaces/wendlerc
74
+ 2. Check if `pong` Space exists
75
+ 3. If not, create it first!
76
+
77
+ ## Alternative: Use Hugging Face CLI
78
+
79
+ If you have `huggingface-cli` installed:
80
+
81
+ ```bash
82
+ # Login
83
+ huggingface-cli login
84
+
85
+ # Create Space (if it doesn't exist)
86
+ huggingface-cli repo create wendlerc/pong --type space --sdk docker
87
+ ```
88
+
89
+ Then push your code.
90
+
app.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Pong backend (GPU, eager) for Hugging Face Spaces.
4
+ Broadcasts readiness via Socket.IO so the frontend can auto-hide a loading overlay once the model is ready.
5
+ """
6
+
7
+ # Eventlet must be imported first and monkey-patched before other imports
8
+ import eventlet
9
+ eventlet.monkey_patch()
10
+
11
+ import sys
12
+ import os
13
+ import time
14
+ import threading
15
+ import base64
16
+ import traceback
17
+ from contextlib import contextmanager
18
+ from io import BytesIO
19
+
20
+ import torch as t
21
+ import torch._dynamo as _dynamo
22
+ import numpy as np
23
+ from PIL import Image
24
+ from flask import Flask, request, jsonify, send_from_directory
25
+ from flask_cors import CORS
26
+ from flask_socketio import SocketIO, emit
27
+
28
+ # --------------------------
29
+ # Project imports
30
+ # --------------------------
31
+ project_root = os.path.dirname(os.path.abspath(__file__))
32
+ if project_root not in sys.path:
33
+ sys.path.insert(0, project_root)
34
+
35
+ from src.utils.checkpoint import load_model_from_config
36
+ from src.inference.sampling import sample
37
+ from src.datasets.pong1m import fixed2frame
38
+ from src.config import Config
39
+
40
+ # --------------------------
41
+ # App setup
42
+ # --------------------------
43
+ app = Flask(__name__, static_folder='static')
44
+ CORS(app)
45
+ # Configure SocketIO - use eventlet for proper WebSocket support
46
+ socketio = SocketIO(
47
+ app,
48
+ cors_allowed_origins="*",
49
+ async_mode='eventlet',
50
+ logger=False,
51
+ engineio_logger=False,
52
+ ping_timeout=60,
53
+ ping_interval=25,
54
+ max_http_buffer_size=1e8 # Allow larger messages
55
+ )
56
+
57
+ # --------------------------
58
+ # Globals
59
+ # --------------------------
60
+ model = None
61
+ pred2frame = None
62
+ device = None
63
+
64
+ server_ready = False # <--- readiness flag
65
+
66
+ stream_lock = threading.Lock()
67
+ stream_thread = None
68
+ stream_running = False
69
+ latest_action = 1 # 0=init, 1=nothing, 2=up, 3=down
70
+ target_fps = 30
71
+ frame_index = 0
72
+
73
+ noise_buf = None # (1,1,3,24,24) on GPU
74
+ action_buf = None # (1,1) long on GPU
75
+ cpu_png_buffer = None # BytesIO; reused
76
+
77
+ step_once = None
78
+
79
+ # --------------------------
80
+ # Perf (new API)
81
+ # --------------------------
82
+ t.backends.cudnn.benchmark = True
83
+ t.backends.cudnn.conv.fp32_precision = "tf32"
84
+ t.backends.cuda.matmul.fp32_precision = "high"
85
+
86
+ # --------------------------
87
+ # Debug helpers
88
+ # --------------------------
89
+ def _shape(x):
90
+ try:
91
+ return f"{tuple(x.shape)} | {x.dtype} | {x.device}"
92
+ except Exception:
93
+ return "<?>"
94
+
95
+ def _shape_attr(obj, name):
96
+ try:
97
+ ten = getattr(obj, name, None)
98
+ return None if ten is None else _shape(ten)
99
+ except Exception:
100
+ return None
101
+
102
+ def _fail(msg, extra=None):
103
+ lines = [f"[GEN ERROR] {msg}"]
104
+ if extra:
105
+ for k, v in extra.items():
106
+ lines.append(f" - {k}: {v}")
107
+ raise RuntimeError("\n".join(lines))
108
+
109
+ @contextmanager
110
+ def log_step_debug(action_tensor=None, noise_tensor=None):
111
+ try:
112
+ yield
113
+ except Exception as e:
114
+ tb = traceback.format_exc(limit=6)
115
+ _fail("Step failed",
116
+ extra={
117
+ "action": _shape(action_tensor),
118
+ "noise": _shape(noise_tensor),
119
+ "model.device": str(device),
120
+ "cache.keys": _shape_attr(getattr(model, "cache", None), "keys"),
121
+ "cache.values": _shape_attr(getattr(model, "cache", None), "values"),
122
+ "frame_index": str(frame_index),
123
+ "exception": f"{type(e).__name__}: {e}",
124
+ "trace": tb.strip()
125
+ })
126
+
127
+ # --------------------------
128
+ # Utilities
129
+ # --------------------------
130
+ def _ensure_cuda():
131
+ if not t.cuda.is_available():
132
+ raise RuntimeError("CUDA GPU required; torch.cuda.is_available() is False.")
133
+ return t.device("cuda:0")
134
+
135
+ def _png_base64_from_uint8(frame_uint8) -> str:
136
+ global cpu_png_buffer
137
+ if cpu_png_buffer is None:
138
+ cpu_png_buffer = BytesIO()
139
+ else:
140
+ cpu_png_buffer.seek(0)
141
+ cpu_png_buffer.truncate(0)
142
+ Image.fromarray(frame_uint8).save(cpu_png_buffer, format="PNG")
143
+ return base64.b64encode(cpu_png_buffer.getvalue()).decode()
144
+
145
+ def _reset_cache_fresh():
146
+ model.cache.reset()
147
+
148
+ def _broadcast_ready():
149
+ """Tell all clients whether the server is ready."""
150
+ socketio.emit('server_status', {'ready': server_ready, 'busy': False})
151
+
152
+ # --------------------------
153
+ # Model init (pure eager) & warmup
154
+ # --------------------------
155
+ def initialize_model():
156
+ global model, pred2frame, device
157
+ global noise_buf, action_buf, step_once, server_ready
158
+
159
+ t_start = time.time()
160
+ print("Loading model and preparing GPU runtime...")
161
+ device = _ensure_cuda()
162
+
163
+ config_path = os.path.join(project_root, "configs/inference.yaml")
164
+
165
+ cfg = Config.from_yaml(config_path)
166
+ checkpoint_path = cfg.model.checkpoint
167
+
168
+ model = load_model_from_config(config_path, checkpoint_path=checkpoint_path, strict=False)
169
+ model.to(device) # Move model to GPU before activating cache
170
+ model.eval()
171
+
172
+ model.activate_caching(1, 300) # Cache will now be created on the same device as model
173
+
174
+ # Use fixed2frame directly instead of get_loader to avoid loading data files
175
+ globals()["pred2frame"] = fixed2frame
176
+
177
+ H = W = 24
178
+ noise_buf = t.empty((1, 1, 3, H, W), device=device)
179
+ action_buf = t.empty((1, 1), dtype=t.long, device=device)
180
+
181
+ @_dynamo.disable
182
+ def _step(model_, action_scalar_long: int, n_steps: int, cfg: float, clamp: bool):
183
+ # Match the notebook logic exactly: create fresh noise each time
184
+ noise = t.randn(1, 1, 3, 24, 24, device=device)
185
+ action_buf.fill_(int(action_scalar_long))
186
+
187
+ assert action_buf.shape == (1, 1) and action_buf.dtype == t.long and action_buf.device == device, \
188
+ f"action_buf wrong: { _shape(action_buf) }"
189
+ assert noise.shape == (1, 1, 3, 24, 24) and noise.device == device, \
190
+ f"noise wrong: { _shape(noise) }"
191
+
192
+ # Debug: Check cache state before sampling
193
+ if model_.cache is not None:
194
+ cache_loc = model_.cache.local_location
195
+ if cache_loc == 0:
196
+ # Cache is empty, this should be fine for the first frame
197
+ pass
198
+ elif cache_loc > 0:
199
+ # Check if cache has valid data
200
+ k_test, v_test = model_.cache.get(0)
201
+ if k_test.shape[1] == 0:
202
+ print(f"Warning: Cache returned empty tensors at frame {frame_index}, resetting...")
203
+ _reset_cache_fresh()
204
+
205
+ # Sample with the fresh noise (matching notebook: sample(model, noise, actions[:, aidx:aidx+1], ...))
206
+ z = sample(model_, noise, action_buf, num_steps=n_steps, cfg=cfg, negative_actions=None)
207
+
208
+ # Update cache location after sample (matching notebook: model.cache.update_global_location(1))
209
+ model_.cache.update_global_location(1)
210
+
211
+ if clamp:
212
+ z = t.clamp(z, -1, 1)
213
+ return z
214
+
215
+ globals()["step_once"] = _step
216
+ print("Mode: eager (no torch.compile)")
217
+
218
+ # Warmup
219
+ _reset_cache_fresh()
220
+ with t.inference_mode(), t.autocast(device_type="cuda", dtype=t.bfloat16):
221
+ for _ in range(4):
222
+ with log_step_debug(action_tensor=action_buf, noise_tensor=noise_buf):
223
+ _ = step_once(model, action_scalar_long=1, n_steps=4, cfg=0.0, clamp=True)
224
+
225
+ server_ready = True
226
+ print(f"Model ready on {device}")
227
+ _broadcast_ready()
228
+ return model, pred2frame
229
+
230
+ # --------------------------
231
+ # Fixed-FPS streaming worker
232
+ # --------------------------
233
+ class FrameScheduler(threading.Thread):
234
+ def __init__(self, fps=30, n_steps=8, cfg=0.0, clamp=True):
235
+ super().__init__(daemon=True)
236
+ self.frame_period = 1.0 / max(1, int(fps))
237
+ self.n_steps = int(n_steps)
238
+ self.cfg = float(cfg)
239
+ self.clamp = bool(clamp)
240
+ self._stop = threading.Event()
241
+ # FPS tracking
242
+ self.frame_times = []
243
+ self.last_frame_time = None
244
+
245
+ def stop(self):
246
+ self._stop.set()
247
+
248
+ def run(self):
249
+ global frame_index, latest_action
250
+ next_tick = time.perf_counter()
251
+ while not self._stop.is_set():
252
+ start = time.perf_counter()
253
+ if start - next_tick > self.frame_period * 0.75:
254
+ next_tick = start + self.frame_period
255
+ continue
256
+ try:
257
+ with stream_lock:
258
+ action = int(latest_action)
259
+ with t.inference_mode(), t.autocast(device_type="cuda", dtype=t.bfloat16):
260
+ with log_step_debug(action_tensor=action_buf, noise_tensor=noise_buf):
261
+ z = step_once(model, action_scalar_long=action,
262
+ n_steps=self.n_steps, cfg=self.cfg, clamp=self.clamp)
263
+ frames_btchw = pred2frame(z)
264
+ # Debug: check what pred2frame returns
265
+ if frame_index < 3:
266
+ print(f"Frame {frame_index}: z range [{z.min().item():.3f}, {z.max().item():.3f}], "
267
+ f"frames_btchw dtype={frames_btchw.dtype}, range [{frames_btchw.min().item()}, {frames_btchw.max().item()}]")
268
+
269
+ frame_arr = frames_btchw[0, 0].permute(1, 2, 0).contiguous()
270
+ if isinstance(frame_arr, t.Tensor):
271
+ frame_np = frame_arr.to("cpu", non_blocking=True).numpy()
272
+ else:
273
+ frame_np = frame_arr.astype(np.uint8, copy=False)
274
+ img_b64 = _png_base64_from_uint8(frame_np)
275
+
276
+ # Calculate achieved FPS
277
+ current_time = time.perf_counter()
278
+ if self.last_frame_time is not None:
279
+ frame_delta = current_time - self.last_frame_time
280
+ self.frame_times.append(frame_delta)
281
+ # Keep only last 30 frames for moving average
282
+ if len(self.frame_times) > 30:
283
+ self.frame_times.pop(0)
284
+ avg_frame_time = sum(self.frame_times) / len(self.frame_times)
285
+ achieved_fps = 1.0 / avg_frame_time if avg_frame_time > 0 else 0
286
+ else:
287
+ achieved_fps = 0
288
+ self.last_frame_time = current_time
289
+
290
+ socketio.emit('frame', {'frame': img_b64,
291
+ 'frame_index': frame_index,
292
+ 'action': action,
293
+ 'fps': achieved_fps})
294
+ frame_index += 1
295
+ except Exception as e:
296
+ print("Generation error:", repr(e))
297
+ socketio.emit('error', {'message': str(e)})
298
+ next_tick += self.frame_period
299
+ now = time.perf_counter()
300
+ sleep_for = next_tick - now
301
+ if sleep_for > 0:
302
+ time.sleep(sleep_for)
303
+
304
+ # --------------------------
305
+ # Routes
306
+ # --------------------------
307
+ @app.route('/')
308
+ def index():
309
+ return send_from_directory('static', 'index.html')
310
+
311
+ @app.errorhandler(500)
312
+ def handle_500(e):
313
+ """Handle WSGI errors gracefully"""
314
+ import traceback
315
+ print(f"Flask error handler caught: {e}")
316
+ traceback.print_exc()
317
+ return jsonify({'error': 'Internal server error'}), 500
318
+
319
+ @app.route('/api/health', methods=['GET'])
320
+ def health():
321
+ return jsonify({
322
+ 'status': 'ok',
323
+ 'ready': server_ready,
324
+ 'model_loaded': model is not None,
325
+ 'device': str(device) if device else None,
326
+ 'stream_running': stream_running,
327
+ 'target_fps': target_fps
328
+ })
329
+
330
+ @app.route('/api/generate', methods=['POST'])
331
+ def generate_frames():
332
+ try:
333
+ if not server_ready:
334
+ return jsonify({'success': False, 'error': 'Server not ready'}), 503
335
+
336
+ data = request.json or {}
337
+ actions_list = data.get('actions', [1])
338
+ n_steps = int(data.get('n_steps', 8))
339
+ cfg = float(data.get('cfg', 0))
340
+ clamp = bool(data.get('clamp', True))
341
+
342
+ if len(actions_list) == 0 or actions_list[0] != 0:
343
+ actions_list = [0] + actions_list
344
+
345
+ _reset_cache_fresh()
346
+
347
+ frames_png = []
348
+ with t.inference_mode(), t.autocast(device_type="cuda", dtype=t.bfloat16):
349
+ for a in actions_list:
350
+ with log_step_debug(action_tensor=action_buf, noise_tensor=noise_buf):
351
+ z = step_once(model, action_scalar_long=int(a), n_steps=n_steps, cfg=cfg, clamp=clamp)
352
+ f_btchw = pred2frame(z)
353
+ f_arr = f_btchw[0, 0].permute(1, 2, 0).contiguous()
354
+ if isinstance(f_arr, t.Tensor):
355
+ if f_arr.dtype != t.uint8:
356
+ f_arr = f_arr.to(t.uint8)
357
+ f_np = f_arr.to("cpu", non_blocking=True).numpy()
358
+ else:
359
+ f_np = f_arr.astype(np.uint8, copy=False)
360
+ frames_png.append(_png_base64_from_uint8(f_np))
361
+
362
+ return jsonify({'success': True, 'frames': frames_png, 'num_frames': len(frames_png)})
363
+
364
+ except Exception as e:
365
+ print("Batch generation error:", repr(e))
366
+ return jsonify({'success': False, 'error': str(e)}), 500
367
+
368
+ # --------------------------
369
+ # Socket events & helpers
370
+ # --------------------------
371
+ def start_stream(n_steps=8, cfg=0.0, fps=30, clamp=True):
372
+ global stream_thread, stream_running, frame_index, target_fps, latest_action
373
+ if not server_ready:
374
+ _broadcast_ready()
375
+ raise RuntimeError("Server not ready")
376
+ with stream_lock:
377
+ stop_stream()
378
+ target_fps = int(fps)
379
+ frame_index = 0
380
+ _reset_cache_fresh()
381
+ latest_action = 0 # first action = 0 (init)
382
+ stream_thread = FrameScheduler(fps=target_fps, n_steps=n_steps, cfg=cfg, clamp=clamp)
383
+ stream_running = True
384
+ stream_thread.start()
385
+
386
+ def stop_stream():
387
+ global stream_thread, stream_running
388
+ if stream_thread is not None:
389
+ stream_thread.stop()
390
+ stream_thread.join(timeout=1.0)
391
+ stream_thread = None
392
+ stream_running = False
393
+
394
+ @socketio.on_error_default
395
+ def default_error_handler(e):
396
+ print(f"SocketIO error: {e}")
397
+ import traceback
398
+ traceback.print_exc()
399
+
400
+ @socketio.on('connect')
401
+ def handle_connect():
402
+ try:
403
+ sid = request.sid
404
+ print(f'Client connected: {sid}')
405
+
406
+ # Immediately tell the new client current readiness
407
+ emit('server_status', {
408
+ 'ready': server_ready,
409
+ 'busy': False
410
+ })
411
+ emit('connected', {
412
+ 'status': 'connected',
413
+ 'model_loaded': model is not None,
414
+ 'ready': server_ready
415
+ })
416
+ except Exception as e:
417
+ print(f"Error in handle_connect: {e}")
418
+ import traceback
419
+ traceback.print_exc()
420
+
421
+ @socketio.on('disconnect')
422
+ def handle_disconnect(*args):
423
+ sid = request.sid
424
+ print(f'Client disconnected: {sid}')
425
+ # Note: We don't stop the stream on disconnect since multiple users can be connected
426
+
427
+ @socketio.on('start_stream')
428
+ def handle_start_stream(data):
429
+ try:
430
+ if not server_ready:
431
+ # Tell client to keep showing spinner
432
+ emit('server_status', {'ready': server_ready, 'busy': False})
433
+ return
434
+
435
+ n_steps = int(data.get('n_steps', 8))
436
+ cfg = float(data.get('cfg', 0))
437
+ fps = int(data.get('fps', 30))
438
+ clamp = bool(data.get('clamp', True))
439
+ print(f"Starting stream @ {fps} FPS (n_steps={n_steps}, cfg={cfg}, clamp={clamp})")
440
+ try:
441
+ start_stream(n_steps=n_steps, cfg=cfg, fps=fps, clamp=clamp)
442
+ emit('stream_started', {'status': 'ok'})
443
+ except Exception as e:
444
+ print(f"Error starting stream: {e}")
445
+ import traceback
446
+ traceback.print_exc()
447
+ emit('error', {'message': str(e)})
448
+ except Exception as e:
449
+ print(f"Error in handle_start_stream: {e}")
450
+ import traceback
451
+ traceback.print_exc()
452
+ emit('error', {'message': f'Failed to start stream: {str(e)}'})
453
+
454
+ @socketio.on('action')
455
+ def handle_action(data):
456
+ global latest_action
457
+ action = int(data.get('action', 1))
458
+ with stream_lock:
459
+ latest_action = action
460
+ emit('action_ack', {'received': action, 'will_apply_to_frame_index': frame_index})
461
+
462
+ @socketio.on('stop_stream')
463
+ def handle_stop_stream():
464
+ print('Stopping stream')
465
+ stop_stream()
466
+
467
+ # --------------------------
468
+ # Entrypoint
469
+ # --------------------------
470
+ if __name__ == '__main__':
471
+ # Start model initialization in background thread so server starts immediately
472
+ init_thread = threading.Thread(target=initialize_model, daemon=True)
473
+ init_thread.start()
474
+
475
+ # Use PORT environment variable for Hugging Face Spaces, default to 7860
476
+ port = int(os.environ.get('PORT', 7860))
477
+ print(f"Starting Flask server on http://0.0.0.0:{port}")
478
+ print("Model will load in background...")
479
+ socketio.run(app, host='0.0.0.0', port=port, debug=False, allow_unsafe_werkzeug=True, use_reloader=False)
480
+
checkpoints/ckpt-step=053700-metric=0.00092727.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3813cf639d5370bb90be4bc3974de5b6858a9cb4216458f757c0d415537d0d6
3
+ size 235359093
cleanup.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Cleanup script - removes temporary files and scripts from toy-wm-hf-space
3
+
4
+ set -e
5
+
6
+ echo "🧹 Cleaning up temporary files..."
7
+ echo ""
8
+
9
+ CLEANUP_DIR="/share/u/wendler/code/toy-wm-hf-space"
10
+
11
+ if [ ! -d "$CLEANUP_DIR" ]; then
12
+ echo "⚠️ Directory $CLEANUP_DIR doesn't exist, skipping cleanup"
13
+ exit 0
14
+ fi
15
+
16
+ cd "$CLEANUP_DIR"
17
+
18
+ echo "Removing temporary scripts..."
19
+ rm -f push.sh push-now.sh push-force.sh setup-git.sh setup.sh fix-remote.sh fix-and-push.sh push-to-hf.py 2>/dev/null || true
20
+
21
+ echo "Removing temporary documentation files..."
22
+ rm -f RUN_SETUP.md TROUBLESHOOTING.md SETUP_STEPS.md SETUP_GUIDE.md QUICKSTART.md START_HERE.md 2>/dev/null || true
23
+
24
+ echo "Keeping essential files:"
25
+ echo " ✅ app.py"
26
+ echo " ✅ Dockerfile"
27
+ echo " ✅ requirements.txt"
28
+ echo " ✅ README.md"
29
+ echo " ✅ DEPLOYMENT.md"
30
+ echo " ✅ SOURCE_FILES.md"
31
+ echo " ✅ .gitignore"
32
+ echo " ✅ All source code and checkpoints"
33
+
34
+ echo ""
35
+ echo "✅ Cleanup complete!"
36
+ echo ""
37
+ echo "The toy-wm-hf-space directory still contains your files"
38
+ echo "but temporary scripts have been removed."
39
+
40
+
configs/inference.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ model_id: "dit_dforce"
3
+ width: 24
4
+ height: 24
5
+ T: 1000
6
+ in_channels: 3
7
+ n_window: 30
8
+ patch_size: 3
9
+ n_heads: 12
10
+ d_model: 384
11
+ n_blocks: 8
12
+ C: 5000
13
+ bidirectional: false
14
+ nocompile: false
15
+ checkpoint: "checkpoints/ckpt-step=053700-metric=0.00092727.pt"
16
+ # "experiments/dulcet-disco-547/ckpt-step=000800-metric=0.00384521.pt"
17
+ #"experiments/radiant-forest-398/ckpt-step=053700-metric=0.00092727.pt"
18
+ # "experiments/polished-paper-531/model.pt"
19
+ #"experiments/polished-paper-531/ckpt-step=000200-metric=0.00251065.pt"
20
+ #"experiments/polished-paper-531/ckpt-step=000800-metric=0.00450636.pt"
21
+ #checkpoint: "experiments/radiant-forest-398/ckpt-step=053700-metric=0.00092727.pt"
22
+ # few frame 2-step distilled model experiments/smart-waterfall-528/model.pt
23
+ # few frame 1-step distilled model experiments/blooming-flower-530/model.pt
24
+ #"experiments/rich-meadow-488/model.pt"
25
+ # "experiments/rich-meadow-488/ckpt-step=003700-metric=0.00309512.pt"
26
+ #checkpoint: "experiments/radiant-forest-398/ckpt-step=053700-metric=0.00092727.pt"
27
+ #checkpoint: "experiments/glad-water-486/model.pt"
28
+ #checkpoint: "experiments/dutiful-river-427/ckpt-step=011600-metric=0.00229805.pt"
29
+ #checkpoint: "experiments/iconic-paper-421/ckpt-step=001600-metric=0.00471355.pt"
30
+ #"experiments/radiant-forest-398/ckpt-step=053700-metric=0.00092727.pt"
31
+ #checkpoint: "experiments/frosty-sunset-395/ckpt-step=002100-metric=0.00160125.pt"
32
+ #checkpoint: "experiments/vivid-sea-390/ckpt-step=000700-metric=0.01958773.pt"
33
+
34
+ train:
35
+ lr1: 0.0002
36
+ lr2: 1.5e-6
37
+ betas: [0.9, 0.95]
38
+ weight_decay: 1.0e-5
39
+ max_steps: 20000
40
+ batch_size: 16
41
+ noclip: false
42
+ duration: 1
43
+ fps: 31
44
+ debug: false
45
+ p_pretrain: 0.95
46
+
47
+ wandb:
48
+ name: null
49
+ project: "toy-wm"
50
+ run_name: "causal-layers8-heads12-d384"
push-and-cleanup.sh ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Complete script: Push from pong directory and clean up
3
+
4
+ set -e
5
+
6
+ echo "🚀 Neural Pong - Complete Push and Cleanup"
7
+ echo "==========================================="
8
+ echo ""
9
+
10
+ # Step 1: Push from pong directory
11
+ echo "Step 1: Pushing from /share/u/wendler/code/pong"
12
+ echo "-----------------------------------------------"
13
+ cd /share/u/wendler/code/pong
14
+
15
+ if [ ! -d ".git" ]; then
16
+ echo "❌ Error: Not a git repository in pong directory"
17
+ exit 1
18
+ fi
19
+
20
+ # Stage and commit
21
+ git add .
22
+ if ! git diff --cached --quiet; then
23
+ git commit -m "Add Neural Pong application files" || git commit --amend --no-edit
24
+ fi
25
+
26
+ # Push
27
+ BRANCH=$(git branch --show-current 2>/dev/null || echo "main")
28
+ echo "Pushing to origin/$BRANCH..."
29
+ if ! git push -u origin $BRANCH 2>&1; then
30
+ echo "Trying force push..."
31
+ git push -u origin $BRANCH --force
32
+ fi
33
+
34
+ echo ""
35
+ echo "✅ Successfully pushed to Hugging Face Spaces!"
36
+ echo ""
37
+
38
+ # Step 2: Cleanup temporary files
39
+ echo "Step 2: Cleaning up temporary files"
40
+ echo "-----------------------------------"
41
+ CLEANUP_DIR="/share/u/wendler/code/toy-wm-hf-space"
42
+
43
+ if [ -d "$CLEANUP_DIR" ]; then
44
+ cd "$CLEANUP_DIR"
45
+
46
+ echo "Removing temporary scripts..."
47
+ rm -f push.sh push-now.sh push-force.sh setup-git.sh setup.sh \
48
+ fix-remote.sh fix-and-push.sh push-to-hf.py 2>/dev/null || true
49
+
50
+ echo "Removing temporary documentation..."
51
+ rm -f RUN_SETUP.md TROUBLESHOOTING.md SETUP_STEPS.md \
52
+ SETUP_GUIDE.md QUICKSTART.md START_HERE.md 2>/dev/null || true
53
+
54
+ echo "✅ Cleanup complete!"
55
+ else
56
+ echo "⚠️ Cleanup directory not found, skipping"
57
+ fi
58
+
59
+ echo ""
60
+ echo "==========================================="
61
+ echo "✅ All done!"
62
+ echo "==========================================="
63
+ echo ""
64
+ echo "🌐 Your Space: https://huggingface.co/spaces/wendlerc/pong"
65
+ echo "📁 Working directory: /share/u/wendler/code/pong"
66
+ echo ""
67
+ echo "The build should start automatically. Check the Logs tab for progress."
68
+
69
+
push.sh ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Push script for /share/u/wendler/code/pong
3
+ # This will push all files to Hugging Face Spaces
4
+
5
+ set -e
6
+
7
+ cd /share/u/wendler/code/pong
8
+
9
+ echo "🚀 Pushing Neural Pong to Hugging Face Spaces..."
10
+ echo ""
11
+
12
+ # Check if we're in a git repo
13
+ if [ ! -d ".git" ]; then
14
+ echo "❌ Error: Not a git repository"
15
+ exit 1
16
+ fi
17
+
18
+ # Stage all files
19
+ echo "📁 Staging files..."
20
+ git add .
21
+
22
+ # Check status
23
+ echo ""
24
+ echo "📋 Files to be committed:"
25
+ git status --short | head -20
26
+
27
+ # Commit changes
28
+ echo ""
29
+ if git diff --cached --quiet; then
30
+ echo "✅ No changes to commit"
31
+ else
32
+ echo "💾 Committing changes..."
33
+ git commit -m "Add Neural Pong application files" || git commit --amend --no-edit
34
+ echo "✅ Changes committed"
35
+ fi
36
+
37
+ # Check remote
38
+ echo ""
39
+ echo "🔗 Checking remote..."
40
+ git remote -v
41
+
42
+ # Check branch
43
+ BRANCH=$(git branch --show-current 2>/dev/null || echo "main")
44
+ echo "🌿 Current branch: $BRANCH"
45
+
46
+ # Push
47
+ echo ""
48
+ echo "📤 Pushing to Hugging Face Spaces..."
49
+ if git push -u origin $BRANCH 2>&1; then
50
+ echo ""
51
+ echo "✅ Successfully pushed!"
52
+ else
53
+ echo ""
54
+ echo "⚠️ Push failed, trying force push..."
55
+ git push -u origin $BRANCH --force
56
+ echo ""
57
+ echo "✅ Force pushed successfully!"
58
+ fi
59
+
60
+ echo ""
61
+ echo "🌐 Your Space is available at:"
62
+ echo " https://huggingface.co/spaces/wendlerc/pong"
63
+ echo ""
64
+ echo "The build should start automatically. Check the Logs tab for progress."
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML framework
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+
5
+ # Web framework
6
+ flask>=3.1.0
7
+ flask-cors>=6.0.0
8
+ flask-socketio>=5.5.0
9
+ eventlet>=0.40.0
10
+
11
+ # Data processing
12
+ numpy>=1.24.0
13
+ pillow>=10.0.0
14
+ einops>=0.7.0
15
+
16
+ # Configuration
17
+ pyyaml>=6.0
18
+ omegaconf>=2.3.0
19
+
20
+ # Type hints
21
+ jaxtyping>=0.2.0
22
+
23
+ # Hugging Face Hub (for model loading if needed)
24
+ huggingface-hub>=0.20.0
25
+
setup.sh ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Quick setup script for Hugging Face Space deployment
3
+
4
+ set -e
5
+
6
+ echo "🚀 Neural Pong - Hugging Face Space Setup"
7
+ echo "=========================================="
8
+ echo ""
9
+
10
+ # Check if we're in the right directory
11
+ if [ ! -f "app.py" ] || [ ! -f "Dockerfile" ]; then
12
+ echo "❌ Error: Please run this script from the toy-wm-hf-space directory"
13
+ exit 1
14
+ fi
15
+
16
+ # Check if checkpoint exists
17
+ if [ ! -f "checkpoints/ckpt-step=053700-metric=0.00092727.pt" ]; then
18
+ echo "❌ Error: Checkpoint file not found!"
19
+ echo " Expected: checkpoints/ckpt-step=053700-metric=0.00092727.pt"
20
+ exit 1
21
+ fi
22
+
23
+ echo "✅ Checkpoint file found"
24
+ echo "✅ All required files present"
25
+ echo ""
26
+
27
+ # Check if git is initialized
28
+ if [ ! -d ".git" ]; then
29
+ echo "📦 Initializing git repository..."
30
+ git init
31
+ echo "✅ Git initialized"
32
+ else
33
+ echo "✅ Git repository already initialized"
34
+ fi
35
+
36
+ echo ""
37
+ echo "📋 Next steps:"
38
+ echo ""
39
+ echo "1. Create a Hugging Face Space:"
40
+ echo " - Go to https://huggingface.co/spaces"
41
+ echo " - Click 'Create new Space'"
42
+ echo " - Name: neural-pong (or your choice)"
43
+ echo " - SDK: Docker"
44
+ echo " - Hardware: GPU (T4 small or larger)"
45
+ echo ""
46
+ echo "2. Add the remote and push:"
47
+ echo " git remote add origin https://huggingface.co/spaces/YOUR_USERNAME/YOUR_SPACE_NAME"
48
+ echo " git add ."
49
+ echo " git commit -m 'Initial commit'"
50
+ echo " git push -u origin main"
51
+ echo ""
52
+ echo "3. Wait for build (5-15 minutes)"
53
+ echo ""
54
+ echo "📖 For detailed instructions, see SETUP_GUIDE.md"
55
+ echo ""
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (162 Bytes). View file
 
src/config.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Optional
3
+ import yaml
4
+ from omegaconf import OmegaConf
5
+
6
+ @dataclass
7
+ class TransformerConfig:
8
+ model_id : str = None
9
+ width : int = 24
10
+ height : int = 24
11
+ T : int = 1000
12
+ in_channels : int = 3
13
+ n_window : int = 7
14
+ patch_size : int = 2
15
+ n_heads : int = 4
16
+ d_model : int = 64
17
+ n_blocks : int = 12
18
+ n_heads : int = 12
19
+ d_model : int = 384
20
+ patch_size : int = 1
21
+ bidirectional : bool = True
22
+ nocompile : bool = False
23
+ checkpoint : str = None
24
+
25
+
26
+ @dataclass
27
+ class TrainingConfig:
28
+ lr1 : float = 0.002
29
+ lr2 : float = 3e-5
30
+ betas : tuple = (0.9, 0.95)
31
+ weight_decay : float = 1e-5
32
+ max_steps : int = 26000
33
+ batch_size : int = 32
34
+ noclip : bool = False
35
+ duration : int = 1
36
+ fps : int = 7
37
+ in_channels : int = 3
38
+ debug : bool = False
39
+
40
+
41
+ @dataclass
42
+ class WANDBConfig:
43
+ name : str = "toy-wm"
44
+ project : str = None
45
+ run_name : str = None
46
+
47
+ @dataclass
48
+ class Config:
49
+ model: TransformerConfig
50
+ train: TrainingConfig
51
+ wandb: WANDBConfig
52
+
53
+ @classmethod
54
+ def from_yaml(cls, path):
55
+ with open(path) as f:
56
+ raw_cfg = yaml.safe_load(f)
57
+
58
+ cfg = OmegaConf.create(raw_cfg)
59
+ return OmegaConf.structured(cls(**cfg))
src/datasets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Datasets module
2
+
src/datasets/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (171 Bytes). View file
 
src/datasets/__pycache__/pong1m.cpython-311.pyc ADDED
Binary file (4.54 kB). View file
 
src/datasets/pong1m.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import TensorDataset, DataLoader
2
+ from torch import nn
3
+ import torch as t
4
+ import numpy as np
5
+ from einops import rearrange
6
+
7
+ mean = t.tensor([[[[[0.0352]],
8
+ [[0.1046]],
9
+ [[0.1046]]]]])
10
+ std = t.tensor([[[[[0.1066]],
11
+ [[0.0995]],
12
+ [[0.0995]]]]])
13
+
14
+ def fixed2frame(y, lam=1e-6):
15
+ y = y.clamp(-1, 1) * 0.5 + 0.5
16
+ frames = (y * 255.0).round().byte()
17
+ return frames
18
+
19
+ def z2frame(y, lam=1e-6, mean=mean, std=std):
20
+ y = y*std.to(y.dtype).to(y.device) + mean.to(y.dtype).to(y.device)
21
+ frames = (y.clamp(0, 1) * 255.0).round().byte()
22
+ return frames
23
+
24
+ def get_loader(batch_size=64, fps=30, duration=5, shuffle=True, debug=False, mode="-1,1", mean=mean, std=std, drop_duration=False):
25
+ frames = t.from_numpy(np.load("./datasets/pong1M/frames.npy"))
26
+ actions = t.from_numpy(np.load("./datasets/pong1M/actions.npy"))
27
+ height, width, channels = frames.shape[-3:]
28
+ n = frames.shape[0]//(fps*duration)
29
+ frames = frames[:n*fps*duration]
30
+ frames = frames.reshape(n, fps*duration, height, width, channels)
31
+ frames = frames.permute(0, 1, 4, 2, 3)
32
+ actions = actions[:n*fps*duration]
33
+ actions = actions.reshape(-1, fps*duration)
34
+ b, dur, c, h, w = frames.shape
35
+ if mode == "-1,1":
36
+ z = rearrange(frames, "b dur c h w -> (b dur h w) c")
37
+ mask = (z == t.tensor([6, 24, 24], dtype=z.dtype)).all(dim=1)
38
+ z = (z.float()/255.0 - 0.5)*2
39
+ z[mask] = 0
40
+ z = rearrange(z, "(b dur h w) c -> b dur c h w", b=b, dur=dur, c=c, h=h, w=w)
41
+ frames = z
42
+ pred2frame = fixed2frame
43
+ elif mode == "z":
44
+ frames = frames.float()/255.0
45
+ frames = (frames - mean) / (std + 1e-6)
46
+ pred2frame = z2frame
47
+ else:
48
+ raise ValueError(f"Invalid mode: {mode}")
49
+
50
+ firstf = frames[0]
51
+ firsta = actions[0]
52
+ if debug:
53
+ frames = 0*frames + firstf[None]
54
+ actions = 0*actions + firsta[None]
55
+ frames = 0*frames + frames[:,0].unsqueeze(1)
56
+ if drop_duration:
57
+ dataset = TensorDataset(frames[:, 0], actions[:,0]*0)
58
+ else:
59
+ dataset = TensorDataset(frames, actions)
60
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
61
+ print(f"{frames.shape[0]//batch_size} batches")
62
+ return loader, pred2frame
src/inference/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sampling import sample, sample_with_grad
src/inference/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (258 Bytes). View file
 
src/inference/__pycache__/sampling.cpython-311.pyc ADDED
Binary file (2.13 kB). View file
 
src/inference/sampling.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as t
2
+
3
+ @t.no_grad()
4
+ def sample(v, z, actions, num_steps=10, cfg=0, negative_actions=None):
5
+ return sample_with_grad(v, z, actions, num_steps, cfg, negative_actions)
6
+
7
+ def sample_with_grad(v, z, actions, num_steps=10, cfg=0, negative_actions=None):
8
+ device = v.device
9
+ ts = 1 - t.linspace(0, 1, num_steps+1, device=device)
10
+ ts = 3*ts/(2*ts + 1)
11
+ z_prev = z.clone()
12
+ z_prev = z_prev.to(device)
13
+ for i in range(len(ts)-1):
14
+ t_cond = ts[i].repeat(z_prev.shape[0], 1)
15
+ v_pred = v(z_prev.to(device), actions.to(device), t_cond.to(device))
16
+ if cfg > 0:
17
+ if negative_actions is not None:
18
+ v_neg = v(z_prev.to(device), negative_actions.to(device), t_cond.to(device))
19
+ else:
20
+ v_neg = v(z_prev.to(device), t.zeros_like(actions, dtype=t.long, device=device), t_cond.to(device))
21
+ v_pred = v_neg + cfg * (v_pred - v_neg)
22
+ z_prev = z_prev + (ts[i] - ts[i+1])*v_pred
23
+ return z_prev
src/models/__init__.py ADDED
File without changes
src/models/dit_dforce.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as t
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ from ..nn.attn import Attention, AttentionEinOps, KVCache
6
+ from ..nn.patch import Patch, UnPatch
7
+ from ..nn.geglu import GEGLU
8
+ from ..nn.pe import FrameRoPE, NumericEncoding, RoPE
9
+ from jaxtyping import Float, Bool, Int
10
+ from torch import Tensor
11
+ from typing import Optional
12
+
13
+ import math
14
+
15
+ def modulate(x, shift, scale):
16
+ return x * (1 + scale) + shift
17
+
18
+ class CausalBlock(nn.Module):
19
+ def __init__(self, layer_idx, d_model, expansion, n_heads, rope=None):
20
+ super().__init__()
21
+ self.layer_idx = layer_idx
22
+ self.d_model = d_model
23
+ self.expansion = expansion
24
+ self.n_heads = n_heads
25
+ self.norm1 = nn.LayerNorm(d_model)
26
+ if t.backends.mps.is_available():
27
+ self.selfattn = AttentionEinOps(d_model, n_heads, rope=rope)
28
+ else:
29
+ self.selfattn = AttentionEinOps(d_model, n_heads, rope=rope) # there is a problem with flexattn i think
30
+ self.norm2 = nn.LayerNorm(d_model)
31
+ self.geglu = GEGLU(d_model, expansion*d_model, d_model)
32
+
33
+ self.modulation = nn.Sequential(
34
+ nn.SiLU(),
35
+ nn.Linear(d_model, 6 * d_model, bias=True),
36
+ )
37
+
38
+ def forward(self, z, cond, mask_self, cache: Optional[KVCache] = None):
39
+ # batch durseq1 d
40
+ # batch durseq2 d
41
+ mu1, sigma1, c1, mu2, sigma2, c2 = self.modulation(cond).chunk(6, dim=-1)
42
+ residual = z
43
+ z = modulate(self.norm1(z), mu1, sigma1)
44
+ if cache is not None:
45
+ k, v = cache.get(self.layer_idx)
46
+ offset = cache.global_location # this enables to include rope and ln into the cache
47
+ offset = 0 # this is for reapplying rope again and again to stay more similar to training
48
+ z, k_new, v_new = self.selfattn(z, z, mask=mask_self, k_cache=k, v_cache=v, offset=offset)
49
+ cache.extend(self.layer_idx, k_new, v_new)
50
+ else:
51
+ z, _, _ = self.selfattn(z, z, mask=mask_self)
52
+
53
+ z = residual + c1*z
54
+
55
+ residual = z
56
+ z = modulate(self.norm2(z), mu2, sigma2)
57
+ z = self.geglu(z)
58
+ z = residual + c2*z
59
+ return z
60
+
61
+
62
+ class CausalDit(nn.Module):
63
+ def __init__(self, height, width, n_window, d_model, T=1000, in_channels=3,
64
+ patch_size=2, n_heads=8, expansion=4, n_blocks=6,
65
+ n_registers=1, n_actions=4, bidirectional=False,
66
+ debug=False,
67
+ legacy=False,
68
+ frame_rope=False,
69
+ rope_C=10000,
70
+ rope_tmax=None):
71
+ super().__init__()
72
+ self.height = height
73
+ self.width = width
74
+ self.n_window = n_window
75
+ self.d_model = d_model
76
+ self.n_heads = n_heads
77
+ self.d_head = self.d_model // self.n_heads
78
+ self.n_blocks = n_blocks
79
+ self.expansion = expansion
80
+ self.n_registers = n_registers
81
+ self.T = T
82
+ self.patch_size = patch_size
83
+ self.debug = debug
84
+ self.legacy = legacy
85
+ self.bidirectional = bidirectional
86
+ self.frame_rope = frame_rope
87
+ self.toks_per_frame = (height//patch_size)*(width//patch_size) + n_registers
88
+ self.rope_C = rope_C
89
+ if frame_rope:
90
+ print("Using frame rope")
91
+ print(self.toks_per_frame)
92
+ self.rope_seq = FrameRoPE(d_model//n_heads, self.n_window, self.toks_per_frame, C=rope_C)
93
+ self.grid_pe = nn.Parameter(t.randn(self.toks_per_frame - n_registers, d_model) * 1/d_model**0.5)
94
+ else:
95
+ if rope_tmax is None:
96
+ rope_tmax = self.n_window*self.toks_per_frame
97
+ self.rope_seq = RoPE(d_model//n_heads, rope_tmax, C=rope_C)
98
+ self.grid_pe = None
99
+ self.rope_tmax = rope_tmax
100
+
101
+ self.blocks = nn.ModuleList([CausalBlock(lidx, d_model, expansion, n_heads, rope=self.rope_seq) for lidx in range(n_blocks)])
102
+ self.patch = Patch(in_channels=in_channels, out_channels=d_model, patch_size=patch_size)
103
+ self.norm = nn.LayerNorm(d_model)
104
+ self.unpatch = UnPatch(height, width, in_channels=d_model, out_channels=in_channels, patch_size=patch_size)
105
+ self.action_emb = nn.Embedding(n_actions, d_model)
106
+ self.registers = nn.Parameter(t.randn(n_registers, d_model) * 1/d_model**0.5)
107
+ self.time_emb = NumericEncoding(dim=d_model, n_max=T)
108
+ self.time_emb_mixer = nn.Linear(d_model, d_model)
109
+ self.modulation = nn.Sequential(
110
+ nn.SiLU(),
111
+ nn.Linear(d_model, 2 * d_model, bias=True),
112
+ )
113
+ self.cache = None
114
+
115
+ def activate_caching(self, batch_size, max_frames=None, cache_rope=False):
116
+ self.cache = KVCache(batch_size, self.n_blocks, self.n_heads, self.d_head, self.toks_per_frame, self.n_window, dtype=self.dtype, device=self.device)
117
+ if max_frames is not None:
118
+ self.rope_seq = RoPE(self.d_head, max_frames*self.toks_per_frame, C=self.rope_C)
119
+ print(self.rope_seq.sins.shape, self.rope_seq.coss.shape)
120
+ self.rope_seq.to(self.device)
121
+ self.rope_seq.to(self.dtype)
122
+ for idx, block in enumerate(self.blocks):
123
+ print("updating rope for block", idx)
124
+ print(self.blocks[idx].selfattn.rope.sins.shape, self.blocks[idx].selfattn.rope.coss.shape)
125
+ self.blocks[idx].selfattn.rope = self.rope_seq
126
+ print(self.blocks[idx].selfattn.rope.sins.shape, self.blocks[idx].selfattn.rope.coss.shape)
127
+ def deactivate_caching(self):
128
+ self.cache = None
129
+
130
+ def forward(self,
131
+ z: Float[Tensor, "batch dur channels height width"],
132
+ actions: Float[Tensor, "batch dur"],
133
+ ts: Int[Tensor, "batch dur"]):
134
+
135
+ if ts.shape[1] == 1:
136
+ ts = ts.repeat(1, z.shape[1])
137
+
138
+ a = self.action_emb(actions) # batch dur d
139
+ ts_scaled = (ts * self.T).clamp(0, self.T - 1).long()
140
+ cond = self.time_emb_mixer(self.time_emb(ts_scaled)) + a
141
+ #print(ts_scaled.shape, a.shape, cond.shape, actions.shape)
142
+ cond = cond.repeat_interleave(self.toks_per_frame, dim=1)
143
+ z = self.patch(z) # batch dur seq d
144
+ if self.grid_pe is not None:
145
+ z = z + self.grid_pe[None, None]
146
+ # self.registers is in 1x
147
+ zr = t.cat((z, self.registers[None, None].repeat([z.shape[0], z.shape[1], 1, 1])), dim=2)# z plus registers
148
+ if self.bidirectional:
149
+ mask_self = None
150
+ else:
151
+ mask_self = self.causal_mask
152
+ batch, durzr, seqzr, d = zr.shape
153
+ zr = zr.reshape(batch, -1, d) # batch durseq d
154
+
155
+ for block in self.blocks:
156
+ zr = block(zr, cond, mask_self, cache=self.cache)
157
+ mu, sigma = self.modulation(cond).chunk(2, dim=-1)
158
+ zr = modulate(self.norm(zr), mu, sigma)
159
+ zr = zr.reshape(batch, durzr, seqzr, d)
160
+ out = self.unpatch(zr[:, :, :-self.n_registers])
161
+ return out # batch dur channels height width
162
+
163
+ @property
164
+ def causal_mask(self):
165
+ size = self.n_window
166
+ m_self = t.tril(t.ones((size, size), dtype=t.int8, device=self.device)) #- t.tril(t.ones((size, size), dtype=t.int8, device=self.device), diagonal=-self.n_window)
167
+ m_self = t.kron(m_self, t.ones((self.toks_per_frame, self.toks_per_frame), dtype=t.int8, device=self.device))
168
+ m_self = m_self.to(bool)
169
+ return ~ m_self # we want to mask out the ones
170
+
171
+ @property
172
+ def device(self):
173
+ return self.parameters().__next__().device
174
+
175
+ @property
176
+ def dtype(self):
177
+ return self.parameters().__next__().dtype
178
+
179
+
180
+ def get_model(height, width, n_window=5, d_model=64, T=100, n_blocks=2, patch_size=2, n_heads=8, bidirectional=False, in_channels=3, frame_rope=False, C=10000):
181
+ return CausalDit(height, width, n_window, d_model, T, in_channels=in_channels, n_blocks=n_blocks, patch_size=patch_size, n_heads=n_heads, bidirectional=bidirectional, frame_rope=frame_rope, rope_C=C)
182
+
183
+ if __name__ == "__main__":
184
+ print("running w/o cache")
185
+ dit = CausalDit(20, 20, 100, 64, 5, n_blocks=2)
186
+ z = t.rand((2, 6, 3, 20, 20))
187
+ actions = t.randint(4, (2, 6))
188
+ ts = t.rand((2, 6))
189
+ out = dit(z, actions, ts)
190
+ print(z.shape)
191
+ print(out.shape)
192
+
193
+ print("running w cache")
194
+ dit = CausalDit(20, 20, 10, 64, 5, n_blocks=2)
195
+ dit.activate_caching(2)
196
+ print(dit.cache.toks_per_frame)
197
+ print(dit.cache.size)
198
+ for i in range(30):
199
+ print(dit.cache.local_loc)
200
+ print(dit.cache.global_loc)
201
+ z = t.rand((2, 1, 3, 20, 20))
202
+ actions = t.randint(4, (2, 1))
203
+ ts = t.rand((2, 1))
204
+ out = dit(z, actions, ts)
205
+ print(i, z.shape)
206
+ print(i, out.shape)
src/nn/__init__.py ADDED
File without changes
src/nn/attn.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from torch.nn import functional as F
3
+ import torch as t
4
+ import einops
5
+ from jaxtyping import Float, Bool
6
+ from torch import Tensor
7
+ from typing import Optional
8
+ from torch.nn.attention.flex_attention import flex_attention
9
+
10
+
11
+ class KVCache(nn.Module):
12
+ """
13
+ Rolling KV cache implemented as a ring buffer.
14
+ - Shapes:
15
+ keys/values per extend(): (batch_size, T, n_heads, d_head)
16
+ - Internal storage:
17
+ (n_layers, batch_size, size, n_heads, d_head) where size = toks_per_frame * n_window
18
+ - Semantics:
19
+ Call `extend(layer_idx, k, v)` once per layer for the *same* frame.
20
+ Call `update_global_location(n_frames)` once after all layers to commit the frame(s).
21
+ """
22
+ def __init__(self, batch_size, n_layers, n_heads, d_head, toks_per_frame, n_window, *, dtype=None, device=None, enforce_layer_order=True):
23
+ super().__init__()
24
+ self.batch_size = batch_size
25
+ self.n_layers = n_layers
26
+ self.n_heads = n_heads
27
+ self.d_head = d_head
28
+ self.toks_per_frame = toks_per_frame
29
+ self.n_window = n_window
30
+ self.size = (toks_per_frame * n_window) #toks_per_frame # (toks_per_frame * n_window)
31
+
32
+ # Pointers / counters
33
+ self.curr_layer = 0 # which layer are we writing for this frame
34
+ self.global_loc = 0 # total tokens ever committed
35
+ self.local_loc = 0 # valid tokens in buffer (<= size)
36
+ self._write_ptr = 0 # ring-buffer write pointer (index of next commit position)
37
+
38
+ # Storage
39
+ dtype = dtype if dtype is not None else t.float32
40
+ self.register_buffer('keys', t.zeros(n_layers, batch_size, self.size, n_heads, d_head, dtype=dtype, device=device))
41
+ self.register_buffer('values', t.zeros(n_layers, batch_size, self.size, n_heads, d_head, dtype=dtype, device=device))
42
+
43
+ # Misc
44
+ self.enforce_layer_order = enforce_layer_order
45
+
46
+ # -------------- Public API --------------
47
+ def get(self, layer_idx):
48
+ """Return (K, V) for given layer in chronological order: shape (B, L, H, D) where L = local_loc."""
49
+ self._check_layer(layer_idx)
50
+ if self.local_loc == 0:
51
+ # return empty views
52
+ empty = self.keys[layer_idx, :, :0]
53
+ return empty, empty
54
+
55
+ start = (self._write_ptr - self.local_loc) % self.size
56
+ if start + self.local_loc <= self.size:
57
+ # contiguous slice
58
+ k = self.keys[layer_idx, :, start:start + self.local_loc]
59
+ v = self.values[layer_idx, :, start:start + self.local_loc]
60
+ else:
61
+ # wrap: concatenate two slices to maintain chronological order
62
+ first = self.size - start
63
+ k = t.cat([
64
+ self.keys[layer_idx, :, start:self.size],
65
+ self.keys[layer_idx, :, 0:(self.local_loc - first)]
66
+ ], dim=1)
67
+ v = t.cat([
68
+ self.values[layer_idx, :, start:self.size],
69
+ self.values[layer_idx, :, 0:(self.local_loc - first)]
70
+ ], dim=1)
71
+ return k, v
72
+
73
+ @t.no_grad()
74
+ def extend(self, layer_idx, keys, values):
75
+ """
76
+ Stage (but do not commit) tokens for the current frame for the given layer.
77
+ Call update_global_location(n_frames) to commit after all layers wrote.
78
+ """
79
+ assert keys.shape == values.shape, f"keys and values shapes must match, got {keys.shape} vs {values.shape}"
80
+ self._check_layer(layer_idx)
81
+
82
+ # Expected shape: (B, T, H, D)
83
+ B, T, H, D = keys.shape
84
+ assert B == self.batch_size, f"batch mismatch: expected {self.batch_size}, got {B}"
85
+ assert H == self.n_heads and D == self.d_head, f"heads/d_head mismatch: expected {(self.n_heads, self.d_head)}, got {(H, D)}"
86
+ assert T > 0 and T <= self.size, f"T must be in 1..{self.size}, got {T}"
87
+ # Optional: if you only ever append whole frames:
88
+ # assert T == self.toks_per_frame, f"T must equal toks_per_frame ({self.toks_per_frame}), got {T}"
89
+
90
+ # Cast to buffer dtype/device if needed
91
+ if keys.dtype != self.keys.dtype or keys.device != self.keys.device:
92
+ keys = keys.to(dtype=self.keys.dtype, device=self.keys.device)
93
+ if values.dtype != self.values.dtype or values.device != self.values.device:
94
+ values = values.to(dtype=self.values.dtype, device=self.values.device)
95
+
96
+ # Write into the ring at the *current* write_ptr (uncommitted until update_global_location)
97
+ i0 = self._write_ptr
98
+ i1 = (self._write_ptr + T) % self.size
99
+ if i0 < i1:
100
+ self.keys[layer_idx, :, i0:i1] = keys
101
+ self.values[layer_idx, :, i0:i1] = values
102
+ else:
103
+ # wraps: split write
104
+ split = self.size - i0
105
+ self.keys[layer_idx, :, i0:self.size] = keys[:, :split]
106
+ self.values[layer_idx, :, i0:self.size] = values[:, :split]
107
+ self.keys[layer_idx, :, 0:i1] = keys[:, split:]
108
+ self.values[layer_idx, :, 0:i1] = values[:, split:]
109
+
110
+ # Advance expected layer (but do *not* advance write_ptr/local_len here)
111
+ self.curr_layer = (self.curr_layer + 1) % self.n_layers
112
+
113
+ @t.no_grad()
114
+ def update_global_location(self, n_frames):
115
+ """
116
+ Commit staged writes for n_frames (advances the ring write pointer once per frame).
117
+ Keep calling extend(layer_idx, ...) for each layer before you call this.
118
+ """
119
+ assert n_frames >= 0, f"n_frames must be >= 0, got {n_frames}"
120
+ tokens = n_frames * self.toks_per_frame
121
+ if tokens == 0:
122
+ return
123
+ assert tokens <= self.size, f"Cannot commit {tokens} tokens (> buffer size {self.size})."
124
+
125
+ self.global_loc += tokens
126
+ # Update valid length (never exceeds capacity)
127
+ self.local_loc = min(self.size, self.local_loc + tokens)
128
+ # Advance write pointer
129
+ self._write_ptr = (self._write_ptr + tokens) % self.size
130
+
131
+ @t.no_grad()
132
+ def reset(self, zero_memory: bool = True):
133
+ self.global_loc = 0
134
+ self.local_loc = 0
135
+ self.curr_layer = 0
136
+ self._write_ptr = 0
137
+ if zero_memory:
138
+ self.keys.zero_()
139
+ self.values.zero_()
140
+
141
+ # -------------- Convenience / Introspection --------------
142
+ @property
143
+ def local_location(self):
144
+ return self.local_loc
145
+
146
+ @property
147
+ def global_location(self):
148
+ return self.global_loc
149
+
150
+ @property
151
+ def device(self):
152
+ return self.keys.device
153
+
154
+ @property
155
+ def dtype(self):
156
+ return self.keys.dtype
157
+
158
+ def get_recent(self, layer_idx, last_T):
159
+ """Return the most recent last_T tokens for a layer (chronological)."""
160
+ self._check_layer(layer_idx, allow_any=True)
161
+ last_T = min(last_T, self.local_loc)
162
+ if last_T == 0:
163
+ empty = self.keys[layer_idx, :, :0]
164
+ return empty, empty
165
+ start = (self._write_ptr - last_T) % self.size
166
+ if start + last_T <= self.size:
167
+ k = self.keys[layer_idx, :, start:start + last_T]
168
+ v = self.values[layer_idx, :, start:start + last_T]
169
+ else:
170
+ first = self.size - start
171
+ k = t.cat([self.keys[layer_idx, :, start:self.size], self.keys[layer_idx, :, 0:(last_T - first)]], dim=1)
172
+ v = t.cat([self.values[layer_idx, :, start:self.size], self.values[layer_idx, :, 0:(last_T - first)]], dim=1)
173
+ return k, v
174
+
175
+ # -------------- Internal checks --------------
176
+ def _check_layer(self, layer_idx, allow_any=False):
177
+ assert 0 <= layer_idx < self.n_layers, f"layer_idx out of range: 0..{self.n_layers-1}, got {layer_idx}"
178
+ if self.enforce_layer_order and not allow_any:
179
+ assert layer_idx == (self.curr_layer % self.n_layers), \
180
+ f"Layer order mismatch: expected {self.curr_layer % self.n_layers}, got {layer_idx}"
181
+
182
+
183
+ class KVCacheMine(nn.Module): # this does not work because it destroys the cache of later timesteps when the earlier ones overflow and move to the left. --> fix as an exercise.
184
+ def __init__(self, batch_size, n_layers, n_heads, d_head, toks_per_frame, n_window):
185
+ """
186
+ This is a rolling KVCache
187
+ """
188
+ super().__init__()
189
+ self.batch_size = batch_size
190
+ self.n_heads = n_heads
191
+ self.d_head = d_head
192
+ self.toks_per_frame = toks_per_frame
193
+ self.n_window = n_window
194
+ self.size = toks_per_frame * n_window#5*n_window#(n_window + 1)
195
+ self.n_layers = n_layers
196
+ self.curr_layer = 0
197
+ self.global_loc = 0
198
+ self.local_loc = 0
199
+ self.register_buffer('keys', t.zeros(n_layers, batch_size, self.size, n_heads, d_head))
200
+ self.register_buffer('values', t.zeros(n_layers, batch_size, self.size, n_heads, d_head))
201
+
202
+ def get(self, layer_idx):
203
+ assert layer_idx == self.curr_layer, f"layer idx should be the same as our internal counter but we got {layer_idx} and internal is {self.curr_layer}."
204
+ return self.keys[layer_idx, :, :self.local_loc], self.values[layer_idx, :, :self.local_loc]
205
+
206
+ def extend(self, layer_idx, keys, values):
207
+ assert keys.shape == values.shape, f"keys and values shapes must match {self.keys.shape} != {self.values.shape}"
208
+ assert layer_idx == self.curr_layer, f"layer idx should be the same as our internal counter but we got {layer_idx} and internal is {self.curr_layer}."
209
+ assert self.local_loc <= self.size, f"the cache size should be between 0 and {self.size}"
210
+ local_loc = self.local_loc
211
+ if local_loc == self.size:
212
+ # move to the left
213
+ local_loc -= keys.shape[1]
214
+ assert local_loc >= 0, f"the cache update {keys.shape[1]} was larger than the cache {self.size}, that's not supported for now."
215
+ assert local_loc % self.toks_per_frame == 0, f"the number of elements in the cache {local_loc} must be a multiple of the number of tokens per frame {self.toks_per_frame}"
216
+ self.keys[layer_idx, :, :local_loc] = self.keys[layer_idx, :, self.toks_per_frame:local_loc+self.toks_per_frame].clone()
217
+ self.values[layer_idx, :, :local_loc] = self.values[layer_idx, :, self.toks_per_frame:local_loc+self.toks_per_frame].clone()
218
+ #self.keys[layer_idx, :, self.toks_per_frame:local_loc+self.toks_per_frame] = self.keys[layer_idx, :, -local_loc:].clone()
219
+ #self.values[layer_idx, :, self.toks_per_frame:local_loc+self.toks_per_frame] = self.values[layer_idx, :, -local_loc:].clone()
220
+
221
+ assert local_loc + keys.shape[1] <= self.size, f"{local_loc + keys.shape[1]} out of bounds {self.size}"
222
+ self.keys[layer_idx, :, local_loc:local_loc + keys.shape[1]] = keys
223
+ self.values[layer_idx, :, local_loc:local_loc + keys.shape[1]] = values
224
+ self.curr_layer = (self.curr_layer + 1) % self.n_layers
225
+
226
+ def update_global_location(self, n_frames):
227
+ self.global_loc += n_frames * self.toks_per_frame
228
+ if self.local_loc < self.size:
229
+ self.local_loc += n_frames * self.toks_per_frame
230
+ assert self.local_loc <= self.size, f"the local loc {self.local_loc} should never be bigger than {self.size}, something went wrong."
231
+
232
+ def reset(self):
233
+ self.global_loc = 0
234
+ self.local_loc = 0
235
+ self.curr_layer = 0
236
+ self.keys.zero_()
237
+ self.values.zero_()
238
+
239
+ @property
240
+ def local_location(self):
241
+ return self.local_loc
242
+
243
+ @property
244
+ def global_location(self):
245
+ return self.global_loc
246
+
247
+ @property
248
+ def device(self):
249
+ return self.keys.device
250
+
251
+ @property
252
+ def dtype(self):
253
+ return self.keys.dtype
254
+
255
+
256
+ class AttentionEinOps(nn.Module):
257
+ IGNORE: Float[Tensor, ""]
258
+
259
+ def __init__(self, d_model, n_heads, rope=None):
260
+ super().__init__()
261
+ assert d_model % n_heads == 0, f"{d_model} must be divisble by {n_heads}"
262
+ self.d_head = d_model // n_heads
263
+ d_head = self.d_head
264
+ self.W_Q = nn.Parameter(t.empty((n_heads, d_model, d_head)))
265
+ self.W_K = nn.Parameter(t.empty((n_heads, d_model, d_head)))
266
+ self.W_V = nn.Parameter(t.empty((n_heads, d_model, d_head)))
267
+ self.W_O = nn.Parameter(t.empty((n_heads, d_head, d_model)))
268
+ self.b_Q = nn.Parameter(t.zeros((n_heads, d_head)))
269
+ self.b_K = nn.Parameter(t.zeros((n_heads, d_head)))
270
+ self.b_V = nn.Parameter(t.zeros((n_heads, d_head)))
271
+ self.b_O = nn.Parameter(t.zeros((d_model)))
272
+ nn.init.normal_(self.W_Q, 1/d_model**0.5)
273
+ nn.init.normal_(self.W_K, 1/d_model**0.5)
274
+ nn.init.normal_(self.W_V, 1/d_model**0.5)
275
+ nn.init.normal_(self.W_O, 1/d_head**0.5)
276
+ self.register_buffer("IGNORE", t.tensor(float('-inf'), dtype=t.float32))
277
+ self.rope = rope
278
+ self.ln1 = nn.LayerNorm(d_head)
279
+ self.ln2 = nn.LayerNorm(d_head)
280
+
281
+
282
+ def forward(
283
+ self,
284
+ x_q: Float[Tensor, "batch posq d_model"],
285
+ x_kv: Float[Tensor, "batch posk d_model"],
286
+ mask: Bool[Tensor, "posq posk"] = None, # the 1s are removed
287
+ k_cache: Optional[Float[Tensor, "batch posk n_head d_head"]] = None,
288
+ v_cache: Optional[Float[Tensor, "batch posk n_head d_head"]] = None,
289
+ offset: int = 0
290
+ ) -> Float[Tensor, "batch posq d_model"]:
291
+ assert (k_cache is None and v_cache is None) or (k_cache is not None and v_cache is not None), "k_cache and v_cache go together."
292
+ d_head = self.d_head
293
+ if k_cache is not None and v_cache is not None:
294
+ q = einops.einsum(x_q, self.W_Q, 'b s d, n d h -> b s n h') + self.b_Q
295
+ k_new = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h') + self.b_K
296
+ v_new = einops.einsum(x_kv, self.W_V, 'b s d, n d h -> b s n h') + self.b_V
297
+
298
+ k = t.cat([k_cache, k_new], dim=1)
299
+ v = t.cat([v_cache, v_new], dim=1)
300
+
301
+ if self.rope is not None:
302
+ q = self.rope(q, offset=k_cache.shape[1])
303
+ k = self.rope(k, offset=0)
304
+ q = self.ln1(q) # this should be before rope
305
+ k = self.ln2(k)
306
+ mask = None
307
+ else:
308
+ q = einops.einsum(x_q, self.W_Q, 'b s d, n d h -> b s n h') + self.b_Q
309
+ k = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h') + self.b_K
310
+ v = einops.einsum(x_kv, self.W_V, 'b s d, n d h -> b s n h') + self.b_V
311
+ if self.rope is not None:
312
+ q = self.rope(q)
313
+ k = self.rope(k)
314
+ q = self.ln1(q)
315
+ k = self.ln2(k) # this leanrs much faster using layernorm here
316
+ k_new = k
317
+ v_new = v
318
+
319
+ attention = einops.einsum(q, k, 'b sq n h, b sk n h -> b n sq sk')
320
+ if mask is not None and k_cache is not None:
321
+ attention = t.where(mask[k_cache.shape[1]:k_cache.shape[1]+q.shape[1], :k.shape[1]], self.IGNORE, attention)
322
+ elif mask is not None:
323
+ if attention.shape[-1] != mask.shape[-1] or attention.shape[-2] != mask.shape[-2]:
324
+ #print(f"Warning: attention shape {attention.shape} does not match mask shape {mask.shape}")
325
+ mask = mask[:attention.shape[-1], :attention.shape[-2]]
326
+ attention = t.where(mask, self.IGNORE, attention)
327
+ probas = attention.softmax(dim=3)
328
+ #plt.imshow(probas[0, 0].cpu().numpy())
329
+ #plt.show()
330
+ z = einops.einsum(probas, v, 'b n sq sk, b sk n h -> b sq n h')
331
+ out = einops.einsum(z, self.W_O, 'b s n h, n h d -> b s n d')
332
+ out = out.sum(dim=2) + self.b_O
333
+ return out, k_new, v_new
334
+
335
+
336
+ class Attention(nn.Module):
337
+ IGNORE: Float[Tensor, ""]
338
+
339
+ def __init__(self, d_model, n_heads, rope=None, use_flex_attention=False):
340
+ raise NotImplementedError("Attention is not implemented yet")
341
+ super().__init__()
342
+ assert d_model % n_heads == 0, f"{d_model} must be divisble by {n_heads}"
343
+ self.d_head = d_model // n_heads
344
+ d_head = self.d_head
345
+ self.W_Q = nn.Parameter(t.empty((n_heads, d_model, d_head)))
346
+ self.W_K = nn.Parameter(t.empty((n_heads, d_model, d_head)))
347
+ self.W_V = nn.Parameter(t.empty((n_heads, d_model, d_head)))
348
+ self.W_O = nn.Parameter(t.empty((n_heads, d_head, d_model)))
349
+ #self.b_Q = nn.Parameter(t.zeros((n_heads, d_head)))
350
+ #self.b_K = nn.Parameter(t.zeros((n_heads, d_head)))
351
+ #self.b_V = nn.Parameter(t.zeros((n_heads, d_head)))
352
+ #self.b_O = nn.Parameter(t.zeros((d_model)))
353
+ nn.init.normal_(self.W_Q, 1/d_model**0.5)
354
+ nn.init.normal_(self.W_K, 1/d_model**0.5)
355
+ nn.init.normal_(self.W_V, 1/d_model**0.5)
356
+ nn.init.normal_(self.W_O, 1/d_head**0.5)
357
+ self.register_buffer("IGNORE", t.tensor(float('-inf'), dtype=t.float32))
358
+ self.rope = rope
359
+ self.use_flex_attention = use_flex_attention
360
+ self.ln1 = nn.LayerNorm(d_head)
361
+ self.ln2 = nn.LayerNorm(d_head)
362
+
363
+
364
+ def forward(
365
+ self,
366
+ x_q: Float[Tensor, "batch posq d_model"],
367
+ x_kv: Float[Tensor, "batch posk d_model"],
368
+ mask: Bool[Tensor, "posq posk"] = None, # the 1s are removed
369
+ k_cache: Optional[Float[Tensor, "batch posk n_head d_head"]] = None,
370
+ v_cache: Optional[Float[Tensor, "batch posk n_head d_head"]] = None,
371
+ ) -> Float[Tensor, "batch posq d_model"]:
372
+ assert (k_cache is None and v_cache is None) or (k_cache is not None and v_cache is not None), "k_cache and v_cache go together."
373
+ d_head = self.d_head
374
+ if k_cache is not None and v_cache is not None:
375
+ raise NotImplementedError("kv cache not implemented yet")
376
+ q = einops.einsum(x, self.W_Q, 'b s d, n d h -> b s n h')
377
+ k_new = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h')
378
+ v_new = einops.einsum(x_kv, self.W_V, 'b s d, n d h -> b s n h')
379
+ k = t.cat([k_cache, k_new], dim=1)
380
+ v = t.cat([v_cache, v_new], dim=1)
381
+ else:
382
+ q = einops.einsum(x_q, self.W_Q, 'b s d, n d h -> b s n h')
383
+ k = einops.einsum(x_kv, self.W_K, 'b s d, n d h -> b s n h')
384
+ v = einops.einsum(x_kv, self.W_V, 'b s d, n d h -> b s n h')
385
+
386
+ q = self.ln1(q)
387
+ k = self.ln2(k)
388
+ if self.rope is not None:
389
+ q = self.rope(q)
390
+ k = self.rope(k)
391
+
392
+ # Convert to (batch, num_heads, seq_len, head_dim) format for flex_attention
393
+ q_perm = q.permute(0, 2, 1, 3) # (batch, n_heads, posq, d_head)
394
+ k_perm = k.permute(0, 2, 1, 3) # (batch, n_heads, posk, d_head)
395
+ v_perm = v.permute(0, 2, 1, 3) # (batch, n_heads, posk, d_head)
396
+
397
+ # Ensure tensors are contiguous to avoid flex_attention indexing bugs
398
+ q_perm = q_perm.contiguous()
399
+ k_perm = k_perm.contiguous()
400
+ v_perm = v_perm.contiguous()
401
+
402
+ if self.use_flex_attention:
403
+ # Handle mask using score_mod if needed
404
+ if mask is not None:
405
+ # Store mask and IGNORE for use in score_mod closure
406
+ mask_tensor = mask # (posq, posk)
407
+ ignore_val = self.IGNORE
408
+ def score_mod(score, b, h, q_idx, kv_idx):
409
+ # score_mod operates on individual scalar scores
410
+ # Apply mask: where mask is True, set to -inf
411
+ # Use torch ops that work in compiled context
412
+ mask_val = mask_tensor[q_idx, kv_idx]
413
+ return t.where(mask_val, ignore_val, score)
414
+ z = flex_attention(q_perm, k_perm, v_perm, score_mod=score_mod)
415
+ else:
416
+ z = flex_attention(q_perm, k_perm, v_perm)
417
+ else:
418
+ condi = mask is None and not self.dtype == t.float32
419
+ with t.backends.cuda.sdp_kernel(
420
+ enable_flash=condi,
421
+ enable_math=not condi,
422
+ enable_mem_efficient=not condi
423
+ ):
424
+ z = F.scaled_dot_product_attention(
425
+ q_perm, k_perm, v_perm,
426
+ attn_mask = mask.logical_not() if mask is not None else None,
427
+ dropout_p = 0.0,
428
+ is_causal = False,
429
+ scale = 1.0
430
+ )
431
+ z = z.permute(0, 2, 1, 3) # Back to (batch, posq, n_heads, d_head)
432
+ out = einops.einsum(z, self.W_O, 'b s n h, n h d -> b s n d')
433
+ out = out.sum(dim=2)
434
+ #print(f"out {out.shape}, attention {probas.shape}, q {q.shape}, k {k.shape}, v {v.shape}")
435
+ return out, z, None
436
+
437
+ @property
438
+ def dtype(self):
439
+ return self.parameters().__next__().dtype
440
+
441
+ @property
442
+ def device(self):
443
+ return self.parameters().__next__().device
444
+
445
+
446
+ if __name__ == "__main__":
447
+ from .pe import RoPE
448
+ import inspect
449
+ rope = RoPE(256//8, 10000)
450
+ dtype = t.float32
451
+ rope = rope.to(dtype)
452
+ attn_slow = AttentionSlow(d_model=256, n_heads=8, rope=rope)
453
+ attn = Attention(d_model=256, n_heads=8, rope=rope)
454
+ attn.load_state_dict(attn_slow.state_dict(), strict=False)
455
+ attn.to(dtype)
456
+ attn_slow.to(dtype)
457
+ x = t.randn(1, 1000, 256, dtype=dtype)*10
458
+ xkv = t.randn(1, 1000, 256, dtype=dtype)*10
459
+ mask = t.randint(0, 2, (1000, 1000), dtype=t.bool)
460
+ y, z, _ = attn(x, xkv, mask=mask)
461
+ y_slow, z_slow, _ = attn_slow(x, xkv, mask=mask)
462
+ #assert t.allclose(z, z_slow, atol=1e-5), f"Attention and AttentionSlow should be the same: {(z - z_slow).abs().max()}"
463
+ #assert t.allclose(y, y_slow, atol=1e-5), f"Attention and AttentionSlow should be the same: {(y - y_slow).abs().max()}"
464
+ print("Attention and AttentionSlow are the same")
465
+
466
+ loss = t.nn.functional.mse_loss(y, y_slow)
467
+ loss.backward()
468
+ print("-"*100)
469
+ for n, p in attn.named_parameters():
470
+ print(n, p.grad.shape, p.grad.max(), p.grad.min())
471
+ print("-"*100)
472
+ for n, p in attn_slow.named_parameters():
473
+ print(n, p.grad.shape, p.grad.max(), p.grad.min())
src/nn/geglu.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ class GEGLU(nn.Module):
4
+ def __init__(self, d_in, d_mid, d_out):
5
+ super().__init__()
6
+ self.d_in = d_in
7
+ self.d_mid = d_mid
8
+ self.d_out = d_out
9
+ self.up_proj = nn.Linear(d_in, d_mid, bias=True)
10
+ self.up_proj.bias.data.zero_()
11
+ self.up_gate = nn.Linear(d_in, d_mid, bias=True)
12
+ self.up_gate.bias.data.zero_()
13
+ self.down = nn.Linear(d_mid, d_out, bias=True)
14
+ self.down.bias.data.zero_()
15
+ self.nonlin = nn.SiLU()
16
+
17
+ def forward(self, x):
18
+ x = self.up_proj(x) * self.nonlin(self.up_gate(x))
19
+ x = self.down(x)
20
+ return x
src/nn/patch.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from einops import rearrange
3
+ import torch as t
4
+
5
+ class Patch(nn.Module): # adapted from https://github.com/cloneofsimo/minRF
6
+ def __init__(self, in_channels=3, out_channels=64, patch_size=2):
7
+ super().__init__()
8
+ self.patch_size = patch_size
9
+ self.in_channels = in_channels
10
+ self.out_channels = out_channels
11
+ dim = out_channels
12
+ if dim % 32 == 0 and dim > 32:
13
+ self.init_conv_seq = nn.Sequential(
14
+ nn.Conv2d(in_channels, dim // 2, kernel_size=5, padding=2, stride=1),
15
+ nn.SiLU(),
16
+ nn.GroupNorm(32, dim // 2),
17
+ nn.Conv2d(dim // 2, dim // 2, kernel_size=5, padding=2, stride=1),
18
+ nn.SiLU(),
19
+ nn.GroupNorm(32, dim // 2),
20
+ )
21
+ else:
22
+ self.init_conv_seq = nn.Sequential(
23
+ nn.Conv2d(in_channels, dim // 2, kernel_size=5, padding=2, stride=1),
24
+ nn.SiLU(),
25
+ nn.Conv2d(dim // 2, dim // 2, kernel_size=5, padding=2, stride=1),
26
+ nn.SiLU(),
27
+ )
28
+
29
+ self.x_embedder = nn.Linear(patch_size * patch_size * dim // 2, dim, bias=True)
30
+ nn.init.constant_(self.x_embedder.bias, 0)
31
+
32
+ def forward(self, x):
33
+ batch, dur, c, h, w = x.shape
34
+ x = x.reshape(-1, c, h, w)
35
+ x = self.init_conv_seq(x)
36
+ x = self.patchify(x)
37
+ x = self.x_embedder(x)
38
+ x = x.reshape(batch, dur, -1, self.out_channels)
39
+ return x
40
+
41
+ def patchify(self, x):
42
+ B, C, H, W = x.size()
43
+ x = x.view(
44
+ B,
45
+ C,
46
+ H // self.patch_size,
47
+ self.patch_size,
48
+ W // self.patch_size,
49
+ self.patch_size,
50
+ )
51
+ x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
52
+ return x
53
+
54
+ class UnPatch(nn.Module):
55
+ def __init__(self, height, width, in_channels=64, out_channels=3, patch_size=2):
56
+ super().__init__()
57
+ self.width = width
58
+ self.height = height
59
+ self.patch_size = patch_size
60
+ self.in_channels = in_channels
61
+ self.out_channels = out_channels
62
+ self.unpatch = nn.Linear(in_channels, out_channels*patch_size**2)
63
+
64
+ def forward(self, x):
65
+ x = self.unpatch(x)
66
+ batch, dur, seq, d = x.shape
67
+ x = x.reshape(-1, seq, d)
68
+ x = self.unpatchify(x)
69
+ x = x.reshape(batch, dur, self.out_channels, self.height, self.width)
70
+ return x
71
+
72
+ def unpatchify(self, x):
73
+ c = self.out_channels
74
+ p = self.patch_size
75
+ h = self.height // p
76
+ w = self.width // p
77
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
78
+ x = t.einsum("nhwpqc->nchpwq", x)
79
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
80
+ return imgs
src/nn/pe.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as t
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ from jaxtyping import Float, Bool, Int
6
+ from torch import Tensor
7
+ from typing import Optional
8
+
9
+ class NumericEncoding(nn.Module):
10
+ def __init__(self, C = 1e4, dim = 64, n_max = 10000):
11
+ super().__init__()
12
+ args = t.exp(- math.log(C) * t.arange(0, dim, 2)/dim)
13
+ args = t.arange(n_max)[:, None] * args[None, :]
14
+ sins = t.sin(args)
15
+ coss = t.cos(args)
16
+ pe = t.empty((n_max, dim))
17
+ pe[:,::2] = sins
18
+ pe[:,1::2] = coss
19
+ self.register_buffer("pe", pe)
20
+
21
+ def forward(self, num):
22
+ """
23
+ expects integers between 0 and n_max
24
+ """
25
+ assert num.dtype == t.int32 or num.dtype == t.int64, f"wrong dtype {num.dtype}"
26
+ return self.pe[num]
27
+
28
+
29
+ class RoPE(nn.Module):
30
+ def __init__(self, d_head, n_ctx, C=10000):
31
+ super().__init__()
32
+ thetas = t.exp(-math.log(C)*t.arange(0,d_head,2)/d_head)
33
+ thetas = thetas.repeat([2,1]).T.flatten()
34
+ positions = t.arange(n_ctx)
35
+ all_thetas = positions.unsqueeze(1)*thetas.unsqueeze(0)
36
+ sins = t.sin(all_thetas)
37
+ coss = t.cos(all_thetas)
38
+ self.register_buffer('sins', sins.unsqueeze(0).unsqueeze(2))
39
+ self.register_buffer('coss', coss.unsqueeze(0).unsqueeze(2))
40
+
41
+ def forward(self, key_or_query: Float[Tensor, "batch sequence n_head d_head"],
42
+ offset: int = 0):
43
+ x = key_or_query
44
+ # start with doing it for just a single position m
45
+ x_perm = t.empty(x.shape, device=x.device, dtype=x.dtype) # batch sequence n_head d_head, we perm the last axis
46
+ even = t.arange(0, x.shape[-1], 2)
47
+ odd = t.arange(1, x.shape[-1],2)
48
+ x_perm[:, :, :, even] = -x[:, :, :, odd]
49
+ x_perm[:, :, :, odd] = x[:, :, :, even]
50
+ assert x.shape[1] >= 1, f"x.shape[1] must be >= 1, got {x.shape}"
51
+ return self.coss[:,offset:offset+x.shape[1]]*x + self.sins[:,offset:offset+x.shape[1]]*x_perm
52
+
53
+
54
+ class FrameRoPE(nn.Module):
55
+ def __init__(self, d_head, n_ctx, toks_per_frame, C=10000):
56
+ super().__init__()
57
+ thetas = t.exp(-math.log(C)*t.arange(0,d_head,2)/d_head)
58
+ thetas = thetas.repeat([2,1]).T.flatten()
59
+ positions = t.arange(n_ctx)
60
+ all_thetas = positions.unsqueeze(1)*thetas.unsqueeze(0)
61
+ sins = t.sin(all_thetas)
62
+ coss = t.cos(all_thetas)
63
+ self.register_buffer('sins', sins.unsqueeze(0).unsqueeze(2))
64
+ self.register_buffer('coss', coss.unsqueeze(0).unsqueeze(2))
65
+ self.toks_per_frame = toks_per_frame
66
+
67
+ def forward(self, key_or_query: Float[Tensor, "batch dur*seq n_head d_head"]):
68
+ x = key_or_query
69
+ # start with doing it for just a single position m
70
+ x_perm = t.empty(x.shape, dtype=x.dtype, device=x.device) # batch sequence n_head d_head, we perm the last axis
71
+ even = t.arange(0, x.shape[-1], 2)
72
+ odd = t.arange(1, x.shape[-1], 2)
73
+ x_perm[:, :, :, even] = -x[:, :, :, odd]
74
+ x_perm[:, :, :, odd] = x[:, :, :, even]
75
+ idcs = t.arange(0, x.shape[1]//self.toks_per_frame, device=x.device)
76
+ idcs = idcs[:, None].repeat(1, self.toks_per_frame).flatten()
77
+ return self.coss[:,idcs]*x + self.sins[:,idcs]*x_perm
src/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .logging import log_video
2
+ from .checkpoint import load_model_from_config
src/utils/checkpoint.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import time
5
+ import shutil
6
+ from pathlib import Path
7
+ from tempfile import NamedTemporaryFile
8
+ from typing import Optional, Dict, Any, List
9
+
10
+ import torch as t
11
+ from torch import nn
12
+
13
+ from ..models.dit_dforce import get_model
14
+ from ..config import Config
15
+
16
+ import yaml
17
+
18
+
19
+ def load_model_from_config(config_path: str, checkpoint_path: str = None, strict: bool = True) -> nn.Module:
20
+ print(f"loading {config_path}")
21
+ cmodel = Config.from_yaml(config_path).model
22
+ model = get_model(cmodel.height, cmodel.width,
23
+ n_window=cmodel.n_window,
24
+ patch_size=cmodel.patch_size,
25
+ n_heads=cmodel.n_heads,d_model=cmodel.d_model,
26
+ n_blocks=cmodel.n_blocks,
27
+ T=cmodel.T,
28
+ in_channels=cmodel.in_channels,
29
+ bidirectional=cmodel.bidirectional)
30
+ if checkpoint_path is None and cmodel.checkpoint is not None:
31
+ checkpoint_path = cmodel.checkpoint
32
+ if checkpoint_path is not None:
33
+ state_dict = t.load(checkpoint_path, weights_only=False)
34
+ if "model" in state_dict:
35
+ state_dict = state_dict["model"]
36
+ if "_orig_mod." in list(state_dict.keys())[0]:
37
+ state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items() if k.startswith("_orig_mod.")}
38
+ model.load_state_dict(state_dict, strict=strict)
39
+ print('loaded state dict')
40
+ return model
41
+
42
+
43
+
44
+ class CheckpointManager:
45
+ """
46
+ Manage top-K checkpoints by a metric. On each save:
47
+ - Write a new checkpoint atomically
48
+ - Keep only the top-K files by metric (max or min)
49
+ - Delete files not in top-K
50
+ - Maintain a small JSON index for quick reloads
51
+ Also scans the directory on init to reconstruct state.
52
+
53
+ Filenames are of the form: ckpt-step=<step>-metric=<metric>.pt
54
+ """
55
+
56
+ CKPT_PATTERN = re.compile(
57
+ r"^ckpt-step=(?P<step>\d+)-metric=(?P<metric>[+-]?\d+(?:\.\d+)?(?:e[+-]?\d+)?)\.pt$"
58
+ )
59
+
60
+ def __init__(
61
+ self,
62
+ dirpath: str | Path,
63
+ k: int = 5,
64
+ mode: str = "max", # or "min"
65
+ metric_name: str = "score",
66
+ is_main_process: bool = True,
67
+ index_filename: str = "ckpt_index.json",
68
+ ):
69
+ self.dir = Path(dirpath)
70
+ self.dir.mkdir(parents=True, exist_ok=True)
71
+ assert mode in {"max", "min"}
72
+ self.k = int(k)
73
+ self.mode = mode
74
+ self.metric_name = metric_name
75
+ self.is_main = bool(is_main_process)
76
+ self.index_path = self.dir / index_filename
77
+
78
+ # entries: list of {path(str), step(int), metric(float), ts(float)}
79
+ self.entries: List[Dict[str, Any]] = []
80
+
81
+ self._load_index()
82
+ self._scan_and_merge()
83
+ self._prune_and_persist()
84
+
85
+ # ---------- Public API ----------
86
+
87
+ @property
88
+ def best(self) -> Optional[Dict[str, Any]]:
89
+ return self.entries[0] if self.entries else None
90
+
91
+ @property
92
+ def paths(self) -> List[str]:
93
+ return [e["path"] for e in self.entries]
94
+
95
+ @property
96
+ def should_save(self) -> bool:
97
+ """Use inside DDP loops to gate saving to rank-0 only."""
98
+ return self.is_main
99
+
100
+ def save(
101
+ self,
102
+ *,
103
+ metric: float,
104
+ step: int,
105
+ model: Optional[nn.Module] = None,
106
+ optimizer: Optional[t.optim.Optimizer] = None,
107
+ scheduler: Optional[Any] = None,
108
+ extra: Optional[Dict[str, Any]] = None,
109
+ state_dict: Optional[Dict[str, Any]] = None,
110
+ ) -> Dict[str, Any]:
111
+ """
112
+ Save a checkpoint and keep only top-K by metric.
113
+
114
+ Provide either `state_dict` or a `model` (optionally optimizer/scheduler).
115
+ The saved file always contains:
116
+ - 'model', 'optimizer', 'scheduler' (if provided)
117
+ - 'step', metric_name, 'timestamp', 'manager'
118
+ Returns info about the saved file and whether it made the top-K.
119
+ """
120
+ if not self.should_save:
121
+ return {"saved": False, "kept": False, "reason": "not main process"}
122
+
123
+ if state_dict is None:
124
+ state_dict = {}
125
+ if model is not None:
126
+ state_dict["model"] = model.state_dict()
127
+ if optimizer is not None:
128
+ state_dict["optimizer"] = optimizer.state_dict()
129
+ if scheduler is not None:
130
+ # Some schedulers (e.g., OneCycleLR) have state_dict
131
+ try:
132
+ state_dict["scheduler"] = scheduler.state_dict()
133
+ except Exception:
134
+ pass
135
+
136
+ ts = time.time()
137
+ filename = f"ckpt-step={int(step):06d}-metric={float(metric):.8f}.pt"
138
+ fpath = self.dir / filename
139
+
140
+ # Attach metadata for convenience
141
+ payload = {
142
+ **state_dict,
143
+ "step": int(step),
144
+ self.metric_name: float(metric),
145
+ "timestamp": ts,
146
+ "manager": {
147
+ "mode": self.mode,
148
+ "k": self.k,
149
+ "metric_name": self.metric_name,
150
+ "filename": filename,
151
+ },
152
+ }
153
+
154
+ # Atomic write
155
+ with NamedTemporaryFile(dir=self.dir, delete=False) as tmp:
156
+ tmp_path = Path(tmp.name)
157
+ try:
158
+ t.save(payload, tmp_path)
159
+ os.replace(tmp_path, fpath) # atomic on POSIX
160
+ finally:
161
+ if tmp_path.exists():
162
+ try:
163
+ tmp_path.unlink()
164
+ except Exception:
165
+ pass
166
+
167
+ # Update entries and prune
168
+ new_entry = {
169
+ "path": str(fpath),
170
+ "step": int(step),
171
+ "metric": float(metric),
172
+ "ts": ts,
173
+ }
174
+ self.entries.append(new_entry)
175
+ kept = self._prune_and_persist() # returns True if new file in top-K
176
+
177
+ return {"saved": True, "kept": kept, "path": str(fpath), "best": self.best}
178
+
179
+ # ---------- Internal helpers ----------
180
+
181
+ def _sort_key(self, e: Dict[str, Any]):
182
+ # For MAX: better first => sort by (-metric, step)
183
+ # For MIN: better first => sort by (metric, step)
184
+ return ((-e["metric"], e["step"]) if self.mode == "max" else (e["metric"], e["step"]))
185
+
186
+ def _load_index(self):
187
+ if not self.index_path.exists():
188
+ self.entries = []
189
+ return
190
+ try:
191
+ data = json.loads(self.index_path.read_text())
192
+ entries = data.get("entries", [])
193
+ # Drop missing files
194
+ self.entries = [e for e in entries if Path(e["path"]).exists()]
195
+ # Normalize types
196
+ for e in self.entries:
197
+ e["metric"] = float(e["metric"])
198
+ e["step"] = int(e["step"])
199
+ e["ts"] = float(e.get("ts", time.time()))
200
+ except Exception:
201
+ # If index is corrupted, fall back to empty and rescan
202
+ self.entries = []
203
+
204
+ def _scan_and_merge(self):
205
+ """Scan directory for checkpoint files and merge with current entries."""
206
+ seen = {Path(e["path"]).name for e in self.entries}
207
+ for p in self.dir.glob("ckpt-step=*-metric=*.pt"):
208
+ name = p.name
209
+ if name in seen:
210
+ continue
211
+ m = self.CKPT_PATTERN.match(name)
212
+ if not m:
213
+ continue
214
+ step = int(m.group("step"))
215
+ try:
216
+ metric = float(m.group("metric"))
217
+ except ValueError:
218
+ continue
219
+ self.entries.append(
220
+ {"path": str(p), "step": step, "metric": metric, "ts": p.stat().st_mtime}
221
+ )
222
+
223
+ def _prune_and_persist(self) -> bool:
224
+ """Sort by metric, keep top-K, delete the rest. Return True if newest file is kept."""
225
+ if not self.entries:
226
+ self._persist_index()
227
+ return False
228
+
229
+ # Sort best-first
230
+ self.entries.sort(key=self._sort_key)
231
+
232
+ # Determine which to keep and which to delete
233
+ keep = self.entries[: self.k]
234
+ drop = self.entries[self.k :]
235
+
236
+ keep_paths = {e["path"] for e in keep}
237
+ newest_path = max(self.entries, key=lambda e: e["ts"])["path"]
238
+ newest_kept = newest_path in keep_paths
239
+
240
+ # Delete files not in top-K
241
+ for e in drop:
242
+ try:
243
+ Path(e["path"]).unlink(missing_ok=True)
244
+ except Exception:
245
+ pass
246
+
247
+ # Commit the top-K
248
+ self.entries = keep
249
+ self._persist_index()
250
+ return newest_kept
251
+
252
+ def _persist_index(self):
253
+ data = {
254
+ "k": self.k,
255
+ "mode": self.mode,
256
+ "metric_name": self.metric_name,
257
+ "entries": self.entries,
258
+ "updated_at": time.time(),
259
+ }
260
+ tmp = self.index_path.with_suffix(".json.tmp")
261
+ tmp.write_text(json.dumps(data, indent=2))
262
+ os.replace(tmp, self.index_path)
263
+
264
+
265
+ # ---------------------- Example usage ----------------------
266
+ if __name__ == "__main__":
267
+ # Example (single process). In DDP, construct with is_main_process=(rank==0).
268
+ mgr = CheckpointManager("checkpoints", k=5, mode="max", metric_name="val_acc")
269
+
270
+ model = nn.Linear(10, 2)
271
+ opt = t.optim.AdamW(model.parameters(), lr=1e-3)
272
+
273
+ # Fake loop
274
+ for epoch in range(10):
275
+ metric = 0.5 + 0.1 * t.rand(1).item() # pretend validation accuracy
276
+ info = mgr.save(metric=metric, step=epoch, model=model, optimizer=opt)
277
+ print(
278
+ f"epoch {epoch:02d} metric={metric:.4f} saved={info['saved']} kept={info['kept']} "
279
+ f"best_metric={mgr.best['metric'] if mgr.best else None:.4f}"
280
+ )
281
+
282
+ print("Top-K paths:", mgr.paths)
283
+ print("Best:", mgr.best)
static/index.html ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html>
3
+ <head>
4
+ <meta charset="utf-8" />
5
+ <title>Pong</title>
6
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
7
+ <!-- Socket.IO client library (CDN) -->
8
+ <script src="https://cdn.socket.io/4.5.4/socket.io.min.js"></script>
9
+ <style>
10
+ html, body { margin:0; height:100%; background:#111; color:#eee; font-family: system-ui, sans-serif; }
11
+ #overlay {
12
+ position: fixed; inset: 0; display: flex; align-items: center; justify-content: center;
13
+ background: rgba(0,0,0,0.8); z-index: 9999; transition: opacity 200ms ease;
14
+ }
15
+ #overlay.hidden { opacity: 0; pointer-events: none; }
16
+ .spinner {
17
+ width: 64px; height: 64px; border: 6px solid #444; border-top-color: #09f; border-radius: 50%;
18
+ animation: spin 0.9s linear infinite;
19
+ }
20
+ @keyframes spin { to { transform: rotate(360deg); } }
21
+ #statusText { margin-top: 12px; color: #aaa; text-align: center; font-size: 14px; white-space: pre-line; }
22
+ #app { padding: 16px; }
23
+ button { padding: 8px 12px; background:#09f; color:#fff; border:none; border-radius:8px; cursor:pointer; }
24
+ button:disabled { opacity: .5; cursor: not-allowed; }
25
+ img#frame { image-rendering: pixelated; width: 240px; height: 240px; background:#222; display:block; margin-top:12px; }
26
+ </style>
27
+ </head>
28
+ <body>
29
+ <div id="overlay">
30
+ <div>
31
+ <div class="spinner"></div>
32
+ <div id="statusText">Loading model…</div>
33
+ </div>
34
+ </div>
35
+
36
+ <div id="app">
37
+ <h1>Pong</h1>
38
+ <div style="margin-bottom: 12px;">
39
+ <label style="display: block; margin-bottom: 8px;">
40
+ FPS: <input type="number" id="fpsInput" value="20" min="1" max="30" step="1" style="width: 60px; padding: 4px; margin-left: 8px;" />
41
+ <span style="color: #aaa; font-size: 12px; margin-left: 8px;">frames per second</span>
42
+ </label>
43
+ <label style="display: block; margin-bottom: 8px;">
44
+ Steps: <input type="number" id="stepsInput" value="4" min="1" max="10" step="1" style="width: 60px; padding: 4px; margin-left: 8px;" />
45
+ <span style="color: #aaa; font-size: 12px; margin-left: 8px;">diffusion steps</span>
46
+ </label>
47
+ </div>
48
+ <div>
49
+ <button id="startBtn" disabled>Start Stream</button>
50
+ <button id="stopBtn" disabled>Stop Stream</button>
51
+ </div>
52
+ <img id="frame" alt="Latest frame" />
53
+ <div id="actionDisplay" style="margin-top: 12px; font-size: 16px; font-family: monospace;">
54
+ Action: <span id="actionValue">-</span>
55
+ </div>
56
+ <div id="fpsDisplay" style="margin-top: 8px; font-size: 16px; font-family: monospace;">
57
+ Achieved FPS: <span id="fpsValue">-</span>
58
+ </div>
59
+ <div>
60
+ This is the output of a small frame-autoregressive transformer trained with rectified flow matching to simulate pong frames conditioned on user inputs for the blue paddle. It should reach 20 FPS when using 4 steps for generation unless something else is running on my machine.
61
+ </div>
62
+ </div>
63
+
64
+ <script>
65
+ // If you serve socket.io client at /socket.io/socket.io.js you can use global io():
66
+ const socket = io({ transports: ['websocket', 'polling'] });
67
+
68
+ const overlay = document.getElementById('overlay');
69
+ const statusText = document.getElementById('statusText');
70
+ const startBtn = document.getElementById('startBtn');
71
+ const stopBtn = document.getElementById('stopBtn');
72
+ const frameImg = document.getElementById('frame');
73
+
74
+ function setStatus(isReady) {
75
+ if (!isReady) {
76
+ // Model is still loading
77
+ overlay.classList.remove('hidden');
78
+ startBtn.disabled = true;
79
+ stopBtn.disabled = true;
80
+ statusText.textContent = 'Loading model…';
81
+ } else {
82
+ // Server is ready and available
83
+ overlay.classList.add('hidden');
84
+ startBtn.disabled = false;
85
+ stopBtn.disabled = false;
86
+ statusText.textContent = 'Ready';
87
+ }
88
+ }
89
+
90
+ // Initial state: assume not ready (show spinner)
91
+ setStatus(false);
92
+
93
+ socket.on('connect', () => {
94
+ // server will immediately emit 'server_status' with current readiness
95
+ console.log('connected');
96
+ });
97
+
98
+ // Backend broadcasts readiness changes
99
+ socket.on('server_status', (payload) => {
100
+ const ready = !!(payload && payload.ready);
101
+ console.log('Server status:', { ready });
102
+ setStatus(ready);
103
+ });
104
+
105
+ // Start/stop controls
106
+ startBtn.addEventListener('click', () => {
107
+ const fps = parseInt(document.getElementById('fpsInput').value) || 12;
108
+ const n_steps = parseInt(document.getElementById('stepsInput').value) || 1;
109
+ socket.emit('start_stream', { n_steps: n_steps, cfg: 0.0, fps: fps, clamp: true });
110
+ });
111
+ stopBtn.addEventListener('click', () => {
112
+ socket.emit('stop_stream');
113
+ });
114
+
115
+ const actionValue = document.getElementById('actionValue');
116
+ const fpsValue = document.getElementById('fpsValue');
117
+
118
+ // Incoming frames
119
+ socket.on('frame', ({ frame, frame_index, action, fps }) => {
120
+ frameImg.src = `data:image/png;base64,${frame}`;
121
+ // Display action: 0=NOOP, 1=UP, 2=DOWN
122
+ const actionLabels = ['START','NOOP', 'UP', 'DOWN'];
123
+ actionValue.textContent = `${action} (${actionLabels[action] || 'UNKNOWN'})`;
124
+ // Display achieved FPS
125
+ if (fps !== undefined) {
126
+ fpsValue.textContent = fps.toFixed(1);
127
+ }
128
+ });
129
+
130
+ socket.on('error', (e) => {
131
+ console.warn('server error', e);
132
+ // The server_status event will handle showing the appropriate overlay
133
+ // Just log the error for now
134
+ if (e && e.message) {
135
+ console.error('Server error message:', e.message);
136
+ }
137
+ });
138
+
139
+ // Keyboard controls for paddle
140
+ // Actions: 0=NOOP, 1=UP, 2=DOWN
141
+ document.addEventListener('keydown', (e) => {
142
+ let action = null;
143
+ if (e.key === 'ArrowUp' || e.key === 'w' || e.key === 'W') {
144
+ action = 2; // UP
145
+ } else if (e.key === 'ArrowDown' || e.key === 's' || e.key === 'S') {
146
+ action = 3; // DOWN
147
+ }
148
+ if (action !== null) {
149
+ socket.emit('action', { action });
150
+ e.preventDefault();
151
+ }
152
+ });
153
+
154
+ document.addEventListener('keyup', (e) => {
155
+ if (['ArrowUp', 'ArrowDown', 'w', 'W', 's', 'S'].includes(e.key)) {
156
+ socket.emit('action', { action: 1 }); // NOOP when key released
157
+ e.preventDefault();
158
+ }
159
+ });
160
+ </script>
161
+ </body>
162
+ </html>