AWQ end-to-end example

AWQ end-to-end example#

This is a simple example of the AWQ algorithm. Before running this script, make sure that amd-quark has been properly installed.

!pip install torch

!pip install transformers==4.52.1

!pip install tqdm

!pip install datasets

!pip install accelerate

For the AWQ algorithm, we provide default configurations for common models. For advanced users who want to use their own AWQ configuration, a configuration JSON file needs to be provided. In this ipynb example, we generate an AWQ configuration JSON file using Python.

import json

# Define the configuration to be written
awq_config = {
    "name": "awq",
    "scaling_layers": [
        {
            "prev_op": "input_layernorm",
            "layers": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
            "inp": "self_attn.q_proj",
            "module2inspect": "self_attn",
        },
        {"prev_op": "self_attn.v_proj", "layers": ["self_attn.o_proj"], "inp": "self_attn.o_proj"},
        {
            "prev_op": "post_attention_layernorm",
            "layers": ["mlp.gate_proj", "mlp.up_proj"],
            "inp": "mlp.gate_proj",
            "module2inspect": "mlp",
        },
        {"prev_op": "mlp.up_proj", "layers": ["mlp.down_proj"], "inp": "mlp.down_proj"},
    ],
    "model_decoder_layers": "model.layers",
}


# Write configuration to a JSON file
with open("custom_awq_config.json", "w") as f:
    json.dump(awq_config, f, indent=4)

print("custom_awq_config.json has been created.")

This is an example of using the AWQ algorithm integrated in Quark to quantize an LLM model. The example includes quantization, export and simple testing.

from typing import Any, Optional

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer

from quark.torch import LLMTemplate, ModelQuantizer, export_safetensors
from quark.torch.quantization.config.config import load_quant_algo_config_from_file


# -----------------------------
# Dataset / Tokenizer
# -----------------------------
def get_pileval(
    tokenizer: PreTrainedTokenizer,
    nsamples: int,
    seqlen: int,
    device: str | None,
    seed: int = 0,
) -> torch.Tensor:
    dataset: Any = load_dataset("mit-han-lab/pile-val-backup", split="validation").shuffle(seed=seed)
    samples, n_run = [], 0

    for data in dataset:
        line_encoded = tokenizer.encode(data["text"].strip())
        if 0 < len(line_encoded) <= seqlen:
            samples.append(torch.tensor([line_encoded], device=device))
            n_run += 1
        if n_run == nsamples:
            break

    cat_samples = torch.cat(samples, dim=1)
    n_split = cat_samples.shape[1] // seqlen
    train_dataset = [cat_samples[:, i * seqlen : (i + 1) * seqlen] for i in range(n_split)]

    return torch.cat(train_dataset, dim=0)


def get_tokenizer(model_id: str, max_seq_len: int = 512) -> PreTrainedTokenizer:
    print(f"Initializing tokenizer from {model_id}")
    tokenizer = AutoTokenizer.from_pretrained(
        model_id,
        model_max_length=max_seq_len,
        padding_side="left",
        trust_remote_code=True,
        use_fast=False,
    )
    if tokenizer.pad_token != "<unk>":
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    assert tokenizer.pad_token is not None, "Pad token cannot be set!"
    return tokenizer


def get_dataloader(
    tokenizer: PreTrainedTokenizer,
    batch_size: int,
    device: str | None,
    seq_len: int = 512,
) -> DataLoader:
    samples: torch.Tensor = get_pileval(tokenizer, nsamples=128, seqlen=seq_len, device=device, seed=42)
    return DataLoader(samples, batch_size=batch_size, shuffle=False, drop_last=True)


# -----------------------------
# Model / Quantization
# -----------------------------
def get_model(model_id: str, device: str | None) -> PreTrainedModel:
    model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
        model_id,
        attn_implementation="eager",
    )
    return model.eval().to(device)


def quantize_model_pipeline(
    model: PreTrainedModel,
    calib_dataloader: DataLoader,
) -> PreTrainedModel:
    # Load custom AWQ config
    custom_awq_config = load_quant_algo_config_from_file("custom_awq_config.json")
    # If you don't need a custom awq_config, you can omit it and use the default configuration.
    template = LLMTemplate(
        model_type=model.config.model_type,
        exclude_layers_name=["lm_head"],
        awq_config=custom_awq_config,
    )
    quant_config = template.get_config(scheme="uint4_wo_128", algorithm=["awq"])

    quantizer = ModelQuantizer(quant_config, multi_device=True)
    quantized_model: PreTrainedModel = quantizer.quantize_model(model, calib_dataloader)

    print("[INFO] Export Quant Model.")
    export_safetensors(model=quantized_model, output_dir="./")

    return quantized_model


# -----------------------------
# Evaluation
# -----------------------------
@torch.no_grad()
def ppl_eval(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizer,
    device: str | None,
) -> torch.Tensor:
    testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt").input_ids.to(device)

    seqlen_for_eval = 2048
    nsamples = testenc.numel() // seqlen_for_eval
    nlls: list[torch.Tensor] = []

    for i in tqdm(range(nsamples)):
        batch = testenc[:, i * seqlen_for_eval : (i + 1) * seqlen_for_eval]
        lm_logits = model(batch)["logits"]

        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = batch[:, 1:]

        loss = torch.nn.CrossEntropyLoss()(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
        )
        nlls.append(loss.float() * seqlen_for_eval)

    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * seqlen_for_eval))
    return ppl


# -----------------------------
# Pipeline
# -----------------------------
def run_quark_awq_example() -> None:
    model_id = "Qwen/Qwen2.5-0.5B"
    batch_size, seq_len = 4, 512
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"[INFO] Loading model: {model_id}")
    model = get_model(model_id, device)
    tokenizer = get_tokenizer(model_id, max_seq_len=seq_len)
    calib_dataloader = get_dataloader(tokenizer, batch_size, device, seq_len)

    print("[INFO] Starting quantization...")
    quantized_model = quantize_model_pipeline(model, calib_dataloader)
    print("[INFO] Quantization complete.")

    print("[INFO] Simple test PPL with wikitext-2.")
    ppl = ppl_eval(quantized_model, tokenizer, device)
    print(f"[INFO] Perplexity: {ppl.item():.4f}")


if __name__ == "__main__":
    with torch.no_grad():
        run_quark_awq_example()