Newer
Older
TelosDB / src / backend / src / mcp / handlers.rs
use axum::{
    extract::{Query, State},
    response::{
        sse::{Event, Sse},
        IntoResponse,
    },
    Json,
};
use futures::stream::Stream;
use std::convert::Infallible;
use serde::Deserialize;
use crate::mcp::types::AppState;

pub async fn llama_status_handler(State(state): State<AppState>) -> impl IntoResponse {
    let status = state.llama_status.read().await.clone();
    Json(serde_json::json!({ "status": status }))
}

pub async fn doc_count_handler(State(state): State<AppState>) -> impl IntoResponse {
    let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM documents")
        .fetch_one(&state.db_pool)
        .await
        .unwrap_or(0);
    Json(serde_json::json!({ "count": count }))
}

pub async fn model_name_handler(State(state): State<AppState>) -> impl IntoResponse {
    Json(serde_json::json!({ "model_name": state.model_name }))
}

pub async fn indexing_status_handler(State(state): State<AppState>) -> impl IntoResponse {
    let status = state.indexing_status.read().await.clone();
    Json(serde_json::json!({ "status": status }))
}

#[allow(dead_code)]
#[derive(Deserialize)]
pub struct SseQuery {
    #[serde(rename = "sessionId")]
    pub session_id: Option<String>,
}

pub async fn sse_handler(
    State(state): State<AppState>,
    Query(_query): Query<SseQuery>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
    // Generate a simple session ID
    let session_id = uuid::Uuid::new_v4().to_string();
    let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<String>();

    log::info!("New MCP SSE Session: {}", session_id);

    // Register session
    state.sessions.write().await.insert(session_id.clone(), tx);

    // Initial endpoint event
    let endpoint_url = format!("/messages?sessionId={}", session_id);
    let endpoint_event = Event::default().event("endpoint").data(endpoint_url);

    let session_id_for_close = session_id.clone();
    let sessions_for_close = state.sessions.clone();
    let global_rx = state.tx.subscribe();

    let stream = futures::stream::unfold(
        (
            rx,
            Some(endpoint_event),
            session_id_for_close,
            sessions_for_close,
            global_rx,
        ),
        |(mut rx, mut initial, sid, smap, mut grx)| async move {
            if let Some(event) = initial.take() {
                return Some((Ok(event), (rx, None, sid, smap, grx)));
            }

            tokio::select! {
                Some(msg) = rx.recv() => {
                    Some((Ok(Event::default().event("message").data(msg)), (rx, None, sid, smap, grx)))
                }
                Ok(msg) = grx.recv() => {
                    // Global notification (e.g. data update)
                    Some((Ok(Event::default().event("update").data(msg)), (rx, None, sid, smap, grx)))
                }
                else => {
                    log::info!("MCP SSE Session Closed: {}", sid);
                    smap.write().await.remove(&sid);
                    None
                }
            }
        },
    );

    Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default())
}