Update model_loader.py
Browse files- model_loader.py +0 -20
model_loader.py
CHANGED
|
@@ -19,26 +19,6 @@ def build_model(nclasses: int = 2, mode: str = None, segment_model: str = None):
|
|
| 19 |
|
| 20 |
return net
|
| 21 |
|
| 22 |
-
def build_SurgFM(nclasses: int = 2, pretrained: bool = True, pretrained_weights = None):
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
#net of ConvNext
|
| 26 |
-
net = torchvision.models.convnext_large(weights='DEFAULT')
|
| 27 |
-
input_emdim = net.classifier[2].in_features
|
| 28 |
-
net.classifier[2] = nn.Identity()
|
| 29 |
-
|
| 30 |
-
if os.path.isfile(pretrained_weights):
|
| 31 |
-
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
| 32 |
-
state_dict = state_dict['teacher']
|
| 33 |
-
|
| 34 |
-
# remove `backbone.` prefix induced by multicrop wrapper
|
| 35 |
-
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items() if k.startswith('backbone.')}
|
| 36 |
-
msg = net.load_state_dict(state_dict, strict=False)
|
| 37 |
-
print(msg, input_emdim)
|
| 38 |
-
|
| 39 |
-
net.cuda()
|
| 40 |
-
|
| 41 |
-
return net
|
| 42 |
|
| 43 |
|
| 44 |
net = build_model(nclasses=num_classes, mode='classify')
|
|
|
|
| 19 |
|
| 20 |
return net
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
net = build_model(nclasses=num_classes, mode='classify')
|