|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import dataclasses |
|
|
import warnings |
|
|
import os |
|
|
|
|
|
from typing import Dict, Any |
|
|
|
|
|
from transformers.utils import is_flash_attn_2_available |
|
|
from transformers.utils import logging |
|
|
|
|
|
from .block_config import BlockConfig |
|
|
from .transformers_4_44_2__configuration_llama import LlamaConfig |
|
|
from .transformers_4_44_2__modeling_rope_utils import rope_config_validation |
|
|
|
|
|
rope_config_validation |
|
|
|
|
|
logger = logging.get_logger("unroll-qwen25") |
|
|
|
|
|
|
|
|
|
|
|
attn_pruning_idx_ranking = { |
|
|
"weighted-avg": [25, 2, 5, 4, 7, 6, 24, 10, 20, 12], |
|
|
"simple-avg": [2, 7, 12, 6, 25, 5, 24, 4, 10, 8], |
|
|
"rank-based": [25, 7, 24, 6, 12, 5, 2, 20, 4, 10], |
|
|
"geometric": [7, 2, 12, 25, 6, 24, 5, 4, 11, 17], |
|
|
"normalized": [2, 25, 7, 6, 5, 4, 24, 12, 10, 20], |
|
|
"top-5-voting": [25, 24, 5, 4, 20, 10, 2, 6, 15, 12], |
|
|
} |
|
|
|
|
|
ffn_pruning_idx_ranking = { |
|
|
"weighted-avg": [20, 1, 5, 19, 6, 2, 4, 26, 16, 9], |
|
|
"simple-avg": [20, 19, 1, 5, 6, 2, 9, 26, 16, 4], |
|
|
"rank-based": [26, 1, 5, 20, 19, 4, 24, 6, 2, 9], |
|
|
"geometric": [19, 26, 5, 1, 6, 20, 16, 9, 4, 10], |
|
|
"normalized": [1, 5, 6, 4, 26, 2, 20, 19, 16, 7], |
|
|
"top-5-voting": [1, 5, 26, 24, 4, 6, 2, 20, 19, 9], |
|
|
} |
|
|
|
|
|
|
|
|
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "-1")) |
|
|
|
|
|
|
|
|
NUM_PRUNE_ATTN = int(os.environ.get("NUM_PRUNE_ATTN", "0")) |
|
|
NUM_PRUNE_FFN = int(os.environ.get("NUM_PRUNE_FFN", "0")) |
|
|
PRUNE_METHOD = os.environ.get("PRUNE_METHOD", "weighted-avg") |
|
|
|
|
|
|
|
|
def get_top_indices_to_prune(method: str, num_layers: int) -> list[int]: |
|
|
"""Get the top N layers to prune based on the specified method.""" |
|
|
assert ( |
|
|
method in attn_pruning_idx_ranking |
|
|
), f"Unknown pruning method '{method}'." |
|
|
prune_indices = attn_pruning_idx_ranking[method][:num_layers] |
|
|
|
|
|
logger.info( |
|
|
f"[DEBUG][${{LOCAL_RANK}}] Using method '{method}' to select {num_layers} layers to prune: {prune_indices}." |
|
|
) |
|
|
return prune_indices |
|
|
|
|
|
|
|
|
def _prune_multiple_layers( |
|
|
block_configs: list[dict], |
|
|
layer_indices: list[int], |
|
|
module_type: str = "attention", |
|
|
) -> list[dict]: |
|
|
""" |
|
|
Prune multiple layers based on a list of indices. |
|
|
|
|
|
Args: |
|
|
block_configs: List of block configuration dictionaries |
|
|
layer_indices: List of layer indices to prune |
|
|
module_type: Type of module to prune ("attention" or "ffn") |
|
|
|
|
|
Returns: |
|
|
Modified block_configs with specified layers pruned |
|
|
""" |
|
|
if not layer_indices or block_configs is None: |
|
|
return block_configs |
|
|
|
|
|
for layer_idx in layer_indices: |
|
|
if 0 <= layer_idx < len(block_configs): |
|
|
logger.info( |
|
|
f"[DEBUG][${LOCAL_RANK}] Pruning {module_type} in layer id :: {layer_idx}" |
|
|
) |
|
|
block_configs[layer_idx][module_type]["no_op"] = True |
|
|
else: |
|
|
logger.warning( |
|
|
f"[DEBUG] Invalid layer index {layer_idx} for {module_type} pruning" |
|
|
) |
|
|
|
|
|
return block_configs |
|
|
|
|
|
|
|
|
class DeciLMConfig(LlamaConfig): |
|
|
model_type = "nemotron-nas" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
block_configs: list[dict] | list[BlockConfig] = None, |
|
|
**kwargs, |
|
|
): |
|
|
if NUM_PRUNE_ATTN > 0 or NUM_PRUNE_FFN > 0: |
|
|
assert NUM_PRUNE_ATTN * NUM_PRUNE_FFN == 0, ( |
|
|
"Cannot prune both attention and ffn layers simultaneously. " |
|
|
"Set either NUM_PRUNE_ATTN or NUM_PRUNE_FFN to 0." |
|
|
) |
|
|
if NUM_PRUNE_ATTN > 0: |
|
|
assert PRUNE_METHOD in attn_pruning_idx_ranking |
|
|
prune_indices = attn_pruning_idx_ranking[PRUNE_METHOD][:NUM_PRUNE_ATTN] |
|
|
block_configs = _prune_multiple_layers( |
|
|
block_configs, |
|
|
prune_indices, |
|
|
"attention", |
|
|
) |
|
|
if NUM_PRUNE_FFN > 0: |
|
|
assert PRUNE_METHOD in ffn_pruning_idx_ranking |
|
|
prune_indices = ffn_pruning_idx_ranking[PRUNE_METHOD][:NUM_PRUNE_FFN] |
|
|
block_configs = _prune_multiple_layers( |
|
|
block_configs, |
|
|
prune_indices, |
|
|
"ffn", |
|
|
) |
|
|
else: |
|
|
logger.warning(f"[DEBUG][{LOCAL_RANK}] Use config json in model path.") |
|
|
|
|
|
attn_implementation = kwargs.pop("attn_implementation", None) |
|
|
if attn_implementation is None and is_flash_attn_2_available(): |
|
|
attn_implementation = "flash_attention_2" |
|
|
|
|
|
if block_configs is not None: |
|
|
if isinstance(block_configs[0], dict): |
|
|
block_configs = [BlockConfig(**conf) for conf in block_configs] |
|
|
|
|
|
using_unshifted_sink = any( |
|
|
[ |
|
|
block_config.attention.unshifted_sink |
|
|
for block_config in block_configs |
|
|
] |
|
|
) |
|
|
if using_unshifted_sink and attn_implementation != "eager": |
|
|
warnings.warn( |
|
|
"Forcing attn_implementation='eager' since some attention layers use unshifted sink" |
|
|
) |
|
|
attn_implementation = "eager" |
|
|
|
|
|
super().__init__(attn_implementation=attn_implementation, **kwargs) |
|
|
|
|
|
self.intermediate_size = None |
|
|
self.num_key_value_heads = None |
|
|
|
|
|
if block_configs is not None: |
|
|
assert len(block_configs) == self.num_hidden_layers |
|
|
|
|
|
self.block_configs: list[BlockConfig] = block_configs |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
self_dict = super().to_dict() |
|
|
if self.block_configs is not None: |
|
|
self_dict["block_configs"] = [ |
|
|
dataclasses.asdict(conf) for conf in self.block_configs |
|
|
] |
|
|
return self_dict |
|
|
|