| | import os |
| | import torch |
| | from unsloth import FastLanguageModel, is_bfloat16_supported |
| | from trl import SFTTrainer |
| | from transformers import TrainingArguments |
| | from datasets import load_dataset |
| | import gradio as gr |
| | import json |
| | from huggingface_hub import HfApi |
| |
|
| | max_seq_length = 4096 |
| | dtype = None |
| | load_in_4bit = True |
| | hf_token = os.getenv("HF_TOKEN") |
| | current_num = os.getenv("NUM") |
| |
|
| | print(f"stage ${current_num}") |
| |
|
| | api = HfApi(token=hf_token) |
| | |
| | model_base = "unsloth/gemma-2-27b-bnb-4bit" |
| |
|
| | print("Starting model and tokenizer loading...") |
| |
|
| | |
| | model, tokenizer = FastLanguageModel.from_pretrained( |
| | model_name=model_base, |
| | max_seq_length=max_seq_length, |
| | dtype=dtype, |
| | load_in_4bit=load_in_4bit, |
| | token=hf_token |
| | ) |
| |
|
| | print("Model and tokenizer loaded successfully.") |
| |
|
| | print("Configuring PEFT model...") |
| | model = FastLanguageModel.get_peft_model( |
| | model, |
| | r=16, |
| | target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
| | lora_alpha=16, |
| | lora_dropout=0, |
| | bias="none", |
| | use_gradient_checkpointing="unsloth", |
| | random_state=3407, |
| | use_rslora=False, |
| | loftq_config=None, |
| | ) |
| | print("PEFT model configured.") |
| |
|
| | |
| | alpaca_prompt = { |
| | "learning_from": """Below is a CVE definition. |
| | ### CVE definition: |
| | {} |
| | ### detail CVE: |
| | {}""", |
| | "definition": """Below is a definition about software vulnerability. Explain it. |
| | ### Definition: |
| | {} |
| | ### Explanation: |
| | {}""", |
| | "code_vulnerability": """Below is a code snippet. Identify the line of code that is vulnerable and describe the type of software vulnerability. |
| | ### Code Snippet: |
| | {} |
| | ### Vulnerability solution: |
| | {}""" |
| | } |
| |
|
| | EOS_TOKEN = tokenizer.eos_token |
| |
|
| | def detect_prompt_type(instruction): |
| | if instruction.startswith("what is code vulnerable of this code:"): |
| | return "code_vulnerability" |
| | elif instruction.startswith("Learning from"): |
| | return "learning_from" |
| | elif instruction.startswith("what is"): |
| | return "definition" |
| | else: |
| | return "unknown" |
| |
|
| | def formatting_prompts_func(examples): |
| | instructions = examples["instruction"] |
| | outputs = examples["output"] |
| | texts = [] |
| |
|
| | for instruction, output in zip(instructions, outputs): |
| | prompt_type = detect_prompt_type(instruction) |
| | if prompt_type in alpaca_prompt: |
| | prompt = alpaca_prompt[prompt_type].format(instruction, output) |
| | else: |
| | prompt = instruction + "\n\n" + output |
| | text = prompt + EOS_TOKEN |
| | texts.append(text) |
| |
|
| | return {"text": texts} |
| |
|
| | print("Loading dataset...") |
| | dataset = load_dataset("dad1909/DCSV", split="train") |
| | print("Dataset loaded successfully.") |
| |
|
| | print("Applying formatting function to the dataset...") |
| | dataset = dataset.map(formatting_prompts_func, batched=True) |
| | print("Formatting function applied.") |
| |
|
| | print("Initializing trainer...") |
| | trainer = SFTTrainer( |
| | model=model, |
| | tokenizer=tokenizer, |
| | train_dataset=dataset, |
| | dataset_text_field="text", |
| | max_seq_length=max_seq_length, |
| | dataset_num_proc=2, |
| | packing=False, |
| | args=TrainingArguments( |
| | per_device_train_batch_size=1, |
| | gradient_accumulation_steps=1, |
| | learning_rate=2e-4, |
| | fp16=not is_bfloat16_supported(), |
| | bf16=is_bfloat16_supported(), |
| | warmup_steps=5, |
| | logging_steps=10, |
| | max_steps=100, |
| | optim="adamw_8bit", |
| | weight_decay=0.01, |
| | lr_scheduler_type="linear", |
| | seed=3407, |
| | output_dir="outputs" |
| | ), |
| | ) |
| | print("Trainer initialized.") |
| |
|
| | print("Starting training...") |
| | trainer_stats = trainer.train() |
| | print("Training completed.") |
| |
|
| | num = int(current_num) |
| | num += 1 |
| |
|
| | uploads_models = f"cybersentinal-2.0-{str(num)}" |
| |
|
| | up = "sentinal-3.1-70B" |
| |
|
| | print("Saving the trained model...") |
| | model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit") |
| | print("Model saved successfully.") |
| |
|
| | print("Pushing the model to the hub...") |
| | model.push_to_hub_merged( |
| | up, |
| | tokenizer, |
| | save_method="merged_16bit", |
| | token=hf_token |
| | ) |
| | print("Model pushed to hub successfully.") |
| |
|
| | api.delete_space_variable(repo_id="dad1909/CyberCode", key="NUM") |
| | api.add_space_variable(repo_id="dad1909/CyberCode", key="NUM", value=str(num)) |