import os import torch import torchvision def build_LemonFM(nclasses: int = 2, pretrained: bool = True, pretrained_weights = None): #net of ConvNext net = torchvision.models.convnext_large() input_emdim = net.classifier[2].in_features net.classifier[2] = nn.Identity() if os.path.isfile(pretrained_weights): state_dict = torch.load(pretrained_weights, map_location="cpu") state_dict = state_dict['teacher'] # remove `backbone.` prefix induced by multicrop wrapper state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items() if k.startswith('backbone.')} msg = net.load_state_dict(state_dict, strict=False) print(msg, input_emdim) net.cuda() return net