Blockwise Joint Tuning (Experimental)#
Warning
Blockwise joint tuning is an experimental feature. It depends on
quark.experimental (in particular
ExperimentalLearnableQuantizedLinear in
quark.experimental.torch.algorithm.blockwise_joint_tuning),
a temporary learnable-quantizer implementation that may change or be
replaced by the official Quark QuantLinear in a future release. APIs and
defaults described here are not yet stable.
AMD Quark supports blockwise joint tuning, a post-training optimization that
recovers accuracy lost during weight quantization by fine-tuning a quantized
model one decoder block at a time. For each block, the original (full-precision)
model provides a teacher signal, and the quantized block is trained to reproduce
that block’s output. “Joint” means the optimization updates both the block
weights and the quantization parameters (scale / zero_point)
together, each with its own learning rate and schedule.
How blockwise joint tuning works#
Quantization error accumulates layer by layer. Rather than retraining the whole network end to end, blockwise joint tuning performs a localized knowledge-distillation per decoder block, which keeps the memory footprint small and the optimization well conditioned:
Teacher forward. The corresponding full-precision block runs on the block’s input batch to produce reference (teacher) outputs.
Make the block learnable. Every
nn.Linearin the quantized block is wrapped in anExperimentalLearnableQuantizedLinearso that itsscale/zero_pointbecome trainable parameters.Joint training. The quantized block is trained to minimize the MSE between its output and the teacher output, optimizing two independent parameter groups – block weights and quantization parameters.
Propagate. The tuned block runs forward to produce the inputs for the next block, and the loop continues.
Two parameter groups are optimized with separate AdamW settings and separate
CosineAnnealingLR schedules:
Weights – modules selected by
trainable_modules, learning rateweight_lr, weight decayweight_decay.Quantization parameters (
scale/zero_point) – modules selected byquant_trainable_modules, learning rateqparam_lr, weight decayqparam_weight_decay.
To bound memory, only the block currently being tuned is moved to the GPU; all other blocks stay on CPU. During training the active block is cast to FP32 and trained under automatic mixed precision (AMP), then restored to its original dtype so downstream evaluation stays dtype-consistent.
The algorithm is implemented by BlockwiseJointTuningProcessor
(in quark.torch.algorithm.blockwise_joint_tuning).
Relation to blockwise_tuning#
Quark also ships a simpler blockwise_tuning algorithm. The difference is what
gets optimized:
blockwise_tuningtunes only module weights selected bytrainable_modules(pluslayernormmodules).blockwise_joint_tuningjointly tunes module weights and quantization parameters (scale/zero_point), with independent controls for each group.
Choose joint tuning when learnable quantization parameters are expected to help
(for example aggressive weight-only INT4 with group quantization); otherwise
plain blockwise_tuning is lighter weight.
Configuring blockwise joint tuning#
Blockwise joint tuning is configured with BlockwiseJointTuningConfig (from
quark.torch.algorithm.config) and applied with the blockwise_tuning_algo
entry point, which – unlike the regular algo_config algorithms – takes
both the full-precision reference model and the quantized model.
Parameter |
Default |
Description |
|---|---|---|
|
|
Algorithm name; selects |
|
5 |
Number of training epochs run per decoder block. |
|
1e-4 |
Learning rate for the block-weight parameter group. Set to |
|
1e-4 |
Learning rate for the quantization-parameter ( |
|
0.0 |
Weight decay for the block-weight group. |
|
0.0 |
Weight decay for the quantization-parameter group. |
|
20.0 |
Cosine schedule floor: each group anneals to |
|
0.3 |
Gradient-norm clipping threshold. |
|
|
Attribute path to the list of decoder blocks to tune
(for example |
|
|
Substring filters selecting which modules’ weights are trainable. Empty means match all modules in the block. |
|
|
Substring filters selecting which modules’ quantization parameters are trainable. Empty means match all. |
Usage#
Blockwise joint tuning is typically applied as a post-PTQ step: quantize the model (for example weight-only INT4), then tune it against a fresh full-precision reference model on a calibration/fine-tuning dataloader.
from quark.torch.algorithm.api import blockwise_tuning_algo
from quark.torch.algorithm.config import BlockwiseJointTuningConfig
# `model` is the already-quantized model.
# `ref_model` is a freshly loaded full-precision copy (the teacher).
ref_model.eval()
blockwise_cfg = BlockwiseJointTuningConfig.from_dict({
"name": "blockwise_joint_tuning",
"epochs": 5,
"weight_lr": 1e-4,
"qparam_lr": 1e-4,
"weight_decay": 0.0,
"qparam_weight_decay": 0.0,
"min_lr_factor": 20.0,
"max_grad_norm": 0.3,
"model_decoder_layers": "model.layers",
"trainable_modules": [],
"quant_trainable_modules": [],
})
is_accelerate = hasattr(model, "hf_device_map")
model = blockwise_tuning_algo(
ref_model,
model,
blockwise_cfg,
is_accelerate,
dataloader,
)
For a complete, validated end-to-end flow (load model, PTQ weight-only
quantization, blockwise joint tuning, then export) see the EfficientQAT example
at examples/torch/language_modeling/llm_qat/efficientqat/main.py.
Notes and limitations#
Memory. Two models are resident at once (quantized + full-precision reference). Per-block GPU offload keeps peak usage low, but the full-precision reference must be available throughout tuning.
Data. A calibration/fine-tuning dataloader is required; the cached block-level activations scale with the number of calibration samples.
Placement. Joint tuning is intended to run after weight quantization (for example PTQ weight-only INT4), not before.