File size: 754 Bytes
dfdb0da
 
 
 
c257c5c
dfdb0da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import os
import torch
import torchvision

def build_LemonFM(nclasses: int = 2, pretrained: bool = True, pretrained_weights = None):


    #net of ConvNext
    net = torchvision.models.convnext_large()
    input_emdim = net.classifier[2].in_features
    net.classifier[2] = nn.Identity()
    
    if os.path.isfile(pretrained_weights):
        state_dict = torch.load(pretrained_weights, map_location="cpu")
        state_dict = state_dict['teacher']

        # remove `backbone.` prefix induced by multicrop wrapper
        state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items() if k.startswith('backbone.')}
        msg = net.load_state_dict(state_dict, strict=False)
        print(msg, input_emdim)

    net.cuda()

    return net