Source code for quark.torch.pruning.config

#
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#
"""Quark Pruning Config API for PyTorch"""

from __future__ import annotations

from abc import ABC
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Optional, Type, TypeVar

T = TypeVar("T", bound="ConfigBase")


[docs] @dataclass(eq=True) class ConfigBase(ABC): name = "" @classmethod def from_dict(cls: type[T], data: dict[str, Any]) -> T: return cls(**data) def update_from_dict(self, data: dict[str, Any]) -> None: for field_name in data: setattr(self, field_name, data[field_name]) def to_dict(self) -> dict[str, Any]: return asdict(self)
[docs] @dataclass(eq=True) class Config(ConfigBase): """ A class that encapsulates comprehensive pruning configurations for a machine learning model, allowing for detailed and hierarchical control over pruning parameters across different model components. :param Optional[AlgoConfig] algo_config: Optional configuration for the pruning algorithm, such as OSSCAR. After this process, the params will be reduced. Default is None. :param Optional[int] log_severity_level: 0:DEBUG, 1:INFO, 2:WARNING. 3:ERROR, 4:CRITICAL/FATAL. Default is 1. """ # Optional configuration for the pruning algorithm, such as OSSCAR # After this process, the datatype/fake_datatype of weights will be changed with pruning scales. algo_config: AlgoConfig | None = None blockwise_tuning_config: AlgoConfig | None = None # Log level for printing on screen log_severity_level: int | None = 1
[docs] @dataclass class AlgoConfigBase(ConfigBase): pass
[docs] @dataclass class AlgoConfig(AlgoConfigBase): pass
[docs] @dataclass class OSSCARConfig(AlgoConfig): name: str = "osscar" damp_percent: float = 0.01 true_sequential: bool = True inside_layer_modules: list[str] = field(default_factory=list) mlp_pruning_modules: list[str] = field(default_factory=list) mlp_scaling_layers: dict[str, str | None] = field(default_factory=dict) mlp_pruning_ratio: float = 0.1 mlp_intermediate_size_name: str = field(default_factory=str) model_decoder_layers: str = field(default_factory=str)
[docs] @dataclass class LayerImportancePruneConfig(AlgoConfig): """ Configuration for layer importance depth wise prune algorithm (for LLM model). :param int delete_layer_num: Number of layers to delete (at least 1). :param List[int] delete_layers_index: Specific indexes of layers to delete. :param bool save_gpu_memory: Whether to save GPU memory (tradeoff speed vs. memory). :param str layer_num_field: Field name for number of layers. :param str model_decoder_layers: Field name for decoder layers. :param str layer_norm_field: Field name for normalization layer. """ name: str = "layer_importance_depth_pruning" delete_layer_num: int = 1 # at least one layer # NOTE used for search delete_layers_index: list[int] = field(default_factory=list) save_gpu_memory: bool = False # balance between speed and layer_num_field: str = field(default_factory=str) model_decoder_layers: str = field(default_factory=str) layer_norm_field: str = field(default_factory=str)
[docs] @dataclass class BlockwiseTuningConfig(AlgoConfig): name: str = "blockwise_tuning" epochs: int = 5 weight_lr: float = 0.0001 weight_decay: float = 0.0 min_lr_factor: float = 20.0 max_grad_norm: float = 0.3 model_decoder_layers: str = field(default_factory=str) trainable_modules: list[str] = field(default_factory=list)