Source code for quark.torch.pruning.config
#
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#
"""Quark Pruning Config API for PyTorch"""
from __future__ import annotations
from abc import ABC
from dataclasses import dataclass, field, asdict
from typing import Optional, TypeVar, Type, Any, Dict, List
T = TypeVar('T', bound='ConfigBase')
[docs]
@dataclass(eq=True)
class ConfigBase(ABC):
name = ""
@classmethod
def from_dict(cls: Type[T], data: Dict[str, Any]) -> T:
return cls(**data)
def update_from_dict(self, data: Dict[str, Any]) -> None:
for field_name in data:
setattr(self, field_name, data[field_name])
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
[docs]
@dataclass(eq=True)
class Config(ConfigBase):
"""
A class that encapsulates comprehensive pruning configurations for a machine learning model, allowing for detailed and hierarchical control over pruning parameters across different model components.
:param Optional[AlgoConfig] algo_config: Optional configuration for the pruning algorithm, such as OSSCAR. After this process, the params will be reduced. Default is None.
:param Optional[int] log_severity_level: 0:DEBUG, 1:INFO, 2:WARNING. 3:ERROR, 4:CRITICAL/FATAL. Default is 1.
"""
# Optional configuration for the pruning algorithm, such as OSSCAR
# After this process, the datatype/fake_datatype of weights will be changed with pruning scales.
algo_config: Optional[AlgoConfig] = None
blockwise_tuning_config: Optional[AlgoConfig] = None
# Log level for printing on screen
log_severity_level: Optional[int] = 1
[docs]
@dataclass
class AlgoConfigBase(ConfigBase):
pass
[docs]
@dataclass
class AlgoConfig(AlgoConfigBase):
pass
[docs]
@dataclass
class OSSCARConfig(AlgoConfig):
name: str = "osscar"
damp_percent: float = 0.01
true_sequential: bool = True
inside_layer_modules: List[str] = field(default_factory=list)
mlp_pruning_modules: List[str] = field(default_factory=list)
mlp_scaling_layers: Dict[str, Optional[str]] = field(default_factory=dict)
mlp_pruning_ratio: float = 0.1
mlp_intermediate_size_name: str = field(default_factory=str)
model_decoder_layers: str = field(default_factory=str)
[docs]
@dataclass
class BlockwiseTuningConfig(AlgoConfig):
name: str = "blockwise_tuning"
epochs: int = 5
weight_lr: float = 0.0001
weight_decay: float = 0.0
min_lr_factor: float = 20.0
max_grad_norm: float = 0.3
model_decoder_layers: str = field(default_factory=str)
trainable_modules: List[str] = field(default_factory=list)