mod handlers;
mod tools;
pub mod types;
use crate::AppState;
use axum::{
extract::State,
response::{sse::{Event, KeepAlive, Sse}, IntoResponse},
routing::{get, post},
Json, Router,
};
use futures::stream::{self, Stream, StreamExt};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tower_http::cors::{Any, CorsLayer};
use tauri::Emitter;
use types::JsonRpcRequest;
struct ConnectionGuard {
count: Arc<AtomicUsize>,
app_handle: tauri::AppHandle,
}
impl ConnectionGuard {
fn new(count: Arc<AtomicUsize>, app_handle: tauri::AppHandle) -> Self {
let new_count = count.fetch_add(1, Ordering::SeqCst) + 1;
println!("MCP Client connected. Total: {}", new_count);
let _ = app_handle.emit("mcp-connection-update", new_count);
Self { count, app_handle }
}
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
let new_count = self.count.fetch_sub(1, Ordering::SeqCst) - 1;
println!("MCP Client disconnected. Total: {}", new_count);
let _ = self.app_handle.emit("mcp-connection-update", new_count);
}
}
pub async fn start_mcp_server(state: Arc<AppState>, port: u16) {
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
let app = Router::new()
.route("/sse", get(sse_handler))
.route("/messages", post(message_handler))
.layer(axum::extract::DefaultBodyLimit::disable())
.layer(cors)
.with_state(state);
log::info!("Starting MCP server on 0.0.0.0:{}", port);
let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", port))
.await
.map_err(|e| {
log::error!("Failed to bind MCP port {}: {}", port, e);
e
})
.expect("CRITICAL: Failed to bind MCP port");
if let Err(e) = axum::serve(listener, app).await {
log::error!("MCP server error: {}", e);
}
}
async fn sse_handler(
State(state): State<Arc<AppState>>,
) -> Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>> {
let guard = ConnectionGuard::new(state.connection_count.clone(), state.app_handle.clone());
// End point invitation
let invite =
stream::once(async { Ok(Event::default().event("endpoint").data("/messages")) });
// Stream from broadcast channel
let receiver = state.mcp_tx.subscribe();
let push_stream = stream::unfold((receiver, guard), |(mut rx, guard)| async move {
match rx.recv().await {
Ok(json) => {
let event = Event::default().event("message").data(json.to_string());
Some((Ok(event), (rx, guard)))
}
Err(_) => {
// Ignore Lagged or Closed for simple implementation
None
}
}
});
let combined_stream = invite.chain(push_stream);
Sse::new(combined_stream).keep_alive(KeepAlive::default())
}
pub async fn message_handler(
State(state): State<Arc<AppState>>,
Json(payload): Json<JsonRpcRequest>,
) -> impl IntoResponse {
println!("MCP Received: {:?}", payload);
// 1. Handle Notifications (id is null/None)
if payload.id.is_none() {
match payload.method.as_str() {
"notifications/initialized" => {
println!("MCP: Server initialized by client");
}
_ => {
println!("MCP: Unhandled notification: {}", payload.method);
}
}
// Notifications do not have JSON-RPC responses
return axum::http::StatusCode::ACCEPTED.into_response();
}
// 2. Handle Requests (id is present)
let result_val = match payload.method.as_str() {
"initialize" => handle_initialize().await,
"tools/list" => tools::handle_tools_list().await,
"tools/call" => {
let empty_value = serde_json::json!({});
let params = payload.params.as_ref().unwrap_or(&empty_value);
tools::handle_tools_call(&state, params).await
}
_ => Err(anyhow::anyhow!("Method not found: {}", payload.method)),
};
// Prepare JSON-RPC Response
let response = match result_val {
Ok(res) => serde_json::json!({
"jsonrpc": "2.0",
"result": res,
"id": payload.id.clone().unwrap()
}),
Err(e) => serde_json::json!({
"jsonrpc": "2.0",
"error": { "code": -32000, "message": e.to_string() },
"id": payload.id.clone().unwrap()
}),
};
// Push to SSE stream
let _ = state.mcp_tx.send(response.clone());
Json(response).into_response()
}
async fn handle_initialize() -> anyhow::Result<serde_json::Value> {
Ok(serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": "SQLite Vector MCP Server",
"version": "0.1.1"
}
}))
}