File size: 2,521 Bytes
c4cf665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
SafeQwen2.5-VL Configuration

This configuration class extends the official Qwen2_5_VLConfig to add safety-aware
classification capabilities for multimodal content moderation.

Author: SafeQwen Team
"""

from typing import Optional, List
from transformers.models.qwen2_5_vl import Qwen2_5_VLConfig


class SafeQwen2_5_VLConfig(Qwen2_5_VLConfig):
    """
    Configuration class for SafeQwen2.5-VL model.

    SafeQwen2.5-VL extends Qwen2.5-VL with an additional safety classification head
    that can identify 20 categories of potentially unsafe content in images.

    Args:
        safety_categories (`List[str]`, *optional*):
            List of safety category names. Defaults to HoliSafe 20-category taxonomy.
        safety_head_hidden_scale (`float`, *optional*, defaults to 4.0):
            Scale factor for safety head hidden size relative to model hidden size.
        safety_loss_lambda (`float`, *optional*, defaults to 1.0):
            Weight for safety classification loss during training.
        safety_num_hidden_layers (`int`, *optional*, defaults to 1):
            Number of hidden layers in the safety classification MLP.
    """

    model_type = "qwen2_5_vl"

    def __init__(
        self,
        # Safety specific parameters
        safety_categories: Optional[List[str]] = None,
        safety_head_hidden_scale: float = 4.0,
        safety_loss_lambda: float = 1.0,
        safety_num_hidden_layers: int = 1,
        **kwargs
    ):
        super().__init__(**kwargs)

        # HoliSafe 20-category safety taxonomy
        self.safety_categories = safety_categories or [
            "safe",
            "gender",
            "race",
            "religion",
            "harassment",
            "disability_discrimination",
            "drug_related_hazards",
            "property_crime",
            "facial_data_exposure",
            "identity_data_exposure",
            "physical_self_injury",
            "suicide",
            "animal_abuse",
            "obscene_gestures",
            "physical_altercation",
            "terrorism",
            "weapon_related_violence",
            "sexual_content",
            "financial_advice",
            "medical_advice"
        ]

        self.safety_head_hidden_scale = safety_head_hidden_scale
        self.safety_loss_lambda = safety_loss_lambda
        self.safety_num_hidden_layers = safety_num_hidden_layers

        # Set num_safety_categories from the list
        self.num_safety_categories = len(self.safety_categories)