gueoguessr-game / app.py
jlincar
Changes
ff3fd86
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForImageClassification
from datasets import load_dataset
import random
import torch
# Load model from Hugging Face
model_name = "Jordiett/convnextv2-geoguessr"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
# Load dataset
print("Loading GeoGuessr dataset...")
dataset = load_dataset("marcelomoreno26/geoguessr", split="test")
print(f"Loaded {len(dataset)} test images")
# List of countries
countries = list(model.config.id2label.values())
# Game state
class GameState:
def __init__(self):
self.player_score = 0
self.ai_score = 0
self.rounds = 0
self.current_image = None
self.correct_country = None
self.ai_prediction = None
self.ai_top3 = None
self.options = []
self.used_indices = []
game = GameState()
def get_ai_prediction(image):
"""Get AI prediction"""
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_id = logits.argmax(-1).item()
probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
# Top 3 predictions
top3_prob, top3_idx = torch.topk(probabilities, 3)
top3_countries = [(model.config.id2label[idx.item()], prob.item())
for idx, prob in zip(top3_idx, top3_prob)]
return model.config.id2label[predicted_id], top3_countries
def generate_options(correct_country, ai_prediction):
"""Generate 4 options (1 correct + 3 incorrect)"""
options = [correct_country]
# Add AI prediction if it's wrong (makes it more interesting)
if ai_prediction != correct_country and ai_prediction not in options:
options.append(ai_prediction)
# Fill remaining slots with random countries
other_countries = [c for c in countries if c not in options]
needed = 4 - len(options)
options.extend(random.sample(other_countries, needed))
random.shuffle(options)
return options
def new_round():
"""Start a new round with a random image from the dataset"""
# Select a random image that hasn't been used yet
available_indices = [i for i in range(len(dataset)) if i not in game.used_indices]
if len(available_indices) == 0:
# Reset if all images have been used
game.used_indices = []
available_indices = list(range(len(dataset)))
idx = random.choice(available_indices)
game.used_indices.append(idx)
# Get image and label from dataset
sample = dataset[idx]
image = sample["image"]
game.correct_country = sample["label"]
game.current_image = image
# Get AI prediction (but don't show it yet!)
ai_pred, top3 = get_ai_prediction(image)
game.ai_prediction = ai_pred
game.ai_top3 = top3
# Generate options
game.options = generate_options(game.correct_country, ai_pred)
return (
image,
"๐ŸŒ **Where do you think this image is from?**\n\nMake your choice before seeing the AI's prediction!",
gr.update(choices=game.options, value=None, visible=True),
gr.update(visible=True),
f"๐ŸŽฎ Player: {game.player_score} | ๐Ÿค– AI: {game.ai_score} | ๐ŸŽฏ Rounds: {game.rounds}",
"" # Clear previous result
)
def check_answer(player_choice):
"""Check player's answer"""
if player_choice is None:
return "โš ๏ธ Please select an option!", gr.update(visible=True)
game.rounds += 1
# Check if player is correct
player_correct = (player_choice == game.correct_country)
if player_correct:
game.player_score += 1
# Check if AI is correct
ai_correct = (game.ai_prediction == game.correct_country)
if ai_correct:
game.ai_score += 1
# Result message
result = f"## ๐ŸŽฏ Round {game.rounds} Result\n\n"
result += f"**Correct country:** {game.correct_country}\n\n"
# Show comparison
if player_correct and ai_correct:
result += "๐ŸŽ‰ **It's a tie!** Both you and the AI got it right!\n"
elif player_correct:
result += "๐Ÿ† **You win!** The AI was wrong.\n"
elif ai_correct:
result += "๐Ÿค– **AI wins!** You were wrong.\n"
else:
result += "โŒ **Both failed!**\n"
result += f"\n**Your answer:** {player_choice} {'โœ…' if player_correct else 'โŒ'}\n"
result += f"**AI prediction:** {game.ai_prediction} {'โœ…' if ai_correct else 'โŒ'}\n"
# Show AI's top 3 predictions
result += f"\n**AI's top 3 predictions:**\n"
for country, prob in game.ai_top3:
result += f"- {country}: {prob*100:.1f}%\n"
# Calculate win rate
if game.rounds > 0:
player_rate = (game.player_score / game.rounds) * 100
ai_rate = (game.ai_score / game.rounds) * 100
result += f"\n---\n"
result += f"**Your accuracy:** {player_rate:.1f}% ({game.player_score}/{game.rounds})\n"
result += f"**AI accuracy:** {ai_rate:.1f}% ({game.ai_score}/{game.rounds})\n"
return result, gr.update(visible=True)
def reset_game():
"""Reset the game"""
game.player_score = 0
game.ai_score = 0
game.rounds = 0
game.used_indices = []
return (
None,
"๐ŸŽฎ **Game reset!** Click 'New Round' to start playing.",
gr.update(choices=[], value=None, visible=False),
gr.update(visible=False),
"๐ŸŽฎ Player: 0 | ๐Ÿค– AI: 0 | ๐ŸŽฏ Rounds: 0",
""
)
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft(), title="GeoGuessr: Player vs AI") as demo:
gr.Markdown("""
# ๐ŸŒ GeoGuessr: Player vs AI
Compete against an AI trained with ConvNeXt V2 to guess countries from Google Street View images!
**How to play:**
1. Click "๐ŸŽฎ New Round" to load a random image from the GeoGuessr dataset
2. Choose one of the 4 proposed countries
3. Click "โœ… Check Answer" to see if you beat the AI!
**Model:** ConvNeXt V2 Base (61% accuracy, 51.77% F1-macro)
""")
with gr.Row():
with gr.Column(scale=2):
image_display = gr.Image(type="pil", label="๐Ÿ“ธ Street View Image", interactive=False)
start_btn = gr.Button("๐ŸŽฎ New Round", variant="primary", size="lg")
with gr.Column(scale=1):
scoreboard = gr.Markdown("๐ŸŽฎ Player: 0 | ๐Ÿค– AI: 0 | ๐ŸŽฏ Rounds: 0")
reset_btn = gr.Button("๐Ÿ”„ Reset Game", variant="secondary")
question = gr.Markdown("โฌ‡๏ธ Click 'New Round' to start!")
options = gr.Radio(
choices=[],
label="๐ŸŒ Select the country:",
visible=False
)
submit_btn = gr.Button("โœ… Check Answer", variant="primary", visible=False)
result = gr.Markdown("")
# Events
start_btn.click(
fn=new_round,
inputs=[],
outputs=[image_display, question, options, submit_btn, scoreboard, result]
)
submit_btn.click(
fn=check_answer,
inputs=[options],
outputs=[result, start_btn]
)
reset_btn.click(
fn=reset_game,
outputs=[image_display, question, options, submit_btn, scoreboard, result]
)
gr.Markdown("""
---
**Dataset:** [GeoGuessr by marcelomoreno26](https://huggingface.co/datasets/marcelomoreno26/geoguessr)
**Model:** [ConvNeXt V2 GeoGuessr by Jordiett](https://huggingface.co/Jordiett/convnextv2-geoguessr)
Images are randomly selected from the test set of the GeoGuessr dataset.
""")
if __name__ == "__main__":
demo.launch()