Vision Model Quantization using Quark FX Graph Mode#
In this example, we present a vision model quantization workflow. The
user specified a nn.Module
and transformed the model to
torch.fx.GraphModule
format by using PyTorch API. During the
quantization process, after annotation and insertion quantizers, this
modified fx.GraphModule
can be used to perform PTQ (Post Training
Quantization), or/and QAT (Quantization Aware Training). We supply a
demonstration code and show how users assign quant config
, more
information can be found in User
Guide.
PTQ#
In PTQ, after the FakeQuantize
is inserted, during the calibration,
the observer
is activated for recoding the tensor’s distribution the
values such as min and max will be recorded to calculate quant
parameters, while not performing fake quantizing. This means all the
calculations are under FP32 precision. After the calibration, we will
activate the fake quantizer to perform quantization and evaluation.
QAT#
Same as PTQ, after the model is prepared. During the training process,
both observer
and fake_quant
are effective, observer
is used
for recording the tensor’s distribution such as min and max value to
calculate quantization parameters, and the tensor will be quantized by
fake_quant
.
Quick Start#
Perform PTQ to get the quantized model and export to ONNX
python3 quantize.py --data_dir [Train and Test Data floder] \
--model_name [mobilenetv2 or resnet18] \
--pretrained [Pre-trained model file address] \
--model_export onnx
--export_dir [dir to save exported model]
Users can also select to perform QAT to further improve classification accuracy. Typically, that there are some training parameters that need to be modified for higher accuracy.
python3 quantize.py --data_dir [Train and Test Data floder] \
--model_name [mobilenetv2 or resnet18] \
--pretrained [Pre-trained model file address] \
--model_export onnx
--export_dir [dir to save exported model] \
--qat True
Fine-Grained User Guide#
Step1:Prepare float point model, dataset, loss function
from torchvision.models import resnet18
float_model = resnet18(pretrained=False)
float_model.load_state_dict(torch.load(pretrained))
calib_loader = prepare_calib_dataset(args.data_dir, device, calib_length=args.train_batch_size * 10)
train_loader, val_loader = prepare_data_loaders(args.data_dir)
criterion = nn.CrossEntropyLoss().to(device)
Step 2: transformer the ``torch.nn.Module`` to ``torch.fx.GraphModule``.
from torch._export import capture_pre_autograd_graph
example_inputs = (torch.rand(args.train_batch_size, 3, 224, 224).to(device), )
graph_model = capture_pre_autograd_graph(float_model, example_inputs)
Step3: Init the quantizer and quantization configuration
from quark.torch.quantization.config.config import QuantizationSpec, QuantizationConfig, Config
from quark.torch.quantization.config.type import Dtype, QSchemeType, ScaleType, RoundType, QuantizationMode
from quark.torch.quantization.observer.observer import PerTensorMinMaxObserver
INT8_PER_TENSER_SPEC = QuantizationSpec(dtype=Dtype.int8,
qscheme=QSchemeType.per_tensor,
observer_cls=PerTensorMinMaxObserver,
symmetric=True,
scale_type=ScaleType.float,
round_method=RoundType.half_even,
is_dynamic=False)
quant_config = QuantizationConfig(input_tensors=INT8_PER_TENSER_SPEC,
output_tensors=INT8_PER_TENSER_SPEC,
weight=INT8_PER_TENSER_SPEC,
bias=INT8_PER_TENSER_SPEC)
quant_config = Config(global_quant_config=quant_config,
quant_mode=QuantizationMode.fx_graph_mode)
quantizer = ModelQuantizer(quant_config)
Step4: Generate the quantized graph model by performing calibration
quantized_model = quantizer.quantize_model(graph_model, calib_loader)
Step5 (Optional): QAT for more high accuracy
train(quantized_model, train_loader, val_loader, criterion, device_ids)
Step6: Validate model performance and export
acc1_quant = validate(val_loader, quantized_model, criterion, device)
freezed_model = quantizer.freeze(prepared_model)
acc1_freeze = validate(val_loader, freezed_model, criterion, device)
# check whether acc1_quant == acc1_freeze
# ==============export to ONNX ==================
from quark.torch import ModelExporter
from quark.torch.export.config.custom_config import DEFAULT_EXPORTER_CONFIG
config = DEFAULT_EXPORTER_CONFIG
exporter = ModelExporter(config=config, export_dir=args.export_dir)
example_inputs = (torch.rand(batch_size, 3, 224, 224).to(device),)
exporter.export_onnx_model(freezed_model, example_inputs[0])
# ==========export using torch.export============
example_inputs = (next(iter(val_loader))[0].to(device),)
model_file_path = os.path.join(args.export_dir, args.model_name + ".pth")
exported_model = torch.export.export(freezeded_model, example_inputs)
torch.export.save(exported_model, model_file_path)
Experiment Result#
We conducted PTQ and QAT on both ResNet-18 and MobileNet-V2. In these model, all weight, bias, and activation are quantized. All kinds of Tensors are quantized in INT8, per-tensor, symmetric(zero point is 0). The scale factor is in float format. The following table shows the validation accuracy in the ImageNet dataset produced by the above script.
Method |
ResNet-18 |
MobileNetV2 |
---|---|---|
Float Model |
69.764 / 89.085 |
71.881 / 90.301 |
PTQ (INT8) |
69.084 / 88.648 |
65.291 / 86.254 |
QAT (INT8) |
69.469 / 88.872 |
68.562 /88.484 |