File size: 754 Bytes
dfdb0da c257c5c dfdb0da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import os
import torch
import torchvision
def build_LemonFM(nclasses: int = 2, pretrained: bool = True, pretrained_weights = None):
#net of ConvNext
net = torchvision.models.convnext_large()
input_emdim = net.classifier[2].in_features
net.classifier[2] = nn.Identity()
if os.path.isfile(pretrained_weights):
state_dict = torch.load(pretrained_weights, map_location="cpu")
state_dict = state_dict['teacher']
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items() if k.startswith('backbone.')}
msg = net.load_state_dict(state_dict, strict=False)
print(msg, input_emdim)
net.cuda()
return net
|