| import torch | |
| from transformers import PreTrainedModel | |
| from .configuration_resnet import ResinConfig | |
| class ResinModel(PreTrainedModel): | |
| config_class = ResinConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = torch.nn.Linear(5, 10) | |
| def forward(self, tensor): | |
| return self.model.forward_features(tensor) | |