Newer
Older
TelosDB / src-backend / src / mcp / handlers.rs
use crate::entities::items;
use crate::AppState;
use sea_orm::*;

use tauri::Emitter;

pub async fn handle_save_document(
    state: &AppState,
    content: &str,
    document_name: &str,
) -> anyhow::Result<serde_json::Value> {
    let embedding = state.llama.get_embedding(content).await?;

    // SeaORM insert
    let new_item = items::ActiveModel {
        content: Set(content.to_owned()),
        document_name: Set(Some(document_name.to_owned())),
        ..Default::default()
    };

    let db = &state.db;
    let res = new_item.insert(db).await?;
    let id = res.id;

    // embedding を items テーブルに直接更新
    let embedding_bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();

    db.execute(Statement::from_sql_and_values(
        DatabaseBackend::Sqlite,
        "UPDATE items SET embedding = ? WHERE id = ?",
        [embedding_bytes.into(), id.into()],
    ))
    .await?;

    // Notify UI of DB update
    let _ = state.app_handle.emit("mcp-db-update", ());

    Ok(serde_json::json!({
        "content": [{ "type": "text", "text": format!("Saved document with id {}", id) }]
    }))
}

pub async fn handle_find_documents(
    state: &AppState,
    content: &str,
    limit: usize,
) -> anyhow::Result<serde_json::Value> {
    let embedding = state.llama.get_embedding(content).await?;
    handle_find_by_vector(state, embedding, limit).await
}

pub async fn handle_find_by_vector(
    state: &AppState,
    vector: Vec<f32>,
    limit: usize,
) -> anyhow::Result<serde_json::Value> {
    let embedding_bytes: Vec<u8> = vector.iter().flat_map(|f| f.to_le_bytes()).collect();

    let db = &state.db;

    // raw SQL query via SeaORM for vector search using sqlite-vector scan
    // JOIN vector_quantize_scan('table', 'column', query_vector, k)
    let results = db
        .query_all(Statement::from_sql_and_values(
            DatabaseBackend::Sqlite,
            "SELECT i.id, i.content, i.document_name, i.created_at, i.updated_at, v.distance
             FROM items i
             JOIN vector_quantize_scan('items', 'embedding', ?, ?) AS v ON i.rowid = v.rowid
             ORDER BY distance",
            [embedding_bytes.into(), (limit as i64).into()],
        ))
        .await?;

    let mut out = Vec::new();
    for res in results {
        out.push(serde_json::json!({
            "id": res.try_get::<i32>("", "id").map_err(|e| anyhow::anyhow!(e))?,
            "content": res.try_get::<String>("", "content").map_err(|e| anyhow::anyhow!(e))?,
            "document_name": res.try_get::<Option<String>>("", "document_name").map_err(|e| anyhow::anyhow!(e))?,
            "created_at": res.try_get::<String>("", "created_at").map_err(|e| anyhow::anyhow!(e))?,
            "updated_at": res.try_get::<String>("", "updated_at").map_err(|e| anyhow::anyhow!(e))?,
            "distance": res.try_get::<f64>("", "distance").map_err(|e| anyhow::anyhow!(e))?
        }));
    }

    Ok(serde_json::json!({
        "content": [{ "type": "text", "text": serde_json::to_string_pretty(&out).unwrap_or_else(|_| "[]".to_string()) }]
    }))
}

pub async fn handle_delete_item(
    state: &AppState,
    id: i32,
) -> anyhow::Result<serde_json::Value> {
    let db = &state.db;

    // Delete from items table (embedding is in the same row)
    let _ = items::Entity::delete_by_id(id).exec(db).await?;

    // Notify UI of DB update
    let _ = state.app_handle.emit("mcp-db-update", ());

    Ok(serde_json::json!({
        "content": [{ "type": "text", "text": format!("Deleted item with id {}", id) }]
    }))
}

pub async fn handle_get_vector(
    state: &AppState,
    id: i32,
) -> anyhow::Result<serde_json::Value> {
    let db = &state.db;

    // Get embedding from items table directly
    let result = db.query_one(Statement::from_sql_and_values(
        DatabaseBackend::Sqlite,
        "SELECT embedding FROM items WHERE id = ?",
        [id.into()]
    )).await?;

    match result {
        Some(res) => {
             let bytes: Vec<u8> = res.try_get("", "embedding").unwrap_or_default();
             
             // Convert bytes back to Vec<f32>
             let vector: Vec<f32> = bytes
                .chunks_exact(4)
                .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()))
                .collect();

             Ok(serde_json::json!({
                "content": [{ "type": "text", "text": serde_json::to_string(&vector)? }]
             }))
        },
        None => Err(anyhow::anyhow!("Item not found with id {}", id))
    }
}

pub async fn handle_get_document(
    state: &AppState,
    id: i32,
) -> anyhow::Result<serde_json::Value> {
    let db = &state.db;
    
    let item = items::Entity::find_by_id(id).one(db).await?;
    
    match item {
        Some(i) => {
             Ok(serde_json::json!({
                "content": [{ "type": "text", "text": i.content }]
             }))
        },
        None => Err(anyhow::anyhow!("Item not found with id {}", id))
    }
}

pub async fn handle_get_documents_count(
    state: &AppState,
) -> anyhow::Result<serde_json::Value> {
    let db = &state.db;
    let count = items::Entity::find().count(db).await?;

    Ok(serde_json::json!({
        "content": [{ "type": "text", "text": count.to_string() }]
    }))
}

pub async fn handle_list_documents(
    state: &AppState,
    limit: u64,
    offset: u64,
) -> anyhow::Result<serde_json::Value> {
    let db = &state.db;
    let items = items::Entity::find()
        .order_by_asc(items::Column::Id)
        .offset(offset)
        .limit(limit)
        .all(db)
        .await?;

    Ok(serde_json::json!({
        "content": [{ "type": "text", "text": serde_json::to_string_pretty(&items)? }]
    }))
}

pub async fn handle_llm_generate(
    state: &AppState,
    prompt: &str,
    n_predict: i32,
    temperature: f32,
) -> anyhow::Result<serde_json::Value> {
    let text = state
        .llama
        .completion(prompt, n_predict, temperature)
        .await?;
    Ok(serde_json::json!({
        "content": [{ "type": "text", "text": text }]
    }))
}

pub async fn handle_read_recent_items(
    state: &AppState,
    limit: u64,
) -> anyhow::Result<serde_json::Value> {
    let db = &state.db;
    let items = items::Entity::find()
        .order_by_desc(items::Column::Id)
        .limit(limit)
        .all(db)
        .await?;

    Ok(serde_json::json!({
        "content": [{ "type": "text", "text": serde_json::to_string_pretty(&items)? }]
    }))
}