| import copy |
| from doctest import ELLIPSIS_MARKER |
| from functools import partial |
| import json |
| from turtle import forward, shape |
| import einops |
| import torch |
| from torch import nn |
|
|
| from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer |
| from transformers import GPT2Model, GPT2Config,GPT2LMHeadModel,GPTNeoForCausalLM,GPTNeoModel, \ |
| BartModel, BartConfig, BartForCausalLM, BertForMaskedLM, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer |
| from transformers import BitsAndBytesConfig |
|
|
| from peft import prepare_model_for_kbit_training |
| from peft import LoraConfig |
| from peft import get_peft_model |
|
|
| |
| from mmcv.cnn import build_norm_layer |
| from mmcv.runner import BaseModule |
| import math |
| from ipdb import set_trace |
|
|
| class mixEmbed(nn.Module): |
| def __init__(self, lm_embed: nn.Embedding , audio_embeddings, *args, **kwargs) -> None: |
| super().__init__(*args, **kwargs) |
| self.lm_embed = lm_embed |
| self.audio_embeddings = audio_embeddings |
| |
| def forward(self, input_ids): |
| text_ids = torch.clamp(input_ids.clone(), 0).long() |
| |
| au_ids = torch.clamp(-(input_ids.clone() + 1), 0).long() |
| text_embeds = self.lm_embed(text_ids) |
| au_embeds = self.audio_embeddings[au_ids] |
| with torch.no_grad(): |
| embed_mask = (input_ids > 0) |
| mix_embeds = au_embeds.clone() |
| mix_embeds[embed_mask] = text_embeds[embed_mask] |
| return mix_embeds |
| |
|
|
| class LMDecoder(nn.Module): |
| def __init__(self, |
| |
| img_size=(80,512), |
| patch_size:int=16, |
| in_chans:int=3, |
| embed_dim=1024, |
| decoder_embed_dim=512, |
| norm_cfg=dict(type='LN', eps=1e-6), |
| |
| decoder_type='gpt2', |
| freeze_decoder=True, |
| additional_layer:int=0, |
| ): |
| super().__init__() |
| self.decoder_type = decoder_type |
| self.load_lm() |
| |
| self.lm_embed = self.lm.get_input_embeddings() |
| try: |
| self.lm_pos_embed = self.lm.get_position_embeddings() |
| except NotImplementedError: |
| self.lm_pos_embed = None |
| |
| |
| if hasattr(self.lm,'embed_dim'): |
| self.embed_dim = self.lm.embed_dim |
| else: |
| self.embed_dim = decoder_embed_dim |
| |
| |
| |
| |
| self.freeze_decoder = False |
| if True: |
| for para in self.lm.parameters(): |
| para.requires_grad = False |
| |
| def load_lm(self): |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type) |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, token='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') |
| self.lm = AutoModelForCausalLM.from_pretrained(self.decoder_type, token='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') |
| |
| |
| def forward(self, input_ids, flatten_embs, attention_mask, labels, **kwargs): |
| mix_embed = mixEmbed(self.lm_embed, flatten_embs) |
| self.lm.set_input_embeddings(mix_embed) |
| output = self.lm(input_ids=input_ids, attention_mask=attention_mask, labels=labels, output_hidden_states=True, **kwargs) |
| self.lm.set_input_embeddings(self.lm_embed) |
| return output |
|
|
| def generate(self, input_ids, flatten_embs): |
| mix_embed = mixEmbed(self.lm_embed, flatten_embs) |
| self.lm.set_input_embeddings(mix_embed) |
| outputs = self.lm.generate(input_ids=input_ids, max_new_tokens=256, use_cache=False) |
| |
| |
| |
| |
| |
| |
| |
| |
| self.lm.set_input_embeddings(self.lm_embed) |
| return outputs |
| ''' |
| ## infer params |
| max_input_tokens: 40 |
| batch_size_test: 16 |
| max_new_tokens: 64 |
| min_length: 2 |
| num_beams: 5 |
| length_penalty: -2.0 |
| top_p: 0.9 |
| top_k: 3 |
| no_repeat_ngram_size: 2 |
| apply_lemmatizer: False |
| use_nucleus_sampling: True |
| ''' |
|
|
| class LMDecoder_qlora(LMDecoder): |
| def __init__(self, |
| |
| img_size=(80,512), |
| patch_size:int=16, |
| in_chans:int=3, |
| embed_dim=1024, |
| decoder_embed_dim=512, |
| norm_cfg=dict(type='LN', eps=1e-6), |
| |
| decoder_type='gpt2', |
| freeze_decoder=True, |
| additional_layer:int=0, |
| ): |
| super().__init__( img_size, patch_size, in_chans, embed_dim, decoder_embed_dim, norm_cfg, decoder_type, freeze_decoder, additional_layer) |
| |
| def load_lm(self): |
| self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type) |
| self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, trust_remote_code=True ) |
| double_quant_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_use_double_quant=True, |
| ) |
| model = AutoModelForCausalLM.from_pretrained(self.decoder_type, |
| |
| |
| |
| |
| trust_remote_code=True ) |
|
|
| model.gradient_checkpointing_enable() |
| model = prepare_model_for_kbit_training(model) |
| lora_config = LoraConfig( |
| r=8, |
| lora_alpha=32, |
| target_modules=["query_key_value"], |
| lora_dropout=0.05, |
| bias="none", |
| task_type="CAUSAL_LM" |
| ) |
|
|
| self.lm = get_peft_model(model, lora_config) |
| self.lm.print_trainable_parameters() |
|
|