#!/usr/bin/env python3 """ Simple inference script for anime style classification. """ import torch from torchvision import models, transforms from PIL import Image import json from pathlib import Path def load_model(model_path=None, config_path='config.json'): """Load the trained model. Args: model_path: Path to model weights. If None, auto-detects (.safetensors preferred) config_path: Path to config file """ # Load config with open(config_path, 'r') as f: config = json.load(f) # Create model model = models.efficientnet_b0(pretrained=False) num_classes = config['num_classes'] model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, num_classes) # Auto-detect model format if not specified if model_path is None: if Path('model.safetensors').exists(): model_path = 'model.safetensors' elif Path('pytorch_model.pth').exists(): model_path = 'pytorch_model.pth' else: raise FileNotFoundError("No model weights found (model.safetensors or pytorch_model.pth)") # Load weights based on format if model_path.endswith('.safetensors'): from safetensors.torch import load_file state_dict = load_file(model_path) model.load_state_dict(state_dict) else: checkpoint = torch.load(model_path, map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) model.eval() return model, config def preprocess_image(image_path, img_size=224): """Preprocess image for model input.""" transform = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = Image.open(image_path).convert('RGB') return transform(image).unsqueeze(0) def classify_image(model, config, image_path): """Classify a single image.""" # Preprocess input_tensor = preprocess_image(image_path) # Inference with torch.no_grad(): output = model(input_tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) # Get predictions results = [] for idx, prob in enumerate(probabilities): style = config['id2label'][str(idx)] results.append({ 'style': style, 'confidence': float(prob) }) # Sort by confidence results.sort(key=lambda x: x['confidence'], reverse=True) return results def main(): import argparse parser = argparse.ArgumentParser(description='Classify anime style') parser.add_argument('image', type=str, help='Path to image') parser.add_argument('--model', type=str, default=None, help='Path to model weights (auto-detects .safetensors or .pth if not specified)') parser.add_argument('--config', type=str, default='config.json') parser.add_argument('--top-k', type=int, default=3, help='Show top-K predictions') args = parser.parse_args() # Load model print(f"Loading model from {args.model}...") model, config = load_model(args.model, args.config) # Classify print(f"Classifying {args.image}...") results = classify_image(model, config, args.image) # Display results print() print("=" * 60) print("PREDICTIONS") print("=" * 60) for i, result in enumerate(results[:args.top_k], 1): print(f"{i}. {result['style']:12s} {result['confidence']:>7.2%}") print("=" * 60) print() print(f"Top prediction: {results[0]['style']} ({results[0]['confidence']:.2%})") if __name__ == '__main__': main()