starry / backend /python-services /predictors /torchscript_predictor.py
k-l-lambda's picture
feat: add Python ML services (CPU mode) with model download
2b7aae2
"""
TorchScript model predictor base class.
Loads and runs inference on TorchScript (.pt) models.
"""
import os
import torch
import yaml
import logging
def resolve_model_path(model_path):
"""Resolve model path from directory using .state.yaml if needed.
If model_path is a file, return as-is.
If model_path is a directory, read .state.yaml 'best' field to find the model file.
"""
if os.path.isfile(model_path):
return model_path
if not os.path.isdir(model_path):
raise ValueError(f'Model path not found: {model_path}')
state_file = os.path.join(model_path, '.state.yaml')
if os.path.exists(state_file):
with open(state_file, 'r') as f:
state = yaml.safe_load(f)
best = state.get('best')
if best:
# Strip extension from best to get base name
base = best
for ext in ['.chkpt', '.pt']:
if best.endswith(ext):
base = best[:-len(ext)]
break
# Prefer .pt (TorchScript) over .chkpt (checkpoint)
for ext in ['.pt', '.chkpt', '']:
candidate = os.path.join(model_path, base + ext)
if os.path.isfile(candidate):
return candidate
# Fallback: find .pt files first, then any model file
pt_files = [f for f in os.listdir(model_path)
if f.endswith('.pt') and not f.startswith('.')]
if len(pt_files) == 1:
return os.path.join(model_path, pt_files[0])
model_files = [f for f in os.listdir(model_path)
if f.endswith(('.pt', '.chkpt')) and not f.startswith('.')]
if len(model_files) == 1:
return os.path.join(model_path, model_files[0])
raise ValueError(f'Cannot resolve model file in directory: {model_path}')
class TorchScriptPredictor:
"""Base class for TorchScript model predictors."""
def __init__(self, model_path, device='cuda'):
self.device = device
resolved_path = resolve_model_path(model_path)
self.model = self._load_model(resolved_path)
logging.info('TorchScript model loaded: %s (device: %s)', resolved_path, device)
def _load_model(self, model_path):
"""Load TorchScript model from file."""
model = torch.jit.load(model_path, map_location=self.device)
model.eval()
return model
def preprocess(self, images):
"""
Preprocess images before inference.
Override in subclass.
images: list of numpy arrays
returns: torch.Tensor
"""
raise NotImplementedError
def postprocess(self, outputs):
"""
Postprocess model outputs.
Override in subclass.
outputs: torch.Tensor
returns: processed results
"""
raise NotImplementedError
def predict(self, streams, **kwargs):
"""
Run prediction on input streams.
streams: list of image byte buffers
yields: prediction results
"""
raise NotImplementedError
def run_inference(self, batch):
"""Run model inference with no_grad context."""
with torch.no_grad():
return self.model(batch)