use std::collections::HashMap;
use sqlx::Row;
use crate::mcp::types::AppState;
pub async fn handle_search_text(
state: &AppState,
actual_method: &str,
args: &serde_json::Map<String, serde_json::Value>,
) -> Option<serde_json::Value> {
let search_content = if actual_method == "lsa_search" {
args.get("query").and_then(|v| v.as_str()).unwrap_or("")
} else {
args.get("content").and_then(|v| v.as_str()).unwrap_or("")
};
let search_limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(10);
// 足切り: この値未満の similarity は返さない。未指定時は 0.3(ノイズを落としつつ関連しそうなものを残す)
let min_score = args.get("min_score")
.or_else(|| args.get("minScore"))
.and_then(|v| v.as_f64())
.unwrap_or(0.3)
.clamp(0.0, 1.0) as f32;
if search_content.is_empty() {
return Some(serde_json::json!({
"content": [{ "type": "text", "text": "Empty search query provided." }]
}));
}
// 1. FTS5 (BM25) search - Elasticsearch-like statistical ranking
let mut fts_results = HashMap::new();
if let Ok(rows) = sqlx::query(
"SELECT rowid, bm25(items_fts) as score
FROM items_fts
WHERE items_fts MATCH ?
ORDER BY score LIMIT ?"
).bind(search_content).bind(search_limit).fetch_all(&state.db_pool).await {
for row in rows {
let id: i64 = row.get(0);
let bm25_score: f64 = row.get(1);
// Convert BM25 score to a 0-1 similarity score (pseudo-normalization)
let sim = (1.0 - (bm25_score / 10.0).tanh()).clamp(0.0, 1.0) as f32;
fts_results.insert(id, sim);
}
}
// 2. Vector Search (LSA/HNSW)
let mut final_results: HashMap<i64, serde_json::Value> = HashMap::new();
let lsa_guard = state.lsa_model.read().await;
if let Some(model) = lsa_guard.as_ref() {
let mut query_counts = HashMap::new();
let tokens = state.tokenizer.tokenize_to_vec(search_content).unwrap_or_default();
for token in tokens {
if let Some(&tid) = model.vocabulary.get(&token) {
*query_counts.entry(tid).or_insert(0.0) += 1.0;
}
}
// 該当する語句が語彙に1つもない場合、query_vec がゼロベクトルになり、
// ゼロに近いドキュメントが distance≈0 で返って similarity=1.0 になるのを防ぐ
if query_counts.is_empty() {
// ベクトル検索はスキップ。FTS のみの結果にする(FTS もヒットしなければ空)
} else {
let mut query_vec = ndarray::Array1::zeros(model.vocabulary.len());
for (tid, count) in query_counts {
query_vec[tid] = count;
}
if let Ok(query_lsa) = model.project_query(&query_vec) {
let mut query_lsa_f32: Vec<f32> = query_lsa.iter().map(|&x| x as f32).collect();
if query_lsa_f32.len() < 50 { query_lsa_f32.resize(50, 0.0); } else { query_lsa_f32.truncate(50); }
// HNSW or Virtual Table search
let mut vector_hits = Vec::new();
let hnsw_idx_guard = state.hnsw_index.read().await;
if let Some(h_ptr) = hnsw_idx_guard.as_ref() {
let neighbors = h_ptr.search(&query_lsa_f32, (search_limit * 2) as usize, 100);
for n in neighbors {
vector_hits.push((n.d_id as i64, 1.0f32 - n.distance));
}
}
if vector_hits.is_empty() {
if let Ok(rows) = sqlx::query(
"SELECT id, distance FROM vec_items WHERE embedding MATCH ? AND k = ?"
)
.bind(serde_json::to_string(&query_lsa_f32).unwrap_or("[]".to_string()))
.bind(search_limit * 2).fetch_all(&state.db_pool).await {
for r in rows {
let id: i64 = r.get(0);
let dist: f64 = r.get(1);
vector_hits.push((id, (1.0 - (dist / 2.0)) as f32));
}
}
}
// 3. Merge Vector and FTS results
for (id, v_sim) in vector_hits {
let f_sim = fts_results.get(&id).cloned().unwrap_or(0.0);
let final_sim = v_sim.max(f_sim);
if let Ok(row) = sqlx::query(
"SELECT i.content, d.path, d.mime FROM items i JOIN documents d ON i.document_id = d.id WHERE i.id = ?"
).bind(id).fetch_one(&state.db_pool).await {
final_results.insert(id, serde_json::json!({
"id": id,
"content": row.get::<String, _>(0),
"path": row.get::<String, _>(1),
"mime": row.get::<Option<String>, _>(2),
"similarity": final_sim.clamp(0.0, 1.0)
}));
}
}
}
}
}
// 4. Add remaining FTS results not found by vector search
for (id, f_sim) in fts_results {
if !final_results.contains_key(&id) {
if let Ok(row) = sqlx::query(
"SELECT i.content, d.path, d.mime FROM items i JOIN documents d ON i.document_id = d.id WHERE i.id = ?"
).bind(id).fetch_one(&state.db_pool).await {
final_results.insert(id, serde_json::json!({
"id": id,
"content": row.get::<String, _>(0),
"path": row.get::<String, _>(1),
"mime": row.get::<Option<String>, _>(2),
"similarity": f_sim.clamp(0.0, 1.0)
}));
}
}
}
let mut sorted: Vec<_> = final_results.into_values().collect();
sorted.sort_by(|a, b| {
b.get("similarity").and_then(|v| v.as_f64()).unwrap_or(0.0)
.partial_cmp(&a.get("similarity").and_then(|v| v.as_f64()).unwrap_or(0.0))
.unwrap_or(std::cmp::Ordering::Equal)
});
let final_items = sorted
.into_iter()
.filter(|v| v.get("similarity").and_then(|x| x.as_f64()).unwrap_or(0.0) >= min_score as f64)
.take(search_limit as usize)
.collect::<Vec<_>>();
let result_text = serde_json::to_string_pretty(&final_items).unwrap_or_else(|_| "[]".to_string());
Some(serde_json::json!({
"content": [{ "type": "text", "text": result_text }]
}))
}