LemonFM / model_loader.py
chengan98's picture
Update model_loader.py
c257c5c verified
raw
history blame contribute delete
754 Bytes
import os
import torch
import torchvision
def build_LemonFM(nclasses: int = 2, pretrained: bool = True, pretrained_weights = None):
#net of ConvNext
net = torchvision.models.convnext_large()
input_emdim = net.classifier[2].in_features
net.classifier[2] = nn.Identity()
if os.path.isfile(pretrained_weights):
state_dict = torch.load(pretrained_weights, map_location="cpu")
state_dict = state_dict['teacher']
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items() if k.startswith('backbone.')}
msg = net.load_state_dict(state_dict, strict=False)
print(msg, input_emdim)
net.cuda()
return net