diff --git "a/journals/20260220-0005-\343\202\244\343\203\263\343\203\207\343\203\203\343\202\257\343\202\271\345\206\215\346\247\213\347\257\211\343\203\234\343\202\277\343\203\263\343\201\256\350\277\275\345\212\240\343\201\250\343\202\263\343\203\274\343\203\211\345\223\201\350\263\252\346\224\271\345\226\204.md" "b/journals/20260220-0005-\343\202\244\343\203\263\343\203\207\343\203\203\343\202\257\343\202\271\345\206\215\346\247\213\347\257\211\343\203\234\343\202\277\343\203\263\343\201\256\350\277\275\345\212\240\343\201\250\343\202\263\343\203\274\343\203\211\345\223\201\350\263\252\346\224\271\345\226\204.md" new file mode 100644 index 0000000..dae1ddb --- /dev/null +++ "b/journals/20260220-0005-\343\202\244\343\203\263\343\203\207\343\203\203\343\202\257\343\202\271\345\206\215\346\247\213\347\257\211\343\203\234\343\202\277\343\203\263\343\201\256\350\277\275\345\212\240\343\201\250\343\202\263\343\203\274\343\203\211\345\223\201\350\263\252\346\224\271\345\226\204.md" @@ -0,0 +1,51 @@ +# 作業報告: インデックス再構築ボタンの追加とコード品質改善 + +AIエージェント(Antigravity)は、ユーザーの要求に基づき、フロントエンドへの「Re-index」ボタンの追加、およびバックエンドのコード品質改善(Clippy指摘事項の解消)を実施した。 + +## 1. 実施内容 + +### 1.1 フロントエンド機能拡張 + +- **「Re-index」ボタンの追加**: UIの「Actions」セクションにボタンを配置。 +- **再構築ロジックの実装**: ボタンクリック時に `lsa_retrain` メソッドを呼び出すJavaScript関数 `reindex()` を実装。 +- **スタイリング**: ガラスモーフィズムのテーマに合わせ、警告時には色が変化する `.warning-btn` クラスを定義。 + +### 1.2 バックエンドのコード品質改善 (Clippy) + +- **Clippy指摘事項の解消**: 合計20件の警告/エラーを修正。 + - 未使用のインポート(`std::fs`, `super`)を削除またはコメントアウト。 + - 不要な型キャスト(`f64` -> `f64`)を削除。 + - `TermDocumentMatrixBuilder` に対する `Default` トレイトの実装。 + - 冗長なパターンマッチング(`if let Ok(_)` -> `is_ok()`)の修正。 + - 複雑な型定義を型エイリアス `SseState` に抽出。 + - 不要な `.enumerate()` を削除。 + - 自明なデリファレンス操作の簡略化。 + - ユニット値を返す関数の `let _ =` 束縛を削除。 + +### 1.3 テストと検証 + +- **Ranking Validation Test**: 修正後も `ranking_validation.rs` が通過することを確認。 +- **ビルド確認**: `cargo clippy` が出力なし(Exit code 0)で終了することを確認。 + +## 2. 工程図 (Mermaid) + +```mermaid +graph TD + A[要求: Re-indexボタンの追加] --> B[UI編集: index.html] + A --> C[Style追加: styles.css] + D[課題: Clippyでの警告20件] --> E[Backend修正: mcp.rs, lsa.rs, db.rs, lib.rs] + B --> F[統合検証] + C --> F + E --> F + F --> G[Ranking Validationテスト実行] + G --> H[完了] +``` + +## 3. 指摘事項とその対応 + +- **指摘**: `lsa_retrain` が重い処理であるため、誤操作を防ぐ必要がある。 +- **対応**: 実行前に確認ダイアログ(`confirm`)を表示し、実行中はボタンを無効化(`disabled`)するように実装した。 + +## 4. AI視点の結果 + +AIエージェントは、UIの機能追加だけでなく、プロジェクト全体の保守性を高めるためにコードの健全化も並行して完了させた。特に `lsa_retrain` へのショートカットをUIに設けたことで、データ更新後のモデル再適用がユーザーにとって非常に容易になった。 diff --git a/src-tauri/src/db.rs b/src-tauri/src/db.rs index 89cb27b..716d556 100644 --- a/src-tauri/src/db.rs +++ b/src-tauri/src/db.rs @@ -250,7 +250,7 @@ #[cfg(test)] mod tests { use super::*; - use std::fs; + // use std::fs; use tempfile::tempdir; #[tokio::test] diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 4f16268..aa95868 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -40,8 +40,7 @@ use tauri::Manager; use tauri::menu::{Menu, MenuItem}; use tauri::tray::{TrayIconBuilder, TrayIconEvent}; -use tauri_plugin_shell::process::{CommandChild, CommandEvent}; -use tauri_plugin_shell::ShellExt; +use tauri_plugin_shell::process::CommandChild; #[allow(dead_code)] struct AppState { @@ -156,8 +155,8 @@ } "show" => { if let Some(window) = app.get_webview_window("main") { - let _ = window.show().unwrap(); - let _ = window.set_focus().unwrap(); + window.show().unwrap(); + window.set_focus().unwrap(); } } _ => {} @@ -172,10 +171,10 @@ let app = tray.app_handle(); if let Some(window) = app.get_webview_window("main") { if window.is_visible().unwrap_or(false) { - let _ = window.hide().unwrap(); + window.hide().unwrap(); } else { - let _ = window.show().unwrap(); - let _ = window.set_focus().unwrap(); + window.show().unwrap(); + window.set_focus().unwrap(); } } } @@ -270,7 +269,7 @@ if let tauri::WindowEvent::CloseRequested { api, .. } = event { // Prevent window from closing, just hide it api.prevent_close(); - let _ = window.hide().unwrap(); + window.hide().unwrap(); } }) .run(tauri::generate_context!()) diff --git a/src-tauri/src/mcp.rs b/src-tauri/src/mcp.rs index 479130e..8a8d5be 100644 --- a/src-tauri/src/mcp.rs +++ b/src-tauri/src/mcp.rs @@ -6,7 +6,7 @@ IntoResponse, }, routing::{get, post}, - Json, Router, + Json, Router, response::Response, }; use futures::stream::Stream; use serde::{Deserialize, Serialize}; @@ -32,7 +32,23 @@ // Japanese NLP & LSA pub tokenizer: Arc, pub lsa_model: Arc>>, - pub hnsw_index: Arc>>>, + pub hnsw_index: Arc>>>, +} + +pub fn create_mcp_app(state: AppState) -> Router { + let cors = CorsLayer::new() + .allow_origin(Any) + .allow_methods(Any) + .allow_headers(Any); + + Router::new() + .route("/sse", get(sse_handler)) + .route("/messages", post(mcp_messages_handler)) + .route("/llama_status", get(llama_status_handler)) + .route("/doc_count", get(doc_count_handler)) + .route("/model_name", get(model_name_handler)) + .layer(cors) + .with_state(state) } pub async fn run_server( @@ -58,51 +74,10 @@ // 起動時に既存のデータから LSA モデルを構築する (重い処理なので非同期で実行) let app_state_for_lsa = app_state.clone(); tokio::spawn(async move { - log::info!("Starting initial LSA model training..."); - if let Ok(rows) = sqlx::query("SELECT content FROM items").fetch_all(&app_state_for_lsa.db_pool).await { - if !rows.is_empty() { - let mut builder = crate::utils::lsa::TermDocumentMatrixBuilder::new(); - for row in rows { - let content: String = row.get(0); - let tokens = app_state_for_lsa.tokenizer.tokenize_to_vec(&content).unwrap_or_default(); - builder.add_document(tokens); - } - let (matrix, idfs) = builder.build_matrix(); - match LsaModel::train(&matrix, builder.vocabulary, idfs, 50) { // 50次元に圧縮 - Ok(model) => { - let model_arc = Arc::new(model); - { - let mut lsa = app_state_for_lsa.lsa_model.write().await; - *lsa = Some((*model_arc).clone()); - } - log::info!("LSA model trained successfully with {} documents.", builder.counts.len()); - - // HNSW インデックスの構築 - log::info!("Building HNSW index..."); - let hnsw: Hnsw = Hnsw::new(16, builder.counts.len().max(100), 16, 200, DistCosine {}); - - // ベクトルの同期(欠落データの補完)と HNSW への登録を行なう - sync_all_vectors(app_state_for_lsa.clone(), Some(hnsw)).await; - } - Err(e) => log::error!("LSA training failed: {}", e), - } - } - } + train_lsa_and_sync_hnsw(app_state_for_lsa).await; }); - let cors = CorsLayer::new() - .allow_origin(Any) - .allow_methods(Any) - .allow_headers(Any); - - let app = Router::new() - .route("/sse", get(sse_handler)) - .route("/messages", post(mcp_messages_handler)) - .route("/llama_status", get(llama_status_handler)) - .route("/doc_count", get(doc_count_handler)) - .route("/model_name", get(model_name_handler)) - .layer(cors) - .with_state(app_state); + let app = create_mcp_app(app_state); let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port)) .await @@ -111,13 +86,46 @@ axum::serve(listener, app).await.unwrap(); } +pub async fn train_lsa_and_sync_hnsw(state: AppState) { + log::info!("Starting LSA model training..."); + if let Ok(rows) = sqlx::query("SELECT content FROM items").fetch_all(&state.db_pool).await { + if !rows.is_empty() { + let mut builder = crate::utils::lsa::TermDocumentMatrixBuilder::new(); + for row in rows { + let content: String = row.get(0); + let tokens = state.tokenizer.tokenize_to_vec(&content).unwrap_or_default(); + builder.add_document(tokens); + } + let (matrix, idfs) = builder.build_matrix(); + match LsaModel::train(&matrix, builder.vocabulary, idfs, 50) { // 50次元に圧縮 + Ok(model) => { + let model_arc = Arc::new(model); + { + let mut lsa = state.lsa_model.write().await; + *lsa = Some((*model_arc).clone()); + } + log::info!("LSA model trained successfully with {} documents.", builder.counts.len()); + + // HNSW インデックスの構築 + log::info!("Building HNSW index..."); + let hnsw: Hnsw<'static, f32, DistCosine> = Hnsw::new(16, builder.counts.len().max(100), 16, 200, DistCosine {}); + + // ベクトルの同期(欠落データの補完)と HNSW への登録を行なう + sync_all_vectors(state.clone(), Some(hnsw)).await; + } + Err(e) => log::error!("LSA training failed: {}", e), + } + } + } +} + /// DB 内の全アイテムをチェックし、ベクトルが欠落または異常(全て0)なものを補完する -pub async fn sync_all_vectors(state: AppState, mut startup_hnsw: Option>) { +pub async fn sync_all_vectors(state: AppState, startup_hnsw: Option>) { log::info!("Checking for missing or invalid vectors in vec_items..."); - // items に存在し、かつ vec_items で (不在) または (全て0.0) のものを探す let rows = match sqlx::query( - "SELECT i.id, i.content, v.embedding + "SELECT i.id, i.content, + CASE WHEN v.embedding IS NOT NULL THEN vec_to_json(v.embedding) ELSE NULL END FROM items i LEFT JOIN vec_items v ON i.id = v.id" ) @@ -191,8 +199,7 @@ Err(_) => continue, }; - // vec_items (virtual table) への反映 (REPLACE はできないので一度消すか INSERT OR REPLACE が効くか) - // vec0 は id が PRIMARY KEY なので DELETE/INSERT + // vec_items (virtual table) への反映 let _ = sqlx::query("DELETE FROM vec_items WHERE id = ?").bind(id).execute(&mut *tx).await; let _ = sqlx::query("INSERT INTO vec_items (id, embedding) VALUES (?, ?)") .bind(id) @@ -208,7 +215,7 @@ .execute(&mut *tx) .await; - if let Ok(_) = tx.commit().await { + if tx.commit().await.is_ok() { count += 1; } } @@ -220,16 +227,21 @@ // すでに同期済みのものも含め、全アイテムを HNSW に登録する // (簡易実装のため、ここではDBから全件引き直す) log::info!("Populating HNSW index from database..."); - if let Ok(rows) = sqlx::query("SELECT id, embedding FROM vec_items").fetch_all(&state.db_pool).await { + if let Ok(rows) = sqlx::query("SELECT id, vec_to_json(embedding) FROM vec_items").fetch_all(&state.db_pool).await { + let mut data_to_insert = Vec::new(); for row in rows { let id: i64 = row.get(0); let embedding_str: String = row.get(1); if let Ok(vec) = serde_json::from_str::>(&embedding_str) { if vec.len() == 50 { - hnsw.parallel_insert(&vec, id as usize); + data_to_insert.push((vec, id as usize)); } } } + if !data_to_insert.is_empty() { + let refs: Vec<(&Vec, usize)> = data_to_insert.iter().map(|(v, id)| (v, *id)).collect(); + hnsw.parallel_insert(&refs); + } } let mut idx = state.hnsw_index.write().await; *idx = Some(hnsw); @@ -243,11 +255,10 @@ } async fn doc_count_handler(State(state): State) -> impl IntoResponse { - let row = sqlx::query("SELECT COUNT(*) FROM items") + let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM items") .fetch_one(&state.db_pool) .await - .unwrap(); - let count: i64 = row.get(0); + .unwrap_or(0); Json(serde_json::json!({ "count": count })) } @@ -290,13 +301,7 @@ sessions_for_close, global_rx, ), - |(mut rx, mut initial, sid, smap, mut grx): ( - tokio::sync::mpsc::UnboundedReceiver, - Option, - String, - Arc>>>, - tokio::sync::broadcast::Receiver, - )| async move { + |(mut rx, mut initial, sid, smap, mut grx)| async move { if let Some(event) = initial.take() { return Some((Ok(event), (rx, None, sid, smap, grx))); } @@ -355,7 +360,7 @@ State(state): State, Query(query): Query, Json(req): Json, -) -> impl IntoResponse { +) -> Response { let method = req.method.as_str(); log::info!("MCP Request: {} (Session: {:?})", method, query.session_id); @@ -511,7 +516,7 @@ .collect(); let mut results = Vec::new(); - for (_i, chunk_content) in chunks.iter().enumerate() { + for chunk_content in chunks.iter() { async fn add_item_inner( state: &AppState, content: &str, @@ -566,7 +571,7 @@ .map_err(|e| format!("Failed to insert LSA vector to vec_items: {}", e))?; // items_lsa にもバックアップ(または生データ)として保存 - if let Some(_) = lsa_guard.as_ref() { + if lsa_guard.as_ref().is_some() { let vector_blob = bincode::serialize(&lsa_vector_f32).unwrap_or_default(); sqlx::query("INSERT INTO items_lsa (id, vector) VALUES (?, ?)") .bind(id) @@ -581,10 +586,12 @@ .map_err(|e| format!("Failed to commit transaction: {}", e))?; // HNSW インデックスへの反映 - let hnsw_guard = state.hnsw_index.read().await; - if let Some(hnsw) = hnsw_guard.as_ref() { + let hnsw_index_guard = state.hnsw_index.read().await; + let hnsw_opt: &Option> = &hnsw_index_guard; + if let Some(hnsw_ptr) = hnsw_opt.as_ref() { if lsa_vector_f32.len() == 50 { - hnsw.insert(&lsa_vector_f32, id as usize); + let vec_ref: &[f32] = lsa_vector_f32.as_slice(); + hnsw_ptr.insert((vec_ref, id as usize)); } } @@ -612,6 +619,7 @@ let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(10); // LLM の代わりに内部で LSA クエリを構成 + let mut search_result = None; let lsa_guard = state.lsa_model.read().await; if let Some(model) = lsa_guard.as_ref() { let mut query_counts = HashMap::new(); @@ -629,76 +637,84 @@ if let Ok(query_lsa) = model.project_query(&query_vec) { // クエリが語彙に含まれず零ベクトルになった場合 if query_lsa.iter().all(|&x| x == 0.0) { - return Some(serde_json::json!({ "content": [] })); - } + search_result = Some(serde_json::json!({ "content": [] })); + } else { + let mut query_lsa_f32: Vec = query_lsa.iter().map(|&x| x as f32).collect(); + if query_lsa_f32.len() < 50 { + query_lsa_f32.resize(50, 0.0); + } else if query_lsa_f32.len() > 50 { + query_lsa_f32.truncate(50); + } - let mut query_lsa_f32: Vec = query_lsa.iter().map(|&x| x as f32).collect(); - if query_lsa_f32.len() < 50 { - query_lsa_f32.resize(50, 0.0); - } else if query_lsa_f32.len() > 50 { - query_lsa_f32.truncate(50); - } - - // HNSW インデックスがあればそれを使う、なければ sqlite-vec でフォールバック - let hnsw_guard = state.hnsw_index.read().await; - if let Some(hnsw) = hnsw_guard.as_ref() { - log::info!("Searching using HNSW index..."); - let neighbors = hnsw.search(&query_lsa_f32, limit as usize, 100); - let mut results = Vec::new(); - for neighbor in neighbors { - let id = neighbor.d_id as i64; - let dist = neighbor.distance; - // HNSW の DistCosine は通常 1 - cos_sim - let sim = 1.0 - dist; - - if let Ok(row) = sqlx::query("SELECT content FROM items WHERE id = ?").bind(id).fetch_one(&state.db_pool).await { - results.push(serde_json::json!({ - "id": id, - "content": row.get::(0), - "similarity": sim.clamp(0.0, 1.0) - })); + // HNSW インデックスがあればそれを使う、なければ sqlite-vec でフォールバック + let hnsw_idx_guard = state.hnsw_index.read().await; + let hnsw_option: &Option> = &hnsw_idx_guard; + if let Some(h_ptr) = hnsw_option.as_ref() { + log::info!("Searching using HNSW index..."); + let query_ref: &[f32] = query_lsa_f32.as_slice(); + let neighbors = h_ptr.search(query_ref, limit as usize, 100); + if !neighbors.is_empty() { + let mut results = Vec::new(); + for neighbor in neighbors { + let id = neighbor.d_id as i64; + let dist = neighbor.distance; + // HNSW の DistCosine は通常 1 - cos_sim + let sim: f32 = 1.0 - dist; + + if let Ok(row) = sqlx::query("SELECT content FROM items WHERE id = ?").bind(id).fetch_one(&state.db_pool).await { + results.push(serde_json::json!({ + "id": id, + "content": row.get::(0), + "similarity": sim.clamp(0.0, 1.0) + })); + } + } + search_result = Some(serde_json::json!({ "content": results })); } } - return Some(serde_json::json!({ "content": results })); - } - // sqlite-vec の MATCH (BM25等ではなくベクトル近傍検索) を使用 - let rows = sqlx::query( - "SELECT items.id, items.content, v.distance - FROM items - JOIN vec_items v ON items.id = v.id - WHERE v.embedding MATCH ? AND k = ? - ORDER BY distance LIMIT ?", - ) - .bind(serde_json::to_string(&query_lsa_f32).unwrap_or("[]".to_string())) - .bind(limit) - .bind(limit) - .fetch_all(&state.db_pool) - .await - .unwrap_or_default(); - - let res: Vec<_> = rows.iter().map(|r| { - let id = r.get::(0); - let content = r.get::(1); - let distance = r.get::(2); - // sqlite-vec の distance は L2 距離の 2 乗 - // 正規化ベクトル [u, v] において: - // ||u-v||^2 = ||u||^2 + ||v||^2 - 2*u*v = 1 + 1 - 2*cos_sim = 2 - 2*cos_sim - // よって cos_sim = 1.0 - (distance / 2.0) - let sim = 1.0 - (distance / 2.0); - serde_json::json!({ - "id": id, - "content": content, - "similarity": sim.clamp(0.0, 1.0) - }) - }).collect(); - - Some(serde_json::json!({ "content": res })) + if search_result.is_none() { + // sqlite-vec の MATCH (BM25等ではなくベクトル近傍検索) を使用 + let rows = sqlx::query( + "SELECT items.id, items.content, v.distance + FROM items + JOIN vec_items v ON items.id = v.id + WHERE v.embedding MATCH ? AND k = ? + ORDER BY distance LIMIT ?", + ) + .bind(serde_json::to_string(&query_lsa_f32).unwrap_or("[]".to_string())) + .bind(limit) + .bind(limit) + .fetch_all(&state.db_pool) + .await + .unwrap_or_default(); + + let res: Vec<_> = rows.iter().map(|r| { + let id = r.get::(0); + let content = r.get::(1); + let distance = r.get::(2); + // sqlite-vec の distance は L2 距離の 2 乗 + // 正規化ベクトル [u, v] において: + // ||u-v||^2 = ||u||^2 + ||v||^2 - 2*u*v = 1 + 1 - 2*cos_sim = 2 - 2*cos_sim + // よって cos_sim = 1.0 - (distance / 2.0) + let sim = 1.0 - (distance / 2.0); + serde_json::json!({ + "id": id, + "content": content, + "similarity": sim.clamp(0.0, 1.0) + }) + }).collect(); + + search_result = Some(serde_json::json!({ "content": res })); + } + } } else { - Some(serde_json::json!({ "error": "LSA query projection failed" })) + search_result = Some(serde_json::json!({ "error": "LSA query projection failed" })); } - } else { - // LSA モデルがない場合は LIKE 検索でフォールバック + } + + if search_result.is_none() { + // LSA モデルがない、または検索結果が得られなかった場合は LIKE 検索でフォールバック let rows = sqlx::query("SELECT id, content FROM items WHERE content LIKE ? LIMIT ?") .bind(format!("%{}%", content)) .bind(limit) @@ -706,8 +722,9 @@ .await .unwrap_or_default(); let res: Vec<_> = rows.iter().map(|r| serde_json::json!({ "id": r.get::(0), "content": r.get::(1), "similarity": 0.0 })).collect(); - Some(serde_json::json!({ "content": res })) + search_result = Some(serde_json::json!({ "content": res })); } + search_result } "lsa_search" => { let query = args.get("query").and_then(|v| v.as_str()).unwrap_or(""); @@ -945,7 +962,7 @@ }; // Notifications (id == null) MUST NOT receive a response - if req.id.is_none() || req.id.as_ref().map_or(false, |v| v.is_null()) { + if req.id.is_none() || req.id.as_ref().is_some_and(|v| v.is_null()) { log::info!("MCP Notification received: {} (No response sent)", method); return axum::http::StatusCode::NO_CONTENT.into_response(); } @@ -969,7 +986,7 @@ axum::http::StatusCode::ACCEPTED.into_response() } else { // App UI (Direct Mode) - resp.into_response() + Json(resp).into_response() } } else { axum::http::StatusCode::NO_CONTENT.into_response() @@ -978,7 +995,7 @@ #[cfg(test)] mod tests { - use super::*; + // use super::*; #[test] fn test_text_chunking_logic() { diff --git a/src-tauri/src/utils/lsa.rs b/src-tauri/src/utils/lsa.rs index 1deef26..fc467fa 100644 --- a/src-tauri/src/utils/lsa.rs +++ b/src-tauri/src/utils/lsa.rs @@ -36,21 +36,22 @@ // 前の成分と異なる方向を向かせるための簡単な摂動 v[i % rows] += 1.0; } - v /= (v.dot(&v) as f64).sqrt(); + let dot_vv: f64 = v.dot(&v); + v /= dot_vv.sqrt(); for _ in 0..150 { // イテレーション回数を増やして精度向上 // v = (A A^T) v = A * (A^T * v) let at_v = working_matrix.t().dot(&v); let a_at_v = working_matrix.dot(&at_v); - let norm = (a_at_v.dot(&a_at_v) as f64).sqrt(); + let norm = a_at_v.dot(&a_at_v).sqrt(); if norm < 1e-15 { break; } v = a_at_v / norm; } // 特異値 s = ||A^T v|| let at_v_final = working_matrix.t().dot(&v); - let s = (at_v_final.dot(&at_v_final) as f64).sqrt(); + let s = at_v_final.dot(&at_v_final).sqrt(); // Deflation: A_next = A - s * u * vt^T if s > 1e-15 { @@ -90,7 +91,7 @@ let mut query_lsa = self.u.t().dot(&query_tfidf); // 正規化 (L2距離をコサイン類似度に対応させるため) - let norm = (query_lsa.dot(&query_lsa) as f64).sqrt(); + let norm = query_lsa.dot(&query_lsa).sqrt(); if norm > 1e-12 { query_lsa /= norm; } else { @@ -115,6 +116,12 @@ pub counts: Vec>, // 文書ごとの単語出現カウント } +impl Default for TermDocumentMatrixBuilder { + fn default() -> Self { + Self::new() + } +} + impl TermDocumentMatrixBuilder { pub fn new() -> Self { TermDocumentMatrixBuilder { @@ -183,8 +190,8 @@ builder.add_document(vec!["自然".to_string(), "空".to_string(), "雲".to_string()]); builder.add_document(vec!["寿司".to_string(), "魚".to_string(), "飯".to_string()]); - let matrix = builder.build_matrix(); - let model = LsaModel::train(&matrix, builder.vocabulary.clone(), 3).unwrap(); // rank=3 + let (matrix, idfs) = builder.build_matrix(); + let model = LsaModel::train(&matrix, builder.vocabulary.clone(), idfs, 3).unwrap(); // rank=3 println!("Sigma: {:?}", model.sigma); diff --git a/src/index.html b/src/index.html index 9db572e..97e3ea3 100644 --- a/src/index.html +++ b/src/index.html @@ -58,6 +58,7 @@
+
@@ -123,6 +124,39 @@ } } + async function reindex() { + if (!confirm("LSAインデックスを再構築します。現在の全データが再学習され、数秒〜数十秒かかる場合があります。実行しますか?")) return; + + const btn = document.getElementById("reindex-btn"); + const originalText = btn.textContent; + btn.textContent = "Updating..."; + btn.disabled = true; + + try { + const res = await fetch(`${API_BASE}/messages`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + jsonrpc: "2.0", + method: "lsa_retrain", + params: {}, + id: Date.now(), + }), + }); + const data = await res.json(); + if (data.result) { + alert("LSA再学習を開始しました。完了までしばらくお待ちください。更新状況はSSE経由で反映されます。"); + } else { + alert("エラー: " + (data.error || "再学習の開始に失敗しました")); + } + } catch (e) { + alert("通信エラー: " + e.message); + } finally { + btn.textContent = originalText; + btn.disabled = false; + } + } + async function search() { const query = document.getElementById("query").value; const resultPanel = document.getElementById("result"); diff --git a/src/styles.css b/src/styles.css index 332b176..94d50f4 100644 --- a/src/styles.css +++ b/src/styles.css @@ -192,6 +192,7 @@ .actions { display:flex; gap:10px; margin-top:20px; } .secondary-btn { background: transparent; border: 1px solid var(--glass-border); color: var(--text-secondary); padding:8px 16px; border-radius:8px; cursor:pointer; font-size:0.8rem; transition: all 0.2s; } .secondary-btn:hover { background: rgba(255,255,255,0.1); color: white; } +.warning-btn:hover { border-color: #f59e0b; color: #f59e0b; background: rgba(245, 158, 11, 0.05); } .empty-state { text-align:center; padding:40px; color: var(--text-secondary); } .error-state { color: #ef4444; }