"""
sonoisa/sentence-bert-base-ja-mean-tokens-v2 を ONNX (FP32) にエクスポートする。
出力は mean pooling 済みの 768 次元文ベクトル。
"""
from pathlib import Path
import torch
from transformers import AutoModel, AutoTokenizer
# 出力先: build/ = 中間ファイル, target/ = 配布するものだけ
ROOT = Path(__file__).resolve().parent
BUILD_DIR = ROOT / "build"
TARGET_DIR = ROOT / "target"
MODEL_ID = "sonoisa/sentence-bert-base-ja-mean-tokens-v2"
ONNX_PATH = BUILD_DIR / "model_fp32.onnx"
class BertWithMeanPooling(torch.nn.Module):
"""BERT 出力に mean pooling を適用して文ベクトルを返すラッパー。"""
def __init__(self, model):
super().__init__()
self.bert = model
def forward(self, input_ids, attention_mask, token_type_ids):
outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
last_hidden = outputs.last_hidden_state # (batch, seq, 768)
# attention_mask: (batch, seq) → (batch, seq, 1) でブロードキャスト
mask = attention_mask.unsqueeze(-1).float()
sum_emb = (last_hidden * mask).sum(dim=1)
sum_mask = mask.sum(dim=1).clamp(min=1e-9)
return sum_emb / sum_mask # (batch, 768)
def main():
BUILD_DIR.mkdir(parents=True, exist_ok=True)
TARGET_DIR.mkdir(parents=True, exist_ok=True)
print(f"Loading {MODEL_ID} ...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModel.from_pretrained(MODEL_ID)
wrapper = BertWithMeanPooling(model)
wrapper.eval()
# ダミー入力(batch=2, seq=64 で動的軸を付与)
batch_size, seq_len = 2, 64
dummy = tokenizer(
["これはテストです。", "もう一つの文です。"],
padding="max_length",
max_length=seq_len,
return_tensors="pt",
truncation=True,
)
input_ids = dummy["input_ids"]
attention_mask = dummy["attention_mask"]
token_type_ids = dummy.get("token_type_ids")
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids, dtype=torch.long)
with torch.no_grad():
torch.onnx.export(
wrapper,
(input_ids, attention_mask, token_type_ids),
str(ONNX_PATH),
input_names=["input_ids", "attention_mask", "token_type_ids"],
output_names=["sentence_embedding"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"token_type_ids": {0: "batch_size", 1: "sequence_length"},
"sentence_embedding": {0: "batch_size"},
},
opset_version=18,
do_constant_folding=False,
)
print(f"Exported: {ONNX_PATH}")
# 配布用にトークナイザー・config だけ target/ へ(配布物に同梱する)
tokenizer.save_pretrained(TARGET_DIR)
model.config.save_pretrained(TARGET_DIR)
print(f"Saved tokenizer and config to {TARGET_DIR} (distribution)")
if __name__ == "__main__":
main()