// use crate::db;
use axum::{
extract::{Query, State},
response::{
sse::{Event, Sse},
IntoResponse,
},
routing::{get, post},
Json, Router,
};
use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use chrono::Utc;
use sqlx::Row;
use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc, RwLock};
use tower_http::cors::{Any, CorsLayer};
use crate::utils::lsa::LsaModel;
use crate::utils::tokenizer::JapaneseTokenizer;
use hnsw_rs::prelude::*;
#[derive(Clone)]
pub struct AppState {
pub db_pool: sqlx::SqlitePool,
pub tx: broadcast::Sender<String>,
pub llama_status: Arc<RwLock<String>>,
pub model_name: String,
// MCP sessions map
pub sessions: Arc<RwLock<HashMap<String, mpsc::UnboundedSender<String>>>>,
// Japanese NLP & LSA
pub tokenizer: Arc<JapaneseTokenizer>,
pub lsa_model: Arc<RwLock<Option<LsaModel>>>,
pub hnsw_index: Arc<RwLock<Option<Hnsw<f32, DistCosine>>>>,
}
pub async fn run_server(
port: u16,
db_pool: sqlx::SqlitePool,
llama_status: Arc<RwLock<String>>,
model_name: String,
) {
let (tx, _rx) = broadcast::channel(100);
let sessions: Arc<RwLock<HashMap<String, mpsc::UnboundedSender<String>>>> = Arc::new(RwLock::new(HashMap::new()));
let app_state = AppState {
db_pool: db_pool.clone(),
tx,
llama_status: llama_status.clone(),
model_name,
sessions,
tokenizer: Arc::new(JapaneseTokenizer::new().expect("Failed to init tokenizer")),
lsa_model: Arc::new(RwLock::new(None)),
hnsw_index: Arc::new(RwLock::new(None)),
};
// 起動時に既存のデータから LSA モデルを構築する (重い処理なので非同期で実行)
let app_state_for_lsa = app_state.clone();
tokio::spawn(async move {
log::info!("Starting initial LSA model training...");
if let Ok(rows) = sqlx::query("SELECT content FROM items").fetch_all(&app_state_for_lsa.db_pool).await {
if !rows.is_empty() {
let mut builder = crate::utils::lsa::TermDocumentMatrixBuilder::new();
for row in rows {
let content: String = row.get(0);
let tokens = app_state_for_lsa.tokenizer.tokenize_to_vec(&content).unwrap_or_default();
builder.add_document(tokens);
}
let (matrix, idfs) = builder.build_matrix();
match LsaModel::train(&matrix, builder.vocabulary, idfs, 50) { // 50次元に圧縮
Ok(model) => {
let model_arc = Arc::new(model);
{
let mut lsa = app_state_for_lsa.lsa_model.write().await;
*lsa = Some((*model_arc).clone());
}
log::info!("LSA model trained successfully with {} documents.", builder.counts.len());
// HNSW インデックスの構築
log::info!("Building HNSW index...");
let hnsw: Hnsw<f32, DistCosine> = Hnsw::new(16, builder.counts.len().max(100), 16, 200, DistCosine {});
// ベクトルの同期(欠落データの補完)と HNSW への登録を行なう
sync_all_vectors(app_state_for_lsa.clone(), Some(hnsw)).await;
}
Err(e) => log::error!("LSA training failed: {}", e),
}
}
}
});
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
let app = Router::new()
.route("/sse", get(sse_handler))
.route("/messages", post(mcp_messages_handler))
.route("/llama_status", get(llama_status_handler))
.route("/doc_count", get(doc_count_handler))
.route("/model_name", get(model_name_handler))
.layer(cors)
.with_state(app_state);
let listener = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
log::info!("MCP Server listening on {}", listener.local_addr().unwrap());
axum::serve(listener, app).await.unwrap();
}
/// DB 内の全アイテムをチェックし、ベクトルが欠落または異常(全て0)なものを補完する
pub async fn sync_all_vectors(state: AppState, mut startup_hnsw: Option<Hnsw<f32, DistCosine>>) {
log::info!("Checking for missing or invalid vectors in vec_items...");
// items に存在し、かつ vec_items で (不在) または (全て0.0) のものを探す
let rows = match sqlx::query(
"SELECT i.id, i.content, v.embedding
FROM items i
LEFT JOIN vec_items v ON i.id = v.id"
)
.fetch_all(&state.db_pool)
.await {
Ok(rows) => rows,
Err(e) => {
log::error!("Failed to fetch items for sync: {}", e);
return;
}
};
let mut to_sync = Vec::new();
for row in rows {
let id: i64 = row.get(0);
let content: String = row.get(1);
let embedding_str: Option<String> = row.get(2);
let needs_sync = if let Some(s) = embedding_str {
if let Ok(vec) = serde_json::from_str::<Vec<f32>>(&s) {
// すべて 0.0 なら異常(ダミー)とみなす
vec.iter().all(|&x| x == 0.0)
} else {
true // パース失敗も異常
}
} else {
true // 不在
};
if needs_sync {
to_sync.push((id, content));
}
}
if to_sync.is_empty() {
log::info!("All vectors are healthy and synchronized.");
return;
}
log::info!("Found {} items needing vector update. Processing...", to_sync.len());
let lsa_guard = state.lsa_model.read().await;
let model = match lsa_guard.as_ref() {
Some(m) => m,
None => {
log::warn!("LSA model not available for sync.");
return;
}
};
let mut count = 0;
for (id, content) in to_sync {
let mut query_counts = HashMap::new();
let tokens = state.tokenizer.tokenize_to_vec(&content).unwrap_or_default();
for token in tokens {
if let Some(&tid) = model.vocabulary.get(&token) {
*query_counts.entry(tid).or_insert(0.0) += 1.0;
}
}
let mut query_vec = ndarray::Array1::zeros(model.vocabulary.len());
for (tid, count) in query_counts {
query_vec[tid] = count;
}
if let Ok(projected) = model.project_query(&query_vec) {
let mut proj_f32: Vec<f32> = projected.iter().map(|&x| x as f32).collect();
if proj_f32.len() < 50 { proj_f32.resize(50, 0.0); } else { proj_f32.truncate(50); }
let mut tx = match state.db_pool.begin().await {
Ok(t) => t,
Err(_) => continue,
};
// vec_items (virtual table) への反映 (REPLACE はできないので一度消すか INSERT OR REPLACE が効くか)
// vec0 は id が PRIMARY KEY なので DELETE/INSERT
let _ = sqlx::query("DELETE FROM vec_items WHERE id = ?").bind(id).execute(&mut *tx).await;
let _ = sqlx::query("INSERT INTO vec_items (id, embedding) VALUES (?, ?)")
.bind(id)
.bind(serde_json::to_string(&proj_f32).unwrap_or("[]".to_string()))
.execute(&mut *tx)
.await;
// items_lsa (backup)
let vector_blob = bincode::serialize(&proj_f32).unwrap_or_default();
let _ = sqlx::query("INSERT OR REPLACE INTO items_lsa (id, vector) VALUES (?, ?)")
.bind(id)
.bind(vector_blob)
.execute(&mut *tx)
.await;
if let Ok(_) = tx.commit().await {
count += 1;
}
}
}
log::info!("Successfully synchronized {} vectors.", count);
// HNSW インデックスを AppState に登録
if let Some(hnsw) = startup_hnsw {
// すでに同期済みのものも含め、全アイテムを HNSW に登録する
// (簡易実装のため、ここではDBから全件引き直す)
log::info!("Populating HNSW index from database...");
if let Ok(rows) = sqlx::query("SELECT id, embedding FROM vec_items").fetch_all(&state.db_pool).await {
for row in rows {
let id: i64 = row.get(0);
let embedding_str: String = row.get(1);
if let Ok(vec) = serde_json::from_str::<Vec<f32>>(&embedding_str) {
if vec.len() == 50 {
hnsw.parallel_insert(&vec, id as usize);
}
}
}
}
let mut idx = state.hnsw_index.write().await;
*idx = Some(hnsw);
log::info!("HNSW index is now ready.");
}
}
async fn llama_status_handler(State(state): State<AppState>) -> impl IntoResponse {
let status = state.llama_status.read().await.clone();
Json(serde_json::json!({ "status": status }))
}
async fn doc_count_handler(State(state): State<AppState>) -> impl IntoResponse {
let row = sqlx::query("SELECT COUNT(*) FROM items")
.fetch_one(&state.db_pool)
.await
.unwrap();
let count: i64 = row.get(0);
Json(serde_json::json!({ "count": count }))
}
async fn model_name_handler(State(state): State<AppState>) -> impl IntoResponse {
Json(serde_json::json!({ "model_name": state.model_name }))
}
#[allow(dead_code)]
#[derive(Deserialize)]
struct SseQuery {
session_id: Option<String>,
}
async fn sse_handler(
State(state): State<AppState>,
Query(_query): Query<SseQuery>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
// Generate a simple session ID
let session_id = uuid::Uuid::new_v4().to_string();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<String>();
log::info!("New MCP SSE Session: {}", session_id);
// Register session
state.sessions.write().await.insert(session_id.clone(), tx);
// Initial endpoint event
let endpoint_url = format!("/messages?session_id={}", session_id);
let endpoint_event = Event::default().event("endpoint").data(endpoint_url);
let session_id_for_close = session_id.clone();
let sessions_for_close = state.sessions.clone();
let global_rx = state.tx.subscribe();
let stream = futures::stream::unfold(
(
rx,
Some(endpoint_event),
session_id_for_close,
sessions_for_close,
global_rx,
),
|(mut rx, mut initial, sid, smap, mut grx): (
tokio::sync::mpsc::UnboundedReceiver<String>,
Option<Event>,
String,
Arc<RwLock<HashMap<String, tokio::sync::mpsc::UnboundedSender<String>>>>,
tokio::sync::broadcast::Receiver<String>,
)| async move {
if let Some(event) = initial.take() {
return Some((Ok(event), (rx, None, sid, smap, grx)));
}
tokio::select! {
Some(msg) = rx.recv() => {
Some((Ok(Event::default().event("message").data(msg)), (rx, None, sid, smap, grx)))
}
Ok(msg) = grx.recv() => {
// Global notification (e.g. data update)
Some((Ok(Event::default().event("update").data(msg)), (rx, None, sid, smap, grx)))
}
else => {
log::info!("MCP SSE Session Closed: {}", sid);
smap.write().await.remove(&sid);
None
}
}
},
);
Sse::new(stream).keep_alive(axum::response::sse::KeepAlive::default())
}
#[derive(Serialize, Deserialize)]
struct JsonRpcRequest {
jsonrpc: String,
method: String,
params: Option<serde_json::Value>,
id: Option<serde_json::Value>,
}
#[derive(Serialize)]
struct JsonRpcResponse {
jsonrpc: &'static str,
#[serde(skip_serializing_if = "Option::is_none")]
result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<serde_json::Value>,
id: Option<serde_json::Value>,
}
#[derive(Deserialize)]
struct MessageQuery {
session_id: Option<String>,
}
impl IntoResponse for JsonRpcResponse {
fn into_response(self) -> axum::response::Response {
Json(self).into_response()
}
}
async fn mcp_messages_handler(
State(state): State<AppState>,
Query(query): Query<MessageQuery>,
Json(req): Json<JsonRpcRequest>,
) -> impl IntoResponse {
let method = req.method.as_str();
log::info!("MCP Request: {} (Session: {:?})", method, query.session_id);
// 受信データを構造化JSONで出力(timestamp と source を含む)
let structured = serde_json::json!({
"timestamp": Utc::now().to_rfc3339(),
"source": "mcp",
"session": query.session_id,
"method": method,
"id": req.id,
"params": req.params,
});
log::info!("{}", serde_json::to_string(&structured).unwrap_or_else(|_| "{\"error\":\"serialize_failed\"}".to_string()));
let result: Option<serde_json::Value> = match method {
"initialize" => {
let client_version = req.params.as_ref()
.and_then(|p| p.get("protocolVersion"))
.and_then(|v| v.as_str())
.unwrap_or("2024-11-05");
log::info!("MCP Handshake: Client requested protocol version {}", client_version);
Some(serde_json::json!({
"protocolVersion": client_version,
"capabilities": {
"tools": { "listChanged": false },
"resources": { "listChanged": false, "subscribe": false },
"prompts": { "listChanged": false },
"logging": {}
},
"serverInfo": { "name": "TelosDB", "version": "0.1.0" }
}))
},
"notifications/initialized" => None,
"tools/list" => Some(serde_json::json!({
"tools": [
{
"name": "add_item_text",
"description": "Store text with auto-generated LSA vectors (No LLM required).",
"inputSchema": {
"type": "object",
"properties": {
"content": { "type": "string" },
"path": { "type": "string" }
},
"required": ["content"]
}
},
{
"name": "search_text",
"description": "Semantic search using LSA (Latent Semantic Analysis). Lightweight and fast.",
"inputSchema": {
"type": "object",
"properties": {
"content": { "type": "string" },
"limit": { "type": "number" }
},
"required": ["content"]
}
},
{
"name": "lsa_retrain",
"description": "Rebuild the LSA semantic model from all current documents. Use this when you've added many new items.",
"inputSchema": { "type": "object", "properties": {} }
},
{
"name": "update_item",
"description": "Update existing text and its LSA vector.",
"inputSchema": {
"type": "object",
"properties": {
"id": { "type": "integer" },
"content": { "type": "string" },
"path": { "type": "string" }
},
"required": ["id", "content"]
}
},
{
"name": "delete_item",
"description": "Delete item by ID.",
"inputSchema": {
"type": "object",
"properties": {
"id": { "type": "integer" }
},
"required": ["id"]
}
},
{
"name": "get_item_by_id",
"description": "Get text content by item ID.",
"inputSchema": {
"type": "object",
"properties": {
"id": { "type": "integer" }
},
"required": ["id"]
}
}
]
})),
"search_text" | "tools/call" | "add_item_text" | "update_item" | "delete_item" | "get_item_by_id" => {
let p = req.params.clone().unwrap_or_default();
let (actual_method, args) = if method == "tools/call" {
(
p.get("name").and_then(|v| v.as_str()).unwrap_or(""),
p.get("arguments").cloned().unwrap_or_default(),
)
} else {
(method, p)
};
// UIへの通知(ツール呼び出し開始)
let _ = state.tx.send(format!("mcp:call:{}", actual_method));
match actual_method {
"get_item_by_id" => {
let id = args.get("id").and_then(|v| v.as_i64()).unwrap_or(0);
let row = sqlx::query("SELECT id, content, path FROM items WHERE id = ?")
.bind(id)
.fetch_optional(&state.db_pool)
.await
.unwrap_or(None);
if let Some(row) = row {
let content: String = row.get("content");
let path: Option<String> = row.try_get("path").ok();
Some(serde_json::json!({
"id": id,
"content": content,
"path": path
}))
} else {
Some(serde_json::json!({ "error": format!("Item not found: {}", id) }))
}
}
"add_item_text" => {
let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("");
let path = args.get("path").and_then(|v| v.as_str());
log::info!(
"Executing add_item_text (LSA-only): content length={}, path='{:?}'",
content.chars().count(),
path
);
// 800文字ずつに分割
let chars: Vec<char> = content.chars().collect();
let chunks: Vec<String> = chars
.chunks(800)
.map(|chunk| chunk.iter().collect::<String>())
.collect();
let mut results = Vec::new();
for (_i, chunk_content) in chunks.iter().enumerate() {
async fn add_item_inner(
state: &AppState,
content: &str,
path: Option<&str>,
) -> Result<i64, String> {
let mut tx =
state.db_pool.begin().await.map_err(|e| {
format!("Failed to begin transaction: {}", e)
})?;
let res =
sqlx::query("INSERT INTO items (content, path) VALUES (?, ?)")
.bind(content)
.bind(path)
.execute(&mut *tx)
.await
.map_err(|e| format!("Failed to insert item: {}", e))?;
let id = res.last_insert_rowid();
// LSA ベクトルの計算
let mut lsa_vector_f32: Vec<f32> = vec![0.0; 50];
let lsa_guard = state.lsa_model.read().await;
if let Some(model) = lsa_guard.as_ref() {
let mut query_counts = HashMap::new();
let tokens = state.tokenizer.tokenize_to_vec(content).unwrap_or_default();
for token in tokens {
if let Some(&tid) = model.vocabulary.get(&token) {
*query_counts.entry(tid).or_insert(0.0) += 1.0;
}
}
let mut query_vec = ndarray::Array1::zeros(model.vocabulary.len());
for (tid, count) in query_counts {
query_vec[tid] = count;
}
if let Ok(projected) = model.project_query(&query_vec) {
lsa_vector_f32 = projected.iter().map(|&x| x as f32).collect();
// 50次元に満たない(モデル初期化時のランク制限等)場合はパディング
if lsa_vector_f32.len() < 50 {
lsa_vector_f32.resize(50, 0.0);
} else if lsa_vector_f32.len() > 50 {
lsa_vector_f32.truncate(50);
}
}
}
// sqlite-vec の仮想テーブル (vec_items) に LSA ベクトルを保存
sqlx::query("INSERT INTO vec_items (id, embedding) VALUES (?, ?)")
.bind(id)
.bind(serde_json::to_string(&lsa_vector_f32).unwrap_or("[]".to_string()))
.execute(&mut *tx)
.await
.map_err(|e| format!("Failed to insert LSA vector to vec_items: {}", e))?;
// items_lsa にもバックアップ(または生データ)として保存
if let Some(_) = lsa_guard.as_ref() {
let vector_blob = bincode::serialize(&lsa_vector_f32).unwrap_or_default();
sqlx::query("INSERT INTO items_lsa (id, vector) VALUES (?, ?)")
.bind(id)
.bind(vector_blob)
.execute(&mut *tx)
.await
.map_err(|e| format!("Failed to insert LSA blob: {}", e))?;
}
tx.commit()
.await
.map_err(|e| format!("Failed to commit transaction: {}", e))?;
// HNSW インデックスへの反映
let hnsw_guard = state.hnsw_index.read().await;
if let Some(hnsw) = hnsw_guard.as_ref() {
if lsa_vector_f32.len() == 50 {
hnsw.insert(&lsa_vector_f32, id as usize);
}
}
Ok(id)
}
match add_item_inner(&state, chunk_content, path).await {
Ok(id) => results.push(id),
Err(e) => log::error!("Failed to add chunk: {}", e),
}
}
if !results.is_empty() {
let _ = state.tx.send("data_changed".to_string());
log::info!("Successfully added {} chunks via LSA.", results.len());
Some(
serde_json::json!({ "content": [{ "type": "text", "text": format!("Successfully added {} chunks (LSA).", results.len()) }] }),
)
} else {
Some(serde_json::json!({ "error": "Failed to add any chunks." }))
}
}
"search_text" => {
let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("");
let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(10);
// LLM の代わりに内部で LSA クエリを構成
let lsa_guard = state.lsa_model.read().await;
if let Some(model) = lsa_guard.as_ref() {
let mut query_counts = HashMap::new();
let tokens = state.tokenizer.tokenize_to_vec(content).unwrap_or_default();
for token in tokens {
if let Some(&id) = model.vocabulary.get(&token) {
*query_counts.entry(id).or_insert(0.0) += 1.0;
}
}
let mut query_vec = ndarray::Array1::zeros(model.vocabulary.len());
for (id, count) in query_counts {
query_vec[id] = count;
}
if let Ok(query_lsa) = model.project_query(&query_vec) {
// クエリが語彙に含まれず零ベクトルになった場合
if query_lsa.iter().all(|&x| x == 0.0) {
return Some(serde_json::json!({ "content": [] }));
}
let mut query_lsa_f32: Vec<f32> = query_lsa.iter().map(|&x| x as f32).collect();
if query_lsa_f32.len() < 50 {
query_lsa_f32.resize(50, 0.0);
} else if query_lsa_f32.len() > 50 {
query_lsa_f32.truncate(50);
}
// HNSW インデックスがあればそれを使う、なければ sqlite-vec でフォールバック
let hnsw_guard = state.hnsw_index.read().await;
if let Some(hnsw) = hnsw_guard.as_ref() {
log::info!("Searching using HNSW index...");
let neighbors = hnsw.search(&query_lsa_f32, limit as usize, 100);
let mut results = Vec::new();
for neighbor in neighbors {
let id = neighbor.d_id as i64;
let dist = neighbor.distance;
// HNSW の DistCosine は通常 1 - cos_sim
let sim = 1.0 - dist;
if let Ok(row) = sqlx::query("SELECT content FROM items WHERE id = ?").bind(id).fetch_one(&state.db_pool).await {
results.push(serde_json::json!({
"id": id,
"content": row.get::<String, _>(0),
"similarity": sim.clamp(0.0, 1.0)
}));
}
}
return Some(serde_json::json!({ "content": results }));
}
// sqlite-vec の MATCH (BM25等ではなくベクトル近傍検索) を使用
let rows = sqlx::query(
"SELECT items.id, items.content, v.distance
FROM items
JOIN vec_items v ON items.id = v.id
WHERE v.embedding MATCH ? AND k = ?
ORDER BY distance LIMIT ?",
)
.bind(serde_json::to_string(&query_lsa_f32).unwrap_or("[]".to_string()))
.bind(limit)
.bind(limit)
.fetch_all(&state.db_pool)
.await
.unwrap_or_default();
let res: Vec<_> = rows.iter().map(|r| {
let id = r.get::<i64, _>(0);
let content = r.get::<String, _>(1);
let distance = r.get::<f64, _>(2);
// sqlite-vec の distance は L2 距離の 2 乗
// 正規化ベクトル [u, v] において:
// ||u-v||^2 = ||u||^2 + ||v||^2 - 2*u*v = 1 + 1 - 2*cos_sim = 2 - 2*cos_sim
// よって cos_sim = 1.0 - (distance / 2.0)
let sim = 1.0 - (distance / 2.0);
serde_json::json!({
"id": id,
"content": content,
"similarity": sim.clamp(0.0, 1.0)
})
}).collect();
Some(serde_json::json!({ "content": res }))
} else {
Some(serde_json::json!({ "error": "LSA query projection failed" }))
}
} else {
// LSA モデルがない場合は LIKE 検索でフォールバック
let rows = sqlx::query("SELECT id, content FROM items WHERE content LIKE ? LIMIT ?")
.bind(format!("%{}%", content))
.bind(limit)
.fetch_all(&state.db_pool)
.await
.unwrap_or_default();
let res: Vec<_> = rows.iter().map(|r| serde_json::json!({ "id": r.get::<i64,_>(0), "content": r.get::<String,_>(1), "similarity": 0.0 })).collect();
Some(serde_json::json!({ "content": res }))
}
}
"lsa_search" => {
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(10);
let lsa_guard = state.lsa_model.read().await;
if let Some(model) = lsa_guard.as_ref() {
// クエリのベクトル化 (TF)
let mut query_counts = HashMap::new();
let tokens = state.tokenizer.tokenize_to_vec(query).unwrap_or_default();
for token in tokens {
if let Some(&id) = model.vocabulary.get(&token) {
*query_counts.entry(id).or_insert(0.0) += 1.0;
}
}
let mut query_vec = ndarray::Array1::zeros(model.vocabulary.len());
for (id, count) in query_counts {
query_vec[id] = count;
}
// LSA 空間への射影
if let Ok(query_lsa) = model.project_query(&query_vec) {
// DB から全ベクトルを取得して比較 (件数が少ない想定)
// 本来はアイテム数が多い場合は BLOB を全件回すと遅いため、インメモリキャッシュ等を検討
let rows = sqlx::query("SELECT id, vector FROM items_lsa")
.fetch_all(&state.db_pool)
.await
.unwrap_or_default();
let mut results = Vec::new();
for row in rows {
let id: i64 = row.get(0);
let vector_blob: Vec<u8> = row.get(1);
if let Ok(vector_f64) = bincode::deserialize::<Vec<f64>>(&vector_blob) {
let doc_vec = ndarray::Array1::from_vec(vector_f64);
let sim = crate::utils::lsa::LsaModel::cosine_similarity(&query_lsa, &doc_vec);
results.push((id, sim));
}
}
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(limit as usize);
let mut filtered_results = Vec::new();
for (id, sim) in results {
if let Ok(doc_row) = sqlx::query("SELECT content FROM items WHERE id = ?").bind(id).fetch_one(&state.db_pool).await {
let content: String = doc_row.get(0);
filtered_results.push(serde_json::json!({
"id": id,
"content": content,
"similarity": sim
}));
}
}
Some(serde_json::json!({ "content": filtered_results }))
} else {
Some(serde_json::json!({ "error": "Query projection failed" }))
}
} else {
Some(serde_json::json!({ "error": "LSA model not initialized or no data available" }))
}
}
"update_item" => {
let id = args.get("id").and_then(|v| v.as_i64()).unwrap_or(0);
let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("");
let path = args.get("path").and_then(|v| v.as_str());
async fn update_item_inner(
state: &AppState,
id: i64,
content: &str,
path: Option<&str>,
) -> Result<(), String> {
let mut tx =
state.db_pool.begin().await.map_err(|e| {
format!("Failed to begin transaction: {}", e)
})?;
sqlx::query("UPDATE items SET content = ?, path = ? WHERE id = ?")
.bind(content)
.bind(path)
.bind(id)
.execute(&mut *tx)
.await
.map_err(|e| format!("Failed to update item: {}", e))?;
// LSA ベクトルの更新
let lsa_guard = state.lsa_model.read().await;
if let Some(model) = lsa_guard.as_ref() {
let mut query_counts = HashMap::new();
let tokens = state.tokenizer.tokenize_to_vec(content).unwrap_or_default();
for token in tokens {
if let Some(&tid) = model.vocabulary.get(&token) {
*query_counts.entry(tid).or_insert(0.0) += 1.0;
}
}
let mut query_vec = ndarray::Array1::zeros(model.vocabulary.len());
for (tid, count) in query_counts {
query_vec[tid] = count;
}
if let Ok(projected) = model.project_query(&query_vec) {
let vector_blob = bincode::serialize(&projected.to_vec()).unwrap_or_default();
sqlx::query("INSERT OR REPLACE INTO items_lsa (id, vector) VALUES (?, ?)")
.bind(id)
.bind(vector_blob)
.execute(&mut *tx)
.await
.map_err(|e| format!("Failed to update LSA vector: {}", e))?;
}
}
tx.commit()
.await
.map_err(|e| format!("Failed to commit transaction: {}", e))?;
Ok(())
}
if let Err(e) = update_item_inner(&state, id, content, path).await
{
Some(serde_json::json!({ "error": e }))
} else {
let _ = state.tx.send("data_changed".to_string());
Some(
serde_json::json!({ "content": [{ "type": "text", "text": format!("Successfully updated item {} (LSA)", id) }] }),
)
}
}
"delete_item" => {
let id = args.get("id").and_then(|v| v.as_i64()).unwrap_or(0);
async fn delete_item_inner(state: &AppState, id: i64) -> Result<(), String> {
let mut tx = state
.db_pool
.begin()
.await
.map_err(|e| format!("Failed to begin transaction: {}", e))?;
sqlx::query("DELETE FROM items WHERE id = ?")
.bind(id)
.execute(&mut *tx)
.await
.map_err(|e| format!("Failed to delete item: {}", e))?;
sqlx::query("DELETE FROM vec_items WHERE id = ?")
.bind(id)
.execute(&mut *tx)
.await
.map_err(|e| format!("Failed to delete vector: {}", e))?;
tx.commit()
.await
.map_err(|e| format!("Failed to commit transaction: {}", e))?;
Ok(())
}
if let Err(e) = delete_item_inner(&state, id).await {
Some(serde_json::json!({ "error": e }))
} else {
let _ = state.tx.send("data_changed".to_string());
Some(
serde_json::json!({ "content": [{ "type": "text", "text": format!("Successfully deleted item {}", id) }] }),
)
}
}
"lsa_retrain" => {
log::info!("Manual LSA retrain triggered.");
let state_clone = state.clone();
tokio::spawn(async move {
if let Ok(rows) = sqlx::query("SELECT id, content FROM items").fetch_all(&state_clone.db_pool).await {
if !rows.is_empty() {
let mut builder = crate::utils::lsa::TermDocumentMatrixBuilder::new();
let mut ids = Vec::new();
for row in rows {
let id: i64 = row.get(0);
let content: String = row.get(1);
let tokens = state_clone.tokenizer.tokenize_to_vec(&content).unwrap_or_default();
builder.add_document(tokens);
ids.push(id);
}
let (matrix, idfs) = builder.build_matrix();
match crate::utils::lsa::LsaModel::train(&matrix, builder.vocabulary, idfs, 50) {
Ok(model) => {
let mut tx = state_clone.db_pool.begin().await.unwrap();
sqlx::query("DELETE FROM items_lsa").execute(&mut *tx).await.unwrap();
sqlx::query("DELETE FROM vec_items").execute(&mut *tx).await.unwrap();
for (i, &id) in ids.iter().enumerate() {
let mut doc_tf = ndarray::Array1::zeros(model.vocabulary.len());
for (&tid, &count) in &builder.counts[i] {
doc_tf[tid] = count;
}
if let Ok(projected) = model.project_query(&doc_tf) {
let mut proj_f32: Vec<f32> = projected.iter().map(|&x| x as f32).collect();
if proj_f32.len() < 50 { proj_f32.resize(50, 0.0); } else { proj_f32.truncate(50); }
let vector_blob = bincode::serialize(&proj_f32).unwrap_or_default();
sqlx::query("INSERT INTO items_lsa (id, vector) VALUES (?, ?)")
.bind(id)
.bind(vector_blob)
.execute(&mut *tx)
.await
.unwrap();
sqlx::query("INSERT INTO vec_items (id, embedding) VALUES (?, ?)")
.bind(id)
.bind(serde_json::to_string(&proj_f32).unwrap_or("[]".to_string()))
.execute(&mut *tx)
.await
.unwrap();
}
}
tx.commit().await.unwrap();
let mut lsa = state_clone.lsa_model.write().await;
*lsa = Some(model);
// HNSW インデックスの再構築
let hnsw: Hnsw<f32, DistCosine> = Hnsw::new(16, ids.len().max(100), 16, 200, DistCosine {});
// 登録済みの全ベクトルを HNSW に入れ直す
// (簡易実装:DBから再度引き直すか、現在のループで生成したものを入れる)
// ここでは sync_all_vectors(state, Some(hnsw)) を呼ぶのが楽
log::info!("Manual LSA retrain completed successfully. Rebuilding HNSW...");
sync_all_vectors(state_clone.clone(), Some(hnsw)).await;
}
Err(e) => log::error!("Manual LSA training failed: {}", e),
}
}
}
});
Some(serde_json::json!({ "content": [{ "type": "text", "text": "LSA retrain started in background." }] }))
}
_ => Some(serde_json::json!({ "error": "Unknown tool" })),
}
}
_ => Some(serde_json::json!({ "error": "Not implemented" })),
};
// Notifications (id == null) MUST NOT receive a response
if req.id.is_none() || req.id.as_ref().map_or(false, |v| v.is_null()) {
log::info!("MCP Notification received: {} (No response sent)", method);
return axum::http::StatusCode::NO_CONTENT.into_response();
}
if let Some(id_val) = req.id {
let resp = JsonRpcResponse {
jsonrpc: "2.0",
result,
error: None,
id: Some(id_val),
};
if let Some(sid) = query.session_id {
// MCP Client (SSE Mode)
let resp_str = serde_json::to_string(&resp).unwrap();
log::info!("Sending MCP Response (Session: {}, ID: {:?}): {}", sid, resp.id, resp_str);
let sessions = state.sessions.read().await;
if let Some(tx) = sessions.get(&sid) {
let _ = tx.send(resp_str);
}
axum::http::StatusCode::ACCEPTED.into_response()
} else {
// App UI (Direct Mode)
resp.into_response()
}
} else {
axum::http::StatusCode::NO_CONTENT.into_response()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_text_chunking_logic() {
// 800文字ずつの分割を確認する
let chunk_size = 800;
// 1. ちょうど 800 文字
let text_800 = "a".repeat(800);
let chunks_800: Vec<String> = text_800.chars()
.collect::<Vec<char>>()
.chunks(chunk_size)
.map(|c| c.iter().collect())
.collect();
assert_eq!(chunks_800.len(), 1);
assert_eq!(chunks_800[0].len(), 800);
// 2. 801 文字 (2 チャンク)
let text_801 = "a".repeat(801);
let chunks_801: Vec<String> = text_801.chars()
.collect::<Vec<char>>()
.chunks(chunk_size)
.map(|c| c.iter().collect())
.collect();
assert_eq!(chunks_801.len(), 2);
assert_eq!(chunks_801[0].len(), 800);
assert_eq!(chunks_801[1].len(), 1);
// 3. 1600 文字 (2 チャンク)
let text_1600 = "a".repeat(1600);
let chunks_1600: Vec<String> = text_1600.chars()
.collect::<Vec<char>>()
.chunks(chunk_size)
.map(|c| c.iter().collect())
.collect();
assert_eq!(chunks_1600.len(), 2);
// 4. 空文字列
let text_empty = "";
let chunks_empty: Vec<String> = text_empty.chars()
.collect::<Vec<char>>()
.chunks(chunk_size)
.map(|c| c.iter().collect())
.collect();
assert_eq!(chunks_empty.len(), 0);
}
}