# app.py from flask import Flask, request, render_template # Import necessary classes from transformers and torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import operator # To find max value in dictionary if needed (alternative to user's lambda) # Initialize Flask app app = Flask(__name__) # --- Load Tokenizer and Model (Using your provided code) --- # Load them globally when the app starts model_name = "CrabInHoney/urlbert-tiny-v3-malicious-url-classifier" try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) # Set the model to evaluation mode (important for inference) model.eval() print(f"Tokenizer and Model '{model_name}' loaded successfully.") # You can optionally print the model's expected labels if available in config # print(f"Model config labels (if available): {model.config.id2label}") except Exception as e: print(f"Error loading tokenizer or model '{model_name}': {e}") tokenizer = None model = None # Flag that loading failed # --- Prediction Function (Your provided code) --- def predict_email(email_text): if not tokenizer or not model: raise RuntimeError("Tokenizer or Model not loaded.") # Should not happen if initial check passes # Preprocess and tokenize inputs = tokenizer( email_text, return_tensors="pt", # PyTorch tensors truncation=True, # Truncate long emails max_length=512 # Max sequence length for the model ) # Get prediction - no need to track gradients for inference with torch.no_grad(): outputs = model(**inputs) # Apply softmax to logits to get probabilities predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) # Get probabilities for each class (index matters!) probs = predictions[0].tolist() # Get the probabilities for the first (only) input # --- Create labels dictionary --- # IMPORTANT: This assumes the model's output logits correspond to these labels IN THIS ORDER. # Verify this order based on the model card or model.config.id2label if possible. labels = { "Legitimate Email": probs[0], "Phishing Link Detected": probs[1], # Assuming 'phishing_url' means a bad link found "Legitimate Link Detected": probs[2], # Assuming 'legitimate_url' means a good link found "Phishing Link Detected (Alt)": probs[3] # Assuming 'phishing_url_alt' is also bad } # Determine the most likely classification based on highest probability # Using operator.itemgetter is slightly more standard than lambda for this case max_label_item = max(labels.items(), key=operator.itemgetter(1)) return { "prediction": max_label_item[0], # The label name with the highest probability "confidence": max_label_item[1], # The highest probability value "all_probabilities": labels # Dictionary of all labels and their probabilities } # --- Flask Routes --- # --- Flask Routes --- @app.route('/', methods=['GET', 'POST']) def index(): # Initialize variables for each request email_text_input = "" error_message = None result_details = None # This block handles when the user SUBMITS the form if request.method == 'POST': email_text_input = request.form.get('text', '') if email_text_input: try: # Get the prediction results result_details = predict_email(email_text_input) except Exception as e: # If prediction fails, set an error message error_message = f"An error occurred: {e}" else: # If form is submitted empty error_message = "Please enter some text to analyze." # Return the page with the results or an error return render_template( 'index.html', result=result_details, text=email_text_input, error=error_message ) # This block handles when the page FIRST LOADS (GET request) # It returns the clean, empty page. return render_template('index.html') # --- Run the App ---