Newer
Older
TelosDB / src-tauri / src / utils / lsa.rs
use ndarray::{Array2, ArrayBase, Data, Ix2};
use ndarray_linalg::SVD;
use std::collections::HashMap;
use anyhow::{Result, anyhow};

pub struct LsaModel {
    pub u: Array2<f64>,      // 左特異ベクトル (単語-概念)
    pub sigma: Vec<f64>,    // 特異値
    pub vt: Array2<f64>,     // 右特異ベクトル (概念-文書)
    pub vocabulary: HashMap<String, usize>,
    pub k: usize,            // 圧縮後の次元数
}

impl LsaModel {
    /// 単語-文書行列 (TF-IDF) から LSA モデルを構築する
    pub fn train(matrix: &Array2<f64>, vocabulary: HashMap<String, usize>, k: usize) -> Result<Self> {
        // SVD の実行
        let (u, sigma, vt) = matrix.svd(true, true).map_err(|e| anyhow!("SVD failed: {}", e))?;
        
        // 次元 k に切り詰め (Truncated SVD)
        let u_val = u.ok_or_else(|| anyhow!("U matrix missing"))?;
        let vt_val = vt.ok_or_else(|| anyhow!("Vt matrix missing"))?;
        
        let k_actual = std::cmp::min(k, sigma.len());
        
        let u_k = u_val.slice(ndarray::s![.., ..k_actual]).to_owned();
        let sigma_k = sigma.slice(ndarray::s![..k_actual]).to_owned();
        let vt_k = vt_val.slice(ndarray::s![..k_actual, ..]).to_owned();

        Ok(LsaModel {
            u: u_k,
            sigma: sigma_k.to_vec(),
            vt: vt_k,
            vocabulary,
            k: k_actual,
        })
    }

    /// クエリを潜在概念空間へ射影する
    pub fn project_query(&self, query_vector: &ndarray::Array1<f64>) -> Result<ndarray::Array1<f64>> {
        // Query_LSA = Query_TFIDF^T * U_k * Sigma_k^-1
        // 単純化のため、ここでは U_k^T * Query_TFIDF を計算
        let query_lsa = self.u.t().dot(query_vector);
        Ok(query_lsa)
    }

    /// 二つのベクトル間のコサイン類似度を計算
    pub fn cosine_similarity(a: &ndarray::Array1<f64>, b: &ndarray::Array1<f64>) -> f64 {
        let dot = a.dot(b);
        let norm_a = a.dot(a).sqrt();
        let norm_b = b.dot(b).sqrt();
        if norm_a == 0.0 || norm_b == 0.0 { return 0.0; }
        dot / (norm_a * norm_b)
    }
}

/// 単語出現頻度をカウントして TF-IDF 行列の元を作成するヘルパー
pub struct TermDocumentMatrixBuilder {
    pub vocabulary: HashMap<String, usize>,
    pub counts: Vec<HashMap<usize, f64>>, // 文書ごとの単語出現カウント
}

impl TermDocumentMatrixBuilder {
    pub fn new() -> Self {
        TermDocumentMatrixBuilder {
            vocabulary: HashMap::new(),
            counts: Vec::new(),
        }
    }

    pub fn add_document(&mut self, tokens: Vec<String>) {
        let mut doc_counts = HashMap::new();
        for token in tokens {
            let id = if let Some(&id) = self.vocabulary.get(&token) {
                id
            } else {
                let id = self.vocabulary.len();
                self.vocabulary.insert(token, id);
                id
            };
            *doc_counts.entry(id).or_insert(0.0) += 1.0;
        }
        self.counts.push(doc_counts);
    }

    pub fn build_matrix(&self) -> Array2<f64> {
        let rows = self.vocabulary.len();
        let cols = self.counts.len();
        let mut matrix = Array2::zeros((rows, cols));

        for (col, doc_counts) in self.counts.iter().enumerate() {
            for (&row, &count) in doc_counts {
                matrix[[row, col]] = count;
            }
        }
        // 本来はここで TF-IDF 変換を行うほうが精度が高い
        matrix
    }
}