PyTorch quantization#
Quark Quantization API for PyTorch.
- class quark.torch.quantization.api.ModelQuantizer(config: Config, multi_device: bool = False)[source]#
Provides an API for quantizing deep learning models using PyTorch.
This class handles the configuration and processing of the model for quantization based on user-defined parameters. It is essential to ensure that the ‘config’ provided has all necessary quantization parameters defined. This class assumes that the model is compatible with the quantization settings specified in ‘config’.
- Parameters:
config (Config) – The model quantization configuration.
- quantize_model(model: Module, dataloader: DataLoader[Tensor] | DataLoader[List[Dict[str, Tensor]]] | DataLoader[Dict[str, Tensor]] | DataLoader[List[BatchFeature]] | None = None) Module [source]#
Quantizes the given PyTorch model to optimize its performance and reduce its size.
The dataloader is used to provide data necessary for calibration during the quantization process. Depending on the type of data provided (either tensors directly or structured as lists or dictionaries of tensors), the function will adapt the quantization approach accordingly.
It is important that the model and dataloader are compatible in terms of the data they expect and produce. Misalignment in data handling between the model and the dataloader can lead to errors during the quantization process.
- Parameters:
model (torch.nn.Module) – The PyTorch model to be quantized. This model should be already trained and ready for quantization.
dataloader (Optional[Union[DataLoader[torch.Tensor], DataLoader[List[Dict[str, torch.Tensor]]], DataLoader[Dict[str, torch.Tensor]], DataLoader[List[BatchFeature]]]]) – The
torch.utils.data.DataLoader
providing data that the quantization process will use for calibration. This can be a simpleDataLoader
returning tensors, or a more complex structure returning either a list of dictionaries or a dictionary of tensors.
- Returns:
The quantized version of the input model. This model is now optimized for inference with reduced size and potentially improved performance on targeted devices.
- Return type:
torch.nn.Module
Example:
# Model & Data preparation from torch.utils.data import DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer from quark.torch.quantization.config.config import Config from quark.torch.quantization.config.type import Dtype, ScaleType, RoundType, QSchemeType from quark.torch.quantization.observer.observer import PerGroupMinMaxObserver from quark.torch import ModelQuantizer model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") model.eval() tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") quant_spec = QuantizationSpec( dtype=Dtype.uint4, observer_cls=PerGroupMinMaxObserver, symmetric=False, scale_type=ScaleType.float, round_method=RoundType.half_even, qscheme=QSchemeType.per_group, ch_axis=1, is_dynamic=False, group_size=128 ) quant_config = Config(global_quant_config=QuantizationConfig(weight=quant_spec)) text = "Hello, how are you?" tokenized_outputs = tokenizer(text, return_tensors="pt") calib_dataloader = DataLoader(tokenized_outputs['input_ids']) quantizer = ModelQuantizer(quant_config) quant_model = quantizer.quantize(model, calib_dataloader)
- static freeze(model: Module | GraphModule) Module | GraphModule [source]#
Freezes the quantized model by replacing
FakeQuantize
modules withFreezedFakeQuantize
modules.`In order to be able to compile a quantized model through
torch.compile
, this method needs to be applied.- Parameters:
model (torch.nn.Module) – The neural network model containing quantized layers.
- Returns:
The modified model with
FakeQuantize
modules replaced byFreezedFakeQuantize
modules.- Return type:
torch.nn.Module
- quark.torch.quantization.api.load_params(model: Module | None = None, json_path: str = '', safetensors_path: str = '', pth_path: str = '', quant_mode: QuantizationMode = QuantizationMode.eager_mode, compressed: bool = False, reorder: bool = True) Module [source]#
Instantiates a quantized model from saved model files, which is generated from the
quark.torch.export.api.save_params()
function.- Parameters:
model (torch.nn.Module) – The original Pytorch model.
json_path (str) – The path of the saved json file. Only available for eager mode quantization.
safetensors_path (str) – The path of the saved safetensors file. Only available for eager mode quantization.
pth_path (str) – The path of the saved
.pth
file. Only available forfx_graph
mode quantization.quant_mode (QuantizationMode) – The quantization mode. The choice includes
"QuantizationMode.eager_mode"
and"QuantizationMode.fx_graph_mode"
. Default is"QuantizationMode.eager_mode"
.compressed (bool) – Whether the quantized model to load is stored using its compressed data type, or in a “fake quantized” format (QDQ).
reorder (bool) – Reorder.
- Returns:
The reloaded quantized version of the input model.
- Return type:
torch.nn.Module
Examples:
# eager mode: from quark.torch import load_params model = load_params(model, json_path=json_path, safetensors_path=safetensors_path)
# fx_graph mode: from quark.torch.quantization.api import load_params model = load_params(pth_path=model_file_path, quant_mode=QuantizationMode.fx_graph_mode)
- Note:
This function does not support dynamic quantization for now.