Spaces:
Sleeping
Sleeping
| """ | |
| TorchScript model predictor base class. | |
| Loads and runs inference on TorchScript (.pt) models. | |
| """ | |
| import os | |
| import torch | |
| import yaml | |
| import logging | |
| def resolve_model_path(model_path): | |
| """Resolve model path from directory using .state.yaml if needed. | |
| If model_path is a file, return as-is. | |
| If model_path is a directory, read .state.yaml 'best' field to find the model file. | |
| """ | |
| if os.path.isfile(model_path): | |
| return model_path | |
| if not os.path.isdir(model_path): | |
| raise ValueError(f'Model path not found: {model_path}') | |
| state_file = os.path.join(model_path, '.state.yaml') | |
| if os.path.exists(state_file): | |
| with open(state_file, 'r') as f: | |
| state = yaml.safe_load(f) | |
| best = state.get('best') | |
| if best: | |
| # Strip extension from best to get base name | |
| base = best | |
| for ext in ['.chkpt', '.pt']: | |
| if best.endswith(ext): | |
| base = best[:-len(ext)] | |
| break | |
| # Prefer .pt (TorchScript) over .chkpt (checkpoint) | |
| for ext in ['.pt', '.chkpt', '']: | |
| candidate = os.path.join(model_path, base + ext) | |
| if os.path.isfile(candidate): | |
| return candidate | |
| # Fallback: find .pt files first, then any model file | |
| pt_files = [f for f in os.listdir(model_path) | |
| if f.endswith('.pt') and not f.startswith('.')] | |
| if len(pt_files) == 1: | |
| return os.path.join(model_path, pt_files[0]) | |
| model_files = [f for f in os.listdir(model_path) | |
| if f.endswith(('.pt', '.chkpt')) and not f.startswith('.')] | |
| if len(model_files) == 1: | |
| return os.path.join(model_path, model_files[0]) | |
| raise ValueError(f'Cannot resolve model file in directory: {model_path}') | |
| class TorchScriptPredictor: | |
| """Base class for TorchScript model predictors.""" | |
| def __init__(self, model_path, device='cuda'): | |
| self.device = device | |
| resolved_path = resolve_model_path(model_path) | |
| self.model = self._load_model(resolved_path) | |
| logging.info('TorchScript model loaded: %s (device: %s)', resolved_path, device) | |
| def _load_model(self, model_path): | |
| """Load TorchScript model from file.""" | |
| model = torch.jit.load(model_path, map_location=self.device) | |
| model.eval() | |
| return model | |
| def preprocess(self, images): | |
| """ | |
| Preprocess images before inference. | |
| Override in subclass. | |
| images: list of numpy arrays | |
| returns: torch.Tensor | |
| """ | |
| raise NotImplementedError | |
| def postprocess(self, outputs): | |
| """ | |
| Postprocess model outputs. | |
| Override in subclass. | |
| outputs: torch.Tensor | |
| returns: processed results | |
| """ | |
| raise NotImplementedError | |
| def predict(self, streams, **kwargs): | |
| """ | |
| Run prediction on input streams. | |
| streams: list of image byte buffers | |
| yields: prediction results | |
| """ | |
| raise NotImplementedError | |
| def run_inference(self, batch): | |
| """Run model inference with no_grad context.""" | |
| with torch.no_grad(): | |
| return self.model(batch) | |