FP8 Quantization for LLM models#
This script demonstrates FP8 quantization of an LLM model using AMD-Quark; ensure AMD-Quark is installed before running.
pip install torch
pip install transformers==4.52.1
pip install tqdm
pip install datasets
pip install accelerate
from typing import Any, Optional
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
from quark.torch import LLMTemplate, ModelQuantizer, export_safetensors
from quark.torch.quantization.config.config import load_quant_algo_config_from_file
# -----------------------------
# Dataset / Tokenizer
# -----------------------------
def get_pileval(
tokenizer: PreTrainedTokenizer,
nsamples: int,
seqlen: int,
device: str | None,
seed: int = 0,
) -> torch.Tensor:
dataset: Any = load_dataset("mit-han-lab/pile-val-backup", split="validation").shuffle(seed=seed)
samples, n_run = [], 0
for data in dataset:
line_encoded = tokenizer.encode(data["text"].strip())
if 0 < len(line_encoded) <= seqlen:
samples.append(torch.tensor([line_encoded], device=device))
n_run += 1
if n_run == nsamples:
break
cat_samples = torch.cat(samples, dim=1)
n_split = cat_samples.shape[1] // seqlen
train_dataset = [cat_samples[:, i * seqlen : (i + 1) * seqlen] for i in range(n_split)]
return torch.cat(train_dataset, dim=0)
def get_tokenizer(model_id: str, max_seq_len: int = 512) -> PreTrainedTokenizer:
print(f"Initializing tokenizer from {model_id}")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
model_max_length=max_seq_len,
padding_side="left",
trust_remote_code=True,
use_fast=False,
)
if tokenizer.pad_token != "<unk>":
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
assert tokenizer.pad_token is not None, "Pad token cannot be set!"
return tokenizer
def get_dataloader(
tokenizer: PreTrainedTokenizer,
batch_size: int,
device: str | None,
seq_len: int = 512,
) -> DataLoader:
samples: torch.Tensor = get_pileval(tokenizer, nsamples=128, seqlen=seq_len, device=device, seed=42)
return DataLoader(samples, batch_size=batch_size, shuffle=False, drop_last=True)
# -----------------------------
# Model / Quantization
# -----------------------------
def get_model(model_id: str, device: str | None) -> PreTrainedModel:
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
model_id,
attn_implementation="eager",
)
return model.eval().to(device)
def quantize_model_pipeline(
model: PreTrainedModel,
calib_dataloader: DataLoader,
) -> PreTrainedModel:
template = LLMTemplate.get(model.config.model_type)
quant_config = template.get_config(scheme="fp8", kv_cache_scheme="fp8")
quantizer = ModelQuantizer(quant_config, multi_device=True)
quantized_model: PreTrainedModel = quantizer.quantize_model(model, calib_dataloader)
print("[INFO] Export Quant Model.")
export_safetensors(model=quantized_model, output_dir="./")
return quantized_model
# -----------------------------
# Evaluation
# -----------------------------
@torch.no_grad()
def ppl_eval(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
device: str | None,
) -> torch.Tensor:
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt").input_ids.to(device)
seqlen_for_eval = 2048
nsamples = testenc.numel() // seqlen_for_eval
nlls: list[torch.Tensor] = []
for i in tqdm(range(nsamples)):
batch = testenc[:, i * seqlen_for_eval : (i + 1) * seqlen_for_eval]
lm_logits = model(batch)["logits"]
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = batch[:, 1:]
loss = torch.nn.CrossEntropyLoss()(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
nlls.append(loss.float() * seqlen_for_eval)
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * seqlen_for_eval))
return ppl
# -----------------------------
# Pipeline
# -----------------------------
def run_quark_fp8_example() -> None:
model_id = "Qwen/Qwen2.5-0.5B"
batch_size, seq_len = 4, 512
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Loading model: {model_id}")
model = get_model(model_id, device)
tokenizer = get_tokenizer(model_id, max_seq_len=seq_len)
calib_dataloader = get_dataloader(tokenizer, batch_size, device, seq_len)
print("[INFO] Starting quantization...")
quantized_model = quantize_model_pipeline(model, calib_dataloader)
print("[INFO] Quantization complete.")
print("[INFO] Simple test PPL with wikitext-2.")
ppl = ppl_eval(quantized_model, tokenizer, device)
print(f"[INFO] Perplexity: {ppl.item():.4f}")
if __name__ == "__main__":
with torch.no_grad():
run_quark_fp8_example()