Newer
Older
TelosDB / src / backend / src / mcp / handlers.rs
use axum::{
    extract::{Query, State},
    response::{
        sse::{Event, Sse},
        IntoResponse,
    },
    Json,
};
use futures::stream::Stream;
use std::convert::Infallible;
use serde::Deserialize;
use crate::mcp::types::AppState;

pub async fn llama_status_handler(State(state): State<AppState>) -> impl IntoResponse {
    let status = state.llama_status.read().await.clone();
    Json(serde_json::json!({ "status": status }))
}

pub async fn doc_count_handler(State(state): State<AppState>) -> impl IntoResponse {
    let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM documents")
        .fetch_one(&state.db_pool)
        .await
        .unwrap_or(0);
    Json(serde_json::json!({ "count": count }))
}

pub async fn model_name_handler(State(state): State<AppState>) -> impl IntoResponse {
    Json(serde_json::json!({ "model_name": state.model_name }))
}

pub async fn indexing_status_handler(State(state): State<AppState>) -> impl IntoResponse {
    let status = state.indexing_status.read().await.clone();
    Json(serde_json::json!({ "status": status }))
}

pub async fn version_handler() -> impl IntoResponse {
    Json(serde_json::json!({ "version": env!("CARGO_PKG_VERSION") }))
}

fn settings_default() -> serde_json::Value {
    serde_json::json!({
        "min_score": 0.3,
        "limit": 10,
        "run_on_login": false,
        "monitor_paths": []
    })
}

pub async fn settings_get_handler(State(state): State<AppState>) -> impl IntoResponse {
    log::info!("[server] GET /settings 受信");
    let path = state.app_data_dir.join("settings.json");
    let default = settings_default();
    if !path.exists() {
        log::info!("settings_get: no file at {:?}, returning default", path);
        return Json(default);
    }
    match tokio::fs::read_to_string(&path).await {
        Ok(s) => {
            log::info!("settings_get: read from {:?}, len={}", path, s.len());
            match serde_json::from_str::<serde_json::Value>(&s) {
                Ok(loaded) => {
                    let empty: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();
                    let obj = loaded.as_object().unwrap_or(&empty);
                    let run_on_login = obj
                        .get("run_on_login")
                        .and_then(|v| v.as_bool().or_else(|| v.as_i64().map(|n| n != 0)))
                        .unwrap_or(false);
                    let monitor_paths = obj.get("monitor_paths")
                        .and_then(|v| v.as_array())
                        .cloned()
                        .unwrap_or_else(|| vec![]);
                    let merged = serde_json::json!({
                        "min_score": obj.get("min_score").and_then(|v| v.as_f64()).unwrap_or(0.3),
                        "limit": obj.get("limit").and_then(|v| v.as_i64()).unwrap_or(10),
                        "run_on_login": run_on_login,
                        "monitor_paths": monitor_paths
                    });
                    log::info!("settings_get: returning run_on_login={}", run_on_login);
                    Json(merged)
                }
                Err(_) => Json(default),
            }
        }
        Err(e) => {
            log::warn!("settings_get: read error {:?}", e);
            Json(default)
        }
    }
}

pub async fn settings_post_handler(
    State(state): State<AppState>,
    axum::Json(payload): axum::Json<serde_json::Value>,
) -> impl IntoResponse {
    log::info!("[server] POST /settings 受信: payload = {:?}", payload);
    let dir = &state.app_data_dir;
    if let Err(e) = tokio::fs::create_dir_all(dir).await {
        log::error!("settings_post: create_dir_all {:?}", e);
        return (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "failed to create dir"})));
    }
    let to_write = if payload.get("min_score").is_some() || payload.get("limit").is_some() || payload.get("run_on_login").is_some() || payload.get("monitor_paths").is_some() {
        payload.clone()
    } else if let Some(inner) = payload.get("settings").and_then(|v| v.as_object()) {
        serde_json::to_value(inner).unwrap_or(payload)
    } else {
        payload
    };
    let path = dir.join("settings.json");
    let s = match serde_json::to_string_pretty(&to_write) {
        Ok(s) => s,
        Err(e) => {
            log::error!("settings_post: to_string {:?}", e);
            return (axum::http::StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": "invalid json"})));
        }
    };
    if let Err(e) = tokio::fs::write(&path, &s).await {
        log::error!("settings_post: write {:?}", e);
        return (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "failed to write"})));
    }
    log::info!("settings_post: wrote to {:?}, run_on_login={:?}", path, to_write.get("run_on_login"));
    (axum::http::StatusCode::OK, Json(serde_json::json!({"ok": true})))
}

#[allow(dead_code)]
#[derive(Deserialize)]
pub struct SseQuery {
    #[serde(rename = "sessionId")]
    pub session_id: Option<String>,
}

pub async fn sse_handler(
    State(state): State<AppState>,
    Query(_query): Query<SseQuery>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
    // Generate a simple session ID
    let session_id = uuid::Uuid::new_v4().to_string();
    let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<String>();

    log::info!("New MCP SSE Session: {}", session_id);

    // Register session
    state.sessions.write().await.insert(session_id.clone(), tx);

    // Initial endpoint event
    let endpoint_url = format!("/messages?sessionId={}", session_id);
    let endpoint_event = Event::default().event("endpoint").data(endpoint_url);

    let session_id_for_close = session_id.clone();
    let sessions_for_close = state.sessions.clone();
    let global_rx = state.tx.subscribe();

    let stream = futures::stream::unfold(
        (
            rx,
            Some(endpoint_event),
            session_id_for_close,
            sessions_for_close,
            global_rx,
        ),
        |(mut rx, mut initial, sid, smap, mut grx)| async move {
            if let Some(event) = initial.take() {
                return Some((Ok(event), (rx, None, sid, smap, grx)));
            }

            tokio::select! {
                Some(msg) = rx.recv() => {
                    Some((Ok(Event::default().event("message").data(msg)), (rx, None, sid, smap, grx)))
                }
                Ok(msg) = grx.recv() => {
                    // Global notification (e.g. data update)
                    Some((Ok(Event::default().event("update").data(msg)), (rx, None, sid, smap, grx)))
                }
                else => {
                    log::info!("MCP SSE Session Closed: {}", sid);
                    smap.write().await.remove(&sid);
                    None
                }
            }
        },
    );

    Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default())
}