use ndarray::{Array1, Array2, Axis};
use std::collections::HashMap;
use anyhow::{Result, anyhow};
#[derive(Clone)]
pub struct LsaModel {
pub u: Array2<f64>, // 左特異ベクトル (単語-概念)
pub sigma: Vec<f64>, // 特異値
pub vocabulary: HashMap<String, usize>,
pub idfs: Vec<f64>, // 学習時の IDF
pub k: usize, // 圧縮後の次元数
}
impl LsaModel {
/// 単語-文書行列 (TF-IDF) から LSA モデルを構築する
pub fn train(matrix: &Array2<f64>, vocabulary: HashMap<String, usize>, idfs: Vec<f64>, k: usize) -> Result<Self> {
let rows = matrix.nrows();
let cols = matrix.ncols();
let k_actual = std::cmp::min(k, std::cmp::min(rows, cols));
if k_actual == 0 {
return Err(anyhow!("Cannot train LSA with 0 documents or 0 features"));
}
// --- Iterative SVD (Power Method with Deflation) ---
// A A^T の上位特異ベクトルを求める
let mut u_k = Array2::zeros((rows, k_actual));
let mut sigma_k = Vec::new();
let mut working_matrix = matrix.clone();
for i in 0..k_actual {
// 第i主成分を抽出 (Power Method)
// 初期ベクトルを少しずつ変えて、収束を安定させる
let mut v = Array1::from_elem(rows, 1.0);
if i > 0 {
// 前の成分と異なる方向を向かせるための簡単な摂動
v[i % rows] += 1.0;
}
v /= (v.dot(&v) as f64).sqrt();
for _ in 0..150 { // イテレーション回数を増やして精度向上
// v = (A A^T) v = A * (A^T * v)
let at_v = working_matrix.t().dot(&v);
let a_at_v = working_matrix.dot(&at_v);
let norm = (a_at_v.dot(&a_at_v) as f64).sqrt();
if norm < 1e-15 { break; }
v = a_at_v / norm;
}
// 特異値 s = ||A^T v||
let at_v_final = working_matrix.t().dot(&v);
let s = (at_v_final.dot(&at_v_final) as f64).sqrt();
// Deflation: A_next = A - s * u * vt^T
if s > 1e-15 {
let vt_i = &at_v_final / s;
let v_col = v.clone().insert_axis(Axis(1));
let vt_row = vt_i.insert_axis(Axis(0));
let projection = v_col.dot(&vt_row);
working_matrix = working_matrix - (s * projection);
for j in 0..rows {
u_k[[j, i]] = v[j];
}
sigma_k.push(s);
} else {
sigma_k.push(0.0);
}
}
Ok(LsaModel {
u: u_k,
sigma: sigma_k,
vocabulary,
idfs,
k: k_actual,
})
}
/// クエリを潜在概念空間へ射影する(TF-IDF重み付け & 正規化済み)
pub fn project_query(&self, query_tf: &ndarray::Array1<f64>) -> Result<ndarray::Array1<f64>> {
// TF-IDF 重み付けの適用
let mut query_tfidf = query_tf.clone();
for (i, &idf) in self.idfs.iter().enumerate() {
query_tfidf[i] *= idf;
}
// Project: Query_LSA = U_k^T * Query_TFIDF
let mut query_lsa = self.u.t().dot(&query_tfidf);
// 正規化 (L2距離をコサイン類似度に対応させるため)
let norm = (query_lsa.dot(&query_lsa) as f64).sqrt();
if norm > 1e-12 {
query_lsa /= norm;
} else {
// クエリが完全に語彙に含まれない場合は零ベクトルを返す
// この場合、類似度は全て 0.0 になる
return Ok(ndarray::Array1::zeros(query_lsa.len()));
}
Ok(query_lsa)
}
/// 二つのベクトル間のコサイン類似度を計算
pub fn cosine_similarity(a: &ndarray::Array1<f64>, b: &ndarray::Array1<f64>) -> f64 {
// 両者が正規化されているなら単なるドット積
a.dot(b)
}
}
/// 単語出現頻度をカウントして TF-IDF 行列の元を作成するヘルパー
pub struct TermDocumentMatrixBuilder {
pub vocabulary: HashMap<String, usize>,
pub counts: Vec<HashMap<usize, f64>>, // 文書ごとの単語出現カウント
}
impl TermDocumentMatrixBuilder {
pub fn new() -> Self {
TermDocumentMatrixBuilder {
vocabulary: HashMap::new(),
counts: Vec::new(),
}
}
pub fn add_document(&mut self, tokens: Vec<String>) {
if tokens.is_empty() { return; }
let mut doc_counts = HashMap::new();
for token in tokens {
let id = if let Some(&id) = self.vocabulary.get(&token) {
id
} else {
let id = self.vocabulary.len();
self.vocabulary.insert(token, id);
id
};
*doc_counts.entry(id).or_insert(0.0) += 1.0;
}
self.counts.push(doc_counts);
}
pub fn build_matrix(&self) -> (Array2<f64>, Vec<f64>) {
let num_terms = self.vocabulary.len();
let num_docs = self.counts.len();
if num_terms == 0 || num_docs == 0 {
return (Array2::zeros((num_terms, num_docs)), Vec::new());
}
let mut matrix = Array2::zeros((num_terms, num_docs));
// IDF の計算
let mut doc_freq = vec![0.0; num_terms];
for doc_counts in &self.counts {
for &term_id in doc_counts.keys() {
doc_freq[term_id] += 1.0;
}
}
let idfs: Vec<f64> = doc_freq.iter()
.map(|&df| ((num_docs as f64) / (df + 1.0)).ln() + 1.0)
.collect();
for (col, doc_counts) in self.counts.iter().enumerate() {
let max_tf = doc_counts.values().fold(0.0, |a, &b| f64::max(a, b));
for (&row, &count) in doc_counts {
// TF-IDF = (tf / max_tf) * idf
matrix[[row, col]] = (count / max_tf) * idfs[row];
}
}
(matrix, idfs)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lsa_variance() {
let mut builder = TermDocumentMatrixBuilder::new();
// 重なりのあるカテゴリ
builder.add_document(vec!["自然".to_string(), "山".to_string(), "登山".to_string()]);
builder.add_document(vec!["自然".to_string(), "空".to_string(), "雲".to_string()]);
builder.add_document(vec!["寿司".to_string(), "魚".to_string(), "飯".to_string()]);
let matrix = builder.build_matrix();
let model = LsaModel::train(&matrix, builder.vocabulary.clone(), 3).unwrap(); // rank=3
println!("Sigma: {:?}", model.sigma);
// クエリ: 「山」
let mut q_vec = ndarray::Array1::zeros(builder.vocabulary.len());
q_vec[*builder.vocabulary.get("山").unwrap()] = 1.0;
let q_lsa = model.project_query(&q_vec).unwrap();
let doc_vectors: Vec<_> = (0..3).map(|i| {
let mut d_vec = ndarray::Array1::zeros(builder.vocabulary.len());
for (&tid, &count) in &builder.counts[i] {
d_vec[tid] = count;
}
model.project_query(&d_vec).unwrap()
}).collect();
let sims: Vec<_> = doc_vectors.iter()
.map(|d| LsaModel::cosine_similarity(&q_lsa, d))
.collect();
println!("Similarities: {:?}", sims);
// 山 doc (0) > 空 doc (1) > 寿司 doc (2) となるはず
assert!(sims[0] > sims[1], "山 should be closer than 空 ({} vs {})", sims[0], sims[1]);
assert!(sims[1] > sims[2], "空 should be closer than 寿司 ({} vs {})", sims[1], sims[2]);
}
}