diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 4f565d3..dcfc0cf 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -43,6 +43,15 @@ dirs = "6.0" libsqlite3-sys = { version = "*", features = ["bundled"] } # Force bundled sqlite uuid = { version = "1", features = ["v4"] } +bincode = "1.3" +# Japanese NLP & LSA +lindera = { version = "0.33", features = ["ipadic"] } +lindera-core = "0.33" +lindera-dictionary = "0.33" +lindera-ipadic = { version = "0.33", features = ["ipadic"] } +ndarray = "0.15" +ndarray-linalg = "0.16" +rsvd = "0.1" [dev-dependencies] tempfile = "3.10" diff --git a/src-tauri/src/db.rs b/src-tauri/src/db.rs index dba379d..89cb27b 100644 --- a/src-tauri/src/db.rs +++ b/src-tauri/src/db.rs @@ -59,6 +59,18 @@ .await .map_err(|e| e.to_string())?; + // LSA ベクトル保存テーブル (意味検索用) + sqlx::query( + "CREATE TABLE IF NOT EXISTS items_lsa ( + id INTEGER PRIMARY KEY, + vector BLOB NOT NULL, + FOREIGN KEY(id) REFERENCES items(id) ON DELETE CASCADE + )", + ) + .execute(pool) + .await + .map_err(|e| e.to_string())?; + // トリガー作成 sqlx::query( "CREATE TRIGGER IF NOT EXISTS update_items_updated_at diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index f0e790b..9b1f07f 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -33,7 +33,8 @@ Ok(lines[start..].join("\n")) } pub mod db; -mod mcp; +pub mod utils; +pub mod mcp; use std::sync::{Arc, Mutex}; use tauri::Manager; diff --git a/src-tauri/src/mcp.rs b/src-tauri/src/mcp.rs index a520efa..08a158d 100644 --- a/src-tauri/src/mcp.rs +++ b/src-tauri/src/mcp.rs @@ -16,8 +16,9 @@ use std::convert::Infallible; use std::sync::Arc; use tokio::sync::broadcast; -use tokio::sync::{mpsc, RwLock}; use tower_http::cors::{Any, CorsLayer}; +use crate::utils::tokenizer::JapaneseTokenizer; +use crate::utils::lsa::LsaModel; #[derive(Clone)] pub struct AppState { @@ -27,6 +28,9 @@ pub model_name: String, // MCP sessions map pub sessions: Arc>>>, + // Japanese NLP & LSA + pub tokenizer: Arc, + pub lsa_model: Arc>>, } pub async fn run_server( @@ -60,13 +64,40 @@ }); let app_state = AppState { - db_pool, + db_pool: db_pool.clone(), tx, llama_status: llama_status.clone(), model_name, sessions, + tokenizer: Arc::new(JapaneseTokenizer::new().expect("Failed to init tokenizer")), + lsa_model: Arc::new(RwLock::new(None)), }; + // 起動時に既存のデータから LSA モデルを構築する (重い処理なので非同期で実行) + let app_state_for_lsa = app_state.clone(); + tokio::spawn(async move { + log::info!("Starting initial LSA model training..."); + if let Ok(rows) = sqlx::query("SELECT content FROM items").fetch_all(&app_state_for_lsa.db_pool).await { + if !rows.is_empty() { + let mut builder = crate::utils::lsa::TermDocumentMatrixBuilder::new(); + for row in rows { + let content: String = row.get(0); + let tokens = app_state_for_lsa.tokenizer.tokenize_to_vec(&content).unwrap_or_default(); + builder.add_document(tokens); + } + let matrix = builder.build_matrix(); + match LsaModel::train(&matrix, builder.vocabulary, 50) { // 50次元に圧縮 + Ok(model) => { + let mut lsa = app_state_for_lsa.lsa_model.write().await; + *lsa = Some(model); + log::info!("LSA model trained successfully with {} documents.", builder.counts.len()); + } + Err(e) => log::error!("LSA training failed: {}", e), + } + } + } + }); + let cors = CorsLayer::new() .allow_origin(Any) .allow_methods(Any) @@ -294,6 +325,23 @@ } }, { + "name": "lsa_search", + "description": "Lightweight Japanese semantic search using LSA (Latent Semantic Analysis). No LLM required.", + "inputSchema": { + "type": "object", + "properties": { + "query": { "type": "string" }, + "limit": { "type": "number" } + }, + "required": ["query"] + } + }, + { + "name": "lsa_retrain", + "description": "Rebuild the LSA semantic model from all current documents. Use this when you've added many new items.", + "inputSchema": { "type": "object", "properties": {} } + }, + { "name": "update_item", "description": "Update existing text and its embedding.", "inputSchema": { @@ -402,6 +450,32 @@ .await .map_err(|e| format!("Failed to insert vector: {}", e))?; + // LSA ベクトルの計算と保存 + let lsa_guard = state.lsa_model.read().await; + if let Some(model) = lsa_guard.as_ref() { + let mut query_counts = HashMap::new(); + let tokens = state.tokenizer.tokenize_to_vec(content).unwrap_or_default(); + for token in tokens { + if let Some(&tid) = model.vocabulary.get(&token) { + *query_counts.entry(tid).or_insert(0.0) += 1.0; + } + } + let mut query_vec = ndarray::Array1::zeros(model.vocabulary.len()); + for (tid, count) in query_counts { + query_vec[tid] = count; + } + + if let Ok(projected) = model.project_query(&query_vec) { + let vector_blob = bincode::serialize(&projected.to_vec()).unwrap_or_default(); + sqlx::query("INSERT INTO items_lsa (id, vector) VALUES (?, ?)") + .bind(id) + .bind(vector_blob) + .execute(&mut *tx) + .await + .map_err(|e| format!("Failed to insert LSA vector: {}", e))?; + } + } + tx.commit() .await .map_err(|e| format!("Failed to commit transaction: {}", e))?; @@ -526,6 +600,69 @@ } } } + "lsa_search" => { + let query = args.get("query").and_then(|v| v.as_str()).unwrap_or(""); + let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(10); + + let lsa_guard = state.lsa_model.read().await; + if let Some(model) = lsa_guard.as_ref() { + // クエリのベクトル化 (TF) + let mut query_counts = HashMap::new(); + let tokens = state.tokenizer.tokenize_to_vec(query).unwrap_or_default(); + for token in tokens { + if let Some(&id) = model.vocabulary.get(&token) { + *query_counts.entry(id).or_insert(0.0) += 1.0; + } + } + + let mut query_vec = ndarray::Array1::zeros(model.vocabulary.len()); + for (id, count) in query_counts { + query_vec[id] = count; + } + + // LSA 空間への射影 + if let Ok(query_lsa) = model.project_query(&query_vec) { + // DB から全ベクトルを取得して比較 (件数が少ない想定) + // 本来はアイテム数が多い場合は BLOB を全件回すと遅いため、インメモリキャッシュ等を検討 + let rows = sqlx::query("SELECT id, vector FROM items_lsa") + .fetch_all(&state.db_pool) + .await + .unwrap_or_default(); + + let mut results = Vec::new(); + for row in rows { + let id: i64 = row.get(0); + let vector_blob: Vec = row.get(1); + if let Ok(vector_f64) = bincode::deserialize::>(&vector_blob) { + let doc_vec = ndarray::Array1::from_vec(vector_f64); + let sim = crate::utils::lsa::LsaModel::cosine_similarity(&query_lsa, &doc_vec); + results.push((id, sim)); + } + } + + results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + results.truncate(limit as usize); + + let mut filtered_results = Vec::new(); + for (id, sim) in results { + if let Ok(doc_row) = sqlx::query("SELECT content FROM items WHERE id = ?").bind(id).fetch_one(&state.db_pool).await { + let content: String = doc_row.get(0); + filtered_results.push(serde_json::json!({ + "id": id, + "content": content, + "similarity": sim + })); + } + } + + Some(serde_json::json!({ "content": filtered_results })) + } else { + Some(serde_json::json!({ "error": "Query projection failed" })) + } + } else { + Some(serde_json::json!({ "error": "LSA model not initialized or no data available" })) + } + } "update_item" => { let id = args.get("id").and_then(|v| v.as_i64()).unwrap_or(0); let content = args.get("content").and_then(|v| v.as_str()).unwrap_or(""); @@ -559,6 +696,32 @@ .await .map_err(|e| format!("Failed to update vector: {}", e))?; + // LSA ベクトルの更新 + let lsa_guard = state.lsa_model.read().await; + if let Some(model) = lsa_guard.as_ref() { + let mut query_counts = HashMap::new(); + let tokens = state.tokenizer.tokenize_to_vec(content).unwrap_or_default(); + for token in tokens { + if let Some(&tid) = model.vocabulary.get(&token) { + *query_counts.entry(tid).or_insert(0.0) += 1.0; + } + } + let mut query_vec = ndarray::Array1::zeros(model.vocabulary.len()); + for (tid, count) in query_counts { + query_vec[tid] = count; + } + + if let Ok(projected) = model.project_query(&query_vec) { + let vector_blob = bincode::serialize(&projected.to_vec()).unwrap_or_default(); + sqlx::query("INSERT OR REPLACE INTO items_lsa (id, vector) VALUES (?, ?)") + .bind(id) + .bind(vector_blob) + .execute(&mut *tx) + .await + .map_err(|e| format!("Failed to update LSA vector: {}", e))?; + } + } + tx.commit() .await .map_err(|e| format!("Failed to commit transaction: {}", e))?; @@ -614,6 +777,59 @@ ) } } + "lsa_retrain" => { + log::info!("Manual LSA retrain triggered."); + let state_clone = state.clone(); + tokio::spawn(async move { + if let Ok(rows) = sqlx::query("SELECT id, content FROM items").fetch_all(&state_clone.db_pool).await { + if !rows.is_empty() { + let mut builder = crate::utils::lsa::TermDocumentMatrixBuilder::new(); + let mut ids = Vec::new(); + for row in rows { + let id: i64 = row.get(0); + let content: String = row.get(1); + let tokens = state_clone.tokenizer.tokenize_to_vec(&content).unwrap_or_default(); + builder.add_document(tokens); + ids.push(id); + } + let matrix = builder.build_matrix(); + match LsaModel::train(&matrix, builder.vocabulary, 50) { + Ok(model) => { + // 全ドキュメントのベクトルを再計算して DB に保存 + let mut tx = state_clone.db_pool.begin().await.unwrap(); + sqlx::query("DELETE FROM items_lsa").execute(&mut *tx).await.unwrap(); + + for (i, &id) in ids.iter().enumerate() { + // 文書 i のベクトルは VT[.., i] * Sigma + // project_query は U^T * TF なので、全文書一括なら U や VT を使った方が早いが + // ここでは一貫性のために各文書の TF を作って射影する + let mut doc_tf = ndarray::Array1::zeros(model.vocabulary.len()); + for (&tid, &count) in &builder.counts[i] { + doc_tf[tid] = count; + } + if let Ok(projected) = model.project_query(&doc_tf) { + let vector_blob = bincode::serialize(&projected.to_vec()).unwrap_or_default(); + sqlx::query("INSERT INTO items_lsa (id, vector) VALUES (?, ?)") + .bind(id) + .bind(vector_blob) + .execute(&mut *tx) + .await + .unwrap(); + } + } + tx.commit().await.unwrap(); + + let mut lsa = state_clone.lsa_model.write().await; + *lsa = Some(model); + log::info!("Manual LSA retrain completed successfully."); + } + Err(e) => log::error!("Manual LSA training failed: {}", e), + } + } + } + }); + Some(serde_json::json!({ "content": [{ "type": "text", "text": "LSA retrain started in background." }] })) + } _ => Some(serde_json::json!({ "error": "Unknown tool" })), } } diff --git a/src-tauri/src/utils/lsa.rs b/src-tauri/src/utils/lsa.rs new file mode 100644 index 0000000..792db97 --- /dev/null +++ b/src-tauri/src/utils/lsa.rs @@ -0,0 +1,99 @@ +use ndarray::{Array2, ArrayBase, Data, Ix2}; +use ndarray_linalg::SVD; +use std::collections::HashMap; +use anyhow::{Result, anyhow}; + +pub struct LsaModel { + pub u: Array2, // 左特異ベクトル (単語-概念) + pub sigma: Vec, // 特異値 + pub vt: Array2, // 右特異ベクトル (概念-文書) + pub vocabulary: HashMap, + pub k: usize, // 圧縮後の次元数 +} + +impl LsaModel { + /// 単語-文書行列 (TF-IDF) から LSA モデルを構築する + pub fn train(matrix: &Array2, vocabulary: HashMap, k: usize) -> Result { + // SVD の実行 + let (u, sigma, vt) = matrix.svd(true, true).map_err(|e| anyhow!("SVD failed: {}", e))?; + + // 次元 k に切り詰め (Truncated SVD) + let u_val = u.ok_or_else(|| anyhow!("U matrix missing"))?; + let vt_val = vt.ok_or_else(|| anyhow!("Vt matrix missing"))?; + + let k_actual = std::cmp::min(k, sigma.len()); + + let u_k = u_val.slice(ndarray::s![.., ..k_actual]).to_owned(); + let sigma_k = sigma.slice(ndarray::s![..k_actual]).to_owned(); + let vt_k = vt_val.slice(ndarray::s![..k_actual, ..]).to_owned(); + + Ok(LsaModel { + u: u_k, + sigma: sigma_k.to_vec(), + vt: vt_k, + vocabulary, + k: k_actual, + }) + } + + /// クエリを潜在概念空間へ射影する + pub fn project_query(&self, query_vector: &ndarray::Array1) -> Result> { + // Query_LSA = Query_TFIDF^T * U_k * Sigma_k^-1 + // 単純化のため、ここでは U_k^T * Query_TFIDF を計算 + let query_lsa = self.u.t().dot(query_vector); + Ok(query_lsa) + } + + /// 二つのベクトル間のコサイン類似度を計算 + pub fn cosine_similarity(a: &ndarray::Array1, b: &ndarray::Array1) -> f64 { + let dot = a.dot(b); + let norm_a = a.dot(a).sqrt(); + let norm_b = b.dot(b).sqrt(); + if norm_a == 0.0 || norm_b == 0.0 { return 0.0; } + dot / (norm_a * norm_b) + } +} + +/// 単語出現頻度をカウントして TF-IDF 行列の元を作成するヘルパー +pub struct TermDocumentMatrixBuilder { + pub vocabulary: HashMap, + pub counts: Vec>, // 文書ごとの単語出現カウント +} + +impl TermDocumentMatrixBuilder { + pub fn new() -> Self { + TermDocumentMatrixBuilder { + vocabulary: HashMap::new(), + counts: Vec::new(), + } + } + + pub fn add_document(&mut self, tokens: Vec) { + let mut doc_counts = HashMap::new(); + for token in tokens { + let id = if let Some(&id) = self.vocabulary.get(&token) { + id + } else { + let id = self.vocabulary.len(); + self.vocabulary.insert(token, id); + id + }; + *doc_counts.entry(id).or_insert(0.0) += 1.0; + } + self.counts.push(doc_counts); + } + + pub fn build_matrix(&self) -> Array2 { + let rows = self.vocabulary.len(); + let cols = self.counts.len(); + let mut matrix = Array2::zeros((rows, cols)); + + for (col, doc_counts) in self.counts.iter().enumerate() { + for (&row, &count) in doc_counts { + matrix[[row, col]] = count; + } + } + // 本来はここで TF-IDF 変換を行うほうが精度が高い + matrix + } +} diff --git a/src-tauri/src/utils/mod.rs b/src-tauri/src/utils/mod.rs new file mode 100644 index 0000000..d21bb8a --- /dev/null +++ b/src-tauri/src/utils/mod.rs @@ -0,0 +1,2 @@ +pub mod tokenizer; +pub mod lsa; diff --git a/src-tauri/src/utils/tokenizer.rs b/src-tauri/src/utils/tokenizer.rs new file mode 100644 index 0000000..8b9728e --- /dev/null +++ b/src-tauri/src/utils/tokenizer.rs @@ -0,0 +1,56 @@ +use lindera::tokenizer::Tokenizer; +use lindera::mode::Mode; +use lindera_core::viterbi::{PenaltyType}; +use anyhow::Result; + +pub struct JapaneseTokenizer { + tokenizer: Tokenizer, +} + +impl JapaneseTokenizer { + pub fn new() -> Result { + // IPADIC 辞書を使用したトークナイザーの初期化 + // 辞書は埋め込みバイナリとして扱われる + let tokenizer = Tokenizer::new_with_config(lindera::tokenizer::TokenizerConfig { + dictionary: lindera::tokenizer::DictionaryConfig { + kind: Some(lindera_core::lexicon::DictionaryKind::IPADIC), + path: None, + }, + user_dictionary: None, + mode: Mode::Decompose(Mode::Normal), // 検索向けに分解モード + })?; + + Ok(JapaneseTokenizer { tokenizer }) + } + + /// テキストを形態素解析し、わかち書き(スペース区切り)の文字列として返す + pub fn tokenize_to_string(&self, text: &str) -> Result { + let tokens = self.tokenizer.tokenize(text)?; + let result: Vec<&str> = tokens + .iter() + .map(|token| token.text) + .collect(); + + Ok(result.join(" ")) + } + + /// 単語のリスト(ベクタ)として返す + pub fn tokenize_to_vec(&self, text: &str) -> Result> { + let tokens = self.tokenizer.tokenize(text)?; + Ok(tokens.iter().map(|t| t.text.to_string()).collect()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_japanese_tokenization() { + let tokenizer = JapaneseTokenizer::new().unwrap(); + let text = "すもももももももものうち"; + let tokenized = tokenizer.tokenize_to_string(text).unwrap(); + assert!(tokenized.contains("すもも")); + assert!(tokenized.contains("もも")); + } +}