import argparse import torch from transformers import AutoConfig, AutoModel, AutoTokenizer import os from peft import get_peft_model, LoraConfig, TaskType # Argument Parser Setup parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default=None, help="The directory of the model") parser.add_argument("--tokenizer", type=str, default=None, help="Tokenizer path") parser.add_argument("--lora-path", type=str, default=None, help="Path to the LoRA model checkpoint") parser.add_argument("--device", type=str, default="cuda", help="Device to use for computation") parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum new tokens for generation") parser.add_argument("--lora-alpha", type=float, default=32, help="LoRA alpha") parser.add_argument("--lora-rank", type=int, default=8, help="LoRA r") parser.add_argument("--lora-dropout", type=float, default=0.1, help="LoRA dropout") args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model # Model and Tokenizer Configuration tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True) model = AutoModel.from_pretrained(args.model, load_in_8bit=False, trust_remote_code=True, device_map="auto").to( args.device) # LoRA Model Configuration peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=True, target_modules=['query_key_value'], r=args.lora_rank, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout ) model = get_peft_model(model, peft_config) if os.path.exists(args.lora_path): model.load_state_dict(torch.load(args.lora_path), strict=False) # Interactive Prompt while True: prompt = input("Prompt: ") inputs = tokenizer(prompt, return_tensors="pt").to(args.device) response = model.generate(input_ids=inputs["input_ids"], max_length=inputs["input_ids"].shape[-1] + args.max_new_tokens) response = response[0, inputs["input_ids"].shape[-1]:] print("Response:", tokenizer.decode(response, skip_special_tokens=True))