File size: 1,148 Bytes
7bfd23a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from transformers import PretrainedConfig
from typing import List, Tuple


class MammoConfig(PretrainedConfig):
    model_type = "mammo"

    def __init__(
        self,
        backbone: str = "tf_efficientnetv2_s",
        feature_dim: int = 1280,
        dropout: float = 0.1,
        num_classes: int = 5,
        in_chans: int = 1,
        num_models: int = 3,
        image_sizes: List[Tuple[int, int]] = [(2048, 1024), (1920, 1280), (1536, 1536)],
        pad_to_aspect_ratio: List[bool] = [True, True, False],
        **kwargs,
    ):
        self.backbone = backbone
        self.feature_dim = feature_dim
        self.dropout = dropout
        self.num_classes = num_classes
        self.in_chans = in_chans
        self.num_models = num_models
        assert len(image_sizes) == len(pad_to_aspect_ratio) == num_models, (
            f"length of `image_sizes` [{len(image_sizes)}] and `pad_to_aspect_ratio` "
            f"[{len(pad_to_aspect_ratio)}] must be equal to `num_models` [{num_models}]."
        )
        self.image_sizes = image_sizes
        self.pad_to_aspect_ratio = pad_to_aspect_ratio
        super().__init__(**kwargs)