llama31-8b-it-unroll / configuration_decilm.py
chengfu0118's picture
Upload folder using huggingface_hub
e2b84df verified
# coding=utf-8
# Copyright 2024 Nvidia Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import warnings
import os
from typing import Dict, Any
from transformers.utils import is_flash_attn_2_available
from transformers.utils import logging
from .block_config import BlockConfig
from .transformers_4_44_2__configuration_llama import LlamaConfig
from .transformers_4_44_2__modeling_rope_utils import rope_config_validation # fake import to make AutoConfig infer the dependency
rope_config_validation # this line is here to make sure that auto-formatting doesn't remove the import
logger = logging.get_logger("unroll-qwen25")
############## Block Configs -- Llama3.1-8B-Instruct ####################
attn_pruning_idx_ranking = {
"weighted-avg": [25, 2, 5, 4, 7, 6, 24, 10, 20, 12],
"simple-avg": [2, 7, 12, 6, 25, 5, 24, 4, 10, 8],
"rank-based": [25, 7, 24, 6, 12, 5, 2, 20, 4, 10],
"geometric": [7, 2, 12, 25, 6, 24, 5, 4, 11, 17],
"normalized": [2, 25, 7, 6, 5, 4, 24, 12, 10, 20],
"top-5-voting": [25, 24, 5, 4, 20, 10, 2, 6, 15, 12],
}
ffn_pruning_idx_ranking = {
"weighted-avg": [20, 1, 5, 19, 6, 2, 4, 26, 16, 9],
"simple-avg": [20, 19, 1, 5, 6, 2, 9, 26, 16, 4],
"rank-based": [26, 1, 5, 20, 19, 4, 24, 6, 2, 9],
"geometric": [19, 26, 5, 1, 6, 20, 16, 9, 4, 10],
"normalized": [1, 5, 6, 4, 26, 2, 20, 19, 16, 7],
"top-5-voting": [1, 5, 26, 24, 4, 6, 2, 20, 19, 9],
}
# Local rank for distributed training
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "-1"))
# # New configuration parameters for multi-layer pruning
NUM_PRUNE_ATTN = int(os.environ.get("NUM_PRUNE_ATTN", "0"))
NUM_PRUNE_FFN = int(os.environ.get("NUM_PRUNE_FFN", "0"))
PRUNE_METHOD = os.environ.get("PRUNE_METHOD", "weighted-avg")
def get_top_indices_to_prune(method: str, num_layers: int) -> list[int]:
"""Get the top N layers to prune based on the specified method."""
assert (
method in attn_pruning_idx_ranking
), f"Unknown pruning method '{method}'."
prune_indices = attn_pruning_idx_ranking[method][:num_layers]
logger.info(
f"[DEBUG][${{LOCAL_RANK}}] Using method '{method}' to select {num_layers} layers to prune: {prune_indices}."
)
return prune_indices
def _prune_multiple_layers(
block_configs: list[dict],
layer_indices: list[int],
module_type: str = "attention",
) -> list[dict]:
"""
Prune multiple layers based on a list of indices.
Args:
block_configs: List of block configuration dictionaries
layer_indices: List of layer indices to prune
module_type: Type of module to prune ("attention" or "ffn")
Returns:
Modified block_configs with specified layers pruned
"""
if not layer_indices or block_configs is None:
return block_configs
for layer_idx in layer_indices:
if 0 <= layer_idx < len(block_configs):
logger.info(
f"[DEBUG][${LOCAL_RANK}] Pruning {module_type} in layer id :: {layer_idx}"
)
block_configs[layer_idx][module_type]["no_op"] = True
else:
logger.warning(
f"[DEBUG] Invalid layer index {layer_idx} for {module_type} pruning"
)
return block_configs
class DeciLMConfig(LlamaConfig):
model_type = "nemotron-nas"
def __init__(
self,
block_configs: list[dict] | list[BlockConfig] = None,
**kwargs,
):
if NUM_PRUNE_ATTN > 0 or NUM_PRUNE_FFN > 0:
assert NUM_PRUNE_ATTN * NUM_PRUNE_FFN == 0, (
"Cannot prune both attention and ffn layers simultaneously. "
"Set either NUM_PRUNE_ATTN or NUM_PRUNE_FFN to 0."
)
if NUM_PRUNE_ATTN > 0:
assert PRUNE_METHOD in attn_pruning_idx_ranking
prune_indices = attn_pruning_idx_ranking[PRUNE_METHOD][:NUM_PRUNE_ATTN]
block_configs = _prune_multiple_layers(
block_configs,
prune_indices,
"attention",
)
if NUM_PRUNE_FFN > 0:
assert PRUNE_METHOD in ffn_pruning_idx_ranking
prune_indices = ffn_pruning_idx_ranking[PRUNE_METHOD][:NUM_PRUNE_FFN]
block_configs = _prune_multiple_layers(
block_configs,
prune_indices,
"ffn",
)
else:
logger.warning(f"[DEBUG][{LOCAL_RANK}] Use config json in model path.")
attn_implementation = kwargs.pop("attn_implementation", None)
if attn_implementation is None and is_flash_attn_2_available():
attn_implementation = "flash_attention_2"
if block_configs is not None:
if isinstance(block_configs[0], dict):
block_configs = [BlockConfig(**conf) for conf in block_configs]
using_unshifted_sink = any(
[
block_config.attention.unshifted_sink
for block_config in block_configs
]
)
if using_unshifted_sink and attn_implementation != "eager":
warnings.warn(
"Forcing attn_implementation='eager' since some attention layers use unshifted sink"
)
attn_implementation = "eager"
super().__init__(attn_implementation=attn_implementation, **kwargs)
self.intermediate_size = None
self.num_key_value_heads = None
if block_configs is not None:
assert len(block_configs) == self.num_hidden_layers
self.block_configs: list[BlockConfig] = block_configs
def to_dict(self) -> Dict[str, Any]:
self_dict = super().to_dict()
if self.block_configs is not None:
self_dict["block_configs"] = [
dataclasses.asdict(conf) for conf in self.block_configs
]
return self_dict