Mitchins's picture
Upload folder using huggingface_hub
53ec084 verified
raw
history blame
3.72 kB
#!/usr/bin/env python3
"""
Simple inference script for anime style classification.
"""
import torch
from torchvision import models, transforms
from PIL import Image
import json
from pathlib import Path
def load_model(model_path=None, config_path='config.json'):
"""Load the trained model.
Args:
model_path: Path to model weights. If None, auto-detects (.safetensors preferred)
config_path: Path to config file
"""
# Load config
with open(config_path, 'r') as f:
config = json.load(f)
# Create model
model = models.efficientnet_b0(pretrained=False)
num_classes = config['num_classes']
model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, num_classes)
# Auto-detect model format if not specified
if model_path is None:
if Path('model.safetensors').exists():
model_path = 'model.safetensors'
elif Path('pytorch_model.pth').exists():
model_path = 'pytorch_model.pth'
else:
raise FileNotFoundError("No model weights found (model.safetensors or pytorch_model.pth)")
# Load weights based on format
if model_path.endswith('.safetensors'):
from safetensors.torch import load_file
state_dict = load_file(model_path)
model.load_state_dict(state_dict)
else:
checkpoint = torch.load(model_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model, config
def preprocess_image(image_path, img_size=224):
"""Preprocess image for model input."""
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert('RGB')
return transform(image).unsqueeze(0)
def classify_image(model, config, image_path):
"""Classify a single image."""
# Preprocess
input_tensor = preprocess_image(image_path)
# Inference
with torch.no_grad():
output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Get predictions
results = []
for idx, prob in enumerate(probabilities):
style = config['id2label'][str(idx)]
results.append({
'style': style,
'confidence': float(prob)
})
# Sort by confidence
results.sort(key=lambda x: x['confidence'], reverse=True)
return results
def main():
import argparse
parser = argparse.ArgumentParser(description='Classify anime style')
parser.add_argument('image', type=str, help='Path to image')
parser.add_argument('--model', type=str, default=None,
help='Path to model weights (auto-detects .safetensors or .pth if not specified)')
parser.add_argument('--config', type=str, default='config.json')
parser.add_argument('--top-k', type=int, default=3, help='Show top-K predictions')
args = parser.parse_args()
# Load model
print(f"Loading model from {args.model}...")
model, config = load_model(args.model, args.config)
# Classify
print(f"Classifying {args.image}...")
results = classify_image(model, config, args.image)
# Display results
print()
print("=" * 60)
print("PREDICTIONS")
print("=" * 60)
for i, result in enumerate(results[:args.top_k], 1):
print(f"{i}. {result['style']:12s} {result['confidence']:>7.2%}")
print("=" * 60)
print()
print(f"Top prediction: {results[0]['style']} ({results[0]['confidence']:.2%})")
if __name__ == '__main__':
main()