remote_code_model_with_dots / modeling_custom.py
Rocketknight1's picture
Rocketknight1 HF Staff
Upload 14 files
6055811 verified
raw
history blame contribute delete
898 Bytes
"""
Custom model with relative import to demonstrate the bug.
"""
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
import torch
import torch.nn as nn
# This relative import should cause the bug when the folder has a dot in the name
from .another_module import custom_function
class CustomModel(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.layer = nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, input_ids=None, **kwargs):
embeddings = self.embeddings(input_ids)
# Use the function from the relative import
output = custom_function(embeddings)
hidden_states = self.layer(output)
return BaseModelOutput(last_hidden_state=hidden_states)