Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from datasets import load_dataset | |
| import random | |
| import torch | |
| # Load model from Hugging Face | |
| model_name = "Jordiett/convnextv2-geoguessr" | |
| processor = AutoImageProcessor.from_pretrained(model_name) | |
| model = AutoModelForImageClassification.from_pretrained(model_name) | |
| # Load dataset | |
| print("Loading GeoGuessr dataset...") | |
| dataset = load_dataset("marcelomoreno26/geoguessr", split="test") | |
| print(f"Loaded {len(dataset)} test images") | |
| # List of countries | |
| countries = list(model.config.id2label.values()) | |
| # Game state | |
| class GameState: | |
| def __init__(self): | |
| self.player_score = 0 | |
| self.ai_score = 0 | |
| self.rounds = 0 | |
| self.current_image = None | |
| self.correct_country = None | |
| self.ai_prediction = None | |
| self.ai_top3 = None | |
| self.options = [] | |
| self.used_indices = [] | |
| game = GameState() | |
| def get_ai_prediction(image): | |
| """Get AI prediction""" | |
| inputs = processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| predicted_id = logits.argmax(-1).item() | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1)[0] | |
| # Top 3 predictions | |
| top3_prob, top3_idx = torch.topk(probabilities, 3) | |
| top3_countries = [(model.config.id2label[idx.item()], prob.item()) | |
| for idx, prob in zip(top3_idx, top3_prob)] | |
| return model.config.id2label[predicted_id], top3_countries | |
| def generate_options(correct_country, ai_prediction): | |
| """Generate 4 options (1 correct + 3 incorrect)""" | |
| options = [correct_country] | |
| # Add AI prediction if it's wrong (makes it more interesting) | |
| if ai_prediction != correct_country and ai_prediction not in options: | |
| options.append(ai_prediction) | |
| # Fill remaining slots with random countries | |
| other_countries = [c for c in countries if c not in options] | |
| needed = 4 - len(options) | |
| options.extend(random.sample(other_countries, needed)) | |
| random.shuffle(options) | |
| return options | |
| def new_round(): | |
| """Start a new round with a random image from the dataset""" | |
| # Select a random image that hasn't been used yet | |
| available_indices = [i for i in range(len(dataset)) if i not in game.used_indices] | |
| if len(available_indices) == 0: | |
| # Reset if all images have been used | |
| game.used_indices = [] | |
| available_indices = list(range(len(dataset))) | |
| idx = random.choice(available_indices) | |
| game.used_indices.append(idx) | |
| # Get image and label from dataset | |
| sample = dataset[idx] | |
| image = sample["image"] | |
| game.correct_country = sample["label"] | |
| game.current_image = image | |
| # Get AI prediction (but don't show it yet!) | |
| ai_pred, top3 = get_ai_prediction(image) | |
| game.ai_prediction = ai_pred | |
| game.ai_top3 = top3 | |
| # Generate options | |
| game.options = generate_options(game.correct_country, ai_pred) | |
| return ( | |
| image, | |
| "๐ **Where do you think this image is from?**\n\nMake your choice before seeing the AI's prediction!", | |
| gr.update(choices=game.options, value=None, visible=True), | |
| gr.update(visible=True), | |
| f"๐ฎ Player: {game.player_score} | ๐ค AI: {game.ai_score} | ๐ฏ Rounds: {game.rounds}", | |
| "" # Clear previous result | |
| ) | |
| def check_answer(player_choice): | |
| """Check player's answer""" | |
| if player_choice is None: | |
| return "โ ๏ธ Please select an option!", gr.update(visible=True) | |
| game.rounds += 1 | |
| # Check if player is correct | |
| player_correct = (player_choice == game.correct_country) | |
| if player_correct: | |
| game.player_score += 1 | |
| # Check if AI is correct | |
| ai_correct = (game.ai_prediction == game.correct_country) | |
| if ai_correct: | |
| game.ai_score += 1 | |
| # Result message | |
| result = f"## ๐ฏ Round {game.rounds} Result\n\n" | |
| result += f"**Correct country:** {game.correct_country}\n\n" | |
| # Show comparison | |
| if player_correct and ai_correct: | |
| result += "๐ **It's a tie!** Both you and the AI got it right!\n" | |
| elif player_correct: | |
| result += "๐ **You win!** The AI was wrong.\n" | |
| elif ai_correct: | |
| result += "๐ค **AI wins!** You were wrong.\n" | |
| else: | |
| result += "โ **Both failed!**\n" | |
| result += f"\n**Your answer:** {player_choice} {'โ ' if player_correct else 'โ'}\n" | |
| result += f"**AI prediction:** {game.ai_prediction} {'โ ' if ai_correct else 'โ'}\n" | |
| # Show AI's top 3 predictions | |
| result += f"\n**AI's top 3 predictions:**\n" | |
| for country, prob in game.ai_top3: | |
| result += f"- {country}: {prob*100:.1f}%\n" | |
| # Calculate win rate | |
| if game.rounds > 0: | |
| player_rate = (game.player_score / game.rounds) * 100 | |
| ai_rate = (game.ai_score / game.rounds) * 100 | |
| result += f"\n---\n" | |
| result += f"**Your accuracy:** {player_rate:.1f}% ({game.player_score}/{game.rounds})\n" | |
| result += f"**AI accuracy:** {ai_rate:.1f}% ({game.ai_score}/{game.rounds})\n" | |
| return result, gr.update(visible=True) | |
| def reset_game(): | |
| """Reset the game""" | |
| game.player_score = 0 | |
| game.ai_score = 0 | |
| game.rounds = 0 | |
| game.used_indices = [] | |
| return ( | |
| None, | |
| "๐ฎ **Game reset!** Click 'New Round' to start playing.", | |
| gr.update(choices=[], value=None, visible=False), | |
| gr.update(visible=False), | |
| "๐ฎ Player: 0 | ๐ค AI: 0 | ๐ฏ Rounds: 0", | |
| "" | |
| ) | |
| # Gradio Interface | |
| with gr.Blocks(theme=gr.themes.Soft(), title="GeoGuessr: Player vs AI") as demo: | |
| gr.Markdown(""" | |
| # ๐ GeoGuessr: Player vs AI | |
| Compete against an AI trained with ConvNeXt V2 to guess countries from Google Street View images! | |
| **How to play:** | |
| 1. Click "๐ฎ New Round" to load a random image from the GeoGuessr dataset | |
| 2. Choose one of the 4 proposed countries | |
| 3. Click "โ Check Answer" to see if you beat the AI! | |
| **Model:** ConvNeXt V2 Base (61% accuracy, 51.77% F1-macro) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_display = gr.Image(type="pil", label="๐ธ Street View Image", interactive=False) | |
| start_btn = gr.Button("๐ฎ New Round", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| scoreboard = gr.Markdown("๐ฎ Player: 0 | ๐ค AI: 0 | ๐ฏ Rounds: 0") | |
| reset_btn = gr.Button("๐ Reset Game", variant="secondary") | |
| question = gr.Markdown("โฌ๏ธ Click 'New Round' to start!") | |
| options = gr.Radio( | |
| choices=[], | |
| label="๐ Select the country:", | |
| visible=False | |
| ) | |
| submit_btn = gr.Button("โ Check Answer", variant="primary", visible=False) | |
| result = gr.Markdown("") | |
| # Events | |
| start_btn.click( | |
| fn=new_round, | |
| inputs=[], | |
| outputs=[image_display, question, options, submit_btn, scoreboard, result] | |
| ) | |
| submit_btn.click( | |
| fn=check_answer, | |
| inputs=[options], | |
| outputs=[result, start_btn] | |
| ) | |
| reset_btn.click( | |
| fn=reset_game, | |
| outputs=[image_display, question, options, submit_btn, scoreboard, result] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **Dataset:** [GeoGuessr by marcelomoreno26](https://huggingface.co/datasets/marcelomoreno26/geoguessr) | |
| **Model:** [ConvNeXt V2 GeoGuessr by Jordiett](https://huggingface.co/Jordiett/convnextv2-geoguessr) | |
| Images are randomly selected from the test set of the GeoGuessr dataset. | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |