use axum::{
extract::{State, Query},
response::{sse::{Event, Sse}, IntoResponse},
routing::{get, post},
Router, Json,
};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, mpsc};
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use tokio::sync::broadcast;
use futures::stream::Stream;
use tokio_stream::StreamExt;
use tower_http::cors::{Any, CorsLayer};
use crate::db;
use sqlx::Row;
#[derive(Clone)]
pub struct AppState {
pub db_pool: sqlx::SqlitePool,
pub tx: broadcast::Sender<String>,
pub llama_status: Arc<RwLock<String>>,
// MCP sessions map
pub sessions: Arc<RwLock<HashMap<String, mpsc::UnboundedSender<String>>>>,
}
pub async fn run_server(port: u16, db_path: &str, vec0_path: &str, llama_status: Arc<RwLock<String>>) {
let db_pool = db::init_pool(db_path, vec0_path.to_owned()).await.expect("DB pool init failed");
let (tx, _rx) = broadcast::channel(100);
let sessions = Arc::new(RwLock::new(HashMap::new()));
// llama-server status monitor
let llama_status_clone = llama_status.clone();
tokio::spawn(async move {
let client = reqwest::Client::new();
loop {
let status = match client.get("http://127.0.0.1:8080/health").send().await {
Ok(resp) if resp.status().is_success() => "running".to_string(),
Ok(_) => "error".to_string(),
Err(_) => "stopped".to_string(),
};
{
let mut s = llama_status_clone.write().await;
if *s != status {
log::info!("llama-server status changed: {} -> {}", *s, status);
*s = status;
}
}
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
}
});
let app_state = AppState { db_pool, tx, llama_status: llama_status.clone(), sessions };
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
let app = Router::new()
.route("/sse", get(sse_handler))
.route("/messages", post(mcp_messages_handler))
.route("/llama_status", get(llama_status_handler))
.layer(cors)
.with_state(app_state);
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
log::info!("MCP Server listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
async fn llama_status_handler(State(state): State<AppState>) -> impl IntoResponse {
let status = state.llama_status.read().await.clone();
Json(serde_json::json!({ "status": status }))
}
#[derive(Deserialize)]
struct SseQuery {
session_id: Option<String>,
}
async fn sse_handler(
State(state): State<AppState>,
Query(_query): Query<SseQuery>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
// Generate a simple session ID
let session_id = uuid::Uuid::new_v4().to_string();
let (tx, mut rx) = mpsc::unbounded_channel::<String>();
log::info!("New MCP SSE Session: {}", session_id);
// Register session
state.sessions.write().await.insert(session_id.clone(), tx);
// Initial endpoint event
let endpoint_url = format!("/messages?session_id={}", session_id);
let endpoint_event = Event::default().event("endpoint").data(endpoint_url);
let session_id_for_close = session_id.clone();
let sessions_for_close = state.sessions.clone();
let stream = futures::stream::unfold(
(rx, Some(endpoint_event), session_id_for_close, sessions_for_close),
|(mut rx, mut initial, sid, smap)| async move {
if let Some(event) = initial.take() {
return Some((Ok(event), (rx, None, sid, smap)));
}
tokio::select! {
Some(msg) = rx.recv() => {
Some((Ok(Event::default().event("message").data(msg)), (rx, None, sid, smap)))
}
else => {
log::info!("MCP SSE Session Closed: {}", sid);
smap.write().await.remove(&sid);
None
}
}
},
);
Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default())
}
#[derive(Deserialize)]
struct JsonRpcRequest {
jsonrpc: String,
method: String,
params: Option<serde_json::Value>,
id: Option<serde_json::Value>,
}
#[derive(Serialize)]
struct JsonRpcResponse {
jsonrpc: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<serde_json::Value>,
id: Option<serde_json::Value>,
}
#[derive(Deserialize)]
struct MessageQuery {
session_id: Option<String>,
}
async fn get_embedding(content: &str) -> Result<Vec<f32>, String> {
let client = reqwest::Client::new();
let resp = client.post("http://127.0.0.1:8080/embedding")
.json(&serde_json::json!({ "content": content }))
.send()
.await
.map_err(|e| e.to_string())?;
let json: serde_json::Value = resp.json().await.map_err(|e| e.to_string())?;
let embedding = json["embedding"].as_array()
.ok_or("No embedding field in llama-server response")?
.iter()
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
.collect();
Ok(embedding)
}
async fn mcp_messages_handler(
State(state): State<AppState>,
Query(query): Query<MessageQuery>,
Json(req): Json<JsonRpcRequest>,
) -> impl IntoResponse {
let method = req.method.as_str();
log::info!("MCP Request: {} (Session: {:?})", method, query.session_id);
let result = match method {
"initialize" => Some(serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": { "tools": {} },
"serverInfo": { "name": "TelosDB", "version": "0.1.0" }
})),
"notifications/initialized" => None,
"tools/list" => Some(serde_json::json!({
"tools": [
{
"name": "add_item_text",
"description": "Store text with auto-generated embeddings.",
"inputSchema": {
"type": "object",
"properties": {
"content": { "type": "string" },
"path": { "type": "string" }
},
"required": ["content"]
}
},
{
"name": "search_text",
"description": "Semantic search using vector embeddings.",
"inputSchema": {
"type": "object",
"properties": {
"content": { "type": "string" },
"limit": { "type": "number" }
},
"required": ["content"]
}
},
{
"name": "update_item",
"description": "Update existing text and its embedding.",
"inputSchema": {
"type": "object",
"properties": {
"id": { "type": "integer" },
"content": { "type": "string" },
"path": { "type": "string" }
},
"required": ["id", "content"]
}
},
{
"name": "delete_item",
"description": "Delete item by ID.",
"inputSchema": {
"type": "object",
"properties": {
"id": { "type": "integer" }
},
"required": ["id"]
}
}
]
})),
"search_text" | "tools/call" => {
let p = req.params.clone().unwrap_or_default();
let (actual_method, args) = if method == "tools/call" {
(p.get("name").and_then(|v| v.as_str()).unwrap_or(""), p.get("arguments").cloned().unwrap_or_default())
} else {
(method, p)
};
match actual_method {
"add_item_text" => {
let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("");
let path = args.get("path").and_then(|v| v.as_str());
match get_embedding(content).await {
Ok(emb) => {
let mut tx = state.db_pool.begin().await.unwrap();
let res = sqlx::query("INSERT INTO items (content, path) VALUES (?, ?)")
.bind(content)
.bind(path)
.execute(&mut *tx)
.await
.unwrap();
let id = res.last_insert_rowid();
sqlx::query("INSERT INTO vec_items (id, embedding) VALUES (?, ?)")
.bind(id)
.bind(serde_json::to_string(&emb).unwrap())
.execute(&mut *tx)
.await
.unwrap();
tx.commit().await.unwrap();
Some(serde_json::json!({ "content": [{ "type": "text", "text": format!("Successfully added item with ID: {}", id) }] }))
}
Err(e) => Some(serde_json::json!({ "error": format!("Embedding failed: {}", e) }))
}
},
"search_text" => {
let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("");
let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as u32;
match get_embedding(content).await {
Ok(emb) => {
let rows = sqlx::query(
"SELECT items.id, items.content, v.distance
FROM items
JOIN vec_items v ON items.id = v.id
WHERE v.embedding MATCH ? AND k = ?
ORDER BY distance LIMIT ?"
)
.bind(serde_json::to_string(&emb).unwrap())
.bind(limit)
.bind(limit)
.fetch_all(&state.db_pool)
.await
.unwrap_or_default();
let is_mcp_output = method == "tools/call";
if is_mcp_output {
let txt = if rows.is_empty() { "No results.".to_string() } else {
rows.iter().map(|r| format!("[ID: {}, Distance: {:.4}]\n{}", r.get::<i64, _>(0), r.get::<f64, _>(2), r.get::<String, _>(1))).collect::<Vec<_>>().join("\n\n---\n\n")
};
Some(serde_json::json!({ "content": [{ "type": "text", "text": txt }] }))
} else {
let res: Vec<_> = rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>(0),
"content": r.get::<String,_>(1),
"distance": r.get::<f64, _>(2)
})).collect();
Some(serde_json::json!({ "content": res }))
}
}
Err(e) => {
// Fallback to LIKE if llama-server is not running
let rows = sqlx::query("SELECT id, content FROM items WHERE content LIKE ? LIMIT ?")
.bind(format!("%{}%", content))
.bind(limit)
.fetch_all(&state.db_pool)
.await
.unwrap_or_default();
let txt = format!("(Fallback SEARCH due to embedding error: {})\n\n", e);
let results = rows.iter().map(|r| format!("ID: {}, Content: {}", r.get::<i64, _>(0), r.get::<String, _>(1))).collect::<Vec<_>>().join("\n\n");
Some(serde_json::json!({ "content": [{ "type": "text", "text": txt + &results }] }))
}
}
},
"update_item" => {
let id = args.get("id").and_then(|v| v.as_i64()).unwrap_or(0);
let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("");
let path = args.get("path").and_then(|v| v.as_str());
match get_embedding(content).await {
Ok(emb) => {
let mut tx = state.db_pool.begin().await.unwrap();
sqlx::query("UPDATE items SET content = ?, path = ? WHERE id = ?")
.bind(content)
.bind(path)
.bind(id)
.execute(&mut *tx)
.await
.unwrap();
sqlx::query("UPDATE vec_items SET embedding = ? WHERE id = ?")
.bind(serde_json::to_string(&emb).unwrap())
.bind(id)
.execute(&mut *tx)
.await
.unwrap();
tx.commit().await.unwrap();
Some(serde_json::json!({ "content": [{ "type": "text", "text": format!("Successfully updated item {}", id) }] }))
}
Err(e) => Some(serde_json::json!({ "error": format!("Embedding failed: {}", e) }))
}
},
"delete_item" => {
let id = args.get("id").and_then(|v| v.as_i64()).unwrap_or(0);
let mut tx = state.db_pool.begin().await.unwrap();
sqlx::query("DELETE FROM items WHERE id = ?").bind(id).execute(&mut *tx).await.unwrap();
sqlx::query("DELETE FROM vec_items WHERE id = ?").bind(id).execute(&mut *tx).await.unwrap();
tx.commit().await.unwrap();
Some(serde_json::json!({ "content": [{ "type": "text", "text": format!("Successfully deleted item {}", id) }] }))
},
_ => Some(serde_json::json!({ "error": "Unknown tool" })),
}
},
_ => Some(serde_json::json!({ "error": "Not implemented" })),
};
if let Some(id_val) = req.id {
let resp = JsonRpcResponse { jsonrpc: "2.0", result, error: None, id: Some(id_val) };
if let Some(sid) = query.session_id {
// MCP Client (SSE Mode)
let resp_str = serde_json::to_string(&resp).unwrap();
let sessions = state.sessions.read().await;
if let Some(tx) = sessions.get(&sid) {
let _ = tx.send(resp_str);
}
axum::http::StatusCode::ACCEPTED.into_response()
} else {
// App UI (Direct Mode)
Json(resp).into_response()
}
} else {
axum::http::StatusCode::NO_CONTENT.into_response()
}
}