Newer
Older
sentence-bert-base-ja-mean-tokens-v2-int8 / quantize.py
"""
model_fp32.onnx をキャリブレーションデータで静的量子化し、model_quantized.onnx を出力する。
先に export_onnx.py で model_fp32.onnx を生成しておくこと。
"""
from pathlib import Path

import numpy as np
from transformers import AutoTokenizer

# onnxruntime 1.16+ では onnxruntime.quantization に統合されている想定
try:
    from onnxruntime.quantization import (
        CalibrationDataReader,
        CalibrationMethod,
        QuantFormat,
        QuantType,
        quantize_static,
    )
except ImportError:
    from onnxruntime.quantization.quantize import quantize_static
    from onnxruntime.quantization.calibrate import CalibrationDataReader, CalibrationMethod
    from onnxruntime.quantization.quant_utils import QuantFormat, QuantType

ROOT = Path(__file__).resolve().parent
BUILD_DIR = ROOT / "build"
TARGET_DIR = ROOT / "target"
MODEL_ID = "sonoisa/sentence-bert-base-ja-mean-tokens-v2"
SENTENCES_PATH = BUILD_DIR / "sentences.txt"
MODEL_FP32 = BUILD_DIR / "model_fp32.onnx"
MODEL_QUANT = TARGET_DIR / "model_quantized.onnx"
MAX_LENGTH = 128
BATCH_SIZE = 1


class SentenceCalibrationDataReader(CalibrationDataReader):
    """キャリブレーション用日本語文をトークナイズして ONNX 入力形式で返す。"""

    def __init__(self, tokenizer, sentences_path: Path, max_length: int = MAX_LENGTH):
        super().__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length
        with open(sentences_path, "r", encoding="utf-8") as f:
            self.sentences = [line.strip() for line in f if line.strip()]
        self.index = 0

    def __len__(self):
        return len(self.sentences)

    def get_next(self):
        if self.index >= len(self.sentences):
            return None
        text = self.sentences[self.index]
        self.index += 1
        enc = self.tokenizer(
            text,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
            return_tensors="np",
        )
        return {
            "input_ids": enc["input_ids"].astype(np.int64),
            "attention_mask": enc["attention_mask"].astype(np.int64),
            "token_type_ids": enc.get("token_type_ids", np.zeros_like(enc["input_ids"], dtype=np.int64)),
        }


def main():
    if not MODEL_FP32.exists():
        raise FileNotFoundError(
            f"{MODEL_FP32} がありません。先に python export_onnx.py を実行してください。"
        )
    if not SENTENCES_PATH.exists():
        raise FileNotFoundError(f"キャリブレーション用文一覧がありません: {SENTENCES_PATH}")

    print(f"Tokenizer: {MODEL_ID}")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    reader = SentenceCalibrationDataReader(tokenizer, SENTENCES_PATH, MAX_LENGTH)
    print(f"Calibration samples: {len(reader)}")

    print("Quantizing (static, QDQ, S8S8) ...")
    quantize_static(
        str(MODEL_FP32),
        str(MODEL_QUANT),
        reader,
        quant_format=QuantFormat.QDQ,
        activation_type=QuantType.QInt8,
        weight_type=QuantType.QInt8,
        calibrate_method=CalibrationMethod.MinMax,
        per_channel=False,
        reduce_range=False,
    )
    print(f"Saved: {MODEL_QUANT}")


if __name__ == "__main__":
    main()