use axum::{
extract::State,
response::{sse::{Event, Sse}, IntoResponse},
routing::{get, post},
Router, Json,
};
use std::sync::Arc;
use tokio::sync::RwLock;
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use tokio::sync::broadcast;
use futures::stream::Stream;
use tokio_stream::StreamExt;
// use tower_http::cors::CorsLayer;
#[derive(Clone)]
pub struct AppState {
pub db_pool: sqlx::SqlitePool,
pub tx: broadcast::Sender<String>,
pub llama_status: Arc<RwLock<String>>, // "running"/"stopped"/"error"
}
// extract::Extension,
use crate::db;
use sqlx::Row;
pub async fn run_server(port: u16, db_path: &str, vec0_path: &str) {
// DBプールを初期化
let db_pool = db::init_pool(db_path, vec0_path.to_owned()).await.expect("DB pool init failed");
let (tx, _rx) = broadcast::channel(100);
let llama_status = Arc::new(RwLock::new("unknown".to_string()));
// llama-server状態監視タスク(ダミー: 3秒ごとに"running"に)
let llama_status_clone = llama_status.clone();
tokio::spawn(async move {
loop {
// TODO: 実際は/health等で死活監視
{
let mut status = llama_status_clone.write().await;
*status = "running".to_string();
}
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
}
});
let app_state = AppState { db_pool, tx, llama_status };
let app = Router::new()
.route("/sse", get(sse_handler))
.route("/message", post(message_handler))
.route("/messages", post(messages_handler))
.route("/llama_status", get(llama_status_handler))
.with_state(app_state);
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
log::info!("MCP Server listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
// llama-server状態返却API
async fn llama_status_handler(State(state): State<AppState>) -> impl IntoResponse {
let status = state.llama_status.read().await.clone();
Json(serde_json::json!({ "status": status }))
}
#[derive(Deserialize)]
struct JsonRpcRequest {
// jsonrpc: String, // 未使用フィールドのためコメントアウト
method: String,
params: Option<serde_json::Value>,
id: serde_json::Value,
}
#[derive(Serialize)]
struct JsonRpcResponse {
jsonrpc: &'static str,
result: serde_json::Value,
id: serde_json::Value,
}
// search_text用パラメータ
#[derive(Deserialize)]
struct SearchTextParams {
content: String,
limit: Option<u32>,
}
async fn messages_handler(
State(state): State<AppState>,
Json(req): Json<JsonRpcRequest>,
) -> impl IntoResponse {
if req.method == "search_text" {
let params: SearchTextParams = serde_json::from_value(req.params.unwrap_or_default()).unwrap();
// 仮: embedding生成は省略し、content LIKE検索でダミー返却
let rows = sqlx::query("SELECT id, content, 0.0 as distance FROM items WHERE content LIKE ? LIMIT ?")
.bind(format!("%{}%", params.content))
.bind(params.limit.unwrap_or(10))
.fetch_all(&state.db_pool)
.await
.unwrap_or_default();
let results: Vec<_> = rows.iter().map(|r| serde_json::json!({
"id": r.get::<i64,_>(0),
"content": r.get::<String,_>(1),
"distance": r.get::<f64,_>(2),
})).collect();
let resp = JsonRpcResponse {
jsonrpc: "2.0",
result: serde_json::json!({"content": results}),
id: req.id,
};
return axum::Json(resp);
}
// 未実装メソッド
axum::Json(JsonRpcResponse {
jsonrpc: "2.0",
result: serde_json::json!({"error": "method not found"}),
id: req.id,
})
}
async fn sse_handler(
State(state): State<AppState>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let rx = state.tx.subscribe();
let stream = tokio_stream::wrappers::BroadcastStream::new(rx).map(|msg| {
match msg {
Ok(msg) => Ok(Event::default().data(msg)),
Err(_) => Ok(Event::default().event("error").data("stream error")),
}
});
Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default())
}
#[derive(Deserialize)]
struct MessageInput {
message: String,
}
async fn message_handler(
State(state): State<AppState>,
Json(input): Json<MessageInput>,
) -> impl IntoResponse {
let _ = state.tx.send(input.message);
"Message sent"
}