Spaces:
Sleeping
Sleeping
| # app.py | |
| from flask import Flask, request, render_template | |
| # Import necessary classes from transformers and torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import operator # To find max value in dictionary if needed (alternative to user's lambda) | |
| # Initialize Flask app | |
| app = Flask(__name__) | |
| # --- Load Tokenizer and Model (Using your provided code) --- | |
| # Load them globally when the app starts | |
| model_name = "CrabInHoney/urlbert-tiny-v3-malicious-url-classifier" | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| # Set the model to evaluation mode (important for inference) | |
| model.eval() | |
| print(f"Tokenizer and Model '{model_name}' loaded successfully.") | |
| # You can optionally print the model's expected labels if available in config | |
| # print(f"Model config labels (if available): {model.config.id2label}") | |
| except Exception as e: | |
| print(f"Error loading tokenizer or model '{model_name}': {e}") | |
| tokenizer = None | |
| model = None # Flag that loading failed | |
| # --- Prediction Function (Your provided code) --- | |
| def predict_email(email_text): | |
| if not tokenizer or not model: | |
| raise RuntimeError("Tokenizer or Model not loaded.") # Should not happen if initial check passes | |
| # Preprocess and tokenize | |
| inputs = tokenizer( | |
| email_text, | |
| return_tensors="pt", # PyTorch tensors | |
| truncation=True, # Truncate long emails | |
| max_length=512 # Max sequence length for the model | |
| ) | |
| # Get prediction - no need to track gradients for inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Apply softmax to logits to get probabilities | |
| predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| # Get probabilities for each class (index matters!) | |
| probs = predictions[0].tolist() # Get the probabilities for the first (only) input | |
| # --- Create labels dictionary --- | |
| # IMPORTANT: This assumes the model's output logits correspond to these labels IN THIS ORDER. | |
| # Verify this order based on the model card or model.config.id2label if possible. | |
| labels = { | |
| "Legitimate Email": probs[0], | |
| "Phishing Link Detected": probs[1], # Assuming 'phishing_url' means a bad link found | |
| "Legitimate Link Detected": probs[2], # Assuming 'legitimate_url' means a good link found | |
| "Phishing Link Detected (Alt)": probs[3] # Assuming 'phishing_url_alt' is also bad | |
| } | |
| # Determine the most likely classification based on highest probability | |
| # Using operator.itemgetter is slightly more standard than lambda for this case | |
| max_label_item = max(labels.items(), key=operator.itemgetter(1)) | |
| return { | |
| "prediction": max_label_item[0], # The label name with the highest probability | |
| "confidence": max_label_item[1], # The highest probability value | |
| "all_probabilities": labels # Dictionary of all labels and their probabilities | |
| } | |
| # --- Flask Routes --- | |
| # --- Flask Routes --- | |
| def index(): | |
| # Initialize variables for each request | |
| email_text_input = "" | |
| error_message = None | |
| result_details = None | |
| # This block handles when the user SUBMITS the form | |
| if request.method == 'POST': | |
| email_text_input = request.form.get('text', '') | |
| if email_text_input: | |
| try: | |
| # Get the prediction results | |
| result_details = predict_email(email_text_input) | |
| except Exception as e: | |
| # If prediction fails, set an error message | |
| error_message = f"An error occurred: {e}" | |
| else: | |
| # If form is submitted empty | |
| error_message = "Please enter some text to analyze." | |
| # Return the page with the results or an error | |
| return render_template( | |
| 'index.html', | |
| result=result_details, | |
| text=email_text_input, | |
| error=error_message | |
| ) | |
| # This block handles when the page FIRST LOADS (GET request) | |
| # It returns the clean, empty page. | |
| return render_template('index.html') | |
| # --- Run the App --- | |