use crate::entities::items;
use crate::AppState;
use axum::{
extract::State,
response::sse::{Event, Sse},
routing::{get, post},
Json, Router,
};
use futures::stream::{self, Stream};
use sea_orm::*;
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::sync::Arc;
#[derive(Debug, Deserialize)]
pub struct JsonRpcRequest {
pub jsonrpc: String,
pub method: String,
pub params: serde_json::Value,
pub id: Option<serde_json::Value>,
}
#[derive(Debug, Serialize)]
pub struct JsonRpcResponse {
pub jsonrpc: String,
pub result: Option<serde_json::Value>,
pub error: Option<serde_json::Value>,
pub id: serde_json::Value,
}
pub async fn start_mcp_server(state: Arc<AppState>, port: u16) {
let app = Router::new()
.route("/sse", get(sse_handler))
.route("/messages", post(message_handler))
.with_state(state);
let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", port))
.await
.unwrap();
axum::serve(listener, app).await.unwrap();
}
async fn sse_handler(
State(_state): State<Arc<AppState>>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
// Basic SSE handler for MCP
// Real MCP SDK would send 'endpoint' event here
let stream = stream::once(async { Ok(Event::default().event("endpoint").data("/messages")) });
Sse::new(stream)
}
pub async fn message_handler(
State(state): State<Arc<AppState>>,
Json(payload): Json<JsonRpcRequest>,
) -> Json<JsonRpcResponse> {
let result = match payload.method.as_str() {
"tools/list" => Some(serde_json::json!({
"tools": [
{
"name": "add_item_text",
"description": "Add item from text",
"inputSchema": {
"type": "object",
"properties": {
"content": { "type": "string" },
"path": { "type": "string" }
}
}
},
{
"name": "search_text",
"description": "Search items by text",
"inputSchema": {
"type": "object",
"properties": {
"content": { "type": "string" },
"limit": { "type": "number" }
}
}
},
{
"name": "add_item",
"description": "Add item with vector",
"inputSchema": {
"type": "object",
"properties": {
"content": { "type": "string" },
"vector": { "type": "array", "items": { "type": "number" } },
"path": { "type": "string" }
}
}
},
{
"name": "search_vector",
"description": "Search items by vector",
"inputSchema": {
"type": "object",
"properties": {
"vector": { "type": "array", "items": { "type": "number" } },
"limit": { "type": "number" }
}
}
},
{
"name": "llm_generate",
"description": "Generate text via LLM",
"inputSchema": {
"type": "object",
"properties": {
"prompt": { "type": "string" },
"n_predict": { "type": "number" },
"temperature": { "type": "number" }
}
}
}
]
})),
"tools/call" => {
let tool_name = payload.params["name"].as_str().unwrap_or("");
let args = &payload.params["arguments"];
match tool_name {
"add_item_text" => {
let content = args["content"].as_str().unwrap_or("");
let path = args["path"].as_str().unwrap_or("");
match handle_add_item_text(&state, content, path).await {
Ok(res) => Some(res),
Err(e) => {
return Json(JsonRpcResponse {
jsonrpc: "2.0".to_string(),
result: None,
error: Some(
serde_json::json!({ "code": -32000, "message": e.to_string() }),
),
id: payload.id.unwrap_or(serde_json::Value::Null),
})
}
}
}
"add_item" => {
let content = args["content"].as_str().unwrap_or("");
let vector: Vec<f32> = args["vector"]
.as_array()
.unwrap_or(&vec![])
.iter()
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
.collect();
let path = args["path"].as_str().unwrap_or("");
match handle_add_item(&state, content, vector, path).await {
Ok(res) => Some(res),
Err(e) => {
return Json(JsonRpcResponse {
jsonrpc: "2.0".to_string(),
result: None,
error: Some(
serde_json::json!({ "code": -32000, "message": e.to_string() }),
),
id: payload.id.unwrap_or(serde_json::Value::Null),
})
}
}
}
"search_text" => {
let content = args["content"].as_str().unwrap_or("");
let limit = args["limit"].as_u64().unwrap_or(10) as usize;
match handle_search_text(&state, content, limit).await {
Ok(res) => Some(res),
Err(e) => {
return Json(JsonRpcResponse {
jsonrpc: "2.0".to_string(),
result: None,
error: Some(
serde_json::json!({ "code": -32000, "message": e.to_string() }),
),
id: payload.id.unwrap_or(serde_json::Value::Null),
})
}
}
}
"search_vector" => {
let vector: Vec<f32> = args["vector"]
.as_array()
.unwrap_or(&vec![])
.iter()
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
.collect();
let limit = args["limit"].as_u64().unwrap_or(10) as usize;
match handle_search_vector(&state, vector, limit).await {
Ok(res) => Some(res),
Err(e) => {
return Json(JsonRpcResponse {
jsonrpc: "2.0".to_string(),
result: None,
error: Some(
serde_json::json!({ "code": -32000, "message": e.to_string() }),
),
id: payload.id.unwrap_or(serde_json::Value::Null),
})
}
}
}
"llm_generate" => {
let prompt = args["prompt"].as_str().unwrap_or("");
let n_predict = args["n_predict"].as_i64().unwrap_or(128) as i32;
let temperature = args["temperature"].as_f64().unwrap_or(0.7) as f32;
match handle_llm_generate(&state, prompt, n_predict, temperature).await {
Ok(res) => Some(res),
Err(e) => {
return Json(JsonRpcResponse {
jsonrpc: "2.0".to_string(),
result: None,
error: Some(
serde_json::json!({ "code": -32000, "message": e.to_string() }),
),
id: payload.id.unwrap_or(serde_json::Value::Null),
})
}
}
}
_ => Some(serde_json::json!({ "error": "Unknown tool" })),
}
}
_ => Some(serde_json::json!({ "error": "Method not found" })),
};
Json(JsonRpcResponse {
jsonrpc: "2.0".to_string(),
result,
error: None,
id: payload.id.unwrap_or(serde_json::Value::Null),
})
}
async fn handle_add_item_text(
state: &AppState,
content: &str,
path: &str,
) -> anyhow::Result<serde_json::Value> {
let embedding = state.llama.get_embedding(content).await?;
// SeaORM insert
let new_item = items::ActiveModel {
content: Set(content.to_owned()),
path: Set(Some(path.to_owned())),
..Default::default()
};
let db = &state.db;
let res = new_item.insert(db).await?;
let id = res.id;
// vec0 table insert (SeaORM raw SQL for now as it's a virtual table)
let embedding_bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
db.execute(Statement::from_sql_and_values(
DatabaseBackend::Sqlite,
"INSERT INTO vec_items (id, embedding) VALUES (?, ?)",
[id.into(), embedding_bytes.into()],
))
.await?;
Ok(serde_json::json!({
"content": [{ "type": "text", "text": format!("Added item with id {}", id) }]
}))
}
async fn handle_add_item(
state: &AppState,
content: &str,
vector: Vec<f32>,
path: &str,
) -> anyhow::Result<serde_json::Value> {
let db = &state.db;
let new_item = items::ActiveModel {
content: Set(content.to_owned()),
path: Set(Some(path.to_owned())),
..Default::default()
};
let res = new_item.insert(db).await?;
let id = res.id;
let embedding_bytes: Vec<u8> = vector.iter().flat_map(|f| f.to_le_bytes()).collect();
db.execute(Statement::from_sql_and_values(
DatabaseBackend::Sqlite,
"INSERT INTO vec_items (id, embedding) VALUES (?, ?)",
[id.into(), embedding_bytes.into()],
))
.await?;
Ok(serde_json::json!({
"content": [{ "type": "text", "text": format!("Added item with id {}", id) }]
}))
}
async fn handle_search_text(
state: &AppState,
content: &str,
limit: usize,
) -> anyhow::Result<serde_json::Value> {
let embedding = state.llama.get_embedding(content).await?;
handle_search_vector(state, embedding, limit).await
}
async fn handle_search_vector(
state: &AppState,
vector: Vec<f32>,
limit: usize,
) -> anyhow::Result<serde_json::Value> {
let embedding_bytes: Vec<u8> = vector.iter().flat_map(|f| f.to_le_bytes()).collect();
let db = &state.db;
// raw SQL query via SeaORM for vector search
let results = db
.query_all(Statement::from_sql_and_values(
DatabaseBackend::Sqlite,
"SELECT i.id, i.content, i.path, i.created_at, i.updated_at, v.distance
FROM vec_items v
JOIN items i ON v.id = i.id
WHERE embedding MATCH ?
ORDER BY distance
LIMIT ?",
[embedding_bytes.into(), (limit as i64).into()],
))
.await?;
let mut out = Vec::new();
for res in results {
out.push(serde_json::json!({
"id": res.try_get::<i32>("", "id").map_err(|e| anyhow::anyhow!(e))?,
"content": res.try_get::<String>("", "content").map_err(|e| anyhow::anyhow!(e))?,
"path": res.try_get::<Option<String>>("", "path").map_err(|e| anyhow::anyhow!(e))?,
"created_at": res.try_get::<String>("", "created_at").map_err(|e| anyhow::anyhow!(e))?,
"updated_at": res.try_get::<String>("", "updated_at").map_err(|e| anyhow::anyhow!(e))?,
"distance": res.try_get::<f64>("", "distance").map_err(|e| anyhow::anyhow!(e))?
}));
}
Ok(serde_json::json!({
"content": [{ "type": "text", "text": serde_json::to_string_pretty(&out)? }]
}))
}
async fn handle_llm_generate(
state: &AppState,
prompt: &str,
n_predict: i32,
temperature: f32,
) -> anyhow::Result<serde_json::Value> {
let text = state
.llama
.completion(prompt, n_predict, temperature)
.await?;
Ok(serde_json::json!({
"content": [{ "type": "text", "text": text }]
}))
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_list_tools_format() {
let req = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
method: "tools/list".to_string(),
params: json!({}),
id: Some(json!(1)),
};
assert_eq!(req.method, "tools/list");
}
#[test]
fn test_search_response_structure() {
let results = vec![json!({
"id": 1,
"content": "test content",
"path": "/test/path",
"created_at": "2024-02-07 15:00:00",
"updated_at": "2024-02-07 15:00:00",
"distance": 0.1
})];
let response_body = json!({
"content": [{ "type": "text", "text": serde_json::to_string_pretty(&results).unwrap() }]
});
assert!(response_body.get("content").is_some());
let text = response_body["content"][0]["text"].as_str().unwrap();
assert!(text.contains("created_at"));
assert!(text.contains("2024-02-07 15:00:00"));
}
}