"""
model_fp32.onnx をキャリブレーションデータで静的量子化し、model_quantized.onnx を出力する。
先に export_onnx.py で model_fp32.onnx を生成しておくこと。
"""
from pathlib import Path
import numpy as np
from transformers import AutoTokenizer
# onnxruntime 1.16+ では onnxruntime.quantization に統合されている想定
try:
from onnxruntime.quantization import (
CalibrationDataReader,
CalibrationMethod,
QuantFormat,
QuantType,
quantize_static,
)
except ImportError:
from onnxruntime.quantization.quantize import quantize_static
from onnxruntime.quantization.calibrate import CalibrationDataReader, CalibrationMethod
from onnxruntime.quantization.quant_utils import QuantFormat, QuantType
ROOT = Path(__file__).resolve().parent
BUILD_DIR = ROOT / "build"
TARGET_DIR = ROOT / "target"
MODEL_ID = "sonoisa/sentence-bert-base-ja-mean-tokens-v2"
SENTENCES_PATH = BUILD_DIR / "sentences.txt"
MODEL_FP32 = BUILD_DIR / "model_fp32.onnx"
MODEL_QUANT = TARGET_DIR / "model_quantized.onnx"
MAX_LENGTH = 128
BATCH_SIZE = 1
class SentenceCalibrationDataReader(CalibrationDataReader):
"""キャリブレーション用日本語文をトークナイズして ONNX 入力形式で返す。"""
def __init__(self, tokenizer, sentences_path: Path, max_length: int = MAX_LENGTH):
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
with open(sentences_path, "r", encoding="utf-8") as f:
self.sentences = [line.strip() for line in f if line.strip()]
self.index = 0
def __len__(self):
return len(self.sentences)
def get_next(self):
if self.index >= len(self.sentences):
return None
text = self.sentences[self.index]
self.index += 1
enc = self.tokenizer(
text,
padding="max_length",
max_length=self.max_length,
truncation=True,
return_tensors="np",
)
return {
"input_ids": enc["input_ids"].astype(np.int64),
"attention_mask": enc["attention_mask"].astype(np.int64),
"token_type_ids": enc.get("token_type_ids", np.zeros_like(enc["input_ids"], dtype=np.int64)),
}
def main():
if not MODEL_FP32.exists():
raise FileNotFoundError(
f"{MODEL_FP32} がありません。先に python export_onnx.py を実行してください。"
)
if not SENTENCES_PATH.exists():
raise FileNotFoundError(f"キャリブレーション用文一覧がありません: {SENTENCES_PATH}")
print(f"Tokenizer: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
reader = SentenceCalibrationDataReader(tokenizer, SENTENCES_PATH, MAX_LENGTH)
print(f"Calibration samples: {len(reader)}")
print("Quantizing (static, QDQ, S8S8) ...")
quantize_static(
str(MODEL_FP32),
str(MODEL_QUANT),
reader,
quant_format=QuantFormat.QDQ,
activation_type=QuantType.QInt8,
weight_type=QuantType.QInt8,
calibrate_method=CalibrationMethod.MinMax,
per_channel=False,
reduce_range=False,
)
print(f"Saved: {MODEL_QUANT}")
if __name__ == "__main__":
main()