use axum::{
extract::{Query, State},
response::{
sse::{Event, Sse},
IntoResponse,
},
Json,
};
use futures::stream::Stream;
use std::convert::Infallible;
use serde::Deserialize;
use crate::mcp::types::AppState;
pub async fn llama_status_handler(State(state): State<AppState>) -> impl IntoResponse {
let status = state.llama_status.read().await.clone();
Json(serde_json::json!({ "status": status }))
}
pub async fn doc_count_handler(State(state): State<AppState>) -> impl IntoResponse {
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM documents")
.fetch_one(&state.db_pool)
.await
.unwrap_or(0);
Json(serde_json::json!({ "count": count }))
}
pub async fn model_name_handler(State(state): State<AppState>) -> impl IntoResponse {
Json(serde_json::json!({ "model_name": state.model_name }))
}
pub async fn indexing_status_handler(State(state): State<AppState>) -> impl IntoResponse {
let status = state.indexing_status.read().await.clone();
Json(serde_json::json!({ "status": status }))
}
pub async fn version_handler() -> impl IntoResponse {
Json(serde_json::json!({ "version": env!("CARGO_PKG_VERSION") }))
}
fn settings_default() -> serde_json::Value {
serde_json::json!({
"min_score": 0.3,
"limit": 10,
"run_on_login": false,
"monitor_paths": []
})
}
pub async fn settings_get_handler(State(state): State<AppState>) -> impl IntoResponse {
log::info!("[server] GET /settings 受信");
let path = state.app_data_dir.join("settings.json");
let default = settings_default();
if !path.exists() {
log::info!("settings_get: no file at {:?}, returning default", path);
return Json(default);
}
match tokio::fs::read_to_string(&path).await {
Ok(s) => {
log::info!("settings_get: read from {:?}, len={}", path, s.len());
match serde_json::from_str::<serde_json::Value>(&s) {
Ok(loaded) => {
let empty: serde_json::Map<String, serde_json::Value> = serde_json::Map::new();
let obj = loaded.as_object().unwrap_or(&empty);
let run_on_login = obj
.get("run_on_login")
.and_then(|v| v.as_bool().or_else(|| v.as_i64().map(|n| n != 0)))
.unwrap_or(false);
let monitor_paths = obj.get("monitor_paths")
.and_then(|v| v.as_array())
.cloned()
.unwrap_or_else(|| vec![]);
let merged = serde_json::json!({
"min_score": obj.get("min_score").and_then(|v| v.as_f64()).unwrap_or(0.3),
"limit": obj.get("limit").and_then(|v| v.as_i64()).unwrap_or(10),
"run_on_login": run_on_login,
"monitor_paths": monitor_paths
});
log::info!("settings_get: returning run_on_login={}", run_on_login);
Json(merged)
}
Err(_) => Json(default),
}
}
Err(e) => {
log::warn!("settings_get: read error {:?}", e);
Json(default)
}
}
}
pub async fn settings_post_handler(
State(state): State<AppState>,
axum::Json(payload): axum::Json<serde_json::Value>,
) -> impl IntoResponse {
log::info!("[server] POST /settings 受信: payload = {:?}", payload);
let dir = &state.app_data_dir;
if let Err(e) = tokio::fs::create_dir_all(dir).await {
log::error!("settings_post: create_dir_all {:?}", e);
return (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "failed to create dir"})));
}
let to_write = if payload.get("min_score").is_some() || payload.get("limit").is_some() || payload.get("run_on_login").is_some() || payload.get("monitor_paths").is_some() {
payload.clone()
} else if let Some(inner) = payload.get("settings").and_then(|v| v.as_object()) {
serde_json::to_value(inner).unwrap_or(payload)
} else {
payload
};
let path = dir.join("settings.json");
let s = match serde_json::to_string_pretty(&to_write) {
Ok(s) => s,
Err(e) => {
log::error!("settings_post: to_string {:?}", e);
return (axum::http::StatusCode::BAD_REQUEST, Json(serde_json::json!({"error": "invalid json"})));
}
};
if let Err(e) = tokio::fs::write(&path, &s).await {
log::error!("settings_post: write {:?}", e);
return (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "failed to write"})));
}
log::info!("settings_post: wrote to {:?}, run_on_login={:?}", path, to_write.get("run_on_login"));
(axum::http::StatusCode::OK, Json(serde_json::json!({"ok": true})))
}
#[allow(dead_code)]
#[derive(Deserialize)]
pub struct SseQuery {
#[serde(rename = "sessionId")]
pub session_id: Option<String>,
}
pub async fn sse_handler(
State(state): State<AppState>,
Query(_query): Query<SseQuery>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
// Generate a simple session ID
let session_id = uuid::Uuid::new_v4().to_string();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<String>();
log::info!("New MCP SSE Session: {}", session_id);
// Register session
state.sessions.write().await.insert(session_id.clone(), tx);
// Initial endpoint event
let endpoint_url = format!("/messages?sessionId={}", session_id);
let endpoint_event = Event::default().event("endpoint").data(endpoint_url);
let session_id_for_close = session_id.clone();
let sessions_for_close = state.sessions.clone();
let global_rx = state.tx.subscribe();
let stream = futures::stream::unfold(
(
rx,
Some(endpoint_event),
session_id_for_close,
sessions_for_close,
global_rx,
),
|(mut rx, mut initial, sid, smap, mut grx)| async move {
if let Some(event) = initial.take() {
return Some((Ok(event), (rx, None, sid, smap, grx)));
}
tokio::select! {
Some(msg) = rx.recv() => {
Some((Ok(Event::default().event("message").data(msg)), (rx, None, sid, smap, grx)))
}
Ok(msg) = grx.recv() => {
// Global notification (e.g. data update)
Some((Ok(Event::default().event("update").data(msg)), (rx, None, sid, smap, grx)))
}
else => {
log::info!("MCP SSE Session Closed: {}", sid);
smap.write().await.remove(&sid);
None
}
}
},
);
Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default())
}