//! Pro 版: ONNX 埋め込み。tract を優先し、失敗時は ort (ONNX Runtime) にフォールバック。
//! sonoisa sentence-bert-base-ja-mean-tokens-v2(model.onnx または model_quantized.onnx)。
use std::collections::HashMap;
use std::path::Path;
use std::sync::Mutex;
#[cfg(feature = "pro")]
use ort::session::Session;
#[cfg(feature = "pro")]
use ort::value::Value;
const MAX_LEN: usize = 512;
const EMBED_DIM: usize = 768;
/// 語彙(vocab.txt)を読み込み、トークン → ID のマップを返す。
fn load_vocab(vocab_path: &Path) -> Result<HashMap<String, i64>, Box<dyn std::error::Error + Send + Sync>> {
let s = std::fs::read_to_string(vocab_path)?;
let mut map = HashMap::new();
for (id, line) in s.lines().enumerate() {
let token = line.trim().to_string();
if !token.is_empty() {
map.insert(token, id as i64);
}
}
Ok(map)
}
/// テキストを BERT 風にトークナイズ(最長一致 + [CLS]/[SEP])。語彙に無い文字は [UNK]。
fn tokenize(vocab: &HashMap<String, i64>, text: &str) -> (Vec<i64>, Vec<i64>) {
let pad_id = *vocab.get("[PAD]").unwrap_or(&0);
let unk_id = *vocab.get("[UNK]").unwrap_or(&1);
let cls_id = *vocab.get("[CLS]").unwrap_or(&2);
let sep_id = *vocab.get("[SEP]").unwrap_or(&3);
let mut input_ids = vec![cls_id];
let mut attention_mask = vec![1i64];
let chars: Vec<char> = text.chars().collect();
let mut i = 0;
while i < chars.len() && input_ids.len() < MAX_LEN - 1 {
let mut found = false;
for len in (1..=chars.len().saturating_sub(i).min(20)).rev() {
let sub: String = chars[i..i + len].iter().collect();
if let Some(&id) = vocab.get(&sub) {
input_ids.push(id);
attention_mask.push(1);
i += len;
found = true;
break;
}
}
if !found {
let c = chars[i].to_string();
let id = vocab.get(&c).copied().unwrap_or(unk_id);
input_ids.push(id);
attention_mask.push(1);
i += 1;
}
}
input_ids.push(sep_id);
attention_mask.push(1);
let seq_len = input_ids.len();
if seq_len < MAX_LEN {
let pad_len = MAX_LEN - seq_len;
for _ in 0..pad_len {
input_ids.push(pad_id);
attention_mask.push(0);
}
} else {
input_ids.truncate(MAX_LEN);
attention_mask.truncate(MAX_LEN);
input_ids[MAX_LEN - 1] = sep_id;
attention_mask[MAX_LEN - 1] = 1;
}
(input_ids, attention_mask)
}
/// 平均プーリング: shape [1, seq, 768] の last_hidden と attention_mask から 768 次元ベクトルを返す。
fn mean_pool(last_hidden: &[f32], attention_mask: &[i64], seq_len: usize) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
if last_hidden.len() < seq_len * EMBED_DIM {
return Err(format!("last_hidden len {} < {}*{}", last_hidden.len(), seq_len, EMBED_DIM).into());
}
let mut out = vec![0f32; EMBED_DIM];
let mut mask_count = 0f32;
for s in 0..seq_len {
let mask = if s < attention_mask.len() && attention_mask[s] != 0 {
1.0
} else {
0.0
};
mask_count += mask;
for d in 0..EMBED_DIM {
out[d] += last_hidden[s * EMBED_DIM + d] * mask;
}
}
if mask_count > 0.0 {
for x in out.iter_mut() {
*x /= mask_count;
}
}
Ok(out)
}
/// バックエンド: tract または ort。
pub enum ProEmbeddingBackend {
Tract(
tract_onnx::prelude::SimplePlan<
tract_onnx::prelude::TypedFact,
Box<dyn tract_onnx::prelude::TypedOp>,
tract_onnx::prelude::Graph<tract_onnx::prelude::TypedFact, Box<dyn tract_onnx::prelude::TypedOp>>,
>,
),
Ort(Mutex<Session>),
}
/// Pro 版埋め込みモデル。ONNX 推論 + 平均プーリングで 768 次元ベクトルを返す。
pub struct ProEmbeddingModel {
backend: ProEmbeddingBackend,
vocab: HashMap<String, i64>,
}
impl ProEmbeddingModel {
/// 指定ディレクトリから ONNX モデルと vocab.txt を読み込む。
/// tract でロードを試み、失敗時は ort (ONNX Runtime) にフォールバックする。
/// model.onnx(非量子化)を優先し、無ければ model_quantized.onnx を使用する。
pub fn load(model_dir: &Path) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let vocab_path = model_dir.join("vocab.txt");
if !vocab_path.exists() {
return Err(format!("vocab not found: {}", vocab_path.display()).into());
}
let model_fp32 = model_dir.join("model.onnx");
let model_quant = model_dir.join("model_quantized.onnx");
let onnx_path = if model_fp32.exists() {
model_fp32
} else if model_quant.exists() {
model_quant
} else {
return Err(format!(
"ONNX not found: need model.onnx or model_quantized.onnx in {}",
model_dir.display()
)
.into());
};
let vocab = load_vocab(&vocab_path)?;
if let Ok(backend) = Self::load_tract(&onnx_path) {
log::info!("[embedding] Loaded with tract");
return Ok(ProEmbeddingModel { backend, vocab });
}
if let Ok(session) = Self::load_ort(&onnx_path) {
log::info!("[embedding] Loaded with ONNX Runtime (ort)");
return Ok(ProEmbeddingModel {
backend: ProEmbeddingBackend::Ort(Mutex::new(session)),
vocab,
});
}
Err("embedding model: tract and ort both failed".into())
}
fn load_tract(onnx_path: &Path) -> Result<ProEmbeddingBackend, Box<dyn std::error::Error + Send + Sync>> {
use tract_onnx::prelude::*;
let skip_optimize = std::env::var("TELOS_EMBEDDING_NO_OPTIMIZE").as_deref() == Ok("1");
let prepared = tract_onnx::onnx()
.model_for_path(onnx_path)
.map_err(|e| format!("ONNX load: {}", e))?
.with_input_fact(
0,
InferenceFact::dt_shape(i64::datum_type(), tvec!(1, MAX_LEN as i64)),
)
.map_err(|e| format!("input fact 0: {}", e))?
.with_input_fact(
1,
InferenceFact::dt_shape(i64::datum_type(), tvec!(1, MAX_LEN as i64)),
)
.map_err(|e| format!("input fact 1: {}", e))?;
let model = if skip_optimize {
log::info!("[embedding] Loading with tract (no optimize)");
prepared
.into_typed()
.map_err(|e| format!("into_typed: {}", e))?
.into_runnable()
.map_err(|e| format!("runnable: {}", e))?
} else {
match prepared
.clone()
.into_optimized()
.map_err(|e| format!("optimize: {}", e))
.and_then(|m| m.into_runnable().map_err(|e| format!("runnable: {}", e)))
{
Ok(r) => r,
Err(e) => {
let err_str = e.to_string();
if err_str.contains("Cast") || err_str.contains("optim") || err_str.contains("optimize") {
log::warn!("[embedding] tract optimize failed ({}), trying without optimization", e);
prepared
.into_typed()
.map_err(|e2| format!("into_typed (fallback): {}", e2))?
.into_runnable()
.map_err(|e2| format!("runnable (fallback): {}", e2))?
} else {
return Err(e.into());
}
}
}
};
Ok(ProEmbeddingBackend::Tract(model))
}
fn load_ort(onnx_path: &Path) -> Result<Session, Box<dyn std::error::Error + Send + Sync>> {
let session = Session::builder()?
.commit_from_file(onnx_path)
.map_err(|e| format!("ort commit_from_file: {}", e))?;
Ok(session)
}
/// 1 文を 768 次元ベクトルにエンコードする。平均プーリングを適用。
pub fn encode(&self, text: &str) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
let (input_ids, attention_mask) = tokenize(&self.vocab, text);
match &self.backend {
ProEmbeddingBackend::Tract(model) => {
use tract_onnx::prelude::*;
let input_ids_arr = tract_ndarray::Array2::from_shape_vec((1, MAX_LEN), input_ids)
.map_err(|e| format!("input_ids shape: {}", e))?;
let attention_mask_arr = tract_ndarray::Array2::from_shape_vec((1, MAX_LEN), attention_mask.clone())
.map_err(|e| format!("attention_mask shape: {}", e))?;
let input_ids_tensor = tract_onnx::prelude::Tensor::from(input_ids_arr);
let attention_mask_tensor = tract_onnx::prelude::Tensor::from(attention_mask_arr);
let outputs = model.run(tvec!(
input_ids_tensor.into(),
attention_mask_tensor.into()
))?;
let last_hidden: &tract_onnx::prelude::Tensor = outputs.first().ok_or("no output")?;
let view = last_hidden.to_array_view::<f32>()?;
let shape = view.shape();
if shape.len() != 3 {
return Err(format!("expected [1, seq, 768], got {:?}", shape).into());
}
let seq_len = shape[1];
let dim = shape[2];
if dim != EMBED_DIM {
return Err(format!("expected dim 768, got {}", dim).into());
}
let mut out = vec![0f32; EMBED_DIM];
let mut mask_count = 0f32;
for s in 0..seq_len {
let mask = if s < attention_mask.len() && attention_mask[s] != 0 {
1.0
} else {
0.0
};
mask_count += mask;
for d in 0..dim {
out[d] += view[[0, s, d]] * mask;
}
}
if mask_count > 0.0 {
for x in out.iter_mut() {
*x /= mask_count;
}
}
Ok(out)
}
ProEmbeddingBackend::Ort(session_mux) => {
let mut session = session_mux.lock().map_err(|e| format!("ort session lock: {}", e))?;
let input_ids_val = Value::from_array(([1_i64, MAX_LEN as i64], input_ids))?;
let attention_mask_val = Value::from_array(([1_i64, MAX_LEN as i64], attention_mask.clone()))?;
let token_type_ids: Vec<i64> = std::iter::repeat(0).take(MAX_LEN).collect();
let token_type_ids_val = Value::from_array(([1_i64, MAX_LEN as i64], token_type_ids))?;
let outputs = session.run(ort::inputs![input_ids_val, attention_mask_val, token_type_ids_val])?;
let (shape, flat) = outputs[0]
.try_extract_tensor::<f32>()
.map_err(|e| format!("ort extract last_hidden: {}", e))?;
// モデルによっては [1, seq, 768] ではなく既にプール済みの [1, 768] を返す
if shape.len() == 2 && shape[0] == 1 && shape[1] == EMBED_DIM as i64 {
return Ok(flat.to_vec());
}
if shape.len() != 3 {
return Err(format!("ort expected [1, seq, 768] or [1, 768], got {:?}", shape).into());
}
let seq_len = shape[1] as usize;
let dim = shape[2] as usize;
if dim != EMBED_DIM {
return Err(format!("ort expected dim 768, got {}", dim).into());
}
mean_pool(flat, &attention_mask, seq_len)
}
}
}
pub fn embed_dim() -> usize {
EMBED_DIM
}
}
#[cfg(all(test, feature = "pro"))]
mod tests {
use super::*;
#[test]
fn embed_dim_is_768() {
assert_eq!(ProEmbeddingModel::embed_dim(), 768, "sentence-bert 系は 768 次元");
}
/// モデルが TELOS_EMBEDDING_MODEL_DIR または embedding_model/ に存在する場合のみ実行。
/// ベクトル化が正しく動くこと(768 次元・非ゼロ)を検証する。
#[test]
fn encode_returns_768_dim_when_model_loaded() {
let model_dir = std::env::var_os("TELOS_EMBEDDING_MODEL_DIR")
.map(std::path::PathBuf::from)
.or_else(|| {
let manifest = std::env::var_os("CARGO_MANIFEST_DIR").map(std::path::PathBuf::from)?;
let root = manifest.parent()?.parent()?;
let p = root.join("embedding_model");
if p.join("model_quantized.onnx").exists() || p.join("model.onnx").exists() {
Some(p)
} else {
None
}
});
let Some(dir) = model_dir else {
eprintln!("encode test skipped: no embedding_model (set TELOS_EMBEDDING_MODEL_DIR or place embedding_model/)");
return;
};
let model = match ProEmbeddingModel::load(&dir) {
Ok(m) => m,
Err(e) => {
eprintln!("encode test skipped: model load failed: {}", e);
return;
}
};
let vec = match model.encode("テスト文です") {
Ok(v) => v,
Err(e) => {
eprintln!("encode test skipped: encode failed (model shape may differ): {}", e);
return;
}
};
assert_eq!(vec.len(), 768, "encode は 768 次元を返すこと");
assert!(
vec.iter().any(|&x| x != 0.0),
"ベクトルが全て 0 だとベクトル化が効いていない"
);
}
}