|
|
|
|
|
""" |
|
|
Simple inference script for anime style classification. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from torchvision import models, transforms |
|
|
from PIL import Image |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
def load_model(model_path=None, config_path='config.json'): |
|
|
"""Load the trained model. |
|
|
|
|
|
Args: |
|
|
model_path: Path to model weights. If None, auto-detects (.safetensors preferred) |
|
|
config_path: Path to config file |
|
|
""" |
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
model = models.efficientnet_b0(pretrained=False) |
|
|
num_classes = config['num_classes'] |
|
|
model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, num_classes) |
|
|
|
|
|
|
|
|
if model_path is None: |
|
|
if Path('model.safetensors').exists(): |
|
|
model_path = 'model.safetensors' |
|
|
elif Path('pytorch_model.pth').exists(): |
|
|
model_path = 'pytorch_model.pth' |
|
|
else: |
|
|
raise FileNotFoundError("No model weights found (model.safetensors or pytorch_model.pth)") |
|
|
|
|
|
|
|
|
if model_path.endswith('.safetensors'): |
|
|
from safetensors.torch import load_file |
|
|
state_dict = load_file(model_path) |
|
|
model.load_state_dict(state_dict) |
|
|
else: |
|
|
checkpoint = torch.load(model_path, map_location='cpu') |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
|
|
model.eval() |
|
|
return model, config |
|
|
|
|
|
|
|
|
def preprocess_image(image_path, img_size=224): |
|
|
"""Preprocess image for model input.""" |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((img_size, img_size)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
return transform(image).unsqueeze(0) |
|
|
|
|
|
|
|
|
def classify_image(model, config, image_path): |
|
|
"""Classify a single image.""" |
|
|
|
|
|
input_tensor = preprocess_image(image_path) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(input_tensor) |
|
|
probabilities = torch.nn.functional.softmax(output[0], dim=0) |
|
|
|
|
|
|
|
|
results = [] |
|
|
for idx, prob in enumerate(probabilities): |
|
|
style = config['id2label'][str(idx)] |
|
|
results.append({ |
|
|
'style': style, |
|
|
'confidence': float(prob) |
|
|
}) |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x['confidence'], reverse=True) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def main(): |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Classify anime style') |
|
|
parser.add_argument('image', type=str, help='Path to image') |
|
|
parser.add_argument('--model', type=str, default=None, |
|
|
help='Path to model weights (auto-detects .safetensors or .pth if not specified)') |
|
|
parser.add_argument('--config', type=str, default='config.json') |
|
|
parser.add_argument('--top-k', type=int, default=3, help='Show top-K predictions') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
print(f"Loading model from {args.model}...") |
|
|
model, config = load_model(args.model, args.config) |
|
|
|
|
|
|
|
|
print(f"Classifying {args.image}...") |
|
|
results = classify_image(model, config, args.image) |
|
|
|
|
|
|
|
|
print() |
|
|
print("=" * 60) |
|
|
print("PREDICTIONS") |
|
|
print("=" * 60) |
|
|
for i, result in enumerate(results[:args.top_k], 1): |
|
|
print(f"{i}. {result['style']:12s} {result['confidence']:>7.2%}") |
|
|
print("=" * 60) |
|
|
print() |
|
|
print(f"Top prediction: {results[0]['style']} ({results[0]['confidence']:.2%})") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|