Newer
Older
TelosDB / src-tauri / src / mcp.rs
use axum::{
    extract::State,
    response::{sse::{Event, Sse}, IntoResponse},
    routing::{get, post},
    Router, Json,
};
use std::sync::Arc;
use tokio::sync::RwLock;
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use tokio::sync::broadcast;
use futures::stream::Stream;
use tokio_stream::StreamExt;

// use tower_http::cors::CorsLayer;

#[derive(Clone)]
pub struct AppState {
    pub db_pool: sqlx::SqlitePool,
    pub tx: broadcast::Sender<String>,
    pub llama_status: Arc<RwLock<String>>, // "running"/"stopped"/"error"
}
    // extract::Extension,
use crate::db;
use sqlx::Row;

pub async fn run_server(port: u16, db_path: &str, vec0_path: &str) {
    // DBプールを初期化
    let db_pool = db::init_pool(db_path, vec0_path.to_owned()).await.expect("DB pool init failed");

    let (tx, _rx) = broadcast::channel(100);
    let llama_status = Arc::new(RwLock::new("unknown".to_string()));
    // llama-server状態監視タスク(ダミー: 3秒ごとに"running"に)
    let llama_status_clone = llama_status.clone();
    tokio::spawn(async move {
        loop {
            // TODO: 実際は/health等で死活監視
            {
                let mut status = llama_status_clone.write().await;
                *status = "running".to_string();
            }
            tokio::time::sleep(std::time::Duration::from_secs(3)).await;
        }
    });

    let app_state = AppState { db_pool, tx, llama_status };

    let app = Router::new()
        .route("/sse", get(sse_handler))
        .route("/message", post(message_handler))
        .route("/messages", post(messages_handler))
        .route("/llama_status", get(llama_status_handler))
        .with_state(app_state);

    let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
        .await
        .unwrap();
    log::info!("MCP Server listening on {}", listener.local_addr().unwrap());
    axum::serve(listener, app).await.unwrap();
}
// llama-server状態返却API
async fn llama_status_handler(State(state): State<AppState>) -> impl IntoResponse {
    let status = state.llama_status.read().await.clone();
    Json(serde_json::json!({ "status": status }))
}

#[derive(Deserialize)]
struct JsonRpcRequest {
    // jsonrpc: String, // 未使用フィールドのためコメントアウト
    method: String,
    params: Option<serde_json::Value>,
    id: serde_json::Value,
}

#[derive(Serialize)]
struct JsonRpcResponse {
    jsonrpc: &'static str,
    result: serde_json::Value,
    id: serde_json::Value,
}

// search_text用パラメータ
#[derive(Deserialize)]
struct SearchTextParams {
    content: String,
    limit: Option<u32>,
}

async fn messages_handler(
    State(state): State<AppState>,
    Json(req): Json<JsonRpcRequest>,
) -> impl IntoResponse {
    if req.method == "search_text" {
        let params: SearchTextParams = serde_json::from_value(req.params.unwrap_or_default()).unwrap();
        // 仮: embedding生成は省略し、content LIKE検索でダミー返却
        let rows = sqlx::query("SELECT id, content, 0.0 as distance FROM items WHERE content LIKE ? LIMIT ?")
            .bind(format!("%{}%", params.content))
            .bind(params.limit.unwrap_or(10))
            .fetch_all(&state.db_pool)
            .await
            .unwrap_or_default();
        let results: Vec<_> = rows.iter().map(|r| serde_json::json!({
            "id": r.get::<i64,_>(0),
            "content": r.get::<String,_>(1),
            "distance": r.get::<f64,_>(2),
        })).collect();
        let resp = JsonRpcResponse {
            jsonrpc: "2.0",
            result: serde_json::json!({"content": results}),
            id: req.id,
        };
        return axum::Json(resp);
    }
    // 未実装メソッド
    axum::Json(JsonRpcResponse {
        jsonrpc: "2.0",
        result: serde_json::json!({"error": "method not found"}),
        id: req.id,
    })
}

async fn sse_handler(
    State(state): State<AppState>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
    let rx = state.tx.subscribe();
    let stream = tokio_stream::wrappers::BroadcastStream::new(rx).map(|msg| {
        match msg {
            Ok(msg) => Ok(Event::default().data(msg)),
            Err(_) => Ok(Event::default().event("error").data("stream error")),
        }
    });

    Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default())
}

#[derive(Deserialize)]
struct MessageInput {
    message: String,
}

async fn message_handler(
    State(state): State<AppState>,
    Json(input): Json<MessageInput>,
) -> impl IntoResponse {
    let _ = state.tx.send(input.message);
    "Message sent"
}