n00b001's picture
Update app.py
a7bfa27 verified
raw
history blame
6.5 kB
import gradio as gr
from huggingface_hub import HfApi, ModelCard, whoami
from gradio_huggingfacehub_search import HuggingfaceHubSearch
import os
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier, GPTQModifier
from llmcompressor.modifiers.awq import AWQModifier, AWQMapping
from transformers import AutoModelForCausalLM, AutoTokenizer
# --- Helper Functions ---
def get_quantization_recipe(method, model_architecture):
"""
Returns the appropriate llm-compressor recipe based on the selected method.
"""
if method == "AWQ":
mappings = [
AWQMapping("re:.*input_layernorm", ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"]),
AWQMapping("re:.*v_proj", ["re:.*o_proj"]),
AWQMapping("re:.*post_attention_layernorm", ["re:.*gate_proj", "re:.*up_proj"]),
AWQMapping("re:.*up_proj", ["re:.*down_proj"]),
]
return [
AWQModifier(
ignore=["lm_head"],
scheme="W4A16_ASYM",
targets=["Linear"],
mappings=mappings
),
]
elif method == "GPTQ":
sequential_target_map = {
"LlamaForCausalLM": "LlamaDecoderLayer",
"MistralForCausalLM": "MistralDecoderLayer",
"MixtralForCausalLM": "MixtralDecoderLayer",
}
sequential_target = sequential_target_map.get(model_architecture, "LlamaDecoderLayer")
return [
GPTQModifier(
targets="Linear",
scheme="W4A16",
sequential_targets=[sequential_target],
ignore=["re:.*lm_head"],
),
]
elif method == "FP8":
ignore_layers = ["lm_head"]
if "Mixtral" in model_architecture:
ignore_layers.append("re:.*block_sparse_moe.gate")
return QuantizationModifier(
scheme="FP8",
targets="Linear",
ignore=ignore_layers
)
else:
raise ValueError(f"Unsupported quantization method: {method}")
# --------------------------------------------------------------------------------
# CHANGE #1: Reverted to the correct function signature that accepts the OAuthToken
# --------------------------------------------------------------------------------
def compress_and_upload(model_id: str, quant_method: str, oauth_token: gr.OAuthToken | None):
"""
Compresses a model using llm-compressor and uploads it to a new HF repo.
"""
if not model_id:
raise gr.Error("Please select a model from the search bar.")
if oauth_token is None:
raise gr.Error("Authentication error. Please log in to continue.")
token = oauth_token.token
try:
# Use the provided token for all hub interactions
username = whoami(token=token)["name"]
# --- 1. Load Model and Tokenizer ---
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map=None, token=token)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
output_dir = f"{model_id.split('/')[-1]}-{quant_method}"
# --- 2. Get Recipe ---
recipe = get_quantization_recipe(quant_method, model.config.architectures[0])
# --- 3. Run Compression ---
oneshot(
model=model,
dataset="wikitext",
dataset_config_name="wikitext-2-raw-v1",
split="train[:1%]",
recipe=recipe,
save_compressed=True,
output_dir=output_dir,
max_seq_length=512,
num_calibration_samples=64,
)
# --- 4. Create Repo and Upload ---
api = HfApi(token=token)
repo_id = f"{username}/{output_dir}"
repo_url = api.create_repo(repo_id=repo_id, exist_ok=True)
api.upload_folder(
folder_path=output_dir,
repo_id=repo_id,
commit_message=f"Upload {quant_method} compressed model",
)
# --- 5. Create Model Card ---
card_content = f"""
---
license: apache-2.0
base_model: {model_id}
tags:
- llm-compressor
- quantization
- {quant_method.lower()}
---
# {quant_method} Compressed Model: {repo_id}
This model was compressed from [`{model_id}`](https://huggingface.co/{model_id}) using the [vLLM LLM-Compressor](https://github.com/vllm-project/llm-compressor) library.
This conversion was performed by the `llm-compressor-my-repo` Hugging Face Space.
## Quantization Method: {quant_method}
For more details on the recipe used, refer to the `recipe.yaml` file in this repository.
"""
card = ModelCard(card_content)
card.push_to_hub(repo_id, token=token)
return f'<h1>✅ Success!</h1><br/>Model compressed and saved to your new repo: <a href="{repo_url}" target="_blank" style="text-decoration:underline">{repo_id}</a>'
except Exception as e:
error_message = str(e).replace("\n", "<br/>")
return f'<h1>❌ ERROR</h1><br/><pre style="white-space:pre-wrap;">{error_message}</pre>'
# --- Gradio Interface ---
with gr.Blocks(css="footer {display: none !important;}") as demo:
gr.Markdown("# LLM-Compressor My Repo")
gr.Markdown(
"Log in, choose a model, select a quantization method, and this Space will create a new compressed model repository on your Hugging Face profile."
)
with gr.Row():
login_button = gr.LoginButton(min_width=250)
gr.Markdown("### 1. Select a Model from the Hugging Face Hub")
model_input = HuggingfaceHubSearch(
label="Search for a Model",
search_type="model",
)
gr.Markdown("### 2. Choose a Quantization Method")
quant_method_dropdown = gr.Dropdown(
["AWQ", "GPTQ", "FP8"],
label="Quantization Method",
value="AWQ"
)
compress_button = gr.Button("Compress and Create Repo", variant="primary")
output_html = gr.HTML(label="Result")
# --------------------------------------------------------------------------------
# CHANGE #2: The `login_button` is correctly passed as an input.
# --------------------------------------------------------------------------------
compress_button.click(
fn=compress_and_upload,
inputs=[model_input, quant_method_dropdown, login_button],
outputs=output_html
)
# CHANGE #3: Removed the gr.Examples component to prevent the TypeError.
demo.queue(max_size=5).launch()