Newer
Older
TelosDB / src-tauri / src / utils / lsa.rs
use ndarray::{Array1, Array2, Axis};
use std::collections::HashMap;
use anyhow::{Result, anyhow};

#[derive(Clone)]
pub struct LsaModel {
    pub u: Array2<f64>,      // 左特異ベクトル (単語-概念)
    pub sigma: Vec<f64>,    // 特異値
    pub vocabulary: HashMap<String, usize>,
    pub idfs: Vec<f64>,     // 学習時の IDF
    pub k: usize,            // 圧縮後の次元数
}

impl LsaModel {
    /// 単語-文書行列 (TF-IDF) から LSA モデルを構築する
    pub fn train(matrix: &Array2<f64>, vocabulary: HashMap<String, usize>, idfs: Vec<f64>, k: usize) -> Result<Self> {
        let rows = matrix.nrows();
        let cols = matrix.ncols();
        let k_actual = std::cmp::min(k, std::cmp::min(rows, cols));

        if k_actual == 0 {
            return Err(anyhow!("Cannot train LSA with 0 documents or 0 features"));
        }

        // --- Iterative SVD (Power Method with Deflation) ---
        // A A^T の上位特異ベクトルを求める
        let mut u_k = Array2::zeros((rows, k_actual));
        let mut sigma_k = Vec::new();
        let mut working_matrix = matrix.clone();

        for i in 0..k_actual {
            // 第i主成分を抽出 (Power Method)
            // 初期ベクトルを少しずつ変えて、収束を安定させる
            let mut v = Array1::from_elem(rows, 1.0); 
            if i > 0 {
                // 前の成分と異なる方向を向かせるための簡単な摂動
                v[i % rows] += 1.0;
            }
            v /= (v.dot(&v) as f64).sqrt();

            for _ in 0..150 { // イテレーション回数を増やして精度向上
                // v = (A A^T) v = A * (A^T * v)
                let at_v = working_matrix.t().dot(&v);
                let a_at_v = working_matrix.dot(&at_v);
                
                let norm = (a_at_v.dot(&a_at_v) as f64).sqrt();
                if norm < 1e-15 { break; }
                v = a_at_v / norm;
            }

            // 特異値 s = ||A^T v||
            let at_v_final = working_matrix.t().dot(&v);
            let s = (at_v_final.dot(&at_v_final) as f64).sqrt();
            
            // Deflation: A_next = A - s * u * vt^T
            if s > 1e-15 {
                let vt_i = &at_v_final / s;
                let v_col = v.clone().insert_axis(Axis(1));
                let vt_row = vt_i.insert_axis(Axis(0));
                let projection = v_col.dot(&vt_row);
                working_matrix = working_matrix - (s * projection);
                
                for j in 0..rows {
                    u_k[[j, i]] = v[j];
                }
                sigma_k.push(s);
            } else {
                sigma_k.push(0.0);
            }
        }

        Ok(LsaModel {
            u: u_k,
            sigma: sigma_k,
            vocabulary,
            idfs,
            k: k_actual,
        })
    }

    /// クエリを潜在概念空間へ射影する(TF-IDF重み付け & 正規化済み)
    pub fn project_query(&self, query_tf: &ndarray::Array1<f64>) -> Result<ndarray::Array1<f64>> {
        // TF-IDF 重み付けの適用
        let mut query_tfidf = query_tf.clone();
        for (i, &idf) in self.idfs.iter().enumerate() {
            query_tfidf[i] *= idf;
        }

        // Project: Query_LSA = U_k^T * Query_TFIDF
        let mut query_lsa = self.u.t().dot(&query_tfidf);
        
        // 正規化 (L2距離をコサイン類似度に対応させるため)
        let norm = (query_lsa.dot(&query_lsa) as f64).sqrt();
        if norm > 1e-12 {
            query_lsa /= norm;
        } else {
            // クエリが完全に語彙に含まれない場合は零ベクトルを返す
            // この場合、類似度は全て 0.0 になる
            return Ok(ndarray::Array1::zeros(query_lsa.len()));
        }
        
        Ok(query_lsa)
    }

    /// 二つのベクトル間のコサイン類似度を計算
    pub fn cosine_similarity(a: &ndarray::Array1<f64>, b: &ndarray::Array1<f64>) -> f64 {
        // 両者が正規化されているなら単なるドット積
        a.dot(b)
    }
}

/// 単語出現頻度をカウントして TF-IDF 行列の元を作成するヘルパー
pub struct TermDocumentMatrixBuilder {
    pub vocabulary: HashMap<String, usize>,
    pub counts: Vec<HashMap<usize, f64>>, // 文書ごとの単語出現カウント
}

impl TermDocumentMatrixBuilder {
    pub fn new() -> Self {
        TermDocumentMatrixBuilder {
            vocabulary: HashMap::new(),
            counts: Vec::new(),
        }
    }

    pub fn add_document(&mut self, tokens: Vec<String>) {
        if tokens.is_empty() { return; }
        let mut doc_counts = HashMap::new();
        for token in tokens {
            let id = if let Some(&id) = self.vocabulary.get(&token) {
                id
            } else {
                let id = self.vocabulary.len();
                self.vocabulary.insert(token, id);
                id
            };
            *doc_counts.entry(id).or_insert(0.0) += 1.0;
        }
        self.counts.push(doc_counts);
    }

    pub fn build_matrix(&self) -> (Array2<f64>, Vec<f64>) {
        let num_terms = self.vocabulary.len();
        let num_docs = self.counts.len();
        if num_terms == 0 || num_docs == 0 {
            return (Array2::zeros((num_terms, num_docs)), Vec::new());
        }

        let mut matrix = Array2::zeros((num_terms, num_docs));

        // IDF の計算
        let mut doc_freq = vec![0.0; num_terms];
        for doc_counts in &self.counts {
            for &term_id in doc_counts.keys() {
                doc_freq[term_id] += 1.0;
            }
        }

        let idfs: Vec<f64> = doc_freq.iter()
            .map(|&df| ((num_docs as f64) / (df + 1.0)).ln() + 1.0)
            .collect();

        for (col, doc_counts) in self.counts.iter().enumerate() {
            let max_tf = doc_counts.values().fold(0.0, |a, &b| f64::max(a, b));
            for (&row, &count) in doc_counts {
                // TF-IDF = (tf / max_tf) * idf
                matrix[[row, col]] = (count / max_tf) * idfs[row];
            }
        }
        (matrix, idfs)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_lsa_variance() {
        let mut builder = TermDocumentMatrixBuilder::new();
        // 重なりのあるカテゴリ
        builder.add_document(vec!["自然".to_string(), "山".to_string(), "登山".to_string()]);
        builder.add_document(vec!["自然".to_string(), "空".to_string(), "雲".to_string()]);
        builder.add_document(vec!["寿司".to_string(), "魚".to_string(), "飯".to_string()]);

        let matrix = builder.build_matrix();
        let model = LsaModel::train(&matrix, builder.vocabulary.clone(), 3).unwrap(); // rank=3

        println!("Sigma: {:?}", model.sigma);

        // クエリ: 「山」
        let mut q_vec = ndarray::Array1::zeros(builder.vocabulary.len());
        q_vec[*builder.vocabulary.get("山").unwrap()] = 1.0;
        let q_lsa = model.project_query(&q_vec).unwrap();

        let doc_vectors: Vec<_> = (0..3).map(|i| {
            let mut d_vec = ndarray::Array1::zeros(builder.vocabulary.len());
            for (&tid, &count) in &builder.counts[i] {
                d_vec[tid] = count;
            }
            model.project_query(&d_vec).unwrap()
        }).collect();

        let sims: Vec<_> = doc_vectors.iter()
            .map(|d| LsaModel::cosine_similarity(&q_lsa, d))
            .collect();

        println!("Similarities: {:?}", sims);
        // 山 doc (0) > 空 doc (1) > 寿司 doc (2) となるはず
        assert!(sims[0] > sims[1], "山 should be closer than 空 ({} vs {})", sims[0], sims[1]);
        assert!(sims[1] > sims[2], "空 should be closer than 寿司 ({} vs {})", sims[1], sims[2]);
    }
}