Spaces:
Sleeping
Sleeping
File size: 2,833 Bytes
d95ff5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
#!/usr/bin/env python
"""
Test script to verify that the Qwen2.5-VL architecture detection and quantization recipe work correctly
"""
from transformers import AutoConfig
from app import determine_model_class, get_quantization_recipe
import torch
def test_qwen2_5_vl_detection():
"""
Test to see if we can properly detect the Qwen2.5-VL architecture.
We'll use a known Qwen2.5-VL model ID to test the detection.
"""
# For testing purposes, use a known Qwen2.5-VL model ID
model_id = "Qwen/Qwen2.5-VL-7B-Instruct" # Use a known Qwen2.5-VL model
# Simulate the architecture string that would come from the model config
# In the real scenario, this comes from model.config.architectures[0]
architectures = ["Qwen2_5_VLForConditionalGeneration"]
print(f"Testing architecture detection for: {model_id}")
print(f"Architectures found: {architectures}")
try:
# Test if our recipe function can handle this architecture
for method in ["GPTQ", "AWQ", "FP8"]:
print(f"\nTesting {method} quantization recipe...")
recipe = get_quantization_recipe(method, architectures[0])
print(f"{method} recipe created successfully: {recipe}")
print(f"Sequential targets: {[mod.sequential_targets if hasattr(mod, 'sequential_targets') else 'N/A' for mod in recipe]}")
print(f"Ignore layers: {[mod.ignore for mod in recipe if hasattr(mod, 'ignore')]}")
print("\n✓ All quantization methods work with Qwen2_5_VLForConditionalGeneration architecture")
except Exception as e:
print(f"\n✗ Error creating quantization recipe: {e}")
import traceback
traceback.print_exc()
return False
return True
def test_manual_model_class_detection():
"""
Test the manual model class detection in the app.
"""
print("\nTesting manual model class detection...")
manual_model_type = "Qwen2_5_VLForConditionalGeneration (Qwen2.5-VL)"
try:
model_class = determine_model_class("test", "dummy_token", manual_model_type)
print(f"Manual detection returned: {model_class}")
print("✓ Manual model class detection works")
return True
except Exception as e:
print(f"✗ Error in manual detection: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
print("Testing Qwen2.5-VL architecture detection and quantization support...\n")
success1 = test_qwen2_5_vl_detection()
success2 = test_manual_model_class_detection()
if success1 and success2:
print("\n✓ All tests passed! Qwen2.5-VL models should now be properly supported.")
else:
print("\n✗ Some tests failed. Please check the implementation.") |