import pytest from unittest.mock import MagicMock, patch from app import get_quantization_recipe, compress_and_upload import gradio as gr from llmcompressor.modifiers.quantization import QuantizationModifier, GPTQModifier from llmcompressor.modifiers.awq import AWQModifier # Mock external dependencies for compress_and_upload @pytest.fixture def mock_hf_api(): with patch('app.HfApi') as mock_api: mock_api_instance = mock_api.return_value mock_api_instance.create_repo.return_value = "https://huggingface.co/test_user/test_model-AWQ" yield mock_api_instance @pytest.fixture def mock_whoami(): with patch('app.whoami') as mock_whoami_func: mock_whoami_func.return_value = {"name": "test_user"} yield mock_whoami_func @pytest.fixture def mock_auto_model_for_causal_lm(): with patch('app.AutoModelForCausalLM') as mock_model_class: mock_model_instance = MagicMock() mock_model_instance.config.architectures = ["LlamaForCausalLM"] mock_model_class.from_pretrained.return_value = mock_model_instance yield mock_model_class @pytest.fixture def mock_oneshot(): with patch('app.oneshot') as mock_oneshot_func: yield mock_oneshot_func @pytest.fixture def mock_model_card(): with patch('app.ModelCard') as mock_card_class: mock_card_instance = MagicMock() mock_card_class.return_value = mock_card_instance yield mock_card_class @pytest.fixture def mock_gr_oauth_token(): mock_token = MagicMock(spec=gr.OAuthToken) mock_token.token = "test_token" return mock_token # --- Test get_quantization_recipe --- def test_get_quantization_recipe_awq(): recipe = get_quantization_recipe("AWQ", "LlamaForCausalLM") assert len(recipe) == 1 assert isinstance(recipe[0], AWQModifier) def test_get_quantization_recipe_gptq(): recipe = get_quantization_recipe("GPTQ", "LlamaForCausalLM") assert len(recipe) == 1 assert isinstance(recipe[0], GPTQModifier) def test_get_quantization_recipe_gptq_mistral(): recipe = get_quantization_recipe("GPTQ", "MistralForCausalLM") assert len(recipe) == 1 assert isinstance(recipe[0], GPTQModifier) assert recipe[0].sequential_targets == ["MistralDecoderLayer"] def test_get_quantization_recipe_gptq_mixtral(): recipe = get_quantization_recipe("GPTQ", "MixtralForCausalLM") assert len(recipe) == 1 assert isinstance(recipe[0], GPTQModifier) assert recipe[0].sequential_targets == ["MixtralDecoderLayer"] def test_get_quantization_recipe_fp8(): recipe = get_quantization_recipe("FP8", "LlamaForCausalLM") assert len(recipe) == 1 assert isinstance(recipe[0], QuantizationModifier) assert recipe[0].scheme == "FP8" assert recipe[0].ignore == ["lm_head"] def test_get_quantization_recipe_fp8_mixtral(): recipe = get_quantization_recipe("FP8", "MixtralForCausalLM") assert len(recipe) == 1 assert isinstance(recipe[0], QuantizationModifier) assert recipe[0].scheme == "FP8" assert "re:.*block_sparse_moe.gate" in recipe[0].ignore def test_get_quantization_recipe_unsupported(): with pytest.raises(ValueError, match="Unsupported quantization method: INVALID"): get_quantization_recipe("INVALID", "LlamaForCausalLM") # --- Test compress_and_upload --- def test_compress_and_upload_no_model_id(mock_gr_oauth_token): with pytest.raises(gr.Error, match="Please select a model from the search bar."): compress_and_upload("", "AWQ", mock_gr_oauth_token) def test_compress_and_upload_no_oauth_token(): with pytest.raises(gr.Error, match="Authentication error. Please log in to continue."): compress_and_upload("test_model", "AWQ", None) def test_compress_and_upload_success( mock_hf_api, mock_whoami, mock_auto_model_for_causal_lm, mock_oneshot, mock_model_card, mock_gr_oauth_token, ): model_id = "org/test_model" quant_method = "AWQ" result = compress_and_upload(model_id, quant_method, mock_gr_oauth_token) mock_whoami.assert_called_once_with(token="test_token") mock_auto_model_for_causal_lm.from_pretrained.assert_called_once_with( model_id, torch_dtype="auto", device_map=None, token="test_token", trust_remote_code=True ) mock_oneshot.assert_called_once() assert mock_oneshot.call_args[1]["model"] == mock_auto_model_for_causal_lm.from_pretrained.return_value assert mock_oneshot.call_args[1]["recipe"] is not None assert mock_oneshot.call_args[1]["output_dir"] == f"test_model-{quant_method}" mock_hf_api.create_repo.assert_called_once_with( repo_id=f"test_user/test_model-{quant_method}", exist_ok=True ) mock_hf_api.upload_folder.assert_called_once_with( folder_path=f"test_model-{quant_method}", repo_id=f"test_user/test_model-{quant_method}", commit_message=f"Upload {quant_method} compressed model", ) mock_model_card.assert_called_once() mock_model_card.return_value.push_to_hub.assert_called_once_with( f"test_user/test_model-{quant_method}", token="test_token" ) assert "✅ Success!" in result assert "https://huggingface.co/test_user/test_model-AWQ" in result def test_compress_and_upload_with_trust_remote_code( mock_hf_api, mock_whoami, mock_auto_model_for_causal_lm, mock_oneshot, mock_model_card, mock_gr_oauth_token, ): model_id = "org/test_model" quant_method = "AWQ" compress_and_upload(model_id, quant_method, mock_gr_oauth_token) mock_auto_model_for_causal_lm.from_pretrained.assert_called_once_with( model_id, torch_dtype="auto", device_map=None, token="test_token", trust_remote_code=True ) def test_compress_and_upload_model_no_architecture( mock_hf_api, mock_whoami, mock_auto_model_for_causal_lm, mock_gr_oauth_token, ): mock_auto_model_for_causal_lm.from_pretrained.return_value.config.architectures = [] with pytest.raises(gr.Error, match="Could not determine model architecture."): compress_and_upload("test_model", "AWQ", mock_gr_oauth_token) def test_compress_and_upload_generic_exception( mock_hf_api, mock_whoami, mock_auto_model_for_causal_lm, mock_gr_oauth_token, ): mock_whoami.side_effect = Exception("Network error") result = compress_and_upload("test_model", "AWQ", mock_gr_oauth_token) assert "❌ ERROR" in result assert "Network error" in result def test_compress_and_upload_unrecognized_architecture( mock_hf_api, mock_whoami, mock_auto_model_for_causal_lm, mock_gr_oauth_token, ): mock_auto_model_for_causal_lm.from_pretrained.return_value.config.architectures = ["UnrecognizedArchitecture"] result = compress_and_upload("test_model", "AWQ", mock_gr_oauth_token) assert "❌ ERROR" in result assert "AWQ quantization is only supported for LlamaForCausalLM architectures, got UnrecognizedArchitecture" in result