FP8 Quantization for LLM models#

This tutorial demonstrates how to perform FP8 quantization on Large Language Models (LLMs) using AMD-Quark. We will guide you through the following steps: 1. Setup and installation 2. Data preparation 3. Model loading 4. Quantization process 5. Evaluation

1. Setup#

First, ensure you have the necessary libraries installed.

pip install torch
pip install transformers==4.52.1
pip install tqdm
pip install datasets
pip install accelerate

Import the required modules.

from typing import Any

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer

from quark.torch import LLMTemplate, ModelQuantizer, export_safetensors

2. Data Preparation#

We need a calibration dataset to gather statistics for quantization. Here we use a subset of the Pile validation set. We also define functions to initialize the tokenizer and create a dataloader.

# -----------------------------
# Dataset / Tokenizer
# -----------------------------
def get_pileval(
    tokenizer: PreTrainedTokenizer,
    nsamples: int,
    seqlen: int,
    device: str | None,
    seed: int = 0,
) -> torch.Tensor:
    dataset: Any = load_dataset("mit-han-lab/pile-val-backup", split="validation").shuffle(seed=seed)
    samples, n_run = [], 0

    for data in dataset:
        line_encoded = tokenizer.encode(data["text"].strip())
        if 0 < len(line_encoded) <= seqlen:
            samples.append(torch.tensor([line_encoded], device=device))
            n_run += 1
        if n_run == nsamples:
            break

    cat_samples = torch.cat(samples, dim=1)
    n_split = cat_samples.shape[1] // seqlen
    train_dataset = [cat_samples[:, i * seqlen : (i + 1) * seqlen] for i in range(n_split)]

    return torch.cat(train_dataset, dim=0)


def get_tokenizer(model_id: str, max_seq_len: int = 512) -> PreTrainedTokenizer:
    print(f"Initializing tokenizer from {model_id}")
    tokenizer = AutoTokenizer.from_pretrained(
        model_id,
        model_max_length=max_seq_len,
        padding_side="left",
        trust_remote_code=True,
        use_fast=False,
    )
    if tokenizer.pad_token != "<unk>":
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    assert tokenizer.pad_token is not None, "Pad token cannot be set!"
    return tokenizer


def get_dataloader(
    tokenizer: PreTrainedTokenizer,
    batch_size: int,
    device: str | None,
    seq_len: int = 512,
) -> DataLoader:
    samples: torch.Tensor = get_pileval(tokenizer, nsamples=128, seqlen=seq_len, device=device, seed=42)
    return DataLoader(samples, batch_size=batch_size, shuffle=False, drop_last=True)

3. Model Loading#

Load the pre-trained model. We configure it to use eager attention implementation and auto device placement.

# -----------------------------
# Model / Quantization
# -----------------------------
def get_model(model_id: str, device: str | None) -> PreTrainedModel:
    model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
        model_id,
        attn_implementation="eager",
        torch_dtype="auto",
    )
    return model.eval().to(device)

4. Quantization Process#

Now we define the quantization pipeline. 1. We fetch the quantization configuration for the model architecture, specifying fp8 for both weights and KV cache. 2. We initialize the ModelQuantizer with this configuration. 3. We run quantize_model using the model and calibration dataloader. 4. Finally, we export the quantized model to the safetensors format.

def quantize_model_pipeline(
    model: PreTrainedModel,
    calib_dataloader: DataLoader,
    tokenizer: PreTrainedTokenizer,
) -> PreTrainedModel:
    template = LLMTemplate.get(model.config.model_type)
    quant_config = template.get_config(scheme="fp8", kv_cache_scheme="fp8")

    quantizer = ModelQuantizer(quant_config, multi_device=True)
    quantized_model: PreTrainedModel = quantizer.quantize_model(model, calib_dataloader)

    print("[INFO] Export Quant Model.")
    export_safetensors(model=quantized_model, output_dir="./")
    tokenizer.save_pretrained("./")

    return quantized_model

5. Evaluation#

To verify the quantization quality, we calculate the perplexity (PPL) on the WikiText-2 test set.

# -----------------------------
# Evaluation
# -----------------------------
@torch.no_grad()
def ppl_eval(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    device: str | None,
) -> torch.Tensor:
    testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt").input_ids.to(device)

    seqlen_for_eval = 2048
    nsamples = testenc.numel() // seqlen_for_eval
    nlls: list[torch.Tensor] = []

    for i in tqdm(range(nsamples)):
        batch = testenc[:, i * seqlen_for_eval : (i + 1) * seqlen_for_eval]
        lm_logits = model(batch)["logits"]

        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = batch[:, 1:]

        loss = torch.nn.CrossEntropyLoss()(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
        )
        nlls.append(loss.float() * seqlen_for_eval)

    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * seqlen_for_eval))
    return ppl

6. Execution#

Put everything together: define the main execution function and run it. In this example, we use the Qwen/Qwen2.5-0.5B model.

# -----------------------------
# Pipeline
# -----------------------------
def run_quark_fp8_example() -> None:
    model_id = "Qwen/Qwen2.5-0.5B"
    batch_size, seq_len = 4, 512
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"[INFO] Loading model: {model_id}")
    model = get_model(model_id, device)
    tokenizer = get_tokenizer(model_id, max_seq_len=seq_len)
    calib_dataloader = get_dataloader(tokenizer, batch_size, device, seq_len)

    print("[INFO] Starting quantization...")
    quantized_model = quantize_model_pipeline(model, calib_dataloader, tokenizer)
    print("[INFO] Quantization complete.")

    print("[INFO] Simple test PPL with wikitext-2.")
    ppl = ppl_eval(quantized_model, tokenizer, device)
    print(f"[INFO] Perplexity: {ppl.item():.4f}")


if __name__ == "__main__":
    with torch.no_grad():
        run_quark_fp8_example()