Newer
Older
TelosDB / src / backend / src / utils / embedding_pro.rs
//! Pro 版: ONNX 埋め込み。tract を優先し、失敗時は ort (ONNX Runtime) にフォールバック。
//! sonoisa sentence-bert-base-ja-mean-tokens-v2(model.onnx または model_quantized.onnx)。

use std::collections::HashMap;
use std::path::Path;
use std::sync::Mutex;

#[cfg(feature = "pro")]
use ort::session::Session;
#[cfg(feature = "pro")]
use ort::value::Value;

const MAX_LEN: usize = 512;
const EMBED_DIM: usize = 768;

/// 語彙(vocab.txt)を読み込み、トークン → ID のマップを返す。
fn load_vocab(vocab_path: &Path) -> Result<HashMap<String, i64>, Box<dyn std::error::Error + Send + Sync>> {
    let s = std::fs::read_to_string(vocab_path)?;
    let mut map = HashMap::new();
    for (id, line) in s.lines().enumerate() {
        let token = line.trim().to_string();
        if !token.is_empty() {
            map.insert(token, id as i64);
        }
    }
    Ok(map)
}

/// テキストを BERT 風にトークナイズ(最長一致 + [CLS]/[SEP])。語彙に無い文字は [UNK]。
fn tokenize(vocab: &HashMap<String, i64>, text: &str) -> (Vec<i64>, Vec<i64>) {
    let pad_id = *vocab.get("[PAD]").unwrap_or(&0);
    let unk_id = *vocab.get("[UNK]").unwrap_or(&1);
    let cls_id = *vocab.get("[CLS]").unwrap_or(&2);
    let sep_id = *vocab.get("[SEP]").unwrap_or(&3);

    let mut input_ids = vec![cls_id];
    let mut attention_mask = vec![1i64];

    let chars: Vec<char> = text.chars().collect();
    let mut i = 0;
    while i < chars.len() && input_ids.len() < MAX_LEN - 1 {
        let mut found = false;
        for len in (1..=chars.len().saturating_sub(i).min(20)).rev() {
            let sub: String = chars[i..i + len].iter().collect();
            if let Some(&id) = vocab.get(&sub) {
                input_ids.push(id);
                attention_mask.push(1);
                i += len;
                found = true;
                break;
            }
        }
        if !found {
            let c = chars[i].to_string();
            let id = vocab.get(&c).copied().unwrap_or(unk_id);
            input_ids.push(id);
            attention_mask.push(1);
            i += 1;
        }
    }

    input_ids.push(sep_id);
    attention_mask.push(1);

    let seq_len = input_ids.len();
    if seq_len < MAX_LEN {
        let pad_len = MAX_LEN - seq_len;
        for _ in 0..pad_len {
            input_ids.push(pad_id);
            attention_mask.push(0);
        }
    } else {
        input_ids.truncate(MAX_LEN);
        attention_mask.truncate(MAX_LEN);
        input_ids[MAX_LEN - 1] = sep_id;
        attention_mask[MAX_LEN - 1] = 1;
    }

    (input_ids, attention_mask)
}

/// 平均プーリング: shape [1, seq, 768] の last_hidden と attention_mask から 768 次元ベクトルを返す。
fn mean_pool(last_hidden: &[f32], attention_mask: &[i64], seq_len: usize) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
    if last_hidden.len() < seq_len * EMBED_DIM {
        return Err(format!("last_hidden len {} < {}*{}", last_hidden.len(), seq_len, EMBED_DIM).into());
    }
    let mut out = vec![0f32; EMBED_DIM];
    let mut mask_count = 0f32;
    for s in 0..seq_len {
        let mask = if s < attention_mask.len() && attention_mask[s] != 0 {
            1.0
        } else {
            0.0
        };
        mask_count += mask;
        for d in 0..EMBED_DIM {
            out[d] += last_hidden[s * EMBED_DIM + d] * mask;
        }
    }
    if mask_count > 0.0 {
        for x in out.iter_mut() {
            *x /= mask_count;
        }
    }
    Ok(out)
}

/// バックエンド: tract または ort。
pub enum ProEmbeddingBackend {
    Tract(
        tract_onnx::prelude::SimplePlan<
            tract_onnx::prelude::TypedFact,
            Box<dyn tract_onnx::prelude::TypedOp>,
            tract_onnx::prelude::Graph<tract_onnx::prelude::TypedFact, Box<dyn tract_onnx::prelude::TypedOp>>,
        >,
    ),
    Ort(Mutex<Session>),
}

/// Pro 版埋め込みモデル。ONNX 推論 + 平均プーリングで 768 次元ベクトルを返す。
pub struct ProEmbeddingModel {
    backend: ProEmbeddingBackend,
    vocab: HashMap<String, i64>,
}

impl ProEmbeddingModel {
    /// 指定ディレクトリから ONNX モデルと vocab.txt を読み込む。
    /// tract でロードを試み、失敗時は ort (ONNX Runtime) にフォールバックする。
    /// model.onnx(非量子化)を優先し、無ければ model_quantized.onnx を使用する。
    pub fn load(model_dir: &Path) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
        let vocab_path = model_dir.join("vocab.txt");
        if !vocab_path.exists() {
            return Err(format!("vocab not found: {}", vocab_path.display()).into());
        }

        let model_fp32 = model_dir.join("model.onnx");
        let model_quant = model_dir.join("model_quantized.onnx");
        let onnx_path = if model_fp32.exists() {
            model_fp32
        } else if model_quant.exists() {
            model_quant
        } else {
            return Err(format!(
                "ONNX not found: need model.onnx or model_quantized.onnx in {}",
                model_dir.display()
            )
            .into());
        };

        let vocab = load_vocab(&vocab_path)?;

        if let Ok(backend) = Self::load_tract(&onnx_path) {
            log::info!("[embedding] Loaded with tract");
            return Ok(ProEmbeddingModel { backend, vocab });
        }

        if let Ok(session) = Self::load_ort(&onnx_path) {
            log::info!("[embedding] Loaded with ONNX Runtime (ort)");
            return Ok(ProEmbeddingModel {
                backend: ProEmbeddingBackend::Ort(Mutex::new(session)),
                vocab,
            });
        }

        Err("embedding model: tract and ort both failed".into())
    }

    fn load_tract(onnx_path: &Path) -> Result<ProEmbeddingBackend, Box<dyn std::error::Error + Send + Sync>> {
        use tract_onnx::prelude::*;

        let skip_optimize = std::env::var("TELOS_EMBEDDING_NO_OPTIMIZE").as_deref() == Ok("1");

        let prepared = tract_onnx::onnx()
            .model_for_path(onnx_path)
            .map_err(|e| format!("ONNX load: {}", e))?
            .with_input_fact(
                0,
                InferenceFact::dt_shape(i64::datum_type(), tvec!(1, MAX_LEN as i64)),
            )
            .map_err(|e| format!("input fact 0: {}", e))?
            .with_input_fact(
                1,
                InferenceFact::dt_shape(i64::datum_type(), tvec!(1, MAX_LEN as i64)),
            )
            .map_err(|e| format!("input fact 1: {}", e))?;

        let model = if skip_optimize {
            log::info!("[embedding] Loading with tract (no optimize)");
            prepared
                .into_typed()
                .map_err(|e| format!("into_typed: {}", e))?
                .into_runnable()
                .map_err(|e| format!("runnable: {}", e))?
        } else {
            match prepared
                .clone()
                .into_optimized()
                .map_err(|e| format!("optimize: {}", e))
                .and_then(|m| m.into_runnable().map_err(|e| format!("runnable: {}", e)))
            {
                Ok(r) => r,
                Err(e) => {
                    let err_str = e.to_string();
                    if err_str.contains("Cast") || err_str.contains("optim") || err_str.contains("optimize") {
                        log::warn!("[embedding] tract optimize failed ({}), trying without optimization", e);
                        prepared
                            .into_typed()
                            .map_err(|e2| format!("into_typed (fallback): {}", e2))?
                            .into_runnable()
                            .map_err(|e2| format!("runnable (fallback): {}", e2))?
                    } else {
                        return Err(e.into());
                    }
                }
            }
        };

        Ok(ProEmbeddingBackend::Tract(model))
    }

    fn load_ort(onnx_path: &Path) -> Result<Session, Box<dyn std::error::Error + Send + Sync>> {
        let session = Session::builder()?
            .commit_from_file(onnx_path)
            .map_err(|e| format!("ort commit_from_file: {}", e))?;
        Ok(session)
    }

    /// 1 文を 768 次元ベクトルにエンコードする。平均プーリングを適用。
    pub fn encode(&self, text: &str) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
        let (input_ids, attention_mask) = tokenize(&self.vocab, text);

        match &self.backend {
            ProEmbeddingBackend::Tract(model) => {
                use tract_onnx::prelude::*;
                let input_ids_arr = tract_ndarray::Array2::from_shape_vec((1, MAX_LEN), input_ids)
                    .map_err(|e| format!("input_ids shape: {}", e))?;
                let attention_mask_arr = tract_ndarray::Array2::from_shape_vec((1, MAX_LEN), attention_mask.clone())
                    .map_err(|e| format!("attention_mask shape: {}", e))?;
                let input_ids_tensor = tract_onnx::prelude::Tensor::from(input_ids_arr);
                let attention_mask_tensor = tract_onnx::prelude::Tensor::from(attention_mask_arr);

                let outputs = model.run(tvec!(
                    input_ids_tensor.into(),
                    attention_mask_tensor.into()
                ))?;

                let last_hidden: &tract_onnx::prelude::Tensor = outputs.first().ok_or("no output")?;
                let view = last_hidden.to_array_view::<f32>()?;
                let shape = view.shape();
                if shape.len() != 3 {
                    return Err(format!("expected [1, seq, 768], got {:?}", shape).into());
                }
                let seq_len = shape[1];
                let dim = shape[2];
                if dim != EMBED_DIM {
                    return Err(format!("expected dim 768, got {}", dim).into());
                }

                let mut out = vec![0f32; EMBED_DIM];
                let mut mask_count = 0f32;
                for s in 0..seq_len {
                    let mask = if s < attention_mask.len() && attention_mask[s] != 0 {
                        1.0
                    } else {
                        0.0
                    };
                    mask_count += mask;
                    for d in 0..dim {
                        out[d] += view[[0, s, d]] * mask;
                    }
                }
                if mask_count > 0.0 {
                    for x in out.iter_mut() {
                        *x /= mask_count;
                    }
                }
                Ok(out)
            }
            ProEmbeddingBackend::Ort(session_mux) => {
                let mut session = session_mux.lock().map_err(|e| format!("ort session lock: {}", e))?;
                let input_ids_val = Value::from_array(([1_i64, MAX_LEN as i64], input_ids))?;
                let attention_mask_val = Value::from_array(([1_i64, MAX_LEN as i64], attention_mask.clone()))?;
                let token_type_ids: Vec<i64> = std::iter::repeat(0).take(MAX_LEN).collect();
                let token_type_ids_val = Value::from_array(([1_i64, MAX_LEN as i64], token_type_ids))?;

                let outputs = session.run(ort::inputs![input_ids_val, attention_mask_val, token_type_ids_val])?;
                let (shape, flat) = outputs[0]
                    .try_extract_tensor::<f32>()
                    .map_err(|e| format!("ort extract last_hidden: {}", e))?;
                // モデルによっては [1, seq, 768] ではなく既にプール済みの [1, 768] を返す
                if shape.len() == 2 && shape[0] == 1 && shape[1] == EMBED_DIM as i64 {
                    return Ok(flat.to_vec());
                }
                if shape.len() != 3 {
                    return Err(format!("ort expected [1, seq, 768] or [1, 768], got {:?}", shape).into());
                }
                let seq_len = shape[1] as usize;
                let dim = shape[2] as usize;
                if dim != EMBED_DIM {
                    return Err(format!("ort expected dim 768, got {}", dim).into());
                }

                mean_pool(flat, &attention_mask, seq_len)
            }
        }
    }

    pub fn embed_dim() -> usize {
        EMBED_DIM
    }
}

#[cfg(all(test, feature = "pro"))]
mod tests {
    use super::*;

    #[test]
    fn embed_dim_is_768() {
        assert_eq!(ProEmbeddingModel::embed_dim(), 768, "sentence-bert 系は 768 次元");
    }

    /// モデルが TELOS_EMBEDDING_MODEL_DIR または embedding_model/ に存在する場合のみ実行。
    /// ベクトル化が正しく動くこと(768 次元・非ゼロ)を検証する。
    #[test]
    fn encode_returns_768_dim_when_model_loaded() {
        let model_dir = std::env::var_os("TELOS_EMBEDDING_MODEL_DIR")
            .map(std::path::PathBuf::from)
            .or_else(|| {
                let manifest = std::env::var_os("CARGO_MANIFEST_DIR").map(std::path::PathBuf::from)?;
                let root = manifest.parent()?.parent()?;
                let p = root.join("embedding_model");
                if p.join("model_quantized.onnx").exists() || p.join("model.onnx").exists() {
                    Some(p)
                } else {
                    None
                }
            });
        let Some(dir) = model_dir else {
            eprintln!("encode test skipped: no embedding_model (set TELOS_EMBEDDING_MODEL_DIR or place embedding_model/)");
            return;
        };
        let model = match ProEmbeddingModel::load(&dir) {
            Ok(m) => m,
            Err(e) => {
                eprintln!("encode test skipped: model load failed: {}", e);
                return;
            }
        };
        let vec = match model.encode("テスト文です") {
            Ok(v) => v,
            Err(e) => {
                eprintln!("encode test skipped: encode failed (model shape may differ): {}", e);
                return;
            }
        };
        assert_eq!(vec.len(), 768, "encode は 768 次元を返すこと");
        assert!(
            vec.iter().any(|&x| x != 0.0),
            "ベクトルが全て 0 だとベクトル化が効いていない"
        );
    }
}