Newer
Older
TelosDB / src-backend / src / mcp / mod.rs
mod handlers;
mod tools;
pub mod types;

use crate::AppState;
use axum::{
    extract::State,
    response::{sse::{Event, KeepAlive, Sse}, IntoResponse},
    routing::{get, post},
    Json, Router,
};
use futures::stream::{self, Stream, StreamExt};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tower_http::cors::{Any, CorsLayer};
use tauri::Emitter;
use types::JsonRpcRequest;

struct ConnectionGuard {
    count: Arc<AtomicUsize>,
    app_handle: tauri::AppHandle,
}

impl ConnectionGuard {
    fn new(count: Arc<AtomicUsize>, app_handle: tauri::AppHandle) -> Self {
        let new_count = count.fetch_add(1, Ordering::SeqCst) + 1;
        println!("MCP Client connected. Total: {}", new_count);
        let _ = app_handle.emit("mcp-connection-update", new_count);
        Self { count, app_handle }
    }
}

impl Drop for ConnectionGuard {
    fn drop(&mut self) {
        let new_count = self.count.fetch_sub(1, Ordering::SeqCst) - 1;
        println!("MCP Client disconnected. Total: {}", new_count);
        let _ = self.app_handle.emit("mcp-connection-update", new_count);
    }
}

pub async fn start_mcp_server(state: Arc<AppState>, port: u16) {
    let cors = CorsLayer::new()
        .allow_origin(Any)
        .allow_methods(Any)
        .allow_headers(Any);

    let app = Router::new()
        .route("/sse", get(sse_handler))
        .route("/messages", post(message_handler))
        .layer(axum::extract::DefaultBodyLimit::disable())
        .layer(cors)
        .with_state(state);

    log::info!("Starting MCP server on 0.0.0.0:{}", port);
    let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", port))
        .await
        .map_err(|e| {
            log::error!("Failed to bind MCP port {}: {}", port, e);
            e
        })
        .expect("CRITICAL: Failed to bind MCP port");
    
    if let Err(e) = axum::serve(listener, app).await {
        log::error!("MCP server error: {}", e);
    }
}

async fn sse_handler(
    State(state): State<Arc<AppState>>,
) -> Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>> {
    let guard = ConnectionGuard::new(state.connection_count.clone(), state.app_handle.clone());

    // End point invitation
    let invite =
        stream::once(async { Ok(Event::default().event("endpoint").data("/messages")) });

    // Stream from broadcast channel
    let receiver = state.mcp_tx.subscribe();
    let push_stream = stream::unfold((receiver, guard), |(mut rx, guard)| async move {
        match rx.recv().await {
            Ok(json) => {
                let event = Event::default().event("message").data(json.to_string());
                Some((Ok(event), (rx, guard)))
            }
            Err(_) => {
                // Ignore Lagged or Closed for simple implementation
                None
            }
        }
    });

    let combined_stream = invite.chain(push_stream);

    Sse::new(combined_stream).keep_alive(KeepAlive::default())
}

pub async fn message_handler(
    State(state): State<Arc<AppState>>,
    Json(payload): Json<JsonRpcRequest>,
) -> impl IntoResponse {
    println!("MCP Received: {:?}", payload);

    // 1. Handle Notifications (id is null/None)
    if payload.id.is_none() {
        match payload.method.as_str() {
            "notifications/initialized" => {
                println!("MCP: Server initialized by client");
            }
            _ => {
                println!("MCP: Unhandled notification: {}", payload.method);
            }
        }
        // Notifications do not have JSON-RPC responses
        return axum::http::StatusCode::ACCEPTED.into_response();
    }

    // 2. Handle Requests (id is present)
    let result_val = match payload.method.as_str() {
        "initialize" => handle_initialize().await,
        "tools/list" => tools::handle_tools_list().await,
        "tools/call" => {
            let empty_value = serde_json::json!({});
            let params = payload.params.as_ref().unwrap_or(&empty_value);
            tools::handle_tools_call(&state, params).await
        }
        _ => Err(anyhow::anyhow!("Method not found: {}", payload.method)),
    };

    // Prepare JSON-RPC Response
    let response = match result_val {
        Ok(res) => serde_json::json!({
            "jsonrpc": "2.0",
            "result": res,
            "id": payload.id.clone().unwrap()
        }),
        Err(e) => serde_json::json!({
            "jsonrpc": "2.0",
            "error": { "code": -32000, "message": e.to_string() },
            "id": payload.id.clone().unwrap()
        }),
    };

    // Push to SSE stream
    let _ = state.mcp_tx.send(response.clone());

    Json(response).into_response()
}

async fn handle_initialize() -> anyhow::Result<serde_json::Value> {
    Ok(serde_json::json!({
        "protocolVersion": "2024-11-05",
        "capabilities": {
            "tools": {}
        },
        "serverInfo": {
            "name": "SQLite Vector MCP Server",
            "version": "0.1.1"
        }
    }))
}