|
|
|
|
|
""" |
|
|
Pong backend (GPU, eager) for Hugging Face Spaces. |
|
|
Broadcasts readiness via Socket.IO so the frontend can auto-hide a loading overlay once the model is ready. |
|
|
""" |
|
|
|
|
|
|
|
|
import eventlet |
|
|
eventlet.monkey_patch() |
|
|
|
|
|
import sys |
|
|
import os |
|
|
import time |
|
|
import threading |
|
|
import base64 |
|
|
import traceback |
|
|
from contextlib import contextmanager |
|
|
from io import BytesIO |
|
|
|
|
|
import torch as t |
|
|
import torch._dynamo as _dynamo |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from flask import Flask, request, jsonify, send_from_directory |
|
|
from flask_cors import CORS |
|
|
from flask_socketio import SocketIO, emit |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
project_root = os.path.dirname(os.path.abspath(__file__)) |
|
|
if project_root not in sys.path: |
|
|
sys.path.insert(0, project_root) |
|
|
|
|
|
from src.utils.checkpoint import load_model_from_config |
|
|
from src.inference.sampling import sample |
|
|
from src.datasets.pong1m import fixed2frame |
|
|
from src.config import Config |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = Flask(__name__, static_folder='static') |
|
|
CORS(app) |
|
|
|
|
|
socketio = SocketIO( |
|
|
app, |
|
|
cors_allowed_origins="*", |
|
|
async_mode='eventlet', |
|
|
logger=False, |
|
|
engineio_logger=False, |
|
|
ping_timeout=60, |
|
|
ping_interval=25, |
|
|
max_http_buffer_size=1e8 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = None |
|
|
pred2frame = None |
|
|
device = None |
|
|
|
|
|
server_ready = False |
|
|
|
|
|
stream_lock = threading.Lock() |
|
|
stream_thread = None |
|
|
stream_running = False |
|
|
latest_action = 1 |
|
|
target_fps = 30 |
|
|
frame_index = 0 |
|
|
|
|
|
noise_buf = None |
|
|
action_buf = None |
|
|
cpu_png_buffer = None |
|
|
|
|
|
step_once = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t.backends.cudnn.benchmark = True |
|
|
t.backends.cudnn.conv.fp32_precision = "tf32" |
|
|
t.backends.cuda.matmul.fp32_precision = "high" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _shape(x): |
|
|
try: |
|
|
return f"{tuple(x.shape)} | {x.dtype} | {x.device}" |
|
|
except Exception: |
|
|
return "<?>" |
|
|
|
|
|
def _shape_attr(obj, name): |
|
|
try: |
|
|
ten = getattr(obj, name, None) |
|
|
return None if ten is None else _shape(ten) |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
def _fail(msg, extra=None): |
|
|
lines = [f"[GEN ERROR] {msg}"] |
|
|
if extra: |
|
|
for k, v in extra.items(): |
|
|
lines.append(f" - {k}: {v}") |
|
|
raise RuntimeError("\n".join(lines)) |
|
|
|
|
|
@contextmanager |
|
|
def log_step_debug(action_tensor=None, noise_tensor=None): |
|
|
try: |
|
|
yield |
|
|
except Exception as e: |
|
|
tb = traceback.format_exc(limit=6) |
|
|
_fail("Step failed", |
|
|
extra={ |
|
|
"action": _shape(action_tensor), |
|
|
"noise": _shape(noise_tensor), |
|
|
"model.device": str(device), |
|
|
"cache.keys": _shape_attr(getattr(model, "cache", None), "keys"), |
|
|
"cache.values": _shape_attr(getattr(model, "cache", None), "values"), |
|
|
"frame_index": str(frame_index), |
|
|
"exception": f"{type(e).__name__}: {e}", |
|
|
"trace": tb.strip() |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ensure_cuda(): |
|
|
if not t.cuda.is_available(): |
|
|
raise RuntimeError("CUDA GPU required; torch.cuda.is_available() is False.") |
|
|
return t.device("cuda:0") |
|
|
|
|
|
def _png_base64_from_uint8(frame_uint8) -> str: |
|
|
global cpu_png_buffer |
|
|
if cpu_png_buffer is None: |
|
|
cpu_png_buffer = BytesIO() |
|
|
else: |
|
|
cpu_png_buffer.seek(0) |
|
|
cpu_png_buffer.truncate(0) |
|
|
Image.fromarray(frame_uint8).save(cpu_png_buffer, format="PNG") |
|
|
return base64.b64encode(cpu_png_buffer.getvalue()).decode() |
|
|
|
|
|
def _reset_cache_fresh(): |
|
|
model.cache.reset() |
|
|
|
|
|
def _broadcast_ready(): |
|
|
"""Tell all clients whether the server is ready.""" |
|
|
socketio.emit('server_status', {'ready': server_ready, 'busy': False}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def initialize_model(): |
|
|
global model, pred2frame, device |
|
|
global noise_buf, action_buf, step_once, server_ready |
|
|
|
|
|
t_start = time.time() |
|
|
print("Loading model and preparing GPU runtime...") |
|
|
device = _ensure_cuda() |
|
|
|
|
|
config_path = os.path.join(project_root, "configs/inference.yaml") |
|
|
|
|
|
cfg = Config.from_yaml(config_path) |
|
|
checkpoint_path = cfg.model.checkpoint |
|
|
|
|
|
model = load_model_from_config(config_path, checkpoint_path=checkpoint_path, strict=False) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
model.activate_caching(1, 300) |
|
|
|
|
|
|
|
|
globals()["pred2frame"] = fixed2frame |
|
|
|
|
|
H = W = 24 |
|
|
noise_buf = t.empty((1, 1, 3, H, W), device=device) |
|
|
action_buf = t.empty((1, 1), dtype=t.long, device=device) |
|
|
|
|
|
@_dynamo.disable |
|
|
def _step(model_, action_scalar_long: int, n_steps: int, cfg: float, clamp: bool): |
|
|
|
|
|
noise = t.randn(1, 1, 3, 24, 24, device=device) |
|
|
action_buf.fill_(int(action_scalar_long)) |
|
|
|
|
|
assert action_buf.shape == (1, 1) and action_buf.dtype == t.long and action_buf.device == device, \ |
|
|
f"action_buf wrong: { _shape(action_buf) }" |
|
|
assert noise.shape == (1, 1, 3, 24, 24) and noise.device == device, \ |
|
|
f"noise wrong: { _shape(noise) }" |
|
|
|
|
|
|
|
|
if model_.cache is not None: |
|
|
cache_loc = model_.cache.local_location |
|
|
if cache_loc == 0: |
|
|
|
|
|
pass |
|
|
elif cache_loc > 0: |
|
|
|
|
|
k_test, v_test = model_.cache.get(0) |
|
|
if k_test.shape[1] == 0: |
|
|
print(f"Warning: Cache returned empty tensors at frame {frame_index}, resetting...") |
|
|
_reset_cache_fresh() |
|
|
|
|
|
|
|
|
z = sample(model_, noise, action_buf, num_steps=n_steps, cfg=cfg, negative_actions=None) |
|
|
|
|
|
|
|
|
model_.cache.update_global_location(1) |
|
|
|
|
|
if clamp: |
|
|
z = t.clamp(z, -1, 1) |
|
|
return z |
|
|
|
|
|
globals()["step_once"] = _step |
|
|
print("Mode: eager (no torch.compile)") |
|
|
|
|
|
|
|
|
_reset_cache_fresh() |
|
|
with t.inference_mode(), t.autocast(device_type="cuda", dtype=t.bfloat16): |
|
|
for _ in range(4): |
|
|
with log_step_debug(action_tensor=action_buf, noise_tensor=noise_buf): |
|
|
_ = step_once(model, action_scalar_long=1, n_steps=4, cfg=0.0, clamp=True) |
|
|
|
|
|
server_ready = True |
|
|
print(f"Model ready on {device}") |
|
|
_broadcast_ready() |
|
|
return model, pred2frame |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FrameScheduler(threading.Thread): |
|
|
def __init__(self, fps=30, n_steps=8, cfg=0.0, clamp=True): |
|
|
super().__init__(daemon=True) |
|
|
self.frame_period = 1.0 / max(1, int(fps)) |
|
|
self.n_steps = int(n_steps) |
|
|
self.cfg = float(cfg) |
|
|
self.clamp = bool(clamp) |
|
|
self._stop = threading.Event() |
|
|
|
|
|
self.frame_times = [] |
|
|
self.last_frame_time = None |
|
|
|
|
|
def stop(self): |
|
|
self._stop.set() |
|
|
|
|
|
def run(self): |
|
|
global frame_index, latest_action |
|
|
next_tick = time.perf_counter() |
|
|
while not self._stop.is_set(): |
|
|
start = time.perf_counter() |
|
|
if start - next_tick > self.frame_period * 0.75: |
|
|
next_tick = start + self.frame_period |
|
|
continue |
|
|
try: |
|
|
with stream_lock: |
|
|
action = int(latest_action) |
|
|
with t.inference_mode(), t.autocast(device_type="cuda", dtype=t.bfloat16): |
|
|
with log_step_debug(action_tensor=action_buf, noise_tensor=noise_buf): |
|
|
z = step_once(model, action_scalar_long=action, |
|
|
n_steps=self.n_steps, cfg=self.cfg, clamp=self.clamp) |
|
|
frames_btchw = pred2frame(z) |
|
|
|
|
|
if frame_index < 3: |
|
|
print(f"Frame {frame_index}: z range [{z.min().item():.3f}, {z.max().item():.3f}], " |
|
|
f"frames_btchw dtype={frames_btchw.dtype}, range [{frames_btchw.min().item()}, {frames_btchw.max().item()}]") |
|
|
|
|
|
frame_arr = frames_btchw[0, 0].permute(1, 2, 0).contiguous() |
|
|
if isinstance(frame_arr, t.Tensor): |
|
|
frame_np = frame_arr.to("cpu", non_blocking=True).numpy() |
|
|
else: |
|
|
frame_np = frame_arr.astype(np.uint8, copy=False) |
|
|
img_b64 = _png_base64_from_uint8(frame_np) |
|
|
|
|
|
|
|
|
current_time = time.perf_counter() |
|
|
if self.last_frame_time is not None: |
|
|
frame_delta = current_time - self.last_frame_time |
|
|
self.frame_times.append(frame_delta) |
|
|
|
|
|
if len(self.frame_times) > 30: |
|
|
self.frame_times.pop(0) |
|
|
avg_frame_time = sum(self.frame_times) / len(self.frame_times) |
|
|
achieved_fps = 1.0 / avg_frame_time if avg_frame_time > 0 else 0 |
|
|
else: |
|
|
achieved_fps = 0 |
|
|
self.last_frame_time = current_time |
|
|
|
|
|
socketio.emit('frame', {'frame': img_b64, |
|
|
'frame_index': frame_index, |
|
|
'action': action, |
|
|
'fps': achieved_fps}) |
|
|
frame_index += 1 |
|
|
except Exception as e: |
|
|
print("Generation error:", repr(e)) |
|
|
socketio.emit('error', {'message': str(e)}) |
|
|
next_tick += self.frame_period |
|
|
now = time.perf_counter() |
|
|
sleep_for = next_tick - now |
|
|
if sleep_for > 0: |
|
|
time.sleep(sleep_for) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.route('/') |
|
|
def index(): |
|
|
return send_from_directory('static', 'index.html') |
|
|
|
|
|
@app.errorhandler(500) |
|
|
def handle_500(e): |
|
|
"""Handle WSGI errors gracefully""" |
|
|
import traceback |
|
|
print(f"Flask error handler caught: {e}") |
|
|
traceback.print_exc() |
|
|
return jsonify({'error': 'Internal server error'}), 500 |
|
|
|
|
|
@app.route('/api/health', methods=['GET']) |
|
|
def health(): |
|
|
return jsonify({ |
|
|
'status': 'ok', |
|
|
'ready': server_ready, |
|
|
'model_loaded': model is not None, |
|
|
'device': str(device) if device else None, |
|
|
'stream_running': stream_running, |
|
|
'target_fps': target_fps |
|
|
}) |
|
|
|
|
|
@app.route('/api/generate', methods=['POST']) |
|
|
def generate_frames(): |
|
|
try: |
|
|
if not server_ready: |
|
|
return jsonify({'success': False, 'error': 'Server not ready'}), 503 |
|
|
|
|
|
data = request.json or {} |
|
|
actions_list = data.get('actions', [1]) |
|
|
n_steps = int(data.get('n_steps', 8)) |
|
|
cfg = float(data.get('cfg', 0)) |
|
|
clamp = bool(data.get('clamp', True)) |
|
|
|
|
|
if len(actions_list) == 0 or actions_list[0] != 0: |
|
|
actions_list = [0] + actions_list |
|
|
|
|
|
_reset_cache_fresh() |
|
|
|
|
|
frames_png = [] |
|
|
with t.inference_mode(), t.autocast(device_type="cuda", dtype=t.bfloat16): |
|
|
for a in actions_list: |
|
|
with log_step_debug(action_tensor=action_buf, noise_tensor=noise_buf): |
|
|
z = step_once(model, action_scalar_long=int(a), n_steps=n_steps, cfg=cfg, clamp=clamp) |
|
|
f_btchw = pred2frame(z) |
|
|
f_arr = f_btchw[0, 0].permute(1, 2, 0).contiguous() |
|
|
if isinstance(f_arr, t.Tensor): |
|
|
if f_arr.dtype != t.uint8: |
|
|
f_arr = f_arr.to(t.uint8) |
|
|
f_np = f_arr.to("cpu", non_blocking=True).numpy() |
|
|
else: |
|
|
f_np = f_arr.astype(np.uint8, copy=False) |
|
|
frames_png.append(_png_base64_from_uint8(f_np)) |
|
|
|
|
|
return jsonify({'success': True, 'frames': frames_png, 'num_frames': len(frames_png)}) |
|
|
|
|
|
except Exception as e: |
|
|
print("Batch generation error:", repr(e)) |
|
|
return jsonify({'success': False, 'error': str(e)}), 500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def start_stream(n_steps=8, cfg=0.0, fps=30, clamp=True): |
|
|
global stream_thread, stream_running, frame_index, target_fps, latest_action |
|
|
if not server_ready: |
|
|
_broadcast_ready() |
|
|
raise RuntimeError("Server not ready") |
|
|
with stream_lock: |
|
|
stop_stream() |
|
|
target_fps = int(fps) |
|
|
frame_index = 0 |
|
|
_reset_cache_fresh() |
|
|
latest_action = 0 |
|
|
stream_thread = FrameScheduler(fps=target_fps, n_steps=n_steps, cfg=cfg, clamp=clamp) |
|
|
stream_running = True |
|
|
stream_thread.start() |
|
|
|
|
|
def stop_stream(): |
|
|
global stream_thread, stream_running |
|
|
if stream_thread is not None: |
|
|
stream_thread.stop() |
|
|
stream_thread.join(timeout=1.0) |
|
|
stream_thread = None |
|
|
stream_running = False |
|
|
|
|
|
@socketio.on_error_default |
|
|
def default_error_handler(e): |
|
|
print(f"SocketIO error: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
@socketio.on('connect') |
|
|
def handle_connect(): |
|
|
try: |
|
|
sid = request.sid |
|
|
print(f'Client connected: {sid}') |
|
|
|
|
|
|
|
|
emit('server_status', { |
|
|
'ready': server_ready, |
|
|
'busy': False |
|
|
}) |
|
|
emit('connected', { |
|
|
'status': 'connected', |
|
|
'model_loaded': model is not None, |
|
|
'ready': server_ready |
|
|
}) |
|
|
except Exception as e: |
|
|
print(f"Error in handle_connect: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
@socketio.on('disconnect') |
|
|
def handle_disconnect(*args): |
|
|
sid = request.sid |
|
|
print(f'Client disconnected: {sid}') |
|
|
|
|
|
|
|
|
@socketio.on('start_stream') |
|
|
def handle_start_stream(data): |
|
|
try: |
|
|
if not server_ready: |
|
|
|
|
|
emit('server_status', {'ready': server_ready, 'busy': False}) |
|
|
return |
|
|
|
|
|
n_steps = int(data.get('n_steps', 8)) |
|
|
cfg = float(data.get('cfg', 0)) |
|
|
fps = int(data.get('fps', 30)) |
|
|
clamp = bool(data.get('clamp', True)) |
|
|
print(f"Starting stream @ {fps} FPS (n_steps={n_steps}, cfg={cfg}, clamp={clamp})") |
|
|
try: |
|
|
start_stream(n_steps=n_steps, cfg=cfg, fps=fps, clamp=clamp) |
|
|
emit('stream_started', {'status': 'ok'}) |
|
|
except Exception as e: |
|
|
print(f"Error starting stream: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
emit('error', {'message': str(e)}) |
|
|
except Exception as e: |
|
|
print(f"Error in handle_start_stream: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
emit('error', {'message': f'Failed to start stream: {str(e)}'}) |
|
|
|
|
|
@socketio.on('action') |
|
|
def handle_action(data): |
|
|
global latest_action |
|
|
action = int(data.get('action', 1)) |
|
|
with stream_lock: |
|
|
latest_action = action |
|
|
emit('action_ack', {'received': action, 'will_apply_to_frame_index': frame_index}) |
|
|
|
|
|
@socketio.on('stop_stream') |
|
|
def handle_stop_stream(): |
|
|
print('Stopping stream') |
|
|
stop_stream() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
init_thread = threading.Thread(target=initialize_model, daemon=True) |
|
|
init_thread.start() |
|
|
|
|
|
|
|
|
port = int(os.environ.get('PORT', 7860)) |
|
|
print(f"Starting Flask server on http://0.0.0.0:{port}") |
|
|
print("Model will load in background...") |
|
|
socketio.run(app, host='0.0.0.0', port=port, debug=False, allow_unsafe_werkzeug=True, use_reloader=False) |
|
|
|
|
|
|