File size: 4,244 Bytes
3368ea3
 
 
 
 
 
 
 
 
 
 
 
8fb66b4
3368ea3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45721a0
3368ea3
 
 
45721a0
3368ea3
 
45721a0
3368ea3
45721a0
3368ea3
45721a0
3368ea3
 
45721a0
3368ea3
 
45721a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# app.py
from flask import Flask, request, render_template
# Import necessary classes from transformers and torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import operator # To find max value in dictionary if needed (alternative to user's lambda)

# Initialize Flask app
app = Flask(__name__)

# --- Load Tokenizer and Model (Using your provided code) ---
# Load them globally when the app starts
model_name = "CrabInHoney/urlbert-tiny-v3-malicious-url-classifier"
try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name)
    # Set the model to evaluation mode (important for inference)
    model.eval()
    print(f"Tokenizer and Model '{model_name}' loaded successfully.")
    # You can optionally print the model's expected labels if available in config
    # print(f"Model config labels (if available): {model.config.id2label}")
except Exception as e:
    print(f"Error loading tokenizer or model '{model_name}': {e}")
    tokenizer = None
    model = None # Flag that loading failed

# --- Prediction Function (Your provided code) ---
def predict_email(email_text):
    if not tokenizer or not model:
        raise RuntimeError("Tokenizer or Model not loaded.") # Should not happen if initial check passes

    # Preprocess and tokenize
    inputs = tokenizer(
        email_text,
        return_tensors="pt", # PyTorch tensors
        truncation=True,     # Truncate long emails
        max_length=512       # Max sequence length for the model
    )

    # Get prediction - no need to track gradients for inference
    with torch.no_grad():
        outputs = model(**inputs)
        # Apply softmax to logits to get probabilities
        predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)

    # Get probabilities for each class (index matters!)
    probs = predictions[0].tolist() # Get the probabilities for the first (only) input

    # --- Create labels dictionary ---
    # IMPORTANT: This assumes the model's output logits correspond to these labels IN THIS ORDER.
    # Verify this order based on the model card or model.config.id2label if possible.
    labels = {
        "Legitimate Email": probs[0],
        "Phishing Link Detected": probs[1], # Assuming 'phishing_url' means a bad link found
        "Legitimate Link Detected": probs[2], # Assuming 'legitimate_url' means a good link found
        "Phishing Link Detected (Alt)": probs[3] # Assuming 'phishing_url_alt' is also bad
    }

    # Determine the most likely classification based on highest probability
    # Using operator.itemgetter is slightly more standard than lambda for this case
    max_label_item = max(labels.items(), key=operator.itemgetter(1))

    return {
        "prediction": max_label_item[0],  # The label name with the highest probability
        "confidence": max_label_item[1],  # The highest probability value
        "all_probabilities": labels       # Dictionary of all labels and their probabilities
    }

# --- Flask Routes ---
# --- Flask Routes ---
@app.route('/', methods=['GET', 'POST'])
def index():
    # Initialize variables for each request
    email_text_input = ""
    error_message = None
    result_details = None

    # This block handles when the user SUBMITS the form
    if request.method == 'POST':
        email_text_input = request.form.get('text', '')
        if email_text_input:
            try:
                # Get the prediction results
                result_details = predict_email(email_text_input)
            except Exception as e:
                # If prediction fails, set an error message
                error_message = f"An error occurred: {e}"
        else:
            # If form is submitted empty
            error_message = "Please enter some text to analyze."

        # Return the page with the results or an error
        return render_template(
            'index.html',
            result=result_details,
            text=email_text_input,
            error=error_message
        )

    # This block handles when the page FIRST LOADS (GET request)
    # It returns the clean, empty page.
    return render_template('index.html')


# --- Run the App ---