use axum::{
extract::{Query, State},
response::{
sse::{Event, Sse},
IntoResponse,
},
Json,
};
use futures::stream::Stream;
use std::convert::Infallible;
use serde::Deserialize;
use crate::mcp::types::AppState;
pub async fn llama_status_handler(State(state): State<AppState>) -> impl IntoResponse {
let status = state.llama_status.read().await.clone();
Json(serde_json::json!({ "status": status }))
}
pub async fn doc_count_handler(State(state): State<AppState>) -> impl IntoResponse {
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM documents")
.fetch_one(&state.db_pool)
.await
.unwrap_or(0);
Json(serde_json::json!({ "count": count }))
}
pub async fn model_name_handler(State(state): State<AppState>) -> impl IntoResponse {
Json(serde_json::json!({ "model_name": state.model_name }))
}
pub async fn indexing_status_handler(State(state): State<AppState>) -> impl IntoResponse {
let status = state.indexing_status.read().await.clone();
Json(serde_json::json!({ "status": status }))
}
#[allow(dead_code)]
#[derive(Deserialize)]
pub struct SseQuery {
#[serde(rename = "sessionId")]
pub session_id: Option<String>,
}
pub async fn sse_handler(
State(state): State<AppState>,
Query(_query): Query<SseQuery>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
// Generate a simple session ID
let session_id = uuid::Uuid::new_v4().to_string();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<String>();
log::info!("New MCP SSE Session: {}", session_id);
// Register session
state.sessions.write().await.insert(session_id.clone(), tx);
// Initial endpoint event
let endpoint_url = format!("/messages?sessionId={}", session_id);
let endpoint_event = Event::default().event("endpoint").data(endpoint_url);
let session_id_for_close = session_id.clone();
let sessions_for_close = state.sessions.clone();
let global_rx = state.tx.subscribe();
let stream = futures::stream::unfold(
(
rx,
Some(endpoint_event),
session_id_for_close,
sessions_for_close,
global_rx,
),
|(mut rx, mut initial, sid, smap, mut grx)| async move {
if let Some(event) = initial.take() {
return Some((Ok(event), (rx, None, sid, smap, grx)));
}
tokio::select! {
Some(msg) = rx.recv() => {
Some((Ok(Event::default().event("message").data(msg)), (rx, None, sid, smap, grx)))
}
Ok(msg) = grx.recv() => {
// Global notification (e.g. data update)
Some((Ok(Event::default().event("update").data(msg)), (rx, None, sid, smap, grx)))
}
else => {
log::info!("MCP SSE Session Closed: {}", sid);
smap.write().await.remove(&sid);
None
}
}
},
);
Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default())
}