Newer
Older
TelosDB / src-tauri / src / mcp.rs
use crate::entities::items;
use crate::AppState;
use axum::{
    extract::State,
    response::sse::{Event, Sse},
    routing::{get, post},
    Json, Router,
};
use futures::stream::{self, Stream};
use sea_orm::*;
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::sync::Arc;

#[derive(Debug, Deserialize)]
pub struct JsonRpcRequest {
    pub jsonrpc: String,
    pub method: String,
    pub params: serde_json::Value,
    pub id: Option<serde_json::Value>,
}

#[derive(Debug, Serialize)]
pub struct JsonRpcResponse {
    pub jsonrpc: String,
    pub result: Option<serde_json::Value>,
    pub error: Option<serde_json::Value>,
    pub id: serde_json::Value,
}

pub async fn start_mcp_server(state: Arc<AppState>, port: u16) {
    let app = Router::new()
        .route("/sse", get(sse_handler))
        .route("/messages", post(message_handler))
        .with_state(state);

    let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", port))
        .await
        .unwrap();
    axum::serve(listener, app).await.unwrap();
}

async fn sse_handler(
    State(_state): State<Arc<AppState>>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
    // Basic SSE handler for MCP
    // Real MCP SDK would send 'endpoint' event here
    let stream = stream::once(async { Ok(Event::default().event("endpoint").data("/messages")) });

    Sse::new(stream)
}

pub async fn message_handler(
    State(state): State<Arc<AppState>>,
    Json(payload): Json<JsonRpcRequest>,
) -> Json<JsonRpcResponse> {
    let result = match payload.method.as_str() {
        "tools/list" => Some(serde_json::json!({
            "tools": [
                {
                    "name": "add_item_text",
                    "description": "Add item from text",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "content": { "type": "string" },
                            "path": { "type": "string" }
                        }
                    }
                },
                {
                    "name": "search_text",
                    "description": "Search items by text",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "content": { "type": "string" },
                            "limit": { "type": "number" }
                        }
                    }
                },
                {
                    "name": "add_item",
                    "description": "Add item with vector",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "content": { "type": "string" },
                            "vector": { "type": "array", "items": { "type": "number" } },
                            "path": { "type": "string" }
                        }
                    }
                },
                {
                    "name": "search_vector",
                    "description": "Search items by vector",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "vector": { "type": "array", "items": { "type": "number" } },
                            "limit": { "type": "number" }
                        }
                    }
                },
                {
                    "name": "llm_generate",
                    "description": "Generate text via LLM",
                    "inputSchema": {
                        "type": "object",
                        "properties": {
                            "prompt": { "type": "string" },
                            "n_predict": { "type": "number" },
                            "temperature": { "type": "number" }
                        }
                    }
                }
            ]
        })),
        "tools/call" => {
            let tool_name = payload.params["name"].as_str().unwrap_or("");
            let args = &payload.params["arguments"];

            match tool_name {
                "add_item_text" => {
                    let content = args["content"].as_str().unwrap_or("");
                    let path = args["path"].as_str().unwrap_or("");
                    match handle_add_item_text(&state, content, path).await {
                        Ok(res) => Some(res),
                        Err(e) => {
                            return Json(JsonRpcResponse {
                                jsonrpc: "2.0".to_string(),
                                result: None,
                                error: Some(
                                    serde_json::json!({ "code": -32000, "message": e.to_string() }),
                                ),
                                id: payload.id.unwrap_or(serde_json::Value::Null),
                            })
                        }
                    }
                }
                "add_item" => {
                    let content = args["content"].as_str().unwrap_or("");
                    let vector: Vec<f32> = args["vector"]
                        .as_array()
                        .unwrap_or(&vec![])
                        .iter()
                        .map(|v| v.as_f64().unwrap_or(0.0) as f32)
                        .collect();
                    let path = args["path"].as_str().unwrap_or("");
                    match handle_add_item(&state, content, vector, path).await {
                        Ok(res) => Some(res),
                        Err(e) => {
                            return Json(JsonRpcResponse {
                                jsonrpc: "2.0".to_string(),
                                result: None,
                                error: Some(
                                    serde_json::json!({ "code": -32000, "message": e.to_string() }),
                                ),
                                id: payload.id.unwrap_or(serde_json::Value::Null),
                            })
                        }
                    }
                }
                "search_text" => {
                    let content = args["content"].as_str().unwrap_or("");
                    let limit = args["limit"].as_u64().unwrap_or(10) as usize;
                    match handle_search_text(&state, content, limit).await {
                        Ok(res) => Some(res),
                        Err(e) => {
                            return Json(JsonRpcResponse {
                                jsonrpc: "2.0".to_string(),
                                result: None,
                                error: Some(
                                    serde_json::json!({ "code": -32000, "message": e.to_string() }),
                                ),
                                id: payload.id.unwrap_or(serde_json::Value::Null),
                            })
                        }
                    }
                }
                "search_vector" => {
                    let vector: Vec<f32> = args["vector"]
                        .as_array()
                        .unwrap_or(&vec![])
                        .iter()
                        .map(|v| v.as_f64().unwrap_or(0.0) as f32)
                        .collect();
                    let limit = args["limit"].as_u64().unwrap_or(10) as usize;
                    match handle_search_vector(&state, vector, limit).await {
                        Ok(res) => Some(res),
                        Err(e) => {
                            return Json(JsonRpcResponse {
                                jsonrpc: "2.0".to_string(),
                                result: None,
                                error: Some(
                                    serde_json::json!({ "code": -32000, "message": e.to_string() }),
                                ),
                                id: payload.id.unwrap_or(serde_json::Value::Null),
                            })
                        }
                    }
                }
                "llm_generate" => {
                    let prompt = args["prompt"].as_str().unwrap_or("");
                    let n_predict = args["n_predict"].as_i64().unwrap_or(128) as i32;
                    let temperature = args["temperature"].as_f64().unwrap_or(0.7) as f32;
                    match handle_llm_generate(&state, prompt, n_predict, temperature).await {
                        Ok(res) => Some(res),
                        Err(e) => {
                            return Json(JsonRpcResponse {
                                jsonrpc: "2.0".to_string(),
                                result: None,
                                error: Some(
                                    serde_json::json!({ "code": -32000, "message": e.to_string() }),
                                ),
                                id: payload.id.unwrap_or(serde_json::Value::Null),
                            })
                        }
                    }
                }
                _ => Some(serde_json::json!({ "error": "Unknown tool" })),
            }
        }
        _ => Some(serde_json::json!({ "error": "Method not found" })),
    };

    Json(JsonRpcResponse {
        jsonrpc: "2.0".to_string(),
        result,
        error: None,
        id: payload.id.unwrap_or(serde_json::Value::Null),
    })
}

async fn handle_add_item_text(
    state: &AppState,
    content: &str,
    path: &str,
) -> anyhow::Result<serde_json::Value> {
    let embedding = state.llama.get_embedding(content).await?;

    // SeaORM insert
    let new_item = items::ActiveModel {
        content: Set(content.to_owned()),
        path: Set(Some(path.to_owned())),
        ..Default::default()
    };

    let db = &state.db;
    let res = new_item.insert(db).await?;
    let id = res.id;

    // vec0 table insert (SeaORM raw SQL for now as it's a virtual table)
    let embedding_bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();

    db.execute(Statement::from_sql_and_values(
        DatabaseBackend::Sqlite,
        "INSERT INTO vec_items (id, embedding) VALUES (?, ?)",
        [id.into(), embedding_bytes.into()],
    ))
    .await?;

    Ok(serde_json::json!({
        "content": [{ "type": "text", "text": format!("Added item with id {}", id) }]
    }))
}

async fn handle_add_item(
    state: &AppState,
    content: &str,
    vector: Vec<f32>,
    path: &str,
) -> anyhow::Result<serde_json::Value> {
    let db = &state.db;

    let new_item = items::ActiveModel {
        content: Set(content.to_owned()),
        path: Set(Some(path.to_owned())),
        ..Default::default()
    };
    let res = new_item.insert(db).await?;
    let id = res.id;

    let embedding_bytes: Vec<u8> = vector.iter().flat_map(|f| f.to_le_bytes()).collect();

    db.execute(Statement::from_sql_and_values(
        DatabaseBackend::Sqlite,
        "INSERT INTO vec_items (id, embedding) VALUES (?, ?)",
        [id.into(), embedding_bytes.into()],
    ))
    .await?;

    Ok(serde_json::json!({
        "content": [{ "type": "text", "text": format!("Added item with id {}", id) }]
    }))
}

async fn handle_search_text(
    state: &AppState,
    content: &str,
    limit: usize,
) -> anyhow::Result<serde_json::Value> {
    let embedding = state.llama.get_embedding(content).await?;
    handle_search_vector(state, embedding, limit).await
}

async fn handle_search_vector(
    state: &AppState,
    vector: Vec<f32>,
    limit: usize,
) -> anyhow::Result<serde_json::Value> {
    let embedding_bytes: Vec<u8> = vector.iter().flat_map(|f| f.to_le_bytes()).collect();

    let db = &state.db;

    // raw SQL query via SeaORM for vector search
    let results = db
        .query_all(Statement::from_sql_and_values(
            DatabaseBackend::Sqlite,
            "SELECT i.id, i.content, i.path, i.created_at, i.updated_at, v.distance
         FROM vec_items v
         JOIN items i ON v.id = i.id
         WHERE embedding MATCH ?
         ORDER BY distance
         LIMIT ?",
            [embedding_bytes.into(), (limit as i64).into()],
        ))
        .await?;

    let mut out = Vec::new();
    for res in results {
        out.push(serde_json::json!({
            "id": res.try_get::<i32>("", "id").map_err(|e| anyhow::anyhow!(e))?,
            "content": res.try_get::<String>("", "content").map_err(|e| anyhow::anyhow!(e))?,
            "path": res.try_get::<Option<String>>("", "path").map_err(|e| anyhow::anyhow!(e))?,
            "created_at": res.try_get::<String>("", "created_at").map_err(|e| anyhow::anyhow!(e))?,
            "updated_at": res.try_get::<String>("", "updated_at").map_err(|e| anyhow::anyhow!(e))?,
            "distance": res.try_get::<f64>("", "distance").map_err(|e| anyhow::anyhow!(e))?
        }));
    }

    Ok(serde_json::json!({
        "content": [{ "type": "text", "text": serde_json::to_string_pretty(&out)? }]
    }))
}

async fn handle_llm_generate(
    state: &AppState,
    prompt: &str,
    n_predict: i32,
    temperature: f32,
) -> anyhow::Result<serde_json::Value> {
    let text = state
        .llama
        .completion(prompt, n_predict, temperature)
        .await?;
    Ok(serde_json::json!({
        "content": [{ "type": "text", "text": text }]
    }))
}

#[cfg(test)]
mod tests {
    use super::*;
    use serde_json::json;

    #[test]
    fn test_list_tools_format() {
        let req = JsonRpcRequest {
            jsonrpc: "2.0".to_string(),
            method: "tools/list".to_string(),
            params: json!({}),
            id: Some(json!(1)),
        };
        assert_eq!(req.method, "tools/list");
    }

    #[test]
    fn test_search_response_structure() {
        let results = vec![json!({
            "id": 1,
            "content": "test content",
            "path": "/test/path",
            "created_at": "2024-02-07 15:00:00",
            "updated_at": "2024-02-07 15:00:00",
            "distance": 0.1
        })];

        let response_body = json!({
            "content": [{ "type": "text", "text": serde_json::to_string_pretty(&results).unwrap() }]
        });

        assert!(response_body.get("content").is_some());
        let text = response_body["content"][0]["text"].as_str().unwrap();
        assert!(text.contains("created_at"));
        assert!(text.contains("2024-02-07 15:00:00"));
    }
}