""" TorchScript model predictor base class. Loads and runs inference on TorchScript (.pt) models. """ import os import torch import yaml import logging def resolve_model_path(model_path): """Resolve model path from directory using .state.yaml if needed. If model_path is a file, return as-is. If model_path is a directory, read .state.yaml 'best' field to find the model file. """ if os.path.isfile(model_path): return model_path if not os.path.isdir(model_path): raise ValueError(f'Model path not found: {model_path}') state_file = os.path.join(model_path, '.state.yaml') if os.path.exists(state_file): with open(state_file, 'r') as f: state = yaml.safe_load(f) best = state.get('best') if best: # Strip extension from best to get base name base = best for ext in ['.chkpt', '.pt']: if best.endswith(ext): base = best[:-len(ext)] break # Prefer .pt (TorchScript) over .chkpt (checkpoint) for ext in ['.pt', '.chkpt', '']: candidate = os.path.join(model_path, base + ext) if os.path.isfile(candidate): return candidate # Fallback: find .pt files first, then any model file pt_files = [f for f in os.listdir(model_path) if f.endswith('.pt') and not f.startswith('.')] if len(pt_files) == 1: return os.path.join(model_path, pt_files[0]) model_files = [f for f in os.listdir(model_path) if f.endswith(('.pt', '.chkpt')) and not f.startswith('.')] if len(model_files) == 1: return os.path.join(model_path, model_files[0]) raise ValueError(f'Cannot resolve model file in directory: {model_path}') class TorchScriptPredictor: """Base class for TorchScript model predictors.""" def __init__(self, model_path, device='cuda'): self.device = device resolved_path = resolve_model_path(model_path) self.model = self._load_model(resolved_path) logging.info('TorchScript model loaded: %s (device: %s)', resolved_path, device) def _load_model(self, model_path): """Load TorchScript model from file.""" model = torch.jit.load(model_path, map_location=self.device) model.eval() return model def preprocess(self, images): """ Preprocess images before inference. Override in subclass. images: list of numpy arrays returns: torch.Tensor """ raise NotImplementedError def postprocess(self, outputs): """ Postprocess model outputs. Override in subclass. outputs: torch.Tensor returns: processed results """ raise NotImplementedError def predict(self, streams, **kwargs): """ Run prediction on input streams. streams: list of image byte buffers yields: prediction results """ raise NotImplementedError def run_inference(self, batch): """Run model inference with no_grad context.""" with torch.no_grad(): return self.model(batch)