Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import joblib | |
| import pandas as pd | |
| import os | |
| import json | |
| from safetensors.torch import load_file | |
| from typing import List, Tuple | |
| from network import PricePredictor | |
| MODEL_DIR = "model" | |
| DATA_DIR = "data" | |
| SCALER_PATH = os.path.join(DATA_DIR, "scaler.pkl") | |
| DATA_PATH = os.path.join(DATA_DIR, "pokemon_final_with_labels.csv") | |
| TARGET_COLUMN = 'price_will_rise_30_in_6m' | |
| def load_model_and_config(model_dir: str) -> Tuple[torch.nn.Module, List[str]]: | |
| config_path = os.path.join(model_dir, "config.json") | |
| with open(config_path, "r") as f: | |
| model_config = json.load(f) | |
| model = PricePredictor(input_size=model_config["input_size"]) | |
| weights_path = os.path.join(model_dir, "model.safetensors") | |
| model.load_state_dict(load_file(weights_path)) | |
| model.eval() | |
| return model, model_config["feature_columns"] | |
| def perform_prediction(model: torch.nn.Module, scaler, input_features: pd.Series) -> Tuple[bool, float]: | |
| features_np = input_features.to_numpy(dtype="float32").reshape(1, -1) | |
| features_scaled = scaler.transform(features_np) | |
| features_tensor = torch.tensor(features_scaled, dtype=torch.float32) | |
| with torch.no_grad(): | |
| logit = model(features_tensor) | |
| probability = torch.sigmoid(logit).item() | |
| predicted_class = bool(round(probability)) | |
| return predicted_class, probability | |
| try: | |
| model, feature_columns = load_model_and_config(MODEL_DIR) | |
| scaler = joblib.load(SCALER_PATH) | |
| full_data = pd.read_csv(DATA_PATH) | |
| ASSETS_LOADED = True | |
| except FileNotFoundError as e: | |
| print(f"Error loading necessary files: {e}") | |
| print("Please make sure you have uploaded the 'model' and 'data' directories to your Hugging Face Space.") | |
| ASSETS_LOADED = False | |
| def predict_price_trend(card_identifier: str) -> str: | |
| if not ASSETS_LOADED: | |
| return "## Application Error\nAssets could not be loaded. Please check the logs on Hugging Face Spaces for details. You may need to upload your `model` and `data` directories." | |
| if not card_identifier or not card_identifier.strip().isdigit(): | |
| return "## Input Error\nPlease enter a valid, numeric TCGPlayer ID." | |
| card_id = int(card_identifier.strip()) | |
| card_data = full_data[full_data['tcgplayer_id'] == card_id] | |
| if card_data.empty: | |
| return f"## Card Not Found\nCould not find a card with TCGPlayer ID '{card_id}'. Please check the ID and try again." | |
| card_sample = card_data.iloc[0] | |
| sample_features = card_sample[feature_columns] | |
| predicted_class, probability = perform_prediction(model, scaler, sample_features) | |
| prediction_text = "**RISE**" if predicted_class else "**NOT RISE**" | |
| confidence = probability if predicted_class else 1 - probability | |
| tcgplayer_id = card_sample['tcgplayer_id'] | |
| tcgplayer_link = f"https://www.tcgplayer.com/product/{tcgplayer_id}?Language=English" | |
| true_label_text = "" | |
| try: | |
| if TARGET_COLUMN in card_sample and pd.notna(card_sample[TARGET_COLUMN]): | |
| true_label = bool(card_sample[TARGET_COLUMN]) | |
| true_label_text = f"\n- **Actual Result in Dataset:** The price did **{'RISE' if true_label else 'NOT RISE'}**." | |
| except (KeyError, TypeError): | |
| pass | |
| output = f""" | |
| ## 🔮 Prediction Report for {card_sample['name']} | |
| - **Prediction:** The model predicts the card's price will {prediction_text} by 30% in the next 6 months. | |
| - **Confidence:** {confidence:.2%} | |
| - **View on TCGPlayer:** [Check Current Price]({tcgplayer_link}) | |
| {true_label_text} | |
| """ | |
| return output | |
| with gr.Blocks(theme=gr.themes.Soft(), title="PricePoke Predictor") as demo: | |
| gr.Markdown( | |
| """ | |
| # 📈 PricePoke: Pokémon Card Price Trend Predictor | |
| Enter a Pokémon card's TCGPlayer ID to predict whether its market price will increase by 30% or more over the next 6 months. | |
| This model was trained on historical TCGPlayer market data. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| card_input = gr.Textbox( | |
| label="TCGPlayer ID", | |
| placeholder="e.g., '84198'", | |
| info="Find the ID in the card's URL on TCGPlayer's website (e.g., tcgplayer.com/product/84198/... has ID 84198)." | |
| ) | |
| predict_button = gr.Button("Predict Trend", variant="primary") | |
| gr.Markdown("---") | |
| gr.Markdown("### Example Cards") | |
| if ASSETS_LOADED: | |
| example_df = full_data.sample(5, random_state=42)[['name', 'tcgplayer_id']] | |
| gr.Markdown(example_df.to_markdown(index=False)) | |
| else: | |
| gr.Markdown("Could not load examples.") | |
| with gr.Column(scale=2): | |
| output_markdown = gr.Markdown() | |
| predict_button.click(fn=predict_price_trend, inputs=[card_input], outputs=[output_markdown]) | |
| card_input.submit(fn=predict_price_trend, inputs=[card_input], outputs=[output_markdown]) | |
| if __name__ == "__main__": | |
| demo.launch() |