Newer
Older
TelosDB / src / backend / src / mcp / mod.rs
pub mod types;
pub mod handlers;
pub mod system;
pub mod tools;

pub use types::AppState;
use axum::{
    routing::{get, post},
    Router,
};
use std::sync::atomic::{AtomicBool, AtomicU64};
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
use tower_http::cors::{Any, CorsLayer};
use std::collections::HashMap;

pub fn create_mcp_app(state: AppState) -> Router {
    let cors = CorsLayer::new()
        .allow_origin(Any)
        .allow_methods(Any)
        .allow_headers(Any);

    Router::new()
        .route("/sse", get(handlers::sse_handler))
        .route("/messages", post(mcp_messages_handler))
        .route("/llama_status", get(handlers::llama_status_handler))
        .route("/doc_count", get(handlers::doc_count_handler))
        .route("/model_name", get(handlers::model_name_handler))
        .route("/indexing_status", get(handlers::indexing_status_handler))
        .layer(cors)
        .with_state(state)
}

pub async fn run_server(
    port: u16,
    db_pool: sqlx::SqlitePool,
    llama_status: Arc<RwLock<String>>,
    model_name: String,
) {
    let (tx, _rx) = broadcast::channel(100);
    let tokenizer = Arc::new(crate::utils::tokenizer::JapaneseTokenizer::new().unwrap());
    
    let state = AppState {
        db_pool,
        tx,
        llama_status,
        model_name,
        sessions: Arc::new(RwLock::new(HashMap::new())),
        tokenizer,
        lsa_model: Arc::new(RwLock::new(None)),
        hnsw_index: Arc::new(RwLock::new(None)),
        changes_since_train: Arc::new(AtomicU64::new(0)),
        retrain_scheduled: Arc::new(AtomicBool::new(false)),
        indexing_status: Arc::new(RwLock::new("idle".to_string())),
    };

    // 初期化時に LSA トレーニングとベクトル同期(HNSW構築含む)をバックグラウンドで開始
    let state_init = state.clone();
    tokio::spawn(async move {
        system::train_lsa_and_sync_hnsw(state_init).await;
    });

    let app = create_mcp_app(state);
    let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
        .await
        .unwrap();
    log::info!("MCP Server listening on {}", listener.local_addr().unwrap());
    axum::serve(listener, app).await.unwrap();
}

// ----------------------------------------------------------------------------
// Main Message Handler (Dispatching to Tools)
// ----------------------------------------------------------------------------
use axum::{
    extract::{Query, State},
    response::{IntoResponse, Response},
    Json,
    body::Bytes,
};
use crate::mcp::types::{JsonRpcRequest, JsonRpcResponse, MessageQuery};

pub async fn mcp_messages_handler(
    State(state): State<AppState>,
    Query(query): Query<MessageQuery>,
    body: Bytes,
) -> Response {
    // 1. Raw body logging for diagnostics
    let body_str = String::from_utf8_lossy(&body);
    log::info!("MCP Incoming RAW Request: {}", body_str);

    // 2. Manual JSON parsing to avoid silent extraction errors
    let req: JsonRpcRequest = match serde_json::from_slice(&body) {
        Ok(r) => r,
        Err(e) => {
            log::error!("MCP JSON Deserialization Error: {}. Raw: {}", e, body_str);
            return (
                axum::http::StatusCode::UNPROCESSABLE_ENTITY,
                Json(serde_json::json!({
                    "jsonrpc": "2.0",
                    "error": { "code": -32700, "message": "Parse error" },
                    "id": null
                }))
            ).into_response();
        }
    };

    let method = req.method.as_str();
    let actual_method = if method == "tools/call" {
        req.params
            .as_ref()
            .and_then(|p| p.get("name"))
            .and_then(|n| n.as_str())
            .unwrap_or("unknown")
    } else {
        method
    };

    log::info!("MCP Request: {} (Actual: {}, Session: {:?})", method, actual_method, query.session_id);

    // tools/call をブロードキャストして UI の MCP ACTIVITY に通知
    if method == "tools/call" || matches!(actual_method, "get_item_by_id" | "add_item_text" | "search_text" | "lsa_search" | "update_item" | "delete_item" | "list_documents" | "get_document_count" | "get_document" | "delete_document" | "lsa_retrain") {
        let _ = state.tx.send(format!("mcp:call:{}", actual_method));
    }

    let result = match method {
        "initialize" => {
            let client_version = req.params.as_ref()
                .and_then(|p| p.as_object())
                .and_then(|p| p.get("protocolVersion"))
                .and_then(|v| v.as_str())
                .unwrap_or("2024-11-05");
            
            log::info!("MCP Handshake: Client requested protocol version {}", client_version);
            
            Some(serde_json::json!({
                "protocolVersion": client_version,
                "capabilities": { 
                    "tools": { "listChanged": false },
                    "resources": { "listChanged": false, "subscribe": false },
                    "prompts": { "listChanged": false },
                    "logging": {}
                },
                "serverInfo": { "name": "TelosDB", "version": "0.3.0" }
            }))
        },
        "notifications/initialized" => {
            log::info!("MCP Client Initialized");
            None
        }
        "resources/list" => Some(serde_json::json!({ "resources": [] })),
        "prompts/list" => Some(serde_json::json!({ "prompts": [] })),
        "tools/list" => Some(serde_json::json!({
            "tools": [
                {
                    "name": "get_item_by_id",
                    "description": "Get a specific document chunk by ID",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "id": { "type": "integer" }
                        },
                        "required": ["id"]
                    }
                },
                {
                    "name": "add_item_text",
                    "description": "Add or overweight a document path. Chunks are generated automatically.",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "content": { "type": "string" },
                            "path": { "type": "string" },
                            "mime": { "type": "string" }
                        },
                        "required": ["content", "path"]
                    }
                },
                {
                    "name": "search_text",
                    "description": "Search document chunks using Hybrid Hybrid/Vector search (BM25 + LSA/HNSW)",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "content": { "type": "string" },
                            "limit": { "type": "integer", "default": 10 },
                            "min_score": { "type": "number", "default": 0.3, "description": "Minimum similarity (0-1). Results below this are dropped. Default 0.3." }
                        },
                        "required": ["content"]
                    }
                },
                {
                    "name": "update_item",
                    "description": "Update an existing chunk's content",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "id": { "type": "integer" },
                            "content": { "type": "string" }
                        },
                        "required": ["id", "content"]
                    }
                },
                {
                    "name": "delete_item",
                    "description": "Delete a chunk by ID",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "id": { "type": "integer" }
                        },
                        "required": ["id"]
                    }
                },
                {
                    "name": "list_documents",
                    "description": "List all documents (path, mime, chunk count)",
                    "inputSchema": { "type": "object", "properties": {} }
                },
                {
                    "name": "get_document_count",
                    "description": "Get the total count of documents stored in the database",
                    "inputSchema": { "type": "object", "properties": {} }
                },
                {
                    "name": "get_document",
                    "description": "Get full document content by document ID",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "document_id": { "type": "integer" },
                            "id": { "type": "integer" }
                        }
                    }
                },
                {
                    "name": "delete_document",
                    "description": "Delete a document and all its chunks",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "document_id": { "type": "integer" },
                            "id": { "type": "integer" }
                        }
                    }
                },
                {
                    "name": "lsa_retrain",
                    "description": "Manually trigger LSA model retraining and vector rebuild",
                    "inputSchema": { "type": "object", "properties": {} }
                }
            ]
        })),
        "tools/call" | "get_item_by_id" | "add_item_text" | "search_text" | "lsa_search" | "update_item" | "delete_item" | "list_documents" | "get_document_count" | "get_document" | "delete_document" | "lsa_retrain" => {
            let empty_map = serde_json::Map::new();
            let mut args = req.params.as_ref().and_then(|p| p.as_object()).unwrap_or(&empty_map);
            
            // tools/call の場合は、実際の引数は "arguments" フィールドにある
            if method == "tools/call" {
                if let Some(inner_args) = args.get("arguments").and_then(|a| a.as_object()) {
                    args = inner_args;
                }
            }
            
            tools::dispatch_tool(&state, method, actual_method, args).await
        }
        _ => Some(serde_json::json!({
            "content": [{ "type": "text", "text": format!("Method not implemented: {}", method) }],
            "isError": true
        })),
    };

    // Notifications (id == null) MUST NOT receive a response
    if req.id.is_none() || req.id.as_ref().is_some_and(|v| v.is_null()) {
        log::info!("MCP Notification received: {} (No response sent)", method);
        return axum::http::StatusCode::NO_CONTENT.into_response();
    }

    if let Some(id_val) = req.id {
        let resp = JsonRpcResponse {
            jsonrpc: "2.0",
            result,
            error: None,
            id: Some(id_val),
        };

        if let Some(sid) = query.session_id {
            // MCP SPEC over SSE: "The server SHOULD send responses via the SSE stream... if a session ID is provided"
            // オリジナルの動作に戻し、SSE経由で配送する。
            let resp_str = serde_json::to_string(&resp).unwrap();
            log::info!("Sending MCP Response via SSE (Session: {}, ID: {:?})", sid, resp.id);
            let sessions = state.sessions.read().await;
            if let Some(tx) = sessions.get(&sid) {
                let _ = tx.send(resp_str);
            }
            axum::http::StatusCode::ACCEPTED.into_response()
        } else {
            // Direct request (not via SSE)
            log::info!("Sending Direct MCP Response (ID: {:?})", resp.id);
            Json(resp).into_response()
        }
    } else {
        axum::http::StatusCode::NO_CONTENT.into_response()
    }
}

#[cfg(test)]
mod tests {
    #[test]
    fn test_text_chunking_logic() {
        // 800文字ずつの分割を確認する
        let chunk_size = 800;
        
        // 1. ちょうど 800 文字
        let text_800 = "a".repeat(800);
        let chunks_800: Vec<String> = text_800.chars()
            .collect::<Vec<char>>()
            .chunks(chunk_size)
            .map(|c| c.iter().collect())
            .collect();
        assert_eq!(chunks_800.len(), 1);
        assert_eq!(chunks_800[0].len(), 800);

        // 2. 801 文字 (2 チャンク)
        let text_801 = "a".repeat(801);
        let chunks_801: Vec<String> = text_801.chars()
            .collect::<Vec<char>>()
            .chunks(chunk_size)
            .map(|c| c.iter().collect())
            .collect();
        assert_eq!(chunks_801.len(), 2);
        assert_eq!(chunks_801[0].len(), 800);
        assert_eq!(chunks_801[1].len(), 1);

        // 3. 1600 文字 (2 チャンク)
        let text_1600 = "a".repeat(1600);
        let chunks_1600: Vec<String> = text_1600.chars()
            .collect::<Vec<char>>()
            .chunks(chunk_size)
            .map(|c| c.iter().collect())
            .collect();
        assert_eq!(chunks_1600.len(), 2);
        
        // 4. 空文字列
        let text_empty = "";
        let chunks_empty: Vec<String> = text_empty.chars()
            .collect::<Vec<char>>()
            .chunks(chunk_size)
            .map(|c| c.iter().collect())
            .collect();
        assert_eq!(chunks_empty.len(), 0);
    }
}