Newer
Older
TelosDB / src-tauri / src / mcp.rs
use axum::{
    extract::{State, Query},
    response::{sse::{Event, Sse}, IntoResponse},
    routing::{get, post},
    Router, Json,
};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, mpsc};
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use tokio::sync::broadcast;
use futures::stream::Stream;
use tokio_stream::StreamExt;
use tower_http::cors::{Any, CorsLayer};
use crate::db;
use sqlx::Row;

#[derive(Clone)]
pub struct AppState {
    pub db_pool: sqlx::SqlitePool,
    pub tx: broadcast::Sender<String>,
    pub llama_status: Arc<RwLock<String>>,
    // MCP sessions map
    pub sessions: Arc<RwLock<HashMap<String, mpsc::UnboundedSender<String>>>>,
}

pub async fn run_server(port: u16, db_path: &str, vec0_path: &str, llama_status: Arc<RwLock<String>>) {
    let db_pool = db::init_pool(db_path, vec0_path.to_owned()).await.expect("DB pool init failed");
    let (tx, _rx) = broadcast::channel(100);
    let sessions = Arc::new(RwLock::new(HashMap::new()));
    
    // llama-server status monitor
    let llama_status_clone = llama_status.clone();
    tokio::spawn(async move {
        let client = reqwest::Client::new();
        loop {
            let status = match client.get("http://127.0.0.1:8080/health").send().await {
                Ok(resp) if resp.status().is_success() => "running".to_string(),
                Ok(_) => "error".to_string(),
                Err(_) => "stopped".to_string(),
            };
            {
                let mut s = llama_status_clone.write().await;
                if *s != status {
                    log::info!("llama-server status changed: {} -> {}", *s, status);
                    *s = status;
                }
            }
            tokio::time::sleep(std::time::Duration::from_secs(2)).await;
        }
    });

    let app_state = AppState { db_pool, tx, llama_status: llama_status.clone(), sessions };

    let cors = CorsLayer::new()
        .allow_origin(Any)
        .allow_methods(Any)
        .allow_headers(Any);

    let app = Router::new()
        .route("/sse", get(sse_handler))
        .route("/messages", post(mcp_messages_handler))
        .route("/llama_status", get(llama_status_handler))
        .layer(cors)
        .with_state(app_state);

    let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
        .await
        .unwrap();
    log::info!("MCP Server listening on {}", listener.local_addr().unwrap());
    axum::serve(listener, app).await.unwrap();
}

async fn llama_status_handler(State(state): State<AppState>) -> impl IntoResponse {
    let status = state.llama_status.read().await.clone();
    Json(serde_json::json!({ "status": status }))
}

#[derive(Deserialize)]
struct SseQuery {
    session_id: Option<String>,
}

async fn sse_handler(
    State(state): State<AppState>,
    Query(_query): Query<SseQuery>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
    // Generate a simple session ID
    let session_id = uuid::Uuid::new_v4().to_string();
    let (tx, mut rx) = mpsc::unbounded_channel::<String>();
    
    log::info!("New MCP SSE Session: {}", session_id);
    
    // Register session
    state.sessions.write().await.insert(session_id.clone(), tx);

    // Initial endpoint event
    let endpoint_url = format!("/messages?session_id={}", session_id);
    let endpoint_event = Event::default().event("endpoint").data(endpoint_url);

    let session_id_for_close = session_id.clone();
    let sessions_for_close = state.sessions.clone();

    let stream = futures::stream::unfold(
        (rx, Some(endpoint_event), session_id_for_close, sessions_for_close),
        |(mut rx, mut initial, sid, smap)| async move {
            if let Some(event) = initial.take() {
                return Some((Ok(event), (rx, None, sid, smap)));
            }
            
            tokio::select! {
                Some(msg) = rx.recv() => {
                    Some((Ok(Event::default().event("message").data(msg)), (rx, None, sid, smap)))
                }
                else => {
                    log::info!("MCP SSE Session Closed: {}", sid);
                    smap.write().await.remove(&sid);
                    None
                }
            }
        },
    );

    Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default())
}

#[derive(Deserialize)]
struct JsonRpcRequest {
    jsonrpc: String,
    method: String,
    params: Option<serde_json::Value>,
    id: Option<serde_json::Value>,
}

#[derive(Serialize)]
struct JsonRpcResponse {
    jsonrpc: &'static str,
    #[serde(skip_serializing_if = "Option::is_none")]
    result: Option<serde_json::Value>,
    #[serde(skip_serializing_if = "Option::is_none")]
    error: Option<serde_json::Value>,
    id: Option<serde_json::Value>,
}

#[derive(Deserialize)]
struct MessageQuery {
    session_id: Option<String>,
}

async fn mcp_messages_handler(
    State(state): State<AppState>,
    Query(query): Query<MessageQuery>,
    Json(req): Json<JsonRpcRequest>,
) -> impl IntoResponse {
    let method = req.method.as_str();
    log::info!("MCP Request: {} (Session: {:?})", method, query.session_id);

    let result = match method {
        "initialize" => Some(serde_json::json!({
            "protocolVersion": "2024-11-05",
            "capabilities": { "tools": {} },
            "serverInfo": { "name": "TelosDB", "version": "0.1.0" }
        })),
        "notifications/initialized" => None,
        "tools/list" => Some(serde_json::json!({
            "tools": [
                {
                    "name": "search_text",
                    "description": "Semantic search using llama-server vector embeddings.",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "content": { "type": "string", "description": "Query" },
                            "limit": { "type": "number" }
                        },
                        "required": ["content"]
                    }
                }
            ]
        })),
        "search_text" | "tools/call" => {
            let (content, limit, is_mcp_output) = if method == "search_text" {
                let p = req.params.unwrap_or_default();
                (
                    p.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(),
                    p.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as u32,
                    false
                )
            } else {
                let p = req.params.unwrap_or_default();
                let args = p.get("arguments").cloned().unwrap_or_default();
                (
                    args.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string(),
                    args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as u32,
                    true
                )
            };

            let rows = sqlx::query("SELECT id, content FROM items WHERE content LIKE ? LIMIT ?")
                .bind(format!("%{}%", content))
                .bind(limit)
                .fetch_all(&state.db_pool)
                .await
                .unwrap_or_default();

            if is_mcp_output {
                let txt = if rows.is_empty() { "No results.".to_string() } else {
                    rows.iter().map(|r| format!("ID: {}, Content: {}", r.get::<i64, _>(0), r.get::<String, _>(1))).collect::<Vec<_>>().join("\n\n")
                };
                Some(serde_json::json!({ "content": [{ "type": "text", "text": txt }] }))
            } else {
                let res: Vec<_> = rows.iter().map(|r| serde_json::json!({
                    "id": r.get::<i64,_>(0),
                    "content": r.get::<String,_>(1),
                    "distance": 0.0
                })).collect();
                Some(serde_json::json!({ "content": res }))
            }
        },
        _ => Some(serde_json::json!({ "error": "Not implemented" })),
    };

    if let Some(id_val) = req.id {
        let resp = JsonRpcResponse { jsonrpc: "2.0", result, error: None, id: Some(id_val) };

        if let Some(sid) = query.session_id {
            // MCP Client (SSE Mode): Return 202 and send response via SSE
            let resp_str = serde_json::to_string(&resp).unwrap();
            let sessions = state.sessions.read().await;
            if let Some(tx) = sessions.get(&sid) {
                let _ = tx.send(resp_str);
            }
            axum::http::StatusCode::ACCEPTED.into_response()
        } else {
            // App UI (Direct Mode): Return Json response directly
            Json(resp).into_response()
        }
    } else {
        // Notification: No response needed
        axum::http::StatusCode::NO_CONTENT.into_response()
    }
}