pub mod types;
pub mod handlers;
pub mod system;
pub mod tools;
pub use types::AppState;
use axum::{
routing::{get, post},
Router,
};
use std::sync::atomic::{AtomicBool, AtomicU64};
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
use tower_http::cors::{Any, CorsLayer};
use std::collections::HashMap;
pub fn create_mcp_app(state: AppState) -> Router {
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
Router::new()
.route("/sse", get(handlers::sse_handler))
.route("/messages", post(mcp_messages_handler))
.route("/llama_status", get(handlers::llama_status_handler))
.route("/doc_count", get(handlers::doc_count_handler))
.route("/model_name", get(handlers::model_name_handler))
.route("/indexing_status", get(handlers::indexing_status_handler))
.layer(cors)
.with_state(state)
}
pub async fn run_server(
port: u16,
db_pool: sqlx::SqlitePool,
llama_status: Arc<RwLock<String>>,
model_name: String,
) {
let (tx, _rx) = broadcast::channel(100);
let tokenizer = Arc::new(crate::utils::tokenizer::JapaneseTokenizer::new().unwrap());
let state = AppState {
db_pool,
tx,
llama_status,
model_name,
sessions: Arc::new(RwLock::new(HashMap::new())),
tokenizer,
lsa_model: Arc::new(RwLock::new(None)),
hnsw_index: Arc::new(RwLock::new(None)),
changes_since_train: Arc::new(AtomicU64::new(0)),
retrain_scheduled: Arc::new(AtomicBool::new(false)),
indexing_status: Arc::new(RwLock::new("idle".to_string())),
};
// 初期化時に LSA トレーニングとベクトル同期(HNSW構築含む)をバックグラウンドで開始
let state_init = state.clone();
tokio::spawn(async move {
system::train_lsa_and_sync_hnsw(state_init).await;
});
let app = create_mcp_app(state);
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
log::info!("MCP Server listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
// ----------------------------------------------------------------------------
// Main Message Handler (Dispatching to Tools)
// ----------------------------------------------------------------------------
use axum::{
extract::{Query, State},
response::{IntoResponse, Response},
Json,
body::Bytes,
};
use crate::mcp::types::{JsonRpcRequest, JsonRpcResponse, MessageQuery};
pub async fn mcp_messages_handler(
State(state): State<AppState>,
Query(query): Query<MessageQuery>,
body: Bytes,
) -> Response {
// 1. Raw body logging for diagnostics
let body_str = String::from_utf8_lossy(&body);
log::info!("MCP Incoming RAW Request: {}", body_str);
// 2. Manual JSON parsing to avoid silent extraction errors
let req: JsonRpcRequest = match serde_json::from_slice(&body) {
Ok(r) => r,
Err(e) => {
log::error!("MCP JSON Deserialization Error: {}. Raw: {}", e, body_str);
return (
axum::http::StatusCode::UNPROCESSABLE_ENTITY,
Json(serde_json::json!({
"jsonrpc": "2.0",
"error": { "code": -32700, "message": "Parse error" },
"id": null
}))
).into_response();
}
};
let method = req.method.as_str();
let actual_method = if method == "tools/call" {
req.params
.as_ref()
.and_then(|p| p.get("name"))
.and_then(|n| n.as_str())
.unwrap_or("unknown")
} else {
method
};
log::info!("MCP Request: {} (Actual: {}, Session: {:?})", method, actual_method, query.session_id);
// tools/call をブロードキャストして UI の MCP ACTIVITY に通知
if method == "tools/call" || matches!(actual_method, "get_item_by_id" | "add_item_text" | "search_text" | "lsa_search" | "update_item" | "delete_item" | "list_documents" | "get_document_count" | "get_document" | "delete_document" | "lsa_retrain") {
let _ = state.tx.send(format!("mcp:call:{}", actual_method));
}
let result = match method {
"initialize" => {
let client_version = req.params.as_ref()
.and_then(|p| p.as_object())
.and_then(|p| p.get("protocolVersion"))
.and_then(|v| v.as_str())
.unwrap_or("2024-11-05");
log::info!("MCP Handshake: Client requested protocol version {}", client_version);
Some(serde_json::json!({
"protocolVersion": client_version,
"capabilities": {
"tools": { "listChanged": false },
"resources": { "listChanged": false, "subscribe": false },
"prompts": { "listChanged": false },
"logging": {}
},
"serverInfo": { "name": "TelosDB", "version": "0.3.0" }
}))
},
"notifications/initialized" => {
log::info!("MCP Client Initialized");
None
}
"resources/list" => Some(serde_json::json!({ "resources": [] })),
"prompts/list" => Some(serde_json::json!({ "prompts": [] })),
"tools/list" => Some(serde_json::json!({
"tools": [
{
"name": "get_item_by_id",
"description": "Get a specific document chunk by ID",
"inputSchema": {
"type": "object",
"properties": {
"id": { "type": "integer" }
},
"required": ["id"]
}
},
{
"name": "add_item_text",
"description": "Add or overweight a document path. Chunks are generated automatically.",
"inputSchema": {
"type": "object",
"properties": {
"content": { "type": "string" },
"path": { "type": "string" },
"mime": { "type": "string" }
},
"required": ["content", "path"]
}
},
{
"name": "search_text",
"description": "Search document chunks using Hybrid Hybrid/Vector search (BM25 + LSA/HNSW)",
"inputSchema": {
"type": "object",
"properties": {
"content": { "type": "string" },
"limit": { "type": "integer", "default": 10 },
"min_score": { "type": "number", "default": 0.3, "description": "Minimum similarity (0-1). Results below this are dropped. Default 0.3." }
},
"required": ["content"]
}
},
{
"name": "update_item",
"description": "Update an existing chunk's content",
"inputSchema": {
"type": "object",
"properties": {
"id": { "type": "integer" },
"content": { "type": "string" }
},
"required": ["id", "content"]
}
},
{
"name": "delete_item",
"description": "Delete a chunk by ID",
"inputSchema": {
"type": "object",
"properties": {
"id": { "type": "integer" }
},
"required": ["id"]
}
},
{
"name": "list_documents",
"description": "List all documents (path, mime, chunk count)",
"inputSchema": { "type": "object", "properties": {} }
},
{
"name": "get_document_count",
"description": "Get the total count of documents stored in the database",
"inputSchema": { "type": "object", "properties": {} }
},
{
"name": "get_document",
"description": "Get full document content by document ID",
"inputSchema": {
"type": "object",
"properties": {
"document_id": { "type": "integer" },
"id": { "type": "integer" }
}
}
},
{
"name": "delete_document",
"description": "Delete a document and all its chunks",
"inputSchema": {
"type": "object",
"properties": {
"document_id": { "type": "integer" },
"id": { "type": "integer" }
}
}
},
{
"name": "lsa_retrain",
"description": "Manually trigger LSA model retraining and vector rebuild",
"inputSchema": { "type": "object", "properties": {} }
}
]
})),
"tools/call" | "get_item_by_id" | "add_item_text" | "search_text" | "lsa_search" | "update_item" | "delete_item" | "list_documents" | "get_document_count" | "get_document" | "delete_document" | "lsa_retrain" => {
let empty_map = serde_json::Map::new();
let mut args = req.params.as_ref().and_then(|p| p.as_object()).unwrap_or(&empty_map);
// tools/call の場合は、実際の引数は "arguments" フィールドにある
if method == "tools/call" {
if let Some(inner_args) = args.get("arguments").and_then(|a| a.as_object()) {
args = inner_args;
}
}
tools::dispatch_tool(&state, method, actual_method, args).await
}
_ => Some(serde_json::json!({
"content": [{ "type": "text", "text": format!("Method not implemented: {}", method) }],
"isError": true
})),
};
// Notifications (id == null) MUST NOT receive a response
if req.id.is_none() || req.id.as_ref().is_some_and(|v| v.is_null()) {
log::info!("MCP Notification received: {} (No response sent)", method);
return axum::http::StatusCode::NO_CONTENT.into_response();
}
if let Some(id_val) = req.id {
let resp = JsonRpcResponse {
jsonrpc: "2.0",
result,
error: None,
id: Some(id_val),
};
if let Some(sid) = query.session_id {
// MCP SPEC over SSE: "The server SHOULD send responses via the SSE stream... if a session ID is provided"
// オリジナルの動作に戻し、SSE経由で配送する。
let resp_str = serde_json::to_string(&resp).unwrap();
log::info!("Sending MCP Response via SSE (Session: {}, ID: {:?})", sid, resp.id);
let sessions = state.sessions.read().await;
if let Some(tx) = sessions.get(&sid) {
let _ = tx.send(resp_str);
}
axum::http::StatusCode::ACCEPTED.into_response()
} else {
// Direct request (not via SSE)
log::info!("Sending Direct MCP Response (ID: {:?})", resp.id);
Json(resp).into_response()
}
} else {
axum::http::StatusCode::NO_CONTENT.into_response()
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_text_chunking_logic() {
// 800文字ずつの分割を確認する
let chunk_size = 800;
// 1. ちょうど 800 文字
let text_800 = "a".repeat(800);
let chunks_800: Vec<String> = text_800.chars()
.collect::<Vec<char>>()
.chunks(chunk_size)
.map(|c| c.iter().collect())
.collect();
assert_eq!(chunks_800.len(), 1);
assert_eq!(chunks_800[0].len(), 800);
// 2. 801 文字 (2 チャンク)
let text_801 = "a".repeat(801);
let chunks_801: Vec<String> = text_801.chars()
.collect::<Vec<char>>()
.chunks(chunk_size)
.map(|c| c.iter().collect())
.collect();
assert_eq!(chunks_801.len(), 2);
assert_eq!(chunks_801[0].len(), 800);
assert_eq!(chunks_801[1].len(), 1);
// 3. 1600 文字 (2 チャンク)
let text_1600 = "a".repeat(1600);
let chunks_1600: Vec<String> = text_1600.chars()
.collect::<Vec<char>>()
.chunks(chunk_size)
.map(|c| c.iter().collect())
.collect();
assert_eq!(chunks_1600.len(), 2);
// 4. 空文字列
let text_empty = "";
let chunks_empty: Vec<String> = text_empty.chars()
.collect::<Vec<char>>()
.chunks(chunk_size)
.map(|c| c.iter().collect())
.collect();
assert_eq!(chunks_empty.len(), 0);
}
}