Spaces:
Paused
Paused
File size: 2,102 Bytes
4721aa1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import argparse
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
import os
from peft import get_peft_model, LoraConfig, TaskType
# Argument Parser Setup
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default=None,
help="The directory of the model")
parser.add_argument("--tokenizer", type=str, default=None, help="Tokenizer path")
parser.add_argument("--lora-path", type=str, default=None,
help="Path to the LoRA model checkpoint")
parser.add_argument("--device", type=str, default="cuda", help="Device to use for computation")
parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum new tokens for generation")
parser.add_argument("--lora-alpha", type=float, default=32, help="LoRA alpha")
parser.add_argument("--lora-rank", type=int, default=8, help="LoRA r")
parser.add_argument("--lora-dropout", type=float, default=0.1, help="LoRA dropout")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
# Model and Tokenizer Configuration
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
model = AutoModel.from_pretrained(args.model, load_in_8bit=False, trust_remote_code=True, device_map="auto").to(
args.device)
# LoRA Model Configuration
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, inference_mode=True,
target_modules=['query_key_value'],
r=args.lora_rank, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout
)
model = get_peft_model(model, peft_config)
if os.path.exists(args.lora_path):
model.load_state_dict(torch.load(args.lora_path), strict=False)
# Interactive Prompt
while True:
prompt = input("Prompt: ")
inputs = tokenizer(prompt, return_tensors="pt").to(args.device)
response = model.generate(input_ids=inputs["input_ids"],
max_length=inputs["input_ids"].shape[-1] + args.max_new_tokens)
response = response[0, inputs["input_ids"].shape[-1]:]
print("Response:", tokenizer.decode(response, skip_special_tokens=True))
|