Newer
Older
TelosDB / src-tauri / src / mcp.rs
@楽曲作りまくりおじさん 楽曲作りまくりおじさん 3 days ago 48 KB feat: add re-index button and resolve all clippy lints
// use crate::db;
use axum::{
    extract::{Query, State},
    response::{
        sse::{Event, Sse},
        IntoResponse,
    },
    routing::{get, post},
    Json, Router, response::Response,
};
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use chrono::Utc;
use sqlx::Row;
use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc, RwLock};
use tower_http::cors::{Any, CorsLayer};
use crate::utils::lsa::LsaModel;
use crate::utils::tokenizer::JapaneseTokenizer;
use hnsw_rs::prelude::*;

#[derive(Clone)]
pub struct AppState {
    pub db_pool: sqlx::SqlitePool,
    pub tx: broadcast::Sender<String>,
    pub llama_status: Arc<RwLock<String>>,
    pub model_name: String,
    // MCP sessions map
    pub sessions: Arc<RwLock<HashMap<String, mpsc::UnboundedSender<String>>>>,
    // Japanese NLP & LSA
    pub tokenizer: Arc<JapaneseTokenizer>,
    pub lsa_model: Arc<RwLock<Option<LsaModel>>>,
    pub hnsw_index: Arc<RwLock<Option<Hnsw<'static, f32, DistCosine>>>>,
}

pub fn create_mcp_app(state: AppState) -> Router {
    let cors = CorsLayer::new()
        .allow_origin(Any)
        .allow_methods(Any)
        .allow_headers(Any);

    Router::new()
        .route("/sse", get(sse_handler))
        .route("/messages", post(mcp_messages_handler))
        .route("/llama_status", get(llama_status_handler))
        .route("/doc_count", get(doc_count_handler))
        .route("/model_name", get(model_name_handler))
        .layer(cors)
        .with_state(state)
}

pub async fn run_server(
    port: u16,
    db_pool: sqlx::SqlitePool,
    llama_status: Arc<RwLock<String>>,
    model_name: String,
) {
    let (tx, _rx) = broadcast::channel(100);
    let sessions: Arc<RwLock<HashMap<String, mpsc::UnboundedSender<String>>>> = Arc::new(RwLock::new(HashMap::new()));

    let app_state = AppState {
        db_pool: db_pool.clone(),
        tx,
        llama_status: llama_status.clone(),
        model_name,
        sessions,
        tokenizer: Arc::new(JapaneseTokenizer::new().expect("Failed to init tokenizer")),
        lsa_model: Arc::new(RwLock::new(None)),
        hnsw_index: Arc::new(RwLock::new(None)),
    };

    // 起動時に既存のデータから LSA モデルを構築する (重い処理なので非同期で実行)
    let app_state_for_lsa = app_state.clone();
    tokio::spawn(async move {
        train_lsa_and_sync_hnsw(app_state_for_lsa).await;
    });

    let app = create_mcp_app(app_state);

    let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
        .await
        .unwrap();
    log::info!("MCP Server listening on {}", listener.local_addr().unwrap());
    axum::serve(listener, app).await.unwrap();
}

pub async fn train_lsa_and_sync_hnsw(state: AppState) {
    log::info!("Starting LSA model training...");
    if let Ok(rows) = sqlx::query("SELECT content FROM items").fetch_all(&state.db_pool).await {
        if !rows.is_empty() {
            let mut builder = crate::utils::lsa::TermDocumentMatrixBuilder::new();
            for row in rows {
                let content: String = row.get(0);
                let tokens = state.tokenizer.tokenize_to_vec(&content).unwrap_or_default();
                builder.add_document(tokens);
            }
            let (matrix, idfs) = builder.build_matrix();
            match LsaModel::train(&matrix, builder.vocabulary, idfs, 50) { // 50次元に圧縮
                Ok(model) => {
                    let model_arc = Arc::new(model);
                    {
                        let mut lsa = state.lsa_model.write().await;
                        *lsa = Some((*model_arc).clone());
                    }
                    log::info!("LSA model trained successfully with {} documents.", builder.counts.len());
                    
                    // HNSW インデックスの構築
                    log::info!("Building HNSW index...");
                    let hnsw: Hnsw<'static, f32, DistCosine> = Hnsw::new(16, builder.counts.len().max(100), 16, 200, DistCosine {});
                    
                    // ベクトルの同期(欠落データの補完)と HNSW への登録を行なう
                    sync_all_vectors(state.clone(), Some(hnsw)).await;
                }
                Err(e) => log::error!("LSA training failed: {}", e),
            }
        }
    }
}

/// DB 内の全アイテムをチェックし、ベクトルが欠落または異常(全て0)なものを補完する
pub async fn sync_all_vectors(state: AppState, startup_hnsw: Option<Hnsw<'static, f32, DistCosine>>) {
    log::info!("Checking for missing or invalid vectors in vec_items...");
    
    let rows = match sqlx::query(
        "SELECT i.id, i.content, 
                CASE WHEN v.embedding IS NOT NULL THEN vec_to_json(v.embedding) ELSE NULL END
         FROM items i 
         LEFT JOIN vec_items v ON i.id = v.id"
    )
    .fetch_all(&state.db_pool)
    .await {
        Ok(rows) => rows,
        Err(e) => {
            log::error!("Failed to fetch items for sync: {}", e);
            return;
        }
    };

    let mut to_sync = Vec::new();
    for row in rows {
        let id: i64 = row.get(0);
        let content: String = row.get(1);
        let embedding_str: Option<String> = row.get(2);

        let needs_sync = if let Some(s) = embedding_str {
            if let Ok(vec) = serde_json::from_str::<Vec<f32>>(&s) {
                // すべて 0.0 なら異常(ダミー)とみなす
                vec.iter().all(|&x| x == 0.0)
            } else {
                true // パース失敗も異常
            }
        } else {
            true // 不在
        };

        if needs_sync {
            to_sync.push((id, content));
        }
    }

    if to_sync.is_empty() {
        log::info!("All vectors are healthy and synchronized.");
        return;
    }

    log::info!("Found {} items needing vector update. Processing...", to_sync.len());
    
    let lsa_guard = state.lsa_model.read().await;
    let model = match lsa_guard.as_ref() {
        Some(m) => m,
        None => {
            log::warn!("LSA model not available for sync.");
            return;
        }
    };

    let mut count = 0;
    for (id, content) in to_sync {
        let mut query_counts = HashMap::new();
        let tokens = state.tokenizer.tokenize_to_vec(&content).unwrap_or_default();
        for token in tokens {
            if let Some(&tid) = model.vocabulary.get(&token) {
                *query_counts.entry(tid).or_insert(0.0) += 1.0;
            }
        }
        let mut query_vec = ndarray::Array1::zeros(model.vocabulary.len());
        for (tid, count) in query_counts {
            query_vec[tid] = count;
        }

        if let Ok(projected) = model.project_query(&query_vec) {
            let mut proj_f32: Vec<f32> = projected.iter().map(|&x| x as f32).collect();
            if proj_f32.len() < 50 { proj_f32.resize(50, 0.0); } else { proj_f32.truncate(50); }
            
            let mut tx = match state.db_pool.begin().await {
                Ok(t) => t,
                Err(_) => continue,
            };

            // vec_items (virtual table) への反映
            let _ = sqlx::query("DELETE FROM vec_items WHERE id = ?").bind(id).execute(&mut *tx).await;
            let _ = sqlx::query("INSERT INTO vec_items (id, embedding) VALUES (?, ?)")
                .bind(id)
                .bind(serde_json::to_string(&proj_f32).unwrap_or("[]".to_string()))
                .execute(&mut *tx)
                .await;

            // items_lsa (backup)
            let vector_blob = bincode::serialize(&proj_f32).unwrap_or_default();
            let _ = sqlx::query("INSERT OR REPLACE INTO items_lsa (id, vector) VALUES (?, ?)")
                .bind(id)
                .bind(vector_blob)
                .execute(&mut *tx)
                .await;

            if tx.commit().await.is_ok() {
                count += 1;
            }
        }
    }
    log::info!("Successfully synchronized {} vectors.", count);

    // HNSW インデックスを AppState に登録
    if let Some(hnsw) = startup_hnsw {
        // すでに同期済みのものも含め、全アイテムを HNSW に登録する
        // (簡易実装のため、ここではDBから全件引き直す)
        log::info!("Populating HNSW index from database...");
        if let Ok(rows) = sqlx::query("SELECT id, vec_to_json(embedding) FROM vec_items").fetch_all(&state.db_pool).await {
            let mut data_to_insert = Vec::new();
            for row in rows {
                let id: i64 = row.get(0);
                let embedding_str: String = row.get(1);
                if let Ok(vec) = serde_json::from_str::<Vec<f32>>(&embedding_str) {
                    if vec.len() == 50 {
                        data_to_insert.push((vec, id as usize));
                    }
                }
            }
            if !data_to_insert.is_empty() {
                let refs: Vec<(&Vec<f32>, usize)> = data_to_insert.iter().map(|(v, id)| (v, *id)).collect();
                hnsw.parallel_insert(&refs);
            }
        }
        let mut idx = state.hnsw_index.write().await;
        *idx = Some(hnsw);
        log::info!("HNSW index is now ready.");
    }
}

async fn llama_status_handler(State(state): State<AppState>) -> impl IntoResponse {
    let status = state.llama_status.read().await.clone();
    Json(serde_json::json!({ "status": status }))
}

async fn doc_count_handler(State(state): State<AppState>) -> impl IntoResponse {
    let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM items")
        .fetch_one(&state.db_pool)
        .await
        .unwrap_or(0);
    Json(serde_json::json!({ "count": count }))
}

async fn model_name_handler(State(state): State<AppState>) -> impl IntoResponse {
    Json(serde_json::json!({ "model_name": state.model_name }))
}

#[allow(dead_code)]
#[derive(Deserialize)]
struct SseQuery {
    session_id: Option<String>,
}

async fn sse_handler(
    State(state): State<AppState>,
    Query(_query): Query<SseQuery>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
    // Generate a simple session ID
    let session_id = uuid::Uuid::new_v4().to_string();
    let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<String>();

    log::info!("New MCP SSE Session: {}", session_id);

    // Register session
    state.sessions.write().await.insert(session_id.clone(), tx);

    // Initial endpoint event
    let endpoint_url = format!("/messages?session_id={}", session_id);
    let endpoint_event = Event::default().event("endpoint").data(endpoint_url);

    let session_id_for_close = session_id.clone();
    let sessions_for_close = state.sessions.clone();
    let global_rx = state.tx.subscribe();

    let stream = futures::stream::unfold(
        (
            rx,
            Some(endpoint_event),
            session_id_for_close,
            sessions_for_close,
            global_rx,
        ),
        |(mut rx, mut initial, sid, smap, mut grx)| async move {
            if let Some(event) = initial.take() {
                return Some((Ok(event), (rx, None, sid, smap, grx)));
            }

            tokio::select! {
                Some(msg) = rx.recv() => {
                    Some((Ok(Event::default().event("message").data(msg)), (rx, None, sid, smap, grx)))
                }
                Ok(msg) = grx.recv() => {
                    // Global notification (e.g. data update)
                    Some((Ok(Event::default().event("update").data(msg)), (rx, None, sid, smap, grx)))
                }
                else => {
                    log::info!("MCP SSE Session Closed: {}", sid);
                    smap.write().await.remove(&sid);
                    None
                }
            }
        },
    );

    Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default())
}

#[derive(Serialize, Deserialize)]
struct JsonRpcRequest {
    jsonrpc: String,
    method: String,
    params: Option<serde_json::Value>,
    id: Option<serde_json::Value>,
}

#[derive(Serialize)]
struct JsonRpcResponse {
    jsonrpc: &'static str,
    #[serde(skip_serializing_if = "Option::is_none")]
    result: Option<serde_json::Value>,
    #[serde(skip_serializing_if = "Option::is_none")]
    error: Option<serde_json::Value>,
    id: Option<serde_json::Value>,
}

#[derive(Deserialize)]
struct MessageQuery {
    session_id: Option<String>,
}

impl IntoResponse for JsonRpcResponse {
    fn into_response(self) -> axum::response::Response {
        Json(self).into_response()
    }
}


async fn mcp_messages_handler(
    State(state): State<AppState>,
    Query(query): Query<MessageQuery>,
    Json(req): Json<JsonRpcRequest>,
) -> Response {
    let method = req.method.as_str();
    log::info!("MCP Request: {} (Session: {:?})", method, query.session_id);

    // 受信データを構造化JSONで出力(timestamp と source を含む)
    let structured = serde_json::json!({
        "timestamp": Utc::now().to_rfc3339(),
        "source": "mcp",
        "session": query.session_id,
        "method": method,
        "id": req.id,
        "params": req.params,
    });
    log::info!("{}", serde_json::to_string(&structured).unwrap_or_else(|_| "{\"error\":\"serialize_failed\"}".to_string()));

    let result: Option<serde_json::Value> = match method {
        "initialize" => {
            let client_version = req.params.as_ref()
                .and_then(|p| p.get("protocolVersion"))
                .and_then(|v| v.as_str())
                .unwrap_or("2024-11-05");
            
            log::info!("MCP Handshake: Client requested protocol version {}", client_version);
            
            Some(serde_json::json!({
                "protocolVersion": client_version,
                "capabilities": { 
                    "tools": { "listChanged": false },
                    "resources": { "listChanged": false, "subscribe": false },
                    "prompts": { "listChanged": false },
                    "logging": {}
                },
                "serverInfo": { "name": "TelosDB", "version": "0.1.0" }
            }))
        },
        "notifications/initialized" => None,
        "tools/list" => Some(serde_json::json!({
            "tools": [
                {
                    "name": "add_item_text",
                    "description": "Store text with auto-generated LSA vectors (No LLM required).",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "content": { "type": "string" },
                            "path": { "type": "string" }
                        },
                        "required": ["content"]
                    }
                },
                {
                    "name": "search_text",
                    "description": "Semantic search using LSA (Latent Semantic Analysis). Lightweight and fast.",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "content": { "type": "string" },
                            "limit": { "type": "number" }
                        },
                        "required": ["content"]
                    }
                },
                {
                    "name": "lsa_retrain",
                    "description": "Rebuild the LSA semantic model from all current documents. Use this when you've added many new items.",
                    "inputSchema": { "type": "object", "properties": {} }
                },
                {
                    "name": "update_item",
                    "description": "Update existing text and its LSA vector.",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "id": { "type": "integer" },
                            "content": { "type": "string" },
                            "path": { "type": "string" }
                        },
                        "required": ["id", "content"]
                    }
                },
                {
                    "name": "delete_item",
                    "description": "Delete item by ID.",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "id": { "type": "integer" }
                        },
                        "required": ["id"]
                    }
                },
                {
                    "name": "get_item_by_id",
                    "description": "Get text content by item ID.",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "id": { "type": "integer" }
                        },
                        "required": ["id"]
                    }
                }
            ]
        })),
        "search_text" | "tools/call" | "add_item_text" | "update_item" | "delete_item" | "get_item_by_id" => {
            let p = req.params.clone().unwrap_or_default();
            let (actual_method, args) = if method == "tools/call" {
                (
                    p.get("name").and_then(|v| v.as_str()).unwrap_or(""),
                    p.get("arguments").cloned().unwrap_or_default(),
                )
            } else {
                (method, p)
            };

            // UIへの通知(ツール呼び出し開始)
            let _ = state.tx.send(format!("mcp:call:{}", actual_method));

            match actual_method {
                "get_item_by_id" => {
                    let id = args.get("id").and_then(|v| v.as_i64()).unwrap_or(0);
                    let row = sqlx::query("SELECT id, content, path FROM items WHERE id = ?")
                        .bind(id)
                        .fetch_optional(&state.db_pool)
                        .await
                        .unwrap_or(None);
                    if let Some(row) = row {
                        let content: String = row.get("content");
                        let path: Option<String> = row.try_get("path").ok();
                        Some(serde_json::json!({
                            "id": id,
                            "content": content,
                            "path": path
                        }))
                    } else {
                        Some(serde_json::json!({ "error": format!("Item not found: {}", id) }))
                    }
                }
                "add_item_text" => {
                    let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("");
                    let path = args.get("path").and_then(|v| v.as_str());

                    log::info!(
                        "Executing add_item_text (LSA-only): content length={}, path='{:?}'",
                        content.chars().count(),
                        path
                    );

                    // 800文字ずつに分割
                    let chars: Vec<char> = content.chars().collect();
                    let chunks: Vec<String> = chars
                        .chunks(800)
                        .map(|chunk| chunk.iter().collect::<String>())
                        .collect();

                    let mut results = Vec::new();
                    for chunk_content in chunks.iter() {
                        async fn add_item_inner(
                            state: &AppState,
                            content: &str,
                            path: Option<&str>,
                        ) -> Result<i64, String> {
                            let mut tx =
                                state.db_pool.begin().await.map_err(|e| {
                                    format!("Failed to begin transaction: {}", e)
                                })?;
                            let res =
                                sqlx::query("INSERT INTO items (content, path) VALUES (?, ?)")
                                    .bind(content)
                                    .bind(path)
                                    .execute(&mut *tx)
                                    .await
                                    .map_err(|e| format!("Failed to insert item: {}", e))?;
                            let id = res.last_insert_rowid();

                            // LSA ベクトルの計算
                            let mut lsa_vector_f32: Vec<f32> = vec![0.0; 50]; 
                            let lsa_guard = state.lsa_model.read().await;
                            if let Some(model) = lsa_guard.as_ref() {
                                let mut query_counts = HashMap::new();
                                let tokens = state.tokenizer.tokenize_to_vec(content).unwrap_or_default();
                                for token in tokens {
                                    if let Some(&tid) = model.vocabulary.get(&token) {
                                        *query_counts.entry(tid).or_insert(0.0) += 1.0;
                                    }
                                }
                                let mut query_vec = ndarray::Array1::zeros(model.vocabulary.len());
                                for (tid, count) in query_counts {
                                    query_vec[tid] = count;
                                }

                                if let Ok(projected) = model.project_query(&query_vec) {
                                    lsa_vector_f32 = projected.iter().map(|&x| x as f32).collect();
                                    // 50次元に満たない(モデル初期化時のランク制限等)場合はパディング
                                    if lsa_vector_f32.len() < 50 {
                                        lsa_vector_f32.resize(50, 0.0);
                                    } else if lsa_vector_f32.len() > 50 {
                                        lsa_vector_f32.truncate(50);
                                    }
                                }
                            }

                            // sqlite-vec の仮想テーブル (vec_items) に LSA ベクトルを保存
                            sqlx::query("INSERT INTO vec_items (id, embedding) VALUES (?, ?)")
                                .bind(id)
                                .bind(serde_json::to_string(&lsa_vector_f32).unwrap_or("[]".to_string()))
                                .execute(&mut *tx)
                                .await
                                .map_err(|e| format!("Failed to insert LSA vector to vec_items: {}", e))?;

                            // items_lsa にもバックアップ(または生データ)として保存
                            if lsa_guard.as_ref().is_some() {
                                let vector_blob = bincode::serialize(&lsa_vector_f32).unwrap_or_default();
                                sqlx::query("INSERT INTO items_lsa (id, vector) VALUES (?, ?)")
                                    .bind(id)
                                    .bind(vector_blob)
                                    .execute(&mut *tx)
                                    .await
                                    .map_err(|e| format!("Failed to insert LSA blob: {}", e))?;
                            }

                            tx.commit()
                                .await
                                .map_err(|e| format!("Failed to commit transaction: {}", e))?;

                            // HNSW インデックスへの反映
                            let hnsw_index_guard = state.hnsw_index.read().await;
                            let hnsw_opt: &Option<Hnsw<'static, f32, DistCosine>> = &hnsw_index_guard;
                            if let Some(hnsw_ptr) = hnsw_opt.as_ref() {
                                if lsa_vector_f32.len() == 50 {
                                    let vec_ref: &[f32] = lsa_vector_f32.as_slice();
                                    hnsw_ptr.insert((vec_ref, id as usize));
                                }
                            }

                            Ok(id)
                        }

                        match add_item_inner(&state, chunk_content, path).await {
                            Ok(id) => results.push(id),
                            Err(e) => log::error!("Failed to add chunk: {}", e),
                        }
                    }

                    if !results.is_empty() {
                        let _ = state.tx.send("data_changed".to_string());
                        log::info!("Successfully added {} chunks via LSA.", results.len());
                        Some(
                            serde_json::json!({ "content": [{ "type": "text", "text": format!("Successfully added {} chunks (LSA).", results.len()) }] }),
                        )
                    } else {
                        Some(serde_json::json!({ "error": "Failed to add any chunks." }))
                    }
                }
                "search_text" => {
                    let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("");
                    let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(10);

                    // LLM の代わりに内部で LSA クエリを構成
                    let mut search_result = None;
                    let lsa_guard = state.lsa_model.read().await;
                    if let Some(model) = lsa_guard.as_ref() {
                        let mut query_counts = HashMap::new();
                        let tokens = state.tokenizer.tokenize_to_vec(content).unwrap_or_default();
                        for token in tokens {
                            if let Some(&id) = model.vocabulary.get(&token) {
                                *query_counts.entry(id).or_insert(0.0) += 1.0;
                            }
                        }
                        let mut query_vec = ndarray::Array1::zeros(model.vocabulary.len());
                        for (id, count) in query_counts {
                            query_vec[id] = count;
                        }

                        if let Ok(query_lsa) = model.project_query(&query_vec) {
                            // クエリが語彙に含まれず零ベクトルになった場合
                            if query_lsa.iter().all(|&x| x == 0.0) {
                                search_result = Some(serde_json::json!({ "content": [] }));
                            } else {
                                let mut query_lsa_f32: Vec<f32> = query_lsa.iter().map(|&x| x as f32).collect();
                                if query_lsa_f32.len() < 50 {
                                    query_lsa_f32.resize(50, 0.0);
                                } else if query_lsa_f32.len() > 50 {
                                    query_lsa_f32.truncate(50);
                                }

                                // HNSW インデックスがあればそれを使う、なければ sqlite-vec でフォールバック
                                let hnsw_idx_guard = state.hnsw_index.read().await;
                                let hnsw_option: &Option<Hnsw<'static, f32, DistCosine>> = &hnsw_idx_guard;
                                if let Some(h_ptr) = hnsw_option.as_ref() {
                                    log::info!("Searching using HNSW index...");
                                    let query_ref: &[f32] = query_lsa_f32.as_slice();
                                    let neighbors = h_ptr.search(query_ref, limit as usize, 100);
                                    if !neighbors.is_empty() {
                                        let mut results = Vec::new();
                                        for neighbor in neighbors {
                                            let id = neighbor.d_id as i64;
                                            let dist = neighbor.distance;
                                            // HNSW の DistCosine は通常 1 - cos_sim
                                            let sim: f32 = 1.0 - dist;
                                            
                                            if let Ok(row) = sqlx::query("SELECT content FROM items WHERE id = ?").bind(id).fetch_one(&state.db_pool).await {
                                                results.push(serde_json::json!({
                                                    "id": id,
                                                    "content": row.get::<String, _>(0),
                                                    "similarity": sim.clamp(0.0, 1.0)
                                                }));
                                            }
                                        }
                                        search_result = Some(serde_json::json!({ "content": results }));
                                    }
                                }

                                if search_result.is_none() {
                                    // sqlite-vec の MATCH (BM25等ではなくベクトル近傍検索) を使用
                                    let rows = sqlx::query(
                                        "SELECT items.id, items.content, v.distance 
                                         FROM items 
                                         JOIN vec_items v ON items.id = v.id 
                                         WHERE v.embedding MATCH ? AND k = ? 
                                         ORDER BY distance LIMIT ?",
                                    )
                                    .bind(serde_json::to_string(&query_lsa_f32).unwrap_or("[]".to_string()))
                                    .bind(limit)
                                    .bind(limit)
                                    .fetch_all(&state.db_pool)
                                    .await
                                    .unwrap_or_default();
                                    
                                    let res: Vec<_> = rows.iter().map(|r| {
                                        let id = r.get::<i64, _>(0);
                                        let content = r.get::<String, _>(1);
                                        let distance = r.get::<f64, _>(2);
                                        // sqlite-vec の distance は L2 距離の 2 乗
                                        // 正規化ベクトル [u, v] において:
                                        // ||u-v||^2 = ||u||^2 + ||v||^2 - 2*u*v = 1 + 1 - 2*cos_sim = 2 - 2*cos_sim
                                        // よって cos_sim = 1.0 - (distance / 2.0)
                                        let sim = 1.0 - (distance / 2.0);
                                        serde_json::json!({
                                            "id": id,
                                            "content": content,
                                            "similarity": sim.clamp(0.0, 1.0)
                                        })
                                    }).collect();
                                    
                                    search_result = Some(serde_json::json!({ "content": res }));
                                }
                            }
                        } else {
                            search_result = Some(serde_json::json!({ "error": "LSA query projection failed" }));
                        }
                    }

                    if search_result.is_none() {
                        // LSA モデルがない、または検索結果が得られなかった場合は LIKE 検索でフォールバック
                        let rows = sqlx::query("SELECT id, content FROM items WHERE content LIKE ? LIMIT ?")
                            .bind(format!("%{}%", content))
                            .bind(limit)
                            .fetch_all(&state.db_pool)
                            .await
                            .unwrap_or_default();
                        let res: Vec<_> = rows.iter().map(|r| serde_json::json!({ "id": r.get::<i64,_>(0), "content": r.get::<String,_>(1), "similarity": 0.0 })).collect();
                        search_result = Some(serde_json::json!({ "content": res }));
                    }
                    search_result
                }
                "lsa_search" => {
                    let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
                    let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(10);

                    let lsa_guard = state.lsa_model.read().await;
                    if let Some(model) = lsa_guard.as_ref() {
                        // クエリのベクトル化 (TF)
                        let mut query_counts = HashMap::new();
                        let tokens = state.tokenizer.tokenize_to_vec(query).unwrap_or_default();
                        for token in tokens {
                            if let Some(&id) = model.vocabulary.get(&token) {
                                *query_counts.entry(id).or_insert(0.0) += 1.0;
                            }
                        }
                        
                        let mut query_vec = ndarray::Array1::zeros(model.vocabulary.len());
                        for (id, count) in query_counts {
                            query_vec[id] = count;
                        }

                        // LSA 空間への射影
                        if let Ok(query_lsa) = model.project_query(&query_vec) {
                            // DB から全ベクトルを取得して比較 (件数が少ない想定)
                            // 本来はアイテム数が多い場合は BLOB を全件回すと遅いため、インメモリキャッシュ等を検討
                            let rows = sqlx::query("SELECT id, vector FROM items_lsa")
                                .fetch_all(&state.db_pool)
                                .await
                                .unwrap_or_default();
                            
                            let mut results = Vec::new();
                            for row in rows {
                                let id: i64 = row.get(0);
                                let vector_blob: Vec<u8> = row.get(1);
                                if let Ok(vector_f64) = bincode::deserialize::<Vec<f64>>(&vector_blob) {
                                    let doc_vec = ndarray::Array1::from_vec(vector_f64);
                                    let sim = crate::utils::lsa::LsaModel::cosine_similarity(&query_lsa, &doc_vec);
                                    results.push((id, sim));
                                }
                            }
                            
                            results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
                            results.truncate(limit as usize);
                            
                            let mut filtered_results = Vec::new();
                            for (id, sim) in results {
                                if let Ok(doc_row) = sqlx::query("SELECT content FROM items WHERE id = ?").bind(id).fetch_one(&state.db_pool).await {
                                    let content: String = doc_row.get(0);
                                    filtered_results.push(serde_json::json!({
                                        "id": id,
                                        "content": content,
                                        "similarity": sim
                                    }));
                                }
                            }
                            
                            Some(serde_json::json!({ "content": filtered_results }))
                        } else {
                            Some(serde_json::json!({ "error": "Query projection failed" }))
                        }
                    } else {
                        Some(serde_json::json!({ "error": "LSA model not initialized or no data available" }))
                    }
                }
                "update_item" => {
                    let id = args.get("id").and_then(|v| v.as_i64()).unwrap_or(0);
                    let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("");
                    let path = args.get("path").and_then(|v| v.as_str());

                    async fn update_item_inner(
                        state: &AppState,
                        id: i64,
                        content: &str,
                        path: Option<&str>,
                    ) -> Result<(), String> {
                        let mut tx =
                            state.db_pool.begin().await.map_err(|e| {
                                format!("Failed to begin transaction: {}", e)
                            })?;
                        sqlx::query("UPDATE items SET content = ?, path = ? WHERE id = ?")
                            .bind(content)
                            .bind(path)
                            .bind(id)
                            .execute(&mut *tx)
                            .await
                            .map_err(|e| format!("Failed to update item: {}", e))?;

                        // LSA ベクトルの更新
                        let lsa_guard = state.lsa_model.read().await;
                        if let Some(model) = lsa_guard.as_ref() {
                            let mut query_counts = HashMap::new();
                            let tokens = state.tokenizer.tokenize_to_vec(content).unwrap_or_default();
                            for token in tokens {
                                if let Some(&tid) = model.vocabulary.get(&token) {
                                    *query_counts.entry(tid).or_insert(0.0) += 1.0;
                                }
                            }
                            let mut query_vec = ndarray::Array1::zeros(model.vocabulary.len());
                            for (tid, count) in query_counts {
                                query_vec[tid] = count;
                            }

                            if let Ok(projected) = model.project_query(&query_vec) {
                                let vector_blob = bincode::serialize(&projected.to_vec()).unwrap_or_default();
                                sqlx::query("INSERT OR REPLACE INTO items_lsa (id, vector) VALUES (?, ?)")
                                    .bind(id)
                                    .bind(vector_blob)
                                    .execute(&mut *tx)
                                    .await
                                    .map_err(|e| format!("Failed to update LSA vector: {}", e))?;
                            }
                        }

                        tx.commit()
                            .await
                            .map_err(|e| format!("Failed to commit transaction: {}", e))?;
                        Ok(())
                    }

                    if let Err(e) = update_item_inner(&state, id, content, path).await
                    {
                        Some(serde_json::json!({ "error": e }))
                    } else {
                        let _ = state.tx.send("data_changed".to_string());
                        Some(
                            serde_json::json!({ "content": [{ "type": "text", "text": format!("Successfully updated item {} (LSA)", id) }] }),
                        )
                    }
                }
                "delete_item" => {
                    let id = args.get("id").and_then(|v| v.as_i64()).unwrap_or(0);

                    async fn delete_item_inner(state: &AppState, id: i64) -> Result<(), String> {
                        let mut tx = state
                            .db_pool
                            .begin()
                            .await
                            .map_err(|e| format!("Failed to begin transaction: {}", e))?;
                        sqlx::query("DELETE FROM items WHERE id = ?")
                            .bind(id)
                            .execute(&mut *tx)
                            .await
                            .map_err(|e| format!("Failed to delete item: {}", e))?;
                        sqlx::query("DELETE FROM vec_items WHERE id = ?")
                            .bind(id)
                            .execute(&mut *tx)
                            .await
                            .map_err(|e| format!("Failed to delete vector: {}", e))?;
                        tx.commit()
                            .await
                            .map_err(|e| format!("Failed to commit transaction: {}", e))?;
                        Ok(())
                    }

                    if let Err(e) = delete_item_inner(&state, id).await {
                        Some(serde_json::json!({ "error": e }))
                    } else {
                        let _ = state.tx.send("data_changed".to_string());
                        Some(
                            serde_json::json!({ "content": [{ "type": "text", "text": format!("Successfully deleted item {}", id) }] }),
                        )
                    }
                }
                "lsa_retrain" => {
                    log::info!("Manual LSA retrain triggered.");
                    let state_clone = state.clone();
                    tokio::spawn(async move {
                        if let Ok(rows) = sqlx::query("SELECT id, content FROM items").fetch_all(&state_clone.db_pool).await {
                            if !rows.is_empty() {
                                let mut builder = crate::utils::lsa::TermDocumentMatrixBuilder::new();
                                let mut ids = Vec::new();
                                for row in rows {
                                    let id: i64 = row.get(0);
                                    let content: String = row.get(1);
                                    let tokens = state_clone.tokenizer.tokenize_to_vec(&content).unwrap_or_default();
                                    builder.add_document(tokens);
                                    ids.push(id);
                                }
                                let (matrix, idfs) = builder.build_matrix();
                                match crate::utils::lsa::LsaModel::train(&matrix, builder.vocabulary, idfs, 50) {
                                    Ok(model) => {
                                        let mut tx = state_clone.db_pool.begin().await.unwrap();
                                        sqlx::query("DELETE FROM items_lsa").execute(&mut *tx).await.unwrap();
                                        sqlx::query("DELETE FROM vec_items").execute(&mut *tx).await.unwrap();
                                        
                                        for (i, &id) in ids.iter().enumerate() {
                                            let mut doc_tf = ndarray::Array1::zeros(model.vocabulary.len());
                                            for (&tid, &count) in &builder.counts[i] {
                                                doc_tf[tid] = count;
                                            }
                                            if let Ok(projected) = model.project_query(&doc_tf) {
                                                let mut proj_f32: Vec<f32> = projected.iter().map(|&x| x as f32).collect();
                                                if proj_f32.len() < 50 { proj_f32.resize(50, 0.0); } else { proj_f32.truncate(50); }
                                                
                                                let vector_blob = bincode::serialize(&proj_f32).unwrap_or_default();
                                                sqlx::query("INSERT INTO items_lsa (id, vector) VALUES (?, ?)")
                                                    .bind(id)
                                                    .bind(vector_blob)
                                                    .execute(&mut *tx)
                                                    .await
                                                    .unwrap();

                                                sqlx::query("INSERT INTO vec_items (id, embedding) VALUES (?, ?)")
                                                    .bind(id)
                                                    .bind(serde_json::to_string(&proj_f32).unwrap_or("[]".to_string()))
                                                    .execute(&mut *tx)
                                                    .await
                                                    .unwrap();
                                            }
                                        }
                                        tx.commit().await.unwrap();

                                        let mut lsa = state_clone.lsa_model.write().await;
                                        *lsa = Some(model);

                                        // HNSW インデックスの再構築
                                        let hnsw: Hnsw<f32, DistCosine> = Hnsw::new(16, ids.len().max(100), 16, 200, DistCosine {});
                                        // 登録済みの全ベクトルを HNSW に入れ直す
                                        // (簡易実装:DBから再度引き直すか、現在のループで生成したものを入れる)
                                        // ここでは sync_all_vectors(state, Some(hnsw)) を呼ぶのが楽
                                        log::info!("Manual LSA retrain completed successfully. Rebuilding HNSW...");
                                        sync_all_vectors(state_clone.clone(), Some(hnsw)).await;
                                    }
                                    Err(e) => log::error!("Manual LSA training failed: {}", e),
                                }
                            }
                        }
                    });
                    Some(serde_json::json!({ "content": [{ "type": "text", "text": "LSA retrain started in background." }] }))
                }
                _ => Some(serde_json::json!({ "error": "Unknown tool" })),
            }
        }
        _ => Some(serde_json::json!({ "error": "Not implemented" })),
    };

    // Notifications (id == null) MUST NOT receive a response
    if req.id.is_none() || req.id.as_ref().is_some_and(|v| v.is_null()) {
        log::info!("MCP Notification received: {} (No response sent)", method);
        return axum::http::StatusCode::NO_CONTENT.into_response();
    }

    if let Some(id_val) = req.id {
        let resp = JsonRpcResponse {
            jsonrpc: "2.0",
            result,
            error: None,
            id: Some(id_val),
        };

        if let Some(sid) = query.session_id {
            // MCP Client (SSE Mode)
            let resp_str = serde_json::to_string(&resp).unwrap();
            log::info!("Sending MCP Response (Session: {}, ID: {:?}): {}", sid, resp.id, resp_str);
            let sessions = state.sessions.read().await;
            if let Some(tx) = sessions.get(&sid) {
                let _ = tx.send(resp_str);
            }
            axum::http::StatusCode::ACCEPTED.into_response()
        } else {
            // App UI (Direct Mode)
            Json(resp).into_response()
        }
    } else {
        axum::http::StatusCode::NO_CONTENT.into_response()
    }
}

#[cfg(test)]
mod tests {
    // use super::*;

    #[test]
    fn test_text_chunking_logic() {
        // 800文字ずつの分割を確認する
        let chunk_size = 800;
        
        // 1. ちょうど 800 文字
        let text_800 = "a".repeat(800);
        let chunks_800: Vec<String> = text_800.chars()
            .collect::<Vec<char>>()
            .chunks(chunk_size)
            .map(|c| c.iter().collect())
            .collect();
        assert_eq!(chunks_800.len(), 1);
        assert_eq!(chunks_800[0].len(), 800);

        // 2. 801 文字 (2 チャンク)
        let text_801 = "a".repeat(801);
        let chunks_801: Vec<String> = text_801.chars()
            .collect::<Vec<char>>()
            .chunks(chunk_size)
            .map(|c| c.iter().collect())
            .collect();
        assert_eq!(chunks_801.len(), 2);
        assert_eq!(chunks_801[0].len(), 800);
        assert_eq!(chunks_801[1].len(), 1);

        // 3. 1600 文字 (2 チャンク)
        let text_1600 = "a".repeat(1600);
        let chunks_1600: Vec<String> = text_1600.chars()
            .collect::<Vec<char>>()
            .chunks(chunk_size)
            .map(|c| c.iter().collect())
            .collect();
        assert_eq!(chunks_1600.len(), 2);
        
        // 4. 空文字列
        let text_empty = "";
        let chunks_empty: Vec<String> = text_empty.chars()
            .collect::<Vec<char>>()
            .chunks(chunk_size)
            .map(|c| c.iter().collect())
            .collect();
        assert_eq!(chunks_empty.len(), 0);
    }
}