Newer
Older
TelosDB / src-tauri / src / mcp.rs
// use crate::db;
use axum::{
    extract::{Query, State},
    response::{
        sse::{Event, Sse},
        IntoResponse,
    },
    routing::{get, post},
    Json, Router,
};
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use chrono::Utc;
use sqlx::Row;
use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::Arc;
use tokio::sync::broadcast;
use tokio::sync::{mpsc, RwLock};
use tower_http::cors::{Any, CorsLayer};

#[derive(Clone)]
pub struct AppState {
    pub db_pool: sqlx::SqlitePool,
    pub tx: broadcast::Sender<String>,
    pub llama_status: Arc<RwLock<String>>,
    pub model_name: String,
    // MCP sessions map
    pub sessions: Arc<RwLock<HashMap<String, mpsc::UnboundedSender<String>>>>,
}

pub async fn run_server(
    port: u16,
    db_pool: sqlx::SqlitePool,
    llama_status: Arc<RwLock<String>>,
    model_name: String,
) {
    let (tx, _rx) = broadcast::channel(100);
    let sessions = Arc::new(RwLock::new(HashMap::new()));

    // llama-server status monitor
    let llama_status_clone = llama_status.clone();
    tokio::spawn(async move {
        let client = reqwest::Client::new();
        loop {
            let status = match client.get("http://127.0.0.1:8080/health").send().await {
                Ok(resp) if resp.status().is_success() => "running".to_string(),
                Ok(_) => "error".to_string(),
                Err(_) => "stopped".to_string(),
            };
            {
                let mut s = llama_status_clone.write().await;
                if *s != status {
                    log::info!("llama-server status changed: {} -> {}", *s, status);
                    *s = status;
                }
            }
            tokio::time::sleep(std::time::Duration::from_secs(2)).await;
        }
    });

    let app_state = AppState {
        db_pool,
        tx,
        llama_status: llama_status.clone(),
        model_name,
        sessions,
    };

    let cors = CorsLayer::new()
        .allow_origin(Any)
        .allow_methods(Any)
        .allow_headers(Any);

    let app = Router::new()
        .route("/sse", get(sse_handler))
        .route("/messages", post(mcp_messages_handler))
        .route("/llama_status", get(llama_status_handler))
        .route("/doc_count", get(doc_count_handler))
        .route("/model_name", get(model_name_handler))
        .layer(cors)
        .with_state(app_state);

    let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
        .await
        .unwrap();
    log::info!("MCP Server listening on {}", listener.local_addr().unwrap());
    axum::serve(listener, app).await.unwrap();
}

async fn llama_status_handler(State(state): State<AppState>) -> impl IntoResponse {
    let status = state.llama_status.read().await.clone();
    Json(serde_json::json!({ "status": status }))
}

async fn doc_count_handler(State(state): State<AppState>) -> impl IntoResponse {
    let row = sqlx::query("SELECT COUNT(*) FROM items")
        .fetch_one(&state.db_pool)
        .await
        .unwrap();
    let count: i64 = row.get(0);
    Json(serde_json::json!({ "count": count }))
}

async fn model_name_handler(State(state): State<AppState>) -> impl IntoResponse {
    Json(serde_json::json!({ "model_name": state.model_name }))
}

#[allow(dead_code)]
#[derive(Deserialize)]
struct SseQuery {
    session_id: Option<String>,
}

async fn sse_handler(
    State(state): State<AppState>,
    Query(_query): Query<SseQuery>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
    // Generate a simple session ID
    let session_id = uuid::Uuid::new_v4().to_string();
    let (tx, rx) = mpsc::unbounded_channel::<String>();

    log::info!("New MCP SSE Session: {}", session_id);

    // Register session
    state.sessions.write().await.insert(session_id.clone(), tx);

    // Initial endpoint event
    let endpoint_url = format!("/messages?session_id={}", session_id);
    let endpoint_event = Event::default().event("endpoint").data(endpoint_url);

    let session_id_for_close = session_id.clone();
    let sessions_for_close = state.sessions.clone();
    let global_rx = state.tx.subscribe();

    let stream = futures::stream::unfold(
        (
            rx,
            Some(endpoint_event),
            session_id_for_close,
            sessions_for_close,
            global_rx,
        ),
        |(mut rx, mut initial, sid, smap, mut grx)| async move {
            if let Some(event) = initial.take() {
                return Some((Ok(event), (rx, None, sid, smap, grx)));
            }

            tokio::select! {
                Some(msg) = rx.recv() => {
                    Some((Ok(Event::default().event("message").data(msg)), (rx, None, sid, smap, grx)))
                }
                Ok(msg) = grx.recv() => {
                    // Global notification (e.g. data update)
                    Some((Ok(Event::default().event("update").data(msg)), (rx, None, sid, smap, grx)))
                }
                else => {
                    log::info!("MCP SSE Session Closed: {}", sid);
                    smap.write().await.remove(&sid);
                    None
                }
            }
        },
    );

    Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default())
}

#[derive(Serialize, Deserialize)]
struct JsonRpcRequest {
    jsonrpc: String,
    method: String,
    params: Option<serde_json::Value>,
    id: Option<serde_json::Value>,
}

#[derive(Serialize)]
struct JsonRpcResponse {
    jsonrpc: &'static str,
    #[serde(skip_serializing_if = "Option::is_none")]
    result: Option<serde_json::Value>,
    #[serde(skip_serializing_if = "Option::is_none")]
    error: Option<serde_json::Value>,
    id: Option<serde_json::Value>,
}

#[derive(Deserialize)]
struct MessageQuery {
    session_id: Option<String>,
}

impl IntoResponse for JsonRpcResponse {
    fn into_response(self) -> axum::response::Response {
        Json(self).into_response()
    }
}

async fn get_embedding(content: &str) -> Result<Vec<f32>, String> {
    let payload = serde_json::json!({
        "input": [content],
        "model": "default"
    });
    log::info!("Sending embedding request: {}", payload);

    let client = reqwest::Client::new();
    let resp = client
        .post("http://127.0.0.1:8080/v1/embeddings")
        .json(&payload)
        .send()
        .await
        .map_err(|e| e.to_string())?;

    let body_text = resp.text().await.map_err(|e| e.to_string())?;
    log::info!("Received embedding response: {}", body_text);

    let json: serde_json::Value = serde_json::from_str(&body_text).map_err(|e| e.to_string())?;

    // Parse OpenAI-compatible response: {"data": [{"embedding": [...]}]}
    let emb_value = json["data"][0]["embedding"].as_array();

    let embedding = emb_value
        .ok_or_else(|| format!("No embedding found in llama-server response: {}", json))?
        .iter()
        .map(|v| v.as_f64().unwrap_or(0.0) as f32)
        .collect();

    Ok(embedding)
}

async fn mcp_messages_handler(
    State(state): State<AppState>,
    Query(query): Query<MessageQuery>,
    Json(req): Json<JsonRpcRequest>,
) -> impl IntoResponse {
    let method = req.method.as_str();
    log::info!("MCP Request: {} (Session: {:?})", method, query.session_id);

    // 受信データを構造化JSONで出力(timestamp と source を含む)
    let structured = serde_json::json!({
        "timestamp": Utc::now().to_rfc3339(),
        "source": "mcp",
        "session": query.session_id,
        "method": method,
        "id": req.id,
        "params": req.params,
    });
    log::info!("{}", serde_json::to_string(&structured).unwrap_or_else(|_| "{\"error\":\"serialize_failed\"}".to_string()));

    let result: Option<serde_json::Value> = match method {
        "initialize" => {
            let client_version = req.params.as_ref()
                .and_then(|p| p.get("protocolVersion"))
                .and_then(|v| v.as_str())
                .unwrap_or("2024-11-05");
            
            log::info!("MCP Handshake: Client requested protocol version {}", client_version);
            
            Some(serde_json::json!({
                "protocolVersion": client_version,
                "capabilities": { 
                    "tools": { "listChanged": false },
                    "resources": { "listChanged": false, "subscribe": false },
                    "prompts": { "listChanged": false },
                    "logging": {}
                },
                "serverInfo": { "name": "TelosDB", "version": "0.1.0" }
            }))
        },
        "notifications/initialized" => None,
        "tools/list" => Some(serde_json::json!({
            "tools": [
                {
                    "name": "add_item_text",
                    "description": "Store text with auto-generated embeddings.",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "content": { "type": "string" },
                            "path": { "type": "string" }
                        },
                        "required": ["content"]
                    }
                },
                {
                    "name": "search_text",
                    "description": "Semantic search using vector embeddings.",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "content": { "type": "string" },
                            "limit": { "type": "number" }
                        },
                        "required": ["content"]
                    }
                },
                {
                    "name": "update_item",
                    "description": "Update existing text and its embedding.",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "id": { "type": "integer" },
                            "content": { "type": "string" },
                            "path": { "type": "string" }
                        },
                        "required": ["id", "content"]
                    }
                },
                {
                    "name": "delete_item",
                    "description": "Delete item by ID.",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "id": { "type": "integer" }
                        },
                        "required": ["id"]
                    }
                },
                {
                    "name": "get_item_by_id",
                    "description": "Get text content by item ID.",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "id": { "type": "integer" }
                        },
                        "required": ["id"]
                    }
                }
            ]
        })),
        "search_text" | "tools/call" | "add_item_text" | "update_item" | "delete_item" | "get_item_by_id" => {
            let p = req.params.clone().unwrap_or_default();
            let (actual_method, args) = if method == "tools/call" {
                (
                    p.get("name").and_then(|v| v.as_str()).unwrap_or(""),
                    p.get("arguments").cloned().unwrap_or_default(),
                )
            } else {
                (method, p)
            };

            // UIへの通知(ツール呼び出し開始)
            let _ = state.tx.send(format!("mcp:call:{}", actual_method));

            match actual_method {
                "get_item_by_id" => {
                    let id = args.get("id").and_then(|v| v.as_i64()).unwrap_or(0);
                    let row = sqlx::query("SELECT id, content, path FROM items WHERE id = ?")
                        .bind(id)
                        .fetch_optional(&state.db_pool)
                        .await
                        .unwrap_or(None);
                    if let Some(row) = row {
                        let content: String = row.get("content");
                        let path: Option<String> = row.try_get("path").ok();
                        Some(serde_json::json!({
                            "id": id,
                            "content": content,
                            "path": path
                        }))
                    } else {
                        Some(serde_json::json!({ "error": format!("Item not found: {}", id) }))
                    }
                }
                "add_item_text" => {
                    let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("");
                    let path = args.get("path").and_then(|v| v.as_str());

                    log::info!(
                        "Executing add_item_text: content='{}', path='{:?}'",
                        content,
                        path
                    );

                    match get_embedding(content).await {
                        Ok(emb) => {
                            async fn add_item_inner(
                                state: &AppState,
                                content: &str,
                                path: Option<&str>,
                                emb: Vec<f32>,
                            ) -> Result<i64, String> {
                                let mut tx =
                                    state.db_pool.begin().await.map_err(|e| {
                                        format!("Failed to begin transaction: {}", e)
                                    })?;
                                let res =
                                    sqlx::query("INSERT INTO items (content, path) VALUES (?, ?)")
                                        .bind(content)
                                        .bind(path)
                                        .execute(&mut *tx)
                                        .await
                                        .map_err(|e| format!("Failed to insert item: {}", e))?;
                                let id = res.last_insert_rowid();

                                sqlx::query("INSERT INTO vec_items (id, embedding) VALUES (?, ?)")
                                    .bind(id)
                                    .bind(serde_json::to_string(&emb).unwrap_or("[]".to_string()))
                                    .execute(&mut *tx)
                                    .await
                                    .map_err(|e| format!("Failed to insert vector: {}", e))?;

                                tx.commit()
                                    .await
                                    .map_err(|e| format!("Failed to commit transaction: {}", e))?;
                                Ok(id)
                            }

                            match add_item_inner(&state, content, path, emb).await {
                                Ok(id) => {
                                    let _ = state.tx.send("data_changed".to_string());
                                    log::info!("Successfully added item ID: {}", id);
                                    Some(
                                        serde_json::json!({ "content": [{ "type": "text", "text": format!("Successfully added item with ID: {}", id) }] }),
                                    )
                                }
                                Err(e) => {
                                    log::error!("Failed to add item: {}", e);
                                    Some(serde_json::json!({ "error": e }))
                                }
                            }
                        }
                        Err(e) => {
                            log::error!("Embedding failed in add_item_text: {}", e);
                            Some(serde_json::json!({ "error": format!("Embedding failed: {}", e) }))
                        }
                    }
                }
                "search_text" => {
                    let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("");
                    let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as u32;

                    match get_embedding(content).await {
                        Ok(emb) => {
                            let rows = sqlx::query(
                                "SELECT items.id, items.content, v.distance 
                                 FROM items 
                                 JOIN vec_items v ON items.id = v.id 
                                 WHERE v.embedding MATCH ? AND k = ? 
                                 ORDER BY distance LIMIT ?",
                            )
                            .bind(serde_json::to_string(&emb).unwrap_or("[]".to_string()))
                            .bind(limit)
                            .bind(limit)
                            .fetch_all(&state.db_pool)
                            .await
                            .unwrap_or_default();

                            log::info!("Search query: '{}'", content);
                            log::info!("Embedding (first 5): {:?}", &emb[..5.min(emb.len())]);

                            // Log results for debugging regardless of output format
                            for r in &rows {
                                let id = r.get::<i64, _>(0);
                                let d = r.get::<f64, _>(2);
                                log::info!("Result ID: {}, Distance: {}", id, d);
                            }

                            let is_mcp_output = method == "tools/call";
                            if is_mcp_output {
                                let txt = if rows.is_empty() {
                                    "No results.".to_string()
                                } else {
                                    rows.iter()
                                        .map(|r| {
                                            format!(
                                                "[ID: {}, Distance: {:.4}]\n{}",
                                                r.get::<i64, _>(0),
                                                r.get::<f64, _>(2),
                                                r.get::<String, _>(1)
                                            )
                                        })
                                        .collect::<Vec<_>>()
                                        .join("\n\n---\n\n")
                                };
                                Some(
                                    serde_json::json!({ "content": [{ "type": "text", "text": txt }] }),
                                )
                            } else {
                                let res: Vec<_> = rows
                                    .iter()
                                    .map(|r| {
                                        serde_json::json!({
                                            "id": r.get::<i64,_>(0),
                                            "content": r.get::<String,_>(1),
                                            "distance": r.get::<f64, _>(2)
                                        })
                                    })
                                    .collect();
                                Some(serde_json::json!({ "content": res }))
                            }
                        }
                        Err(e) => {
                            log::warn!(
                                "Embedding failed in search_text, falling back to LIKE: {}",
                                e
                            );
                            // Fallback to LIKE if llama-server is not running
                            let rows = sqlx::query(
                                "SELECT id, content FROM items WHERE content LIKE ? LIMIT ?",
                            )
                            .bind(format!("%{}%", content))
                            .bind(limit)
                            .fetch_all(&state.db_pool)
                            .await
                            .unwrap_or_default();

                            let txt =
                                format!("(Fallback SEARCH due to embedding error: {})\n\n", e);
                            let results = rows
                                .iter()
                                .map(|r| {
                                    format!(
                                        "ID: {}, Content: {}",
                                        r.get::<i64, _>(0),
                                        r.get::<String, _>(1)
                                    )
                                })
                                .collect::<Vec<_>>()
                                .join("\n\n");
                            Some(
                                serde_json::json!({ "content": [{ "type": "text", "text": txt + &results }] }),
                            )
                        }
                    }
                }
                "update_item" => {
                    let id = args.get("id").and_then(|v| v.as_i64()).unwrap_or(0);
                    let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("");
                    let path = args.get("path").and_then(|v| v.as_str());

                    match get_embedding(content).await {
                        Ok(emb) => {
                            async fn update_item_inner(
                                state: &AppState,
                                id: i64,
                                content: &str,
                                path: Option<&str>,
                                emb: Vec<f32>,
                            ) -> Result<(), String> {
                                let mut tx =
                                    state.db_pool.begin().await.map_err(|e| {
                                        format!("Failed to begin transaction: {}", e)
                                    })?;
                                sqlx::query("UPDATE items SET content = ?, path = ? WHERE id = ?")
                                    .bind(content)
                                    .bind(path)
                                    .bind(id)
                                    .execute(&mut *tx)
                                    .await
                                    .map_err(|e| format!("Failed to update item: {}", e))?;

                                sqlx::query("UPDATE vec_items SET embedding = ? WHERE id = ?")
                                    .bind(serde_json::to_string(&emb).unwrap_or("[]".to_string()))
                                    .bind(id)
                                    .execute(&mut *tx)
                                    .await
                                    .map_err(|e| format!("Failed to update vector: {}", e))?;

                                tx.commit()
                                    .await
                                    .map_err(|e| format!("Failed to commit transaction: {}", e))?;
                                Ok(())
                            }

                            if let Err(e) = update_item_inner(&state, id, content, path, emb).await
                            {
                                Some(serde_json::json!({ "error": e }))
                            } else {
                                let _ = state.tx.send("data_changed".to_string());
                                Some(
                                    serde_json::json!({ "content": [{ "type": "text", "text": format!("Successfully updated item {}", id) }] }),
                                )
                            }
                        }
                        Err(e) => {
                            Some(serde_json::json!({ "error": format!("Embedding failed: {}", e) }))
                        }
                    }
                }
                "delete_item" => {
                    let id = args.get("id").and_then(|v| v.as_i64()).unwrap_or(0);

                    async fn delete_item_inner(state: &AppState, id: i64) -> Result<(), String> {
                        let mut tx = state
                            .db_pool
                            .begin()
                            .await
                            .map_err(|e| format!("Failed to begin transaction: {}", e))?;
                        sqlx::query("DELETE FROM items WHERE id = ?")
                            .bind(id)
                            .execute(&mut *tx)
                            .await
                            .map_err(|e| format!("Failed to delete item: {}", e))?;
                        sqlx::query("DELETE FROM vec_items WHERE id = ?")
                            .bind(id)
                            .execute(&mut *tx)
                            .await
                            .map_err(|e| format!("Failed to delete vector: {}", e))?;
                        tx.commit()
                            .await
                            .map_err(|e| format!("Failed to commit transaction: {}", e))?;
                        Ok(())
                    }

                    if let Err(e) = delete_item_inner(&state, id).await {
                        Some(serde_json::json!({ "error": e }))
                    } else {
                        let _ = state.tx.send("data_changed".to_string());
                        Some(
                            serde_json::json!({ "content": [{ "type": "text", "text": format!("Successfully deleted item {}", id) }] }),
                        )
                    }
                }
                _ => Some(serde_json::json!({ "error": "Unknown tool" })),
            }
        }
        _ => Some(serde_json::json!({ "error": "Not implemented" })),
    };

    // Notifications (id == null) MUST NOT receive a response
    if req.id.is_none() || req.id.as_ref().map_or(false, |v| v.is_null()) {
        log::info!("MCP Notification received: {} (No response sent)", method);
        return axum::http::StatusCode::NO_CONTENT.into_response();
    }

    if let Some(id_val) = req.id {
        let resp = JsonRpcResponse {
            jsonrpc: "2.0",
            result,
            error: None,
            id: Some(id_val),
        };

        if let Some(sid) = query.session_id {
            // MCP Client (SSE Mode)
            let resp_str = serde_json::to_string(&resp).unwrap();
            log::info!("Sending MCP Response (Session: {}, ID: {:?}): {}", sid, resp.id, resp_str);
            let sessions = state.sessions.read().await;
            if let Some(tx) = sessions.get(&sid) {
                let _ = tx.send(resp_str);
            }
            axum::http::StatusCode::ACCEPTED.into_response()
        } else {
            // App UI (Direct Mode)
            resp.into_response()
        }
    } else {
        axum::http::StatusCode::NO_CONTENT.into_response()
    }
}