|
|
import os |
|
|
import torch |
|
|
import torchvision |
|
|
|
|
|
def build_LemonFM(nclasses: int = 2, pretrained: bool = True, pretrained_weights = None): |
|
|
|
|
|
|
|
|
|
|
|
net = torchvision.models.convnext_large() |
|
|
input_emdim = net.classifier[2].in_features |
|
|
net.classifier[2] = nn.Identity() |
|
|
|
|
|
if os.path.isfile(pretrained_weights): |
|
|
state_dict = torch.load(pretrained_weights, map_location="cpu") |
|
|
state_dict = state_dict['teacher'] |
|
|
|
|
|
|
|
|
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items() if k.startswith('backbone.')} |
|
|
msg = net.load_state_dict(state_dict, strict=False) |
|
|
print(msg, input_emdim) |
|
|
|
|
|
net.cuda() |
|
|
|
|
|
return net |
|
|
|