Newer
Older
sentence-bert-base-ja-mean-tokens-v2-int8 / export_onnx.py
"""
sonoisa/sentence-bert-base-ja-mean-tokens-v2 を ONNX (FP32) にエクスポートする。
出力は mean pooling 済みの 768 次元文ベクトル。
"""
from pathlib import Path

import torch
from transformers import AutoModel, AutoTokenizer

# 出力先: build/ = 中間ファイル, target/ = 配布するものだけ
ROOT = Path(__file__).resolve().parent
BUILD_DIR = ROOT / "build"
TARGET_DIR = ROOT / "target"
MODEL_ID = "sonoisa/sentence-bert-base-ja-mean-tokens-v2"
ONNX_PATH = BUILD_DIR / "model_fp32.onnx"


class BertWithMeanPooling(torch.nn.Module):
    """BERT 出力に mean pooling を適用して文ベクトルを返すラッパー。"""

    def __init__(self, model):
        super().__init__()
        self.bert = model

    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
        last_hidden = outputs.last_hidden_state  # (batch, seq, 768)
        # attention_mask: (batch, seq) → (batch, seq, 1) でブロードキャスト
        mask = attention_mask.unsqueeze(-1).float()
        sum_emb = (last_hidden * mask).sum(dim=1)
        sum_mask = mask.sum(dim=1).clamp(min=1e-9)
        return sum_emb / sum_mask  # (batch, 768)


def main():
    BUILD_DIR.mkdir(parents=True, exist_ok=True)
    TARGET_DIR.mkdir(parents=True, exist_ok=True)
    print(f"Loading {MODEL_ID} ...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModel.from_pretrained(MODEL_ID)
    wrapper = BertWithMeanPooling(model)
    wrapper.eval()

    # ダミー入力(batch=2, seq=64 で動的軸を付与)
    batch_size, seq_len = 2, 64
    dummy = tokenizer(
        ["これはテストです。", "もう一つの文です。"],
        padding="max_length",
        max_length=seq_len,
        return_tensors="pt",
        truncation=True,
    )
    input_ids = dummy["input_ids"]
    attention_mask = dummy["attention_mask"]
    token_type_ids = dummy.get("token_type_ids")
    if token_type_ids is None:
        token_type_ids = torch.zeros_like(input_ids, dtype=torch.long)

    with torch.no_grad():
        torch.onnx.export(
            wrapper,
            (input_ids, attention_mask, token_type_ids),
            str(ONNX_PATH),
            input_names=["input_ids", "attention_mask", "token_type_ids"],
            output_names=["sentence_embedding"],
            dynamic_axes={
                "input_ids": {0: "batch_size", 1: "sequence_length"},
                "attention_mask": {0: "batch_size", 1: "sequence_length"},
                "token_type_ids": {0: "batch_size", 1: "sequence_length"},
                "sentence_embedding": {0: "batch_size"},
            },
            opset_version=18,
            do_constant_folding=False,
        )

    print(f"Exported: {ONNX_PATH}")

    # 配布用にトークナイザー・config だけ target/ へ(配布物に同梱する)
    tokenizer.save_pretrained(TARGET_DIR)
    model.config.save_pretrained(TARGET_DIR)
    print(f"Saved tokenizer and config to {TARGET_DIR} (distribution)")


if __name__ == "__main__":
    main()