quark.torch.extensions.brevitas.api
#
Module Contents#
Classes#
- class quark.torch.extensions.brevitas.api.ModelQuantizer(config: quark.torch.extensions.brevitas.config.Config)#
Provides an API for quantizing deep learning models using Brevitas.
The way this class interacts with Brevitas is based on the brevitas ptq example found here: Xilinx/brevitas
- Example usage:
weight_spec = QuantizationSpec() global_config = QuantizationConfig(weight=weight_spec) config = Config(global_quant_config=global_config) quantizer = ModelQuantizer(config) quant_model = quantizer.quantize_model(model, calib_dataloader)
- quantize_model(model: torch.nn.Module, calib_loader: Optional[torch.utils.data.DataLoader] = None) torch.nn.Module #
Quantizes the given model.
model: The model to be quantized.
calib_loader: A dataloader for calibration data, technically optional but required for most quantization processes.
- class quark.torch.extensions.brevitas.api.ModelExporter(export_path: str)#
Provides an API for exporting pytorch models quantized with Brevitas. This class converts the quantized model to an onnx graph, and saves it to the specified export_path.
- Example usage:
exporter = ModelExporter(“model.onnx”) exporter.export_onnx_model(quant_model, args=torch.ones(1, 1, 784))
- export_onnx_model(model: torch.nn.Module, args: Union[torch.Tensor, Tuple[torch.Tensor]]) None #
Exports a model to onnx.
model: The pytorch model to export.
args: Representative tensor(s) in the same shape as the expected input(s) (can be zero, random, ones or even real data).