# coding=utf-8 # Copyright 2020-present the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. """ import os from typing import Optional from transformers import Trainer import torch from transformers.modeling_utils import PreTrainedModel, unwrap_model from transformers.utils import logging logger = logging.get_logger(__name__) WEIGHTS_NAME = "pytorch_model.pt" TRAINING_ARGS_NAME = "training_args.bin" class LoRATrainer(Trainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def compute_loss(self, model, inputs, return_outputs=False): return model(**inputs).loss def save_model(self, output_dir=None, _internal_call=False): output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving model checkpoint to {output_dir}") model_to_save = unwrap_model(self.model) saved_params = { k: v.to("cuda") for k, v in model_to_save.named_parameters() if v.requires_grad } torch.save(saved_params, os.path.join(output_dir, WEIGHTS_NAME)) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))