Newer
Older
TelosDB / src / backend / src / mcp / tools / search.rs
use std::collections::HashMap;
use sqlx::Row;
use crate::mcp::types::AppState;

pub async fn handle_search_text(
    state: &AppState,
    actual_method: &str,
    args: &serde_json::Map<String, serde_json::Value>,
) -> Option<serde_json::Value> {
    let search_content = if actual_method == "lsa_search" {
        args.get("query").and_then(|v| v.as_str()).unwrap_or("")
    } else {
        args.get("content").and_then(|v| v.as_str()).unwrap_or("")
    };
    let search_limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(10);
    // 足切り: この値未満の similarity は返さない。未指定時は 0.3(ノイズを落としつつ関連しそうなものを残す)
    let min_score = args.get("min_score")
        .or_else(|| args.get("minScore"))
        .and_then(|v| v.as_f64())
        .unwrap_or(0.3)
        .clamp(0.0, 1.0) as f32;

    if search_content.is_empty() {
        return Some(serde_json::json!({
            "content": [{ "type": "text", "text": "Empty search query provided." }]
        }));
    }

    // 1. FTS5 (BM25) search - Elasticsearch-like statistical ranking
    let mut fts_results = HashMap::new();
    if let Ok(rows) = sqlx::query(
        "SELECT rowid, bm25(items_fts) as score 
         FROM items_fts 
         WHERE items_fts MATCH ? 
         ORDER BY score LIMIT ?"
    ).bind(search_content).bind(search_limit).fetch_all(&state.db_pool).await {
        for row in rows {
            let id: i64 = row.get(0);
            let bm25_score: f64 = row.get(1);
            // Convert BM25 score to a 0-1 similarity score (pseudo-normalization)
            let sim = (1.0 - (bm25_score / 10.0).tanh()).clamp(0.0, 1.0) as f32;
            fts_results.insert(id, sim);
        }
    }

    // 2. Vector Search (LSA/HNSW)
    let mut final_results: HashMap<i64, serde_json::Value> = HashMap::new();
    let lsa_guard = state.lsa_model.read().await;
    if let Some(model) = lsa_guard.as_ref() {
        let mut query_counts = HashMap::new();
        let tokens = state.tokenizer.tokenize_to_vec(search_content).unwrap_or_default();
        for token in tokens {
            if let Some(&tid) = model.vocabulary.get(&token) {
                *query_counts.entry(tid).or_insert(0.0) += 1.0;
            }
        }
        // 該当する語句が語彙に1つもない場合、query_vec がゼロベクトルになり、
        // ゼロに近いドキュメントが distance≈0 で返って similarity=1.0 になるのを防ぐ
        if query_counts.is_empty() {
            // ベクトル検索はスキップ。FTS のみの結果にする(FTS もヒットしなければ空)
        } else {
        let mut query_vec = ndarray::Array1::zeros(model.vocabulary.len());
        for (tid, count) in query_counts {
            query_vec[tid] = count;
        }

        if let Ok(query_lsa) = model.project_query(&query_vec) {
            let mut query_lsa_f32: Vec<f32> = query_lsa.iter().map(|&x| x as f32).collect();
            if query_lsa_f32.len() < 50 { query_lsa_f32.resize(50, 0.0); } else { query_lsa_f32.truncate(50); }

            // HNSW or Virtual Table search
            let mut vector_hits = Vec::new();
            let hnsw_idx_guard = state.hnsw_index.read().await;
            if let Some(h_ptr) = hnsw_idx_guard.as_ref() {
                let neighbors = h_ptr.search(&query_lsa_f32, (search_limit * 2) as usize, 100);
                for n in neighbors {
                    vector_hits.push((n.d_id as i64, 1.0f32 - n.distance));
                }
            } 
            
            if vector_hits.is_empty() {
                if let Ok(rows) = sqlx::query(
                    "SELECT id, distance FROM vec_items WHERE embedding MATCH ? AND k = ?"
                )
                .bind(serde_json::to_string(&query_lsa_f32).unwrap_or("[]".to_string()))
                .bind(search_limit * 2).fetch_all(&state.db_pool).await {
                    for r in rows {
                        let id: i64 = r.get(0);
                        let dist: f64 = r.get(1);
                        vector_hits.push((id, (1.0 - (dist / 2.0)) as f32));
                    }
                }
            }

            // 3. Merge Vector and FTS results
            for (id, v_sim) in vector_hits {
                let f_sim = fts_results.get(&id).cloned().unwrap_or(0.0);
                let final_sim = v_sim.max(f_sim);
                
                if let Ok(row) = sqlx::query(
                    "SELECT i.content, d.path, d.mime FROM items i JOIN documents d ON i.document_id = d.id WHERE i.id = ?"
                ).bind(id).fetch_one(&state.db_pool).await {
                    final_results.insert(id, serde_json::json!({
                        "id": id,
                        "content": row.get::<String, _>(0),
                        "path": row.get::<String, _>(1),
                        "mime": row.get::<Option<String>, _>(2),
                        "similarity": final_sim.clamp(0.0, 1.0)
                    }));
                }
            }
        }
        }
    }

    // 4. Add remaining FTS results not found by vector search
    for (id, f_sim) in fts_results {
        if !final_results.contains_key(&id) {
            if let Ok(row) = sqlx::query(
                "SELECT i.content, d.path, d.mime FROM items i JOIN documents d ON i.document_id = d.id WHERE i.id = ?"
            ).bind(id).fetch_one(&state.db_pool).await {
                final_results.insert(id, serde_json::json!({
                    "id": id,
                    "content": row.get::<String, _>(0),
                    "path": row.get::<String, _>(1),
                    "mime": row.get::<Option<String>, _>(2),
                    "similarity": f_sim.clamp(0.0, 1.0)
                }));
            }
        }
    }

    let mut sorted: Vec<_> = final_results.into_values().collect();
    sorted.sort_by(|a, b| {
        b.get("similarity").and_then(|v| v.as_f64()).unwrap_or(0.0)
            .partial_cmp(&a.get("similarity").and_then(|v| v.as_f64()).unwrap_or(0.0))
            .unwrap_or(std::cmp::Ordering::Equal)
    });

    let final_items = sorted
        .into_iter()
        .filter(|v| v.get("similarity").and_then(|x| x.as_f64()).unwrap_or(0.0) >= min_score as f64)
        .take(search_limit as usize)
        .collect::<Vec<_>>();
    let result_text = serde_json::to_string_pretty(&final_items).unwrap_or_else(|_| "[]".to_string());

    Some(serde_json::json!({ 
        "content": [{ "type": "text", "text": result_text }] 
    }))
}