File size: 7,720 Bytes
8eba60b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff3fd86
8eba60b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff3fd86
8eba60b
 
ff3fd86
8eba60b
 
 
 
 
 
ff3fd86
8eba60b
 
ff3fd86
 
8eba60b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff3fd86
8eba60b
 
 
 
 
 
 
 
 
 
 
 
ff3fd86
 
 
 
 
8eba60b
 
 
 
ff3fd86
 
8eba60b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff3fd86
8eba60b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff3fd86
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForImageClassification
from datasets import load_dataset
import random
import torch

# Load model from Hugging Face
model_name = "Jordiett/convnextv2-geoguessr"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)

# Load dataset
print("Loading GeoGuessr dataset...")
dataset = load_dataset("marcelomoreno26/geoguessr", split="test")
print(f"Loaded {len(dataset)} test images")

# List of countries
countries = list(model.config.id2label.values())

# Game state
class GameState:
    def __init__(self):
        self.player_score = 0
        self.ai_score = 0
        self.rounds = 0
        self.current_image = None
        self.correct_country = None
        self.ai_prediction = None
        self.ai_top3 = None
        self.options = []
        self.used_indices = []

game = GameState()

def get_ai_prediction(image):
    """Get AI prediction"""
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predicted_id = logits.argmax(-1).item()
    probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
    
    # Top 3 predictions
    top3_prob, top3_idx = torch.topk(probabilities, 3)
    top3_countries = [(model.config.id2label[idx.item()], prob.item()) 
                      for idx, prob in zip(top3_idx, top3_prob)]
    
    return model.config.id2label[predicted_id], top3_countries

def generate_options(correct_country, ai_prediction):
    """Generate 4 options (1 correct + 3 incorrect)"""
    options = [correct_country]
    
    # Add AI prediction if it's wrong (makes it more interesting)
    if ai_prediction != correct_country and ai_prediction not in options:
        options.append(ai_prediction)
    
    # Fill remaining slots with random countries
    other_countries = [c for c in countries if c not in options]
    needed = 4 - len(options)
    options.extend(random.sample(other_countries, needed))
    
    random.shuffle(options)
    return options

def new_round():
    """Start a new round with a random image from the dataset"""
    # Select a random image that hasn't been used yet
    available_indices = [i for i in range(len(dataset)) if i not in game.used_indices]
    
    if len(available_indices) == 0:
        # Reset if all images have been used
        game.used_indices = []
        available_indices = list(range(len(dataset)))
    
    idx = random.choice(available_indices)
    game.used_indices.append(idx)
    
    # Get image and label from dataset
    sample = dataset[idx]
    image = sample["image"]
    game.correct_country = sample["label"]
    game.current_image = image
    
    # Get AI prediction (but don't show it yet!)
    ai_pred, top3 = get_ai_prediction(image)
    game.ai_prediction = ai_pred
    game.ai_top3 = top3
    
    # Generate options
    game.options = generate_options(game.correct_country, ai_pred)
    
    return (
        image,
        "๐ŸŒ **Where do you think this image is from?**\n\nMake your choice before seeing the AI's prediction!",
        gr.update(choices=game.options, value=None, visible=True),
        gr.update(visible=True),
        f"๐ŸŽฎ Player: {game.player_score} | ๐Ÿค– AI: {game.ai_score} | ๐ŸŽฏ Rounds: {game.rounds}",
        ""  # Clear previous result
    )

def check_answer(player_choice):
    """Check player's answer"""
    if player_choice is None:
        return "โš ๏ธ Please select an option!", gr.update(visible=True)
    
    game.rounds += 1
    
    # Check if player is correct
    player_correct = (player_choice == game.correct_country)
    if player_correct:
        game.player_score += 1
    
    # Check if AI is correct
    ai_correct = (game.ai_prediction == game.correct_country)
    if ai_correct:
        game.ai_score += 1
    
    # Result message
    result = f"## ๐ŸŽฏ Round {game.rounds} Result\n\n"
    result += f"**Correct country:** {game.correct_country}\n\n"
    
    # Show comparison
    if player_correct and ai_correct:
        result += "๐ŸŽ‰ **It's a tie!** Both you and the AI got it right!\n"
    elif player_correct:
        result += "๐Ÿ† **You win!** The AI was wrong.\n"
    elif ai_correct:
        result += "๐Ÿค– **AI wins!** You were wrong.\n"
    else:
        result += "โŒ **Both failed!**\n"
    
    result += f"\n**Your answer:** {player_choice} {'โœ…' if player_correct else 'โŒ'}\n"
    result += f"**AI prediction:** {game.ai_prediction} {'โœ…' if ai_correct else 'โŒ'}\n"
    
    # Show AI's top 3 predictions
    result += f"\n**AI's top 3 predictions:**\n"
    for country, prob in game.ai_top3:
        result += f"- {country}: {prob*100:.1f}%\n"
    
    # Calculate win rate
    if game.rounds > 0:
        player_rate = (game.player_score / game.rounds) * 100
        ai_rate = (game.ai_score / game.rounds) * 100
        result += f"\n---\n"
        result += f"**Your accuracy:** {player_rate:.1f}% ({game.player_score}/{game.rounds})\n"
        result += f"**AI accuracy:** {ai_rate:.1f}% ({game.ai_score}/{game.rounds})\n"
    
    return result, gr.update(visible=True)

def reset_game():
    """Reset the game"""
    game.player_score = 0
    game.ai_score = 0
    game.rounds = 0
    game.used_indices = []
    return (
        None,
        "๐ŸŽฎ **Game reset!** Click 'New Round' to start playing.",
        gr.update(choices=[], value=None, visible=False),
        gr.update(visible=False),
        "๐ŸŽฎ Player: 0 | ๐Ÿค– AI: 0 | ๐ŸŽฏ Rounds: 0",
        ""
    )

# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft(), title="GeoGuessr: Player vs AI") as demo:
    gr.Markdown("""
    # ๐ŸŒ GeoGuessr: Player vs AI
    
    Compete against an AI trained with ConvNeXt V2 to guess countries from Google Street View images!
    
    **How to play:**
    1. Click "๐ŸŽฎ New Round" to load a random image from the GeoGuessr dataset
    2. Choose one of the 4 proposed countries
    3. Click "โœ… Check Answer" to see if you beat the AI!
    
    **Model:** ConvNeXt V2 Base (61% accuracy, 51.77% F1-macro)
    """)
    
    with gr.Row():
        with gr.Column(scale=2):
            image_display = gr.Image(type="pil", label="๐Ÿ“ธ Street View Image", interactive=False)
            start_btn = gr.Button("๐ŸŽฎ New Round", variant="primary", size="lg")
        
        with gr.Column(scale=1):
            scoreboard = gr.Markdown("๐ŸŽฎ Player: 0 | ๐Ÿค– AI: 0 | ๐ŸŽฏ Rounds: 0")
            reset_btn = gr.Button("๐Ÿ”„ Reset Game", variant="secondary")
    
    question = gr.Markdown("โฌ‡๏ธ Click 'New Round' to start!")
    
    options = gr.Radio(
        choices=[],
        label="๐ŸŒ Select the country:",
        visible=False
    )
    
    submit_btn = gr.Button("โœ… Check Answer", variant="primary", visible=False)
    
    result = gr.Markdown("")
    
    # Events
    start_btn.click(
        fn=new_round,
        inputs=[],
        outputs=[image_display, question, options, submit_btn, scoreboard, result]
    )
    
    submit_btn.click(
        fn=check_answer,
        inputs=[options],
        outputs=[result, start_btn]
    )
    
    reset_btn.click(
        fn=reset_game,
        outputs=[image_display, question, options, submit_btn, scoreboard, result]
    )
    
    gr.Markdown("""
    ---
    **Dataset:** [GeoGuessr by marcelomoreno26](https://huggingface.co/datasets/marcelomoreno26/geoguessr)  
    **Model:** [ConvNeXt V2 GeoGuessr by Jordiett](https://huggingface.co/Jordiett/convnextv2-geoguessr)
    
    Images are randomly selected from the test set of the GeoGuessr dataset.
    """)

if __name__ == "__main__":
    demo.launch()