Newer
Older
TelosDB / src / backend / src / mcp / handlers.rs
use axum::{
    extract::{Query, State},
    response::{
        sse::{Event, Sse},
        IntoResponse,
    },
    Json,
};
use futures::stream::Stream;
use std::convert::Infallible;
use serde::Deserialize;
use crate::mcp::types::AppState;

pub async fn doc_count_handler(State(state): State<AppState>) -> impl IntoResponse {
    let count = crate::db::get_document_count(&state.db_pool).await.unwrap_or(0);
    Json(serde_json::json!({ "count": count }))
}

pub async fn model_name_handler(State(state): State<AppState>) -> impl IntoResponse {
    Json(serde_json::json!({ "model_name": state.model_name }))
}

pub async fn edition_handler(State(state): State<AppState>) -> impl IntoResponse {
    #[cfg(feature = "pro")]
    let obj = {
        let loaded = state.embedding_model.read().await.is_some();
        serde_json::json!({ "edition": state.edition, "embedding_loaded": loaded })
    };
    #[cfg(not(feature = "pro"))]
    let obj = serde_json::json!({ "edition": state.edition });
    Json(obj)
}

pub async fn heal_handler(State(state): State<AppState>) -> impl IntoResponse {
    let synced = crate::db::heal_items_fts(&state.db_pool).await.unwrap_or(0);
    log::info!("[heal] items_fts: synced {} rows.", synced);
    Json(serde_json::json!({ "synced": synced }))
}

pub async fn indexing_status_handler(State(state): State<AppState>) -> impl IntoResponse {
    let status = state.indexing_status.read().await.clone();
    let ingestion = state
        .watch_ingestion_status
        .read()
        .map(|g| g.clone())
        .unwrap_or_default();
    Json(serde_json::json!({ "status": status, "ingestion": ingestion }))
}

pub async fn version_handler() -> impl IntoResponse {
    Json(serde_json::json!({ "version": env!("CARGO_PKG_VERSION") }))
}

fn settings_default() -> serde_json::Value {
    serde_json::json!({
        "min_score": 0.3,
        "limit": 5,
        "run_on_login": false,
        "monitor_paths": [],
        "watch_extensions": ["txt", "md", "json", "html", "css", "js", "mjs", "ts", "rs"]
    })
}

pub async fn settings_get_handler(State(state): State<AppState>) -> impl IntoResponse {
    log::info!("[server] GET /settings 受信");
    let path = state.app_data_dir.join("settings.json");
    let default = settings_default();
    if !path.exists() {
        log::info!("settings_get: no file at {:?}, returning default", path);
        return Json(default);
    }
    match tokio::fs::read_to_string(&path).await {
        Ok(s) => {
            log::info!("settings_get: read from {:?}, len={}", path, s.len());
            match serde_json::from_str::<serde_json::Value>(&s) {
                Ok(loaded) => {
                    let empty: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();
                    let obj = loaded.as_object().unwrap_or(&empty);
                    let run_on_login = obj
                        .get("run_on_login")
                        .and_then(|v| v.as_bool().or_else(|| v.as_i64().map(|n| n != 0)))
                        .unwrap_or(false);
                    let monitor_paths = obj.get("monitor_paths")
                        .and_then(|v| v.as_array())
                        .cloned()
                        .unwrap_or_else(|| vec![]);
                    let watch_extensions: Vec<serde_json::Value> = obj.get("watch_extensions")
                        .and_then(|v| v.as_array())
                        .cloned()
                        .unwrap_or_else(|| {
                            crate::mcp::watch::DEFAULT_WATCH_EXTENSIONS
                                .iter()
                                .map(|s| serde_json::Value::String((*s).to_string()))
                                .collect()
                        });
                    let merged = serde_json::json!({
                        "min_score": obj.get("min_score").and_then(|v| v.as_f64()).unwrap_or(0.3),
                        "limit": obj.get("limit").and_then(|v| v.as_i64()).unwrap_or(5),
                        "run_on_login": run_on_login,
                        "monitor_paths": monitor_paths,
                        "watch_extensions": watch_extensions
                    });
                    log::info!("settings_get: returning run_on_login={}", run_on_login);
                    Json(merged)
                }
                Err(_) => Json(default),
            }
        }
        Err(e) => {
            log::warn!("settings_get: read error {:?}", e);
            Json(default)
        }
    }
}

pub async fn settings_post_handler(
    State(state): State<AppState>,
    axum::Json(payload): axum::Json<serde_json::Value>,
) -> impl IntoResponse {
    log::info!("[server] POST /settings 受信: payload = {:?}", payload);
    let dir = &state.app_data_dir;
    if let Err(e) = tokio::fs::create_dir_all(dir).await {
        log::error!("settings_post: create_dir_all {:?}", e);
        return (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "failed to create dir"})));
    }
    let mut to_write = if payload.get("min_score").is_some() || payload.get("limit").is_some() || payload.get("run_on_login").is_some() || payload.get("monitor_paths").is_some() || payload.get("watch_extensions").is_some() {
        payload.clone()
    } else if let Some(inner) = payload.get("settings").and_then(|v| v.as_object()) {
        serde_json::to_value(inner).unwrap_or_else(|_| payload.clone())
    } else {
        payload.clone()
    };
    // remove_from_index_paths は保存しない(解除時の一回限りの指示)
    if let Some(obj) = to_write.as_object_mut() {
        obj.remove("remove_from_index_paths");
    }
    let path = dir.join("settings.json");
    let s = match serde_json::to_string_pretty(&to_write) {
        Ok(s) => s,
        Err(e) => {
            log::error!("settings_post: to_string {:?}", e);
            return (axum::http::StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": "invalid json"})));
        }
    };
    if let Err(e) = tokio::fs::write(&path, &s).await {
        log::error!("settings_post: write {:?}", e);
        return (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "failed to write"})));
    }
    log::info!("settings_post: wrote to {:?}, run_on_login={:?}", path, to_write.get("run_on_login"));

    // フォルダ監視: 保存した monitor_paths と watch_extensions をワッチャーに送り再起動(Phase 3)
    if let Some(ref tx) = state.watcher_restart_tx {
        let (monitor_paths, category_map) = to_write
            .get("monitor_paths")
            .and_then(|a| a.as_array())
            .map(|a| crate::mcp::types::parse_monitor_paths(a))
            .unwrap_or_default();
        let watch_extensions: Vec<String> = to_write
            .get("watch_extensions")
            .and_then(|a| a.as_array())
            .map(|a| a.iter().filter_map(|v| v.as_str().map(String::from)).collect())
            .unwrap_or_else(|| {
                crate::mcp::watch::DEFAULT_WATCH_EXTENSIONS
                    .iter()
                    .map(|s| s.to_string())
                    .collect()
            });
        let config = crate::mcp::types::WatcherConfig {
            paths: monitor_paths.clone(),
            extensions: watch_extensions,
            category_map: category_map.clone(),
        };
        if tx.send(config).is_err() {
            log::warn!("settings_post: watcher channel closed");
        } else {
            log::info!("settings_post: sent watcher config ({} paths) to watcher", monitor_paths.len());
        }

        // カテゴリマップに基づき既存ドキュメントのカテゴリを一括更新
        for (dir, cat) in &category_map {
            let prefix = crate::db::normalize_document_path(&dir.to_string_lossy());
            if prefix.is_empty() { continue; }
            let like_pat = format!("{}/%", crate::db::escape_like(&prefix));
            match sqlx::query("UPDATE documents SET category = ? WHERE (path = ? OR path LIKE ? ESCAPE '\\') AND COALESCE(category, '') != ?")
                .bind(cat)
                .bind(&prefix)
                .bind(&like_pat)
                .bind(cat)
                .execute(&state.db_pool)
                .await
            {
                Ok(r) if r.rows_affected() > 0 => log::info!("settings_post: updated category '{}' for {} docs under {}", cat, r.rows_affected(), prefix),
                Ok(_) => {}
                Err(e) => log::warn!("settings_post: category update for {}: {}", prefix, e),
            }
        }
        // カテゴリなしのパス → 空文字にリセット
        for path in &monitor_paths {
            if !category_map.contains_key(path) {
                let prefix = crate::db::normalize_document_path(&path.to_string_lossy());
                if prefix.is_empty() { continue; }
                let like_pat = format!("{}/%", crate::db::escape_like(&prefix));
                let _ = sqlx::query("UPDATE documents SET category = '' WHERE (path = ? OR path LIKE ? ESCAPE '\\') AND COALESCE(category, '') != ''")
                    .bind(&prefix)
                    .bind(&like_pat)
                    .execute(&state.db_pool)
                    .await;
            }
        }
    }

    // モニター解除時に「インデックスからも削除」を選んだパスを処理
    if let Some(arr) = payload.get("remove_from_index_paths").and_then(|v| v.as_array()) {
        for path_value in arr {
            if let Some(path_str) = path_value.as_str() {
                match crate::mcp::tools::items::delete_documents_under_path_prefix(&state, path_str).await {
                    Ok(n) => log::info!("settings_post: removed {} documents under path {}", n, path_str),
                    Err(e) => log::warn!("settings_post: delete under path {}: {}", path_str, e),
                }
            }
        }
    }

    (axum::http::StatusCode::OK, Json(serde_json::json!({"ok": true})))
}

#[allow(dead_code)]
#[derive(Deserialize)]
pub struct SseQuery {
    #[serde(rename = "sessionId")]
    pub session_id: Option<String>,
}

pub async fn sse_handler(
    State(state): State<AppState>,
    Query(_query): Query<SseQuery>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
    // Generate a simple session ID
    let session_id = uuid::Uuid::new_v4().to_string();
    let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<String>();

    log::info!("New MCP SSE Session: {}", session_id);

    // Register session
    state.sessions.write().await.insert(session_id.clone(), tx);

    // Initial endpoint event
    let endpoint_url = format!("/messages?sessionId={}", session_id);
    let endpoint_event = Event::default().event("endpoint").data(endpoint_url);

    let session_id_for_close = session_id.clone();
    let sessions_for_close = state.sessions.clone();
    let global_rx = state.tx.subscribe();

    let stream = futures::stream::unfold(
        (
            rx,
            Some(endpoint_event),
            session_id_for_close,
            sessions_for_close,
            global_rx,
        ),
        |(mut rx, mut initial, sid, smap, mut grx)| async move {
            if let Some(event) = initial.take() {
                return Some((Ok(event), (rx, None, sid, smap, grx)));
            }

            tokio::select! {
                Some(msg) = rx.recv() => {
                    Some((Ok(Event::default().event("message").data(msg)), (rx, None, sid, smap, grx)))
                }
                Ok(msg) = grx.recv() => {
                    // Global notification (e.g. data update)
                    Some((Ok(Event::default().event("update").data(msg)), (rx, None, sid, smap, grx)))
                }
                else => {
                    log::info!("MCP SSE Session Closed: {}", sid);
                    smap.write().await.remove(&sid);
                    None
                }
            }
        },
    );

    Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default())
}

#[cfg(test)]
mod tests {
    use super::*;

    /// 計画 auto_start: デフォルトは自動起動オフ(01 スコープ)
    #[test]
    fn settings_default_has_run_on_login_false() {
        let d = settings_default();
        assert_eq!(
            d.get("run_on_login").and_then(|v| v.as_bool()),
            Some(false),
            "初回・既存ユーザーアップデート後は自動起動はオフ"
        );
    }

    /// 計画 folder_monitor: デフォルトは監視パスなし(01 スコープ)
    #[test]
    fn settings_default_has_monitor_paths_empty() {
        let d = settings_default();
        let arr = d.get("monitor_paths").and_then(|v| v.as_array());
        assert!(arr.is_some(), "monitor_paths は配列");
        assert_eq!(
            arr.unwrap().len(),
            0,
            "初回・既存ユーザーアップデート後は監視パスは空"
        );
    }

    /// 取込対象拡張子のデフォルトが設定されている
    #[test]
    fn settings_default_has_watch_extensions() {
        let d = settings_default();
        let arr = d.get("watch_extensions").and_then(|v| v.as_array());
        assert!(arr.is_some(), "watch_extensions は配列");
        assert!(!arr.unwrap().is_empty(), "デフォルトで拡張子が入っている");
    }
}