Update modeling_typhoon2audio.py
Browse files
modeling_typhoon2audio.py
CHANGED
|
@@ -18,6 +18,7 @@ from transformers import (
|
|
| 18 |
WhisperModel,
|
| 19 |
PreTrainedModel,
|
| 20 |
AutoTokenizer,
|
|
|
|
| 21 |
AutoModelForCausalLM,
|
| 22 |
)
|
| 23 |
from transformers.cache_utils import Cache, StaticCache
|
|
@@ -63,6 +64,7 @@ from transformers.modeling_utils import (
|
|
| 63 |
apply_chunking_to_forward,
|
| 64 |
find_pruneable_heads_and_indices,
|
| 65 |
prune_linear_layer,
|
|
|
|
| 66 |
)
|
| 67 |
from transformers.models.bert.configuration_bert import BertConfig
|
| 68 |
|
|
@@ -841,9 +843,9 @@ class Typhoon2AudioForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
|
| 841 |
self.second_stride = config.second_stride
|
| 842 |
|
| 843 |
# 2. LLM (e.g., Llama3)
|
| 844 |
-
|
| 845 |
-
config.llama_base_model
|
| 846 |
-
|
| 847 |
# tokenizer
|
| 848 |
self.llama_tokenizer = AutoTokenizer.from_pretrained(
|
| 849 |
config.llama_base_model, use_fast=False
|
|
|
|
| 18 |
WhisperModel,
|
| 19 |
PreTrainedModel,
|
| 20 |
AutoTokenizer,
|
| 21 |
+
AutoConfig,
|
| 22 |
AutoModelForCausalLM,
|
| 23 |
)
|
| 24 |
from transformers.cache_utils import Cache, StaticCache
|
|
|
|
| 64 |
apply_chunking_to_forward,
|
| 65 |
find_pruneable_heads_and_indices,
|
| 66 |
prune_linear_layer,
|
| 67 |
+
no_init_weights
|
| 68 |
)
|
| 69 |
from transformers.models.bert.configuration_bert import BertConfig
|
| 70 |
|
|
|
|
| 843 |
self.second_stride = config.second_stride
|
| 844 |
|
| 845 |
# 2. LLM (e.g., Llama3)
|
| 846 |
+
with no_init_weights(_enable=True):
|
| 847 |
+
llm_config = AutoConfig.from_pretrained(config.llama_base_model)
|
| 848 |
+
self.llama_model = AutoModelForCausalLM.from_config(llm_config, attn_implementation=attn_implementation)
|
| 849 |
# tokenizer
|
| 850 |
self.llama_tokenizer = AutoTokenizer.from_pretrained(
|
| 851 |
config.llama_base_model, use_fast=False
|