xieli commited on
Commit
0dc7005
Β·
1 Parent(s): d94f450

feat: support int4/int8 quantization when load

Browse files
Files changed (5) hide show
  1. README.md +1 -0
  2. app.py +50 -11
  3. model_loader.py +159 -30
  4. requirements.txt +1 -0
  5. tts.py +13 -3
README.md CHANGED
@@ -11,3 +11,4 @@ short_description: Try out Step-Audio-EditX
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
app.py CHANGED
@@ -81,7 +81,10 @@ def initialize_models():
81
  os.path.join(args_global.model_path, "Step-Audio-EditX"),
82
  encoder,
83
  model_source=model_source,
84
- tts_model_id=args_global.tts_model_id
 
 
 
85
  )
86
  logger.info("βœ“ StepCommonAudioTTS loaded")
87
  print("Models initialized inside GPU context.")
@@ -477,26 +480,62 @@ if __name__ == "__main__":
477
  default=None,
478
  help="TTS model ID for online loading (if different from model-path)"
479
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
 
481
  args = parser.parse_args()
482
 
483
  # Store args globally for model configuration
484
  args_global = args
485
-
486
  logger.info(f"Configuration loaded:")
487
- logger.info(f"Model source: {args.model_source}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  logger.info(f"Model path: {args.model_path}")
489
  logger.info(f"Tokenizer model ID: {args.tokenizer_model_id}")
 
 
490
  if args.tts_model_id:
491
  logger.info(f"TTS model ID: {args.tts_model_id}")
492
-
493
- # Models will be initialized on first GPU call to avoid ZeroGPU main process errors
494
-
495
- if ZEROGPU_AVAILABLE:
496
- logger.info("πŸŽ‰ ZeroGPU detected - using dynamic GPU duration management!")
497
- logger.info("πŸ’‘ First call: 300s (model loading), subsequent calls: 120s (inference only)")
498
- else:
499
- logger.info("πŸ’» Running in local mode - models will be loaded on first call")
500
 
501
  # Create EditxTab instance
502
  editx_tab = EditxTab(args)
 
81
  os.path.join(args_global.model_path, "Step-Audio-EditX"),
82
  encoder,
83
  model_source=model_source,
84
+ tts_model_id=args_global.tts_model_id,
85
+ quantization_config=args_global.quantization,
86
+ torch_dtype=torch_dtype,
87
+ device_map=args_global.device_map,
88
  )
89
  logger.info("βœ“ StepCommonAudioTTS loaded")
90
  print("Models initialized inside GPU context.")
 
480
  default=None,
481
  help="TTS model ID for online loading (if different from model-path)"
482
  )
483
+ parser.add_argument(
484
+ "--quantization",
485
+ type=str,
486
+ default=None,
487
+ choices=["int4", "int8"],
488
+ help="Enable quantization for the TTS model to reduce memory usage."
489
+ "Choices: int4 (online), int8 (online)."
490
+ "When quantization is enabled, data types are handled automatically by the quantization library."
491
+ )
492
+ parser.add_argument(
493
+ "--torch-dtype",
494
+ type=str,
495
+ default="bfloat16",
496
+ choices=["float16", "bfloat16", "float32"],
497
+ help="PyTorch data type for model operations. This setting only applies when quantization is disabled. "
498
+ "When quantization is enabled, data types are managed automatically."
499
+ )
500
+ parser.add_argument(
501
+ "--device-map",
502
+ type=str,
503
+ default="cuda",
504
+ help="Device mapping for model loading (default: cuda)"
505
+ )
506
 
507
  args = parser.parse_args()
508
 
509
  # Store args globally for model configuration
510
  args_global = args
 
511
  logger.info(f"Configuration loaded:")
512
+
513
+ # Map string arguments to actual types
514
+ source_mapping = {
515
+ "auto": ModelSource.AUTO,
516
+ "local": ModelSource.LOCAL,
517
+ "modelscope": ModelSource.MODELSCOPE,
518
+ "huggingface": ModelSource.HUGGINGFACE
519
+ }
520
+ model_source = source_mapping[args.model_source]
521
+
522
+ # Map torch dtype string to actual torch dtype
523
+ dtype_mapping = {
524
+ "float16": torch.float16,
525
+ "bfloat16": torch.bfloat16,
526
+ "float32": torch.float32
527
+ }
528
+ torch_dtype = dtype_mapping[args.torch_dtype]
529
+
530
+ logger.info(f"Loading models with source: {args.model_source}")
531
  logger.info(f"Model path: {args.model_path}")
532
  logger.info(f"Tokenizer model ID: {args.tokenizer_model_id}")
533
+ logger.info(f"Torch dtype: {args.torch_dtype}")
534
+ logger.info(f"Device map: {args.device_map}")
535
  if args.tts_model_id:
536
  logger.info(f"TTS model ID: {args.tts_model_id}")
537
+ if args.quantization:
538
+ logger.info(f"πŸ”§ {args.quantization.upper()} quantization enabled")
 
 
 
 
 
 
539
 
540
  # Create EditxTab instance
541
  editx_tab = EditxTab(args)
model_loader.py CHANGED
@@ -1,17 +1,14 @@
1
  """
2
  Unified model loading utility supporting ModelScope, HuggingFace and local path loading
3
  """
4
- import importlib
5
  import os
6
  import logging
7
- from pathlib import Path
8
- import sys
9
  import threading
10
- from typing import Union, Optional, Dict, Any
11
- import spaces
12
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
13
  from funasr_detach import AutoModel
14
- from transformers.models.auto import tokenization_auto, configuration_auto
15
 
16
  # Global cache for downloaded models to avoid repeated downloads
17
  # Key: (model_path, source)
@@ -104,19 +101,71 @@ class UnifiedModelLoader:
104
  modelscope_patterns = []
105
  return any(pattern in model_path for pattern in modelscope_patterns)
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  def load_transformers_model(
108
  self,
109
  model_path: str,
110
  source: str = ModelSource.AUTO,
 
111
  **kwargs
112
- ) -> tuple:
113
  """
114
  Load Transformers model (for StepAudioTTS)
115
 
116
  Args:
117
  model_path: Model path or ID
118
  source: Model source, auto means auto-detect
119
- **kwargs: Other parameters
 
120
 
121
  Returns:
122
  (model, tokenizer) tuple
@@ -125,17 +174,47 @@ class UnifiedModelLoader:
125
  source = self.detect_model_source(model_path)
126
 
127
  self.logger.info(f"Loading Transformers model from {source}: {model_path}")
 
 
 
 
 
128
 
129
  try:
130
  if source == ModelSource.LOCAL:
131
  # Local loading
132
- model = AutoModelForCausalLM.from_pretrained(
133
- model_path,
134
- torch_dtype=kwargs.get("torch_dtype"),
135
- device_map=kwargs.get("device_map", "auto"),
136
- trust_remote_code=True,
137
- local_files_only=True
138
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  tokenizer = AutoTokenizer.from_pretrained(
140
  model_path,
141
  trust_remote_code=True,
@@ -148,13 +227,38 @@ class UnifiedModelLoader:
148
  from modelscope import AutoTokenizer as MSAutoTokenizer
149
  model_path = self._cached_snapshot_download(model_path, ModelSource.MODELSCOPE)
150
 
151
- model = MSAutoModelForCausalLM.from_pretrained(
152
- model_path,
153
- torch_dtype=kwargs.get("torch_dtype"),
154
- device_map=kwargs.get("device_map", "auto"),
155
- trust_remote_code=True,
156
- local_files_only=True
157
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  tokenizer = MSAutoTokenizer.from_pretrained(
159
  model_path,
160
  trust_remote_code=True,
@@ -165,13 +269,38 @@ class UnifiedModelLoader:
165
  model_path = self._cached_snapshot_download(model_path, ModelSource.HUGGINGFACE)
166
 
167
  # Load from HuggingFace
168
- model = AutoModelForCausalLM.from_pretrained(
169
- model_path,
170
- torch_dtype=kwargs.get("torch_dtype"),
171
- device_map=kwargs.get("device_map", "auto"),
172
- trust_remote_code=True,
173
- local_files_only=True
174
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  tokenizer = AutoTokenizer.from_pretrained(
176
  model_path,
177
  trust_remote_code=True,
 
1
  """
2
  Unified model loading utility supporting ModelScope, HuggingFace and local path loading
3
  """
 
4
  import os
5
  import logging
 
 
6
  import threading
7
+ from typing import Optional, Dict, Any, Tuple
8
+ import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
10
+ from awq import AutoAWQForCausalLM
11
  from funasr_detach import AutoModel
 
12
 
13
  # Global cache for downloaded models to avoid repeated downloads
14
  # Key: (model_path, source)
 
101
  modelscope_patterns = []
102
  return any(pattern in model_path for pattern in modelscope_patterns)
103
 
104
+ def _prepare_quantization_config(self, quantization_config: Optional[str], torch_dtype: Optional[torch.dtype] = None) -> Tuple[Dict[str, Any], bool]:
105
+ """
106
+ Prepare quantization configuration for model loading
107
+
108
+ Args:
109
+ quantization_config: Quantization type ('int4', 'int8', 'int4_offline_awq', or None)
110
+ torch_dtype: PyTorch data type for compute operations
111
+
112
+ Returns:
113
+ Tuple of (quantization parameters dict, should_set_torch_dtype)
114
+ """
115
+ if not quantization_config:
116
+ return {}, True
117
+
118
+ quantization_config = quantization_config.lower()
119
+
120
+ if quantization_config == "int4_offline_awq":
121
+ # For pre-quantized AWQ models, no additional quantization needed
122
+ self.logger.info("πŸ”§ Loading pre-quantized AWQ 4-bit model (offline)")
123
+ return {}, True # Load pre-quantized model normally, allow torch_dtype setting
124
+
125
+ elif quantization_config == "int8":
126
+ # Use user-specified torch_dtype for compute, default to bfloat16
127
+ compute_dtype = torch_dtype if torch_dtype is not None else torch.bfloat16
128
+ self.logger.info(f"πŸ”§ INT8 quantization: using {compute_dtype} for compute operations")
129
+
130
+ bnb_config = BitsAndBytesConfig(
131
+ load_in_8bit=True,
132
+ bnb_8bit_compute_dtype=compute_dtype,
133
+ )
134
+ return {
135
+ "quantization_config": bnb_config
136
+ }, False # INT8 quantization handles data types automatically, don't set torch_dtype
137
+ elif quantization_config == "int4":
138
+ # Use user-specified torch_dtype for compute, default to bfloat16
139
+ compute_dtype = torch_dtype if torch_dtype is not None else torch.bfloat16
140
+ self.logger.info(f"πŸ”§ INT4 quantization: using {compute_dtype} for compute operations")
141
+
142
+ bnb_config = BitsAndBytesConfig(
143
+ load_in_4bit=True,
144
+ bnb_4bit_quant_type="nf4",
145
+ bnb_4bit_compute_dtype=compute_dtype,
146
+ bnb_4bit_use_double_quant=True,
147
+ )
148
+ return {
149
+ "quantization_config": bnb_config
150
+ }, False # INT4 quantization handles torch_dtype internally, don't set it again
151
+ else:
152
+ raise ValueError(f"Unsupported quantization config: {quantization_config}. Supported: 'int4', 'int8', 'int4_offline_awq'")
153
+
154
  def load_transformers_model(
155
  self,
156
  model_path: str,
157
  source: str = ModelSource.AUTO,
158
+ quantization_config: Optional[str] = None,
159
  **kwargs
160
+ ) -> Tuple:
161
  """
162
  Load Transformers model (for StepAudioTTS)
163
 
164
  Args:
165
  model_path: Model path or ID
166
  source: Model source, auto means auto-detect
167
+ quantization_config: Quantization configuration ('int4', 'int8', 'int4_offline_awq', or None for no quantization)
168
+ **kwargs: Other parameters (torch_dtype, device_map, etc.)
169
 
170
  Returns:
171
  (model, tokenizer) tuple
 
174
  source = self.detect_model_source(model_path)
175
 
176
  self.logger.info(f"Loading Transformers model from {source}: {model_path}")
177
+ if quantization_config:
178
+ self.logger.info(f"πŸ”§ {quantization_config.upper()} quantization enabled")
179
+
180
+ # Prepare quantization configuration
181
+ quantization_kwargs, should_set_torch_dtype = self._prepare_quantization_config(quantization_config, kwargs.get("torch_dtype"))
182
 
183
  try:
184
  if source == ModelSource.LOCAL:
185
  # Local loading
186
+ load_kwargs = {
187
+ "device_map": kwargs.get("device_map", "auto"),
188
+ "trust_remote_code": True,
189
+ "local_files_only": True
190
+ }
191
+
192
+ # Add quantization configuration if specified
193
+ load_kwargs.update(quantization_kwargs)
194
+
195
+ # Add torch_dtype based on quantization requirements
196
+ if should_set_torch_dtype and kwargs.get("torch_dtype") is not None:
197
+ load_kwargs["torch_dtype"] = kwargs.get("torch_dtype")
198
+
199
+ # Check if using AWQ quantization
200
+ if quantization_config and quantization_config.lower() == "int4_offline_awq":
201
+ # Use AWQ loading for pre-quantized AWQ models
202
+ awq_model_path = os.path.join(model_path, "awq_quantized")
203
+ if not os.path.exists(awq_model_path):
204
+ raise FileNotFoundError(f"AWQ quantized model not found at {awq_model_path}. Please run quantize_model_offline.py first.")
205
+
206
+ self.logger.info(f"πŸ”§ Loading AWQ quantized model from: {awq_model_path}")
207
+ model = AutoAWQForCausalLM.from_quantized(
208
+ awq_model_path,
209
+ device_map=kwargs.get("device_map", "auto"),
210
+ trust_remote_code=True
211
+ )
212
+ else:
213
+ # Standard loading
214
+ model = AutoModelForCausalLM.from_pretrained(
215
+ model_path,
216
+ **load_kwargs
217
+ )
218
  tokenizer = AutoTokenizer.from_pretrained(
219
  model_path,
220
  trust_remote_code=True,
 
227
  from modelscope import AutoTokenizer as MSAutoTokenizer
228
  model_path = self._cached_snapshot_download(model_path, ModelSource.MODELSCOPE)
229
 
230
+ load_kwargs = {
231
+ "device_map": kwargs.get("device_map", "auto"),
232
+ "trust_remote_code": True,
233
+ "local_files_only": True
234
+ }
235
+
236
+ # Add quantization configuration if specified
237
+ load_kwargs.update(quantization_kwargs)
238
+
239
+ # Add torch_dtype based on quantization requirements
240
+ if should_set_torch_dtype and kwargs.get("torch_dtype") is not None:
241
+ load_kwargs["torch_dtype"] = kwargs.get("torch_dtype")
242
+
243
+ # Check if using AWQ quantization
244
+ if quantization_config and quantization_config.lower() == "int4_offline_awq":
245
+ # Use AWQ loading for pre-quantized AWQ models
246
+ awq_model_path = os.path.join(model_path, "awq_quantized")
247
+ if not os.path.exists(awq_model_path):
248
+ raise FileNotFoundError(f"AWQ quantized model not found at {awq_model_path}. Please run quantize_model_offline.py first.")
249
+
250
+ self.logger.info(f"πŸ”§ Loading AWQ quantized model from: {awq_model_path}")
251
+ model = AutoAWQForCausalLM.from_quantized(
252
+ awq_model_path,
253
+ device_map=kwargs.get("device_map", "auto"),
254
+ trust_remote_code=True
255
+ )
256
+ else:
257
+ # Standard loading
258
+ model = MSAutoModelForCausalLM.from_pretrained(
259
+ model_path,
260
+ **load_kwargs
261
+ )
262
  tokenizer = MSAutoTokenizer.from_pretrained(
263
  model_path,
264
  trust_remote_code=True,
 
269
  model_path = self._cached_snapshot_download(model_path, ModelSource.HUGGINGFACE)
270
 
271
  # Load from HuggingFace
272
+ load_kwargs = {
273
+ "device_map": kwargs.get("device_map", "auto"),
274
+ "trust_remote_code": True,
275
+ "local_files_only": True
276
+ }
277
+
278
+ # Add quantization configuration if specified
279
+ load_kwargs.update(quantization_kwargs)
280
+
281
+ # Add torch_dtype based on quantization requirements
282
+ if should_set_torch_dtype and kwargs.get("torch_dtype") is not None:
283
+ load_kwargs["torch_dtype"] = kwargs.get("torch_dtype")
284
+
285
+ # Check if using AWQ quantization
286
+ if quantization_config and quantization_config.lower() == "int4_offline_awq":
287
+ # Use AWQ loading for pre-quantized AWQ models
288
+ awq_model_path = os.path.join(model_path, "awq_quantized")
289
+ if not os.path.exists(awq_model_path):
290
+ raise FileNotFoundError(f"AWQ quantized model not found at {awq_model_path}. Please run quantize_model_offline.py first.")
291
+
292
+ self.logger.info(f"πŸ”§ Loading AWQ quantized model from: {awq_model_path}")
293
+ model = AutoAWQForCausalLM.from_quantized(
294
+ awq_model_path,
295
+ device_map=kwargs.get("device_map", "auto"),
296
+ trust_remote_code=True
297
+ )
298
+ else:
299
+ # Standard loading
300
+ model = AutoModelForCausalLM.from_pretrained(
301
+ model_path,
302
+ **load_kwargs
303
+ )
304
  tokenizer = AutoTokenizer.from_pretrained(
305
  model_path,
306
  trust_remote_code=True,
requirements.txt CHANGED
@@ -22,3 +22,4 @@ gradio>=5.16.0
22
  nvidia-cuda-nvrtc-cu12==12.8.93
23
  spaces==0.42.1
24
  matplotlib==3.10.7
 
 
22
  nvidia-cuda-nvrtc-cu12==12.8.93
23
  spaces==0.42.1
24
  matplotlib==3.10.7
25
+ autoawq==0.2.8
tts.py CHANGED
@@ -60,7 +60,10 @@ class StepAudioTTS:
60
  model_path,
61
  audio_tokenizer,
62
  model_source=ModelSource.AUTO,
63
- tts_model_id=None
 
 
 
64
  ):
65
  """
66
  Initialize StepAudioTTS
@@ -70,6 +73,9 @@ class StepAudioTTS:
70
  audio_tokenizer: Audio tokenizer for wav2token processing
71
  model_source: Model source (auto/local/modelscope/huggingface)
72
  tts_model_id: TTS model ID, if None use model_path
 
 
 
73
  """
74
  # Determine model ID or path to load
75
  if tts_model_id is None:
@@ -87,8 +93,9 @@ class StepAudioTTS:
87
  self.llm, self.tokenizer, model_path = model_loader.load_transformers_model(
88
  tts_model_id,
89
  source=model_source,
90
- torch_dtype=torch.bfloat16,
91
- device_map="cuda"
 
92
  )
93
  logger.info(f"βœ… Successfully loaded LLM and tokenizer: {tts_model_id}")
94
  except Exception as e:
@@ -100,6 +107,9 @@ class StepAudioTTS:
100
  os.path.join(model_path, "CosyVoice-300M-25Hz")
101
  )
102
 
 
 
 
103
  # Use system prompts from config module
104
  self.edit_clone_sys_prompt_tpl = AUDIO_EDIT_CLONE_SYSTEM_PROMPT_TPL
105
  self.edit_sys_prompt = AUDIO_EDIT_SYSTEM_PROMPT
 
60
  model_path,
61
  audio_tokenizer,
62
  model_source=ModelSource.AUTO,
63
+ tts_model_id=None,
64
+ quantization_config=None,
65
+ torch_dtype=torch.bfloat16,
66
+ device_map="cuda"
67
  ):
68
  """
69
  Initialize StepAudioTTS
 
73
  audio_tokenizer: Audio tokenizer for wav2token processing
74
  model_source: Model source (auto/local/modelscope/huggingface)
75
  tts_model_id: TTS model ID, if None use model_path
76
+ quantization_config: Quantization configuration ('int4', 'int8', or None)
77
+ torch_dtype: PyTorch data type for model weights (default: torch.bfloat16)
78
+ device_map: Device mapping for model (default: "cuda")
79
  """
80
  # Determine model ID or path to load
81
  if tts_model_id is None:
 
93
  self.llm, self.tokenizer, model_path = model_loader.load_transformers_model(
94
  tts_model_id,
95
  source=model_source,
96
+ quantization_config=quantization_config,
97
+ torch_dtype=torch_dtype,
98
+ device_map=device_map
99
  )
100
  logger.info(f"βœ… Successfully loaded LLM and tokenizer: {tts_model_id}")
101
  except Exception as e:
 
107
  os.path.join(model_path, "CosyVoice-300M-25Hz")
108
  )
109
 
110
+ # Print final GPU memory usage after all models are loaded
111
+ logger.info("🎀 CosyVoice model loaded successfully")
112
+
113
  # Use system prompts from config module
114
  self.edit_clone_sys_prompt_tpl = AUDIO_EDIT_CLONE_SYSTEM_PROMPT_TPL
115
  self.edit_sys_prompt = AUDIO_EDIT_SYSTEM_PROMPT