File size: 3,719 Bytes
53ec084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python3
"""
Simple inference script for anime style classification.
"""

import torch
from torchvision import models, transforms
from PIL import Image
import json
from pathlib import Path


def load_model(model_path=None, config_path='config.json'):
    """Load the trained model.

    Args:
        model_path: Path to model weights. If None, auto-detects (.safetensors preferred)
        config_path: Path to config file
    """
    # Load config
    with open(config_path, 'r') as f:
        config = json.load(f)

    # Create model
    model = models.efficientnet_b0(pretrained=False)
    num_classes = config['num_classes']
    model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, num_classes)

    # Auto-detect model format if not specified
    if model_path is None:
        if Path('model.safetensors').exists():
            model_path = 'model.safetensors'
        elif Path('pytorch_model.pth').exists():
            model_path = 'pytorch_model.pth'
        else:
            raise FileNotFoundError("No model weights found (model.safetensors or pytorch_model.pth)")

    # Load weights based on format
    if model_path.endswith('.safetensors'):
        from safetensors.torch import load_file
        state_dict = load_file(model_path)
        model.load_state_dict(state_dict)
    else:
        checkpoint = torch.load(model_path, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])

    model.eval()
    return model, config


def preprocess_image(image_path, img_size=224):
    """Preprocess image for model input."""
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])

    image = Image.open(image_path).convert('RGB')
    return transform(image).unsqueeze(0)


def classify_image(model, config, image_path):
    """Classify a single image."""
    # Preprocess
    input_tensor = preprocess_image(image_path)

    # Inference
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)

    # Get predictions
    results = []
    for idx, prob in enumerate(probabilities):
        style = config['id2label'][str(idx)]
        results.append({
            'style': style,
            'confidence': float(prob)
        })

    # Sort by confidence
    results.sort(key=lambda x: x['confidence'], reverse=True)

    return results


def main():
    import argparse

    parser = argparse.ArgumentParser(description='Classify anime style')
    parser.add_argument('image', type=str, help='Path to image')
    parser.add_argument('--model', type=str, default=None,
                       help='Path to model weights (auto-detects .safetensors or .pth if not specified)')
    parser.add_argument('--config', type=str, default='config.json')
    parser.add_argument('--top-k', type=int, default=3, help='Show top-K predictions')

    args = parser.parse_args()

    # Load model
    print(f"Loading model from {args.model}...")
    model, config = load_model(args.model, args.config)

    # Classify
    print(f"Classifying {args.image}...")
    results = classify_image(model, config, args.image)

    # Display results
    print()
    print("=" * 60)
    print("PREDICTIONS")
    print("=" * 60)
    for i, result in enumerate(results[:args.top_k], 1):
        print(f"{i}. {result['style']:12s} {result['confidence']:>7.2%}")
    print("=" * 60)
    print()
    print(f"Top prediction: {results[0]['style']} ({results[0]['confidence']:.2%})")


if __name__ == '__main__':
    main()