Newer
Older
TelosDB / src / backend / src / mcp / tools / items.rs
@楽曲作りまくりおじさん 楽曲作りまくりおじさん 1 day ago 12 KB refactor: Split mcp.rs into functional modules and sub-tools
use sqlx::Row;
use std::collections::HashMap;
use crate::mcp::types::AppState;
use std::path::Path;

pub async fn handle_get_item_by_id(
    state: &AppState,
    args: &serde_json::Map<String, serde_json::Value>,
) -> Option<serde_json::Value> {
    let id = args.get("id").and_then(|v| v.as_i64()).unwrap_or(0);
    let row: Option<sqlx::sqlite::SqliteRow> = sqlx::query(
        "SELECT i.content, d.path, d.mime FROM items i JOIN documents d ON i.document_id = d.id WHERE i.id = ?",
    )
        .bind(id)
        .fetch_optional(&state.db_pool)
        .await
        .unwrap_or(None);
    if let Some(row) = row {
        let content: String = row.get("content");
        let path: String = row.get("path");
        let mime: Option<String> = row.get("mime");
        Some(serde_json::json!({
            "id": id,
            "content": content,
            "path": path,
            "mime": mime
        }))
    } else {
        Some(serde_json::json!({
            "content": [{ "type": "text", "text": format!("Item not found: {}", id) }],
            "isError": true
        }))
    }
}

pub async fn handle_add_item_text(
    state: &AppState,
    args: &serde_json::Map<String, serde_json::Value>,
) -> Option<serde_json::Value> {
    let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("");
    let path_str = args.get("path").and_then(|v| v.as_str()).unwrap_or("unknown");
    let mut mime_str = args.get("mime").and_then(|v| v.as_str()).map(|s| s.to_string());

    // MIMEタイプが未指定なら拡張子から推測
    if mime_str.is_none() && path_str != "unknown" {
        let path = Path::new(path_str);
        if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
            mime_str = Some(match ext.to_lowercase().as_str() {
                "md" | "markdown" => "text/markdown".to_string(),
                "txt" => "text/plain".to_string(),
                "rs" => "text/x-rust".to_string(),
                "js" | "mjs" => "text/javascript".to_string(),
                "ts" => "text/typescript".to_string(),
                "json" => "application/json".to_string(),
                "html" => "text/html".to_string(),
                "css" => "text/css".to_string(),
                _ => "application/octet-stream".to_string(),
            });
        }
    }

    log::info!(
        "Executing add_item_text (LSA-only): content length={}, path='{}', mime='{:?}'",
        content.chars().count(),
        path_str,
        mime_str
    );

    // 800文字ずつに分割
    let chars: Vec<char> = content.chars().collect();
    let chunk_strings: Vec<String> = chars
        .chunks(800)
        .map(|chunk| chunk.iter().collect::<String>())
        .collect();

    let mut results = Vec::new();
    
    // 1. ドキュメントレコードの取得または作成
    let doc_id_res = match sqlx::query("SELECT id FROM documents WHERE path = ?")
        .bind(path_str)
        .fetch_optional(&state.db_pool)
        .await 
    {
        Ok(Some(row)) => {
            let id = row.get::<i64, _>(0);
            if let Some(m) = &mime_str {
                let _ = sqlx::query("UPDATE documents SET mime = ? WHERE id = ? AND (mime IS NULL OR mime != ?)")
                    .bind(m)
                    .bind(id)
                    .bind(m)
                    .execute(&state.db_pool)
                    .await;
            }
            Ok(id)
        },
        Ok(None) => {
            match sqlx::query("INSERT INTO documents (path, mime) VALUES (?, ?)")
                .bind(path_str)
                .bind(mime_str)
                .execute(&state.db_pool)
                .await
            {
                Ok(res) => Ok(res.last_insert_rowid()),
                Err(e) => Err(serde_json::json!({
                    "content": [{ "type": "text", "text": format!("Failed to create document: {}", e) }],
                    "isError": true
                }))
            }
        },
        Err(e) => Err(serde_json::json!({
            "content": [{ "type": "text", "text": format!("Database error: {}", e) }],
            "isError": true
        }))
    };

    let doc_id = match doc_id_res {
        Ok(id) => id,
        Err(err_json) => {
            return Some(err_json);
        }
    };

    // 2. 既存の同一ドキュメントの全チャンクを削除(上書き)
    if let Err(e) = sqlx::query("DELETE FROM items WHERE document_id = ?")
        .bind(doc_id)
        .execute(&state.db_pool)
        .await 
    {
        log::error!("Failed to delete old chunks for document {}: {}", doc_id, e);
    }

    // 3. 各チャンクを保存
    for (idx, chunk_content) in chunk_strings.iter().enumerate() {
        match add_item_chunk_inner(state, doc_id, idx as i32, chunk_content).await {
            Ok(id) => results.push(id),
            Err(e) => log::error!("Failed to add chunk {}: {}", idx, e),
        }
    }

    if !results.is_empty() {
        let _ = state.tx.send("data_changed".to_string());
        log::info!("Successfully added {} chunks to document {}.", results.len(), path_str);
        Some(
            serde_json::json!({ "content": [{ "type": "text", "text": format!("Successfully added {} chunks for {}", results.len(), path_str) }] }),
        )
    } else {
        Some(serde_json::json!({
            "content": [{ "type": "text", "text": "Failed to add any chunks." }],
            "isError": true
        }))
    }
}

async fn add_item_chunk_inner(
    state: &AppState,
    doc_id: i64,
    chunk_index: i32,
    content: &str,
) -> Result<i64, String> {
    let mut tx =
        state.db_pool.begin().await.map_err(|e| {
            format!("Failed to begin transaction: {}", e)
        })?;
    
    let res =
        sqlx::query("INSERT INTO items (document_id, chunk_index, content) VALUES (?, ?, ?)")
            .bind(doc_id)
            .bind(chunk_index)
            .bind(content)
            .execute(&mut *tx)
            .await
            .map_err(|e| format!("Failed to insert chunk: {}", e))?;
    let id = res.last_insert_rowid();

    // FTS5 への保存
    sqlx::query("INSERT INTO items_fts (rowid, content) VALUES (?, ?)")
        .bind(id)
        .bind(content)
        .execute(&mut *tx)
        .await
        .map_err(|e| format!("Failed to insert to FTS: {}", e))?;

    // LSA ベクトルの計算
    let mut lsa_vector_f32: Vec<f32> = vec![0.0; 50]; 
    let lsa_guard = state.lsa_model.read().await;
    if let Some(model) = lsa_guard.as_ref() {
        let mut query_counts = HashMap::new();
        let tokens = state.tokenizer.tokenize_to_vec(content).unwrap_or_default();
        for token in tokens {
            if let Some(&tid) = model.vocabulary.get(&token) {
                *query_counts.entry(tid).or_insert(0.0) += 1.0;
            }
        }
        let mut query_vec = ndarray::Array1::zeros(model.vocabulary.len());
        for (tid, count) in query_counts {
            query_vec[tid] = count;
        }

        if let Ok(projected) = model.project_query(&query_vec) {
            lsa_vector_f32 = projected.iter().map(|&x| x as f32).collect();
            if lsa_vector_f32.len() < 50 {
                lsa_vector_f32.resize(50, 0.0);
            } else if lsa_vector_f32.len() > 50 {
                lsa_vector_f32.truncate(50);
            }
        }
    }

    // vec_items に保存
    sqlx::query("INSERT INTO vec_items (id, embedding) VALUES (?, ?)")
        .bind(id)
        .bind(serde_json::to_string(&lsa_vector_f32).unwrap_or("[]".to_string()))
        .execute(&mut *tx)
        .await
        .map_err(|e| format!("Failed to insert LSA vector to vec_items: {}", e))?;

    // items_lsa にも保存
    if lsa_guard.as_ref().is_some() {
        let vector_blob = bincode::serialize(&lsa_vector_f32).unwrap_or_default();
        sqlx::query("INSERT INTO items_lsa (id, vector) VALUES (?, ?)")
            .bind(id)
            .bind(vector_blob)
            .execute(&mut *tx)
            .await
            .map_err(|e| format!("Failed to insert LSA blob: {}", e))?;
    }

    tx.commit()
        .await
        .map_err(|e| format!("Failed to commit transaction: {}", e))?;

    // HNSW インデックス
    let hnsw_index_guard = state.hnsw_index.read().await;
    if let Some(hnsw_ptr) = hnsw_index_guard.as_ref() {
        let vec_ref: &[f32] = lsa_vector_f32.as_slice();
        hnsw_ptr.insert((vec_ref, id as usize));
    }

    Ok(id)
}

pub async fn handle_update_item(
    state: &AppState,
    args: &serde_json::Map<String, serde_json::Value>,
) -> Option<serde_json::Value> {
    let id = args.get("id").and_then(|v| v.as_i64()).unwrap_or(0);
    let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("");

    if let Err(e) = update_item_inner(state, id, content).await {
        Some(serde_json::json!({
            "content": [{ "type": "text", "text": format!("Error: {}", e) }],
            "isError": true
        }))
    } else {
        let _ = state.tx.send("data_changed".to_string());
        Some(
            serde_json::json!({ "content": [{ "type": "text", "text": format!("Successfully updated item {} (LSA)", id) }] }),
        )
    }
}

async fn update_item_inner(
    state: &AppState,
    id: i64,
    content: &str,
) -> Result<(), String> {
    let mut tx = state.db_pool.begin().await.map_err(|e| format!("Failed to begin transaction: {}", e))?;
    
    sqlx::query("UPDATE items SET content = ? WHERE id = ?")
        .bind(content)
        .bind(id)
        .execute(&mut *tx)
        .await
        .map_err(|e| format!("Failed to update item: {}", e))?;

    sqlx::query("UPDATE items_fts SET content = ? WHERE rowid = ?")
        .bind(content)
        .bind(id)
        .execute(&mut *tx)
        .await
        .map_err(|e| format!("Failed to update FTS: {}", e))?;

    let lsa_guard = state.lsa_model.read().await;
    if let Some(model) = lsa_guard.as_ref() {
        let mut query_counts = HashMap::new();
        let tokens = state.tokenizer.tokenize_to_vec(content).unwrap_or_default();
        for token in tokens {
            if let Some(&tid) = model.vocabulary.get(&token) {
                *query_counts.entry(tid).or_insert(0.0) += 1.0;
            }
        }
        let mut query_vec = ndarray::Array1::zeros(model.vocabulary.len());
        for (tid, count) in query_counts {
            query_vec[tid] = count;
        }

        if let Ok(projected) = model.project_query(&query_vec) {
            let mut proj_f32: Vec<f32> = projected.iter().map(|&x| x as f32).collect();
            if proj_f32.len() < 50 { proj_f32.resize(50, 0.0); } else { proj_f32.truncate(50); }

            let vector_blob = bincode::serialize(&proj_f32).unwrap_or_default();
            sqlx::query("INSERT OR REPLACE INTO items_lsa (id, vector) VALUES (?, ?)")
                .bind(id)
                .bind(vector_blob)
                .execute(&mut *tx)
                .await
                .map_err(|e| format!("Failed to update LSA vector: {}", e))?;
            
            sqlx::query("INSERT OR REPLACE INTO vec_items (id, embedding) VALUES (?, ?)")
                .bind(id)
                .bind(serde_json::to_string(&proj_f32).unwrap_or("[]".to_string()))
                .execute(&mut *tx)
                .await
                .map_err(|e| format!("Failed to update vec_items: {}", e))?;
        }
    }

    tx.commit().await.map_err(|e| format!("Failed to commit transaction: {}", e))?;
    Ok(())
}

pub async fn handle_delete_item(
    state: &AppState,
    args: &serde_json::Map<String, serde_json::Value>,
) -> Option<serde_json::Value> {
    let id = args.get("id").and_then(|v| v.as_i64()).unwrap_or(0);

    if let Err(e) = delete_item_inner(state, id).await {
        Some(serde_json::json!({
            "content": [{ "type": "text", "text": format!("Error: {}", e) }],
            "isError": true
        }))
    } else {
        let _ = state.tx.send("data_changed".to_string());
        Some(
            serde_json::json!({ "content": [{ "type": "text", "text": format!("Successfully deleted item {}", id) }] }),
        )
    }
}

async fn delete_item_inner(state: &AppState, id: i64) -> Result<(), String> {
    let mut tx = state.db_pool.begin().await.map_err(|e| format!("Failed to begin transaction: {}", e))?;
    sqlx::query("DELETE FROM items WHERE id = ?").bind(id).execute(&mut *tx).await.map_err(|e| format!("Failed to delete item: {}", e))?;
    sqlx::query("DELETE FROM vec_items WHERE id = ?").bind(id).execute(&mut *tx).await.map_err(|e| format!("Failed to delete vector: {}", e))?;
    sqlx::query("DELETE FROM items_fts WHERE rowid = ?").bind(id).execute(&mut *tx).await.map_err(|e| format!("Failed to delete from FTS: {}", e))?;
    tx.commit().await.map_err(|e| format!("Failed to commit transaction: {}", e))?;
    Ok(())
}