use ndarray::Array2;
use std::collections::HashMap;
use anyhow::{Result, anyhow};
pub struct LsaModel {
pub u: Array2<f64>, // 左特異ベクトル (単語-概念)
pub sigma: Vec<f64>, // 特異値
pub vt: Array2<f64>, // 右特異ベクトル (概念-文書)
pub vocabulary: HashMap<String, usize>,
pub k: usize, // 圧縮後の次元数
}
impl LsaModel {
/// 単語-文書行列 (TF-IDF) から LSA モデルを構築する
pub fn train(matrix: &Array2<f64>, vocabulary: HashMap<String, usize>, k: usize) -> Result<Self> {
// TODO: ndarray-linalg なしの純 Rust SVD 実装に差し替え
// let (u, sigma, vt) = matrix.svd(true, true).map_err(|e| anyhow!("SVD failed: {}", e))?;
let rows = matrix.nrows();
let cols = matrix.ncols();
let k_actual = std::cmp::min(k, std::cmp::min(rows, cols));
// ダミーデータでビルド確認用
let u_k = Array2::zeros((rows, k_actual));
let sigma_k = vec![1.0; k_actual];
let vt_k = Array2::zeros((k_actual, cols));
Ok(LsaModel {
u: u_k,
sigma: sigma_k.to_vec(),
vt: vt_k,
vocabulary,
k: k_actual,
})
}
/// クエリを潜在概念空間へ射影する
pub fn project_query(&self, query_vector: &ndarray::Array1<f64>) -> Result<ndarray::Array1<f64>> {
// Query_LSA = Query_TFIDF^T * U_k * Sigma_k^-1
// 単純化のため、ここでは U_k^T * Query_TFIDF を計算
let query_lsa = self.u.t().dot(query_vector);
Ok(query_lsa)
}
/// 二つのベクトル間のコサイン類似度を計算
pub fn cosine_similarity(a: &ndarray::Array1<f64>, b: &ndarray::Array1<f64>) -> f64 {
let dot = a.dot(b);
let norm_a = a.dot(a).sqrt();
let norm_b = b.dot(b).sqrt();
if norm_a == 0.0 || norm_b == 0.0 { return 0.0; }
dot / (norm_a * norm_b)
}
}
/// 単語出現頻度をカウントして TF-IDF 行列の元を作成するヘルパー
pub struct TermDocumentMatrixBuilder {
pub vocabulary: HashMap<String, usize>,
pub counts: Vec<HashMap<usize, f64>>, // 文書ごとの単語出現カウント
}
impl TermDocumentMatrixBuilder {
pub fn new() -> Self {
TermDocumentMatrixBuilder {
vocabulary: HashMap::new(),
counts: Vec::new(),
}
}
pub fn add_document(&mut self, tokens: Vec<String>) {
let mut doc_counts = HashMap::new();
for token in tokens {
let id = if let Some(&id) = self.vocabulary.get(&token) {
id
} else {
let id = self.vocabulary.len();
self.vocabulary.insert(token, id);
id
};
*doc_counts.entry(id).or_insert(0.0) += 1.0;
}
self.counts.push(doc_counts);
}
pub fn build_matrix(&self) -> Array2<f64> {
let rows = self.vocabulary.len();
let cols = self.counts.len();
let mut matrix = Array2::zeros((rows, cols));
for (col, doc_counts) in self.counts.iter().enumerate() {
for (&row, &count) in doc_counts {
matrix[[row, col]] = count;
}
}
// 本来はここで TF-IDF 変換を行うほうが精度が高い
matrix
}
}