Newer
Older
sentence-bert-base-ja-mean-tokens-v2-int8 / scripts / run_inference.py
"""
target/ の量子化 ONNX で推論する。target/ を配布先にコピーしたうえで実行する。
"""
from pathlib import Path

import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer

ROOT = Path(__file__).resolve().parent.parent
TARGET_DIR = ROOT / "target"


def encode(
    text: str,
    target_dir: Path | None = None,
    max_length: int = 128,
):
    """文を 768 次元ベクトルにエンコードする。"""
    target_dir = target_dir or TARGET_DIR
    if not (target_dir / "model_quantized.onnx").exists():
        raise FileNotFoundError(f"target に model_quantized.onnx がありません: {target_dir}")
    tokenizer = AutoTokenizer.from_pretrained(str(target_dir))
    session = ort.InferenceSession(
        str(target_dir / "model_quantized.onnx"),
        providers=["CPUExecutionProvider"],
    )
    enc = tokenizer(
        text,
        padding="max_length",
        max_length=max_length,
        truncation=True,
        return_tensors="np",
    )
    token_type_ids = enc.get(
        "token_type_ids",
        np.zeros_like(enc["input_ids"], dtype=np.int64),
    )
    out, = session.run(
        None,
        {
            "input_ids": enc["input_ids"].astype(np.int64),
            "attention_mask": enc["attention_mask"].astype(np.int64),
            "token_type_ids": token_type_ids,
        },
    )
    return out  # (1, 768)


def main():
    vec = encode("今日は良い天気です。")
    print(vec.shape)  # (1, 768)


if __name__ == "__main__":
    main()