Spaces:
Paused
Paused
File size: 6,463 Bytes
4721aa1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
lora_checkpoint: str = field(
default=None, metadata={"help": "Path to lora checkpoints"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
)
},
)
resize_position_embeddings: Optional[bool] = field(
default=None,
metadata={
"help": (
"Whether to automatically resize the position embeddings if `max_source_length` exceeds "
"the model's position embeddings."
)
},
)
quantization_bit: Optional[int] = field(
default=None,
metadata={
"help": (
"An optional parameter specifying the number of bits used for quantization. "
"Quantization is a process that reduces the model size by limiting the number of "
"bits that represent each weight in the model. A lower number of bits can reduce "
"the model size and speed up inference, but might also decrease model accuracy. "
"If not set (None), quantization is not applied."
)
},
)
lora_rank: Optional[int] = field(
default=8,
metadata={
"help": (
"balancing between complexity and model flexibility. A higher rank allows more "
"complex adaptations but increases the number of parameters and computational cost."
)
},
)
lora_alpha: Optional[float] = field(
default=32,
metadata={
"help": (
"A higher value results in more significant adjustments, potentially improving adaptation to new tasks or data, "
"but might also risk overfitting. A lower value makes smaller adjustments, possibly maintaining better generalization."
)
}, )
lora_dropout: Optional[float] = field(
default=0.1,
metadata={
"help": (
"during training to prevent the model from overly relying on specific patterns in the training data. "
"Higher dropout rates can improve model generalization but may reduce learning efficiency."
)
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
)
max_seq_length: Optional[int] = field(
default=2048,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated."
)
},
)
max_source_length: Optional[int] = field(
default=1024,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": (
"The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
train_format: str = field(
default=None, metadata={"help": "The format of the training data file (mulit-turn or input-output)"},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_seq_length: Optional[int] = field(
default=1024,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": (
"Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
)
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
def __post_init__(self):
extension = self.train_file.split(".")[-1]
assert extension in {"jsonl", "json"}, "`train_file` should be a jsonl or a json file."
assert self.train_format in {"multi-turn", "input-output"}
|