Newer
Older
TelosDB / src-tauri / src / db.rs
use sqlx::Row;
use sqlx::SqlitePool;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use std::borrow::Cow;
use std::path::Path;
use std::str::FromStr;

/// データベースを初期化し、コネクションプールを返す。
/// 埋め込み次元が変更されている場合は vec_items テーブルを再作成する。
pub async fn initialize_database(
    db_path: &Path,
    extension_path: &Path,
    dimension: usize,
) -> Result<SqlitePool, String> {
    // ディレクトリの作成
    if let Some(parent) = db_path.parent() {
        if !parent.exists() {
            std::fs::create_dir_all(parent).map_err(|e| e.to_string())?;
        }
    }

    let db_path_str = db_path.to_str().ok_or("Invalid DB path")?;
    let ext_path_str = extension_path.to_str().ok_or("Invalid extension path")?;

    let opts = SqliteConnectOptions::from_str(&format!("sqlite://{}?mode=rwc", db_path_str))
        .map_err(|e| e.to_string())?
        .extension(ext_path_str.to_owned());

    let pool = SqlitePoolOptions::new()
        .max_connections(5)
        .connect_with(opts)
        .await
        .map_err(|e| e.to_string())?;

    // スキーマ初期化
    init_schema(&pool, dimension).await?;

    Ok(pool)
}

async fn init_schema(pool: &SqlitePool, dimension: usize) -> Result<(), String> {
    // PRAGMA設定
    sqlx::query("PRAGMA journal_mode = WAL")
        .execute(pool)
        .await
        .map_err(|e| e.to_string())?;

    // 標準テーブル作成
    sqlx::query(
        "CREATE TABLE IF NOT EXISTS items (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            content TEXT NOT NULL,
            path TEXT,
            created_at TEXT DEFAULT (datetime('now', 'localtime')),
            updated_at TEXT DEFAULT (datetime('now', 'localtime'))
        )",
    )
    .execute(pool)
    .await
    .map_err(|e| e.to_string())?;

    // トリガー作成
    sqlx::query(
        "CREATE TRIGGER IF NOT EXISTS update_items_updated_at
        AFTER UPDATE ON items
        FOR EACH ROW
        BEGIN
            UPDATE items SET updated_at = datetime('now', 'localtime') WHERE id = OLD.id;
        END",
    )
    .execute(pool)
    .await
    .map_err(|e| e.to_string())?;

    // vec_items の次元数チェックと初期化
    check_and_init_vector_table(pool, dimension).await?;

    Ok(())
}

async fn check_and_init_vector_table(pool: &SqlitePool, dimension: usize) -> Result<(), String> {
    // 現在のテーブル定義を確認
    let row = sqlx::query("SELECT sql FROM sqlite_master WHERE type='table' AND name='vec_items'")
        .fetch_optional(pool)
        .await
        .map_err(|e| e.to_string())?;

    let should_recreate = if let Some(row) = row {
        let sql: String = row.get(0);
        // "FLOAT[640]" のような文字列が含まれているかチェック
        let expected = format!("FLOAT[{}]", dimension);
        if !sql.contains(&expected) {
            log::info!("Dimension mismatch detected (expected {}). Rebuilding vec_items...", dimension);
            true
        } else {
            false
        }
    } else {
        true
    };

    if should_recreate {
        let mut tx = pool.begin().await.map_err(|e| e.to_string())?;

        sqlx::query("DROP TABLE IF EXISTS vec_items")
            .execute(&mut *tx)
            .await
            .map_err(|e| e.to_string())?;

        let create_sql = format!(
            "CREATE VIRTUAL TABLE vec_items USING vec0(id INTEGER PRIMARY KEY, embedding FLOAT[{}])",
            dimension
        );
        sqlx::query(&create_sql)
            .execute(&mut *tx)
            .await
            .map_err(|e| e.to_string())?;

        // 既存の items があればベクトルを再生成する必要があるが、
        // ここでは空のテーブル作成に留める。
        // (/v1/embeddings が必要になるため、再生成は別のユーティリティで行うのが安全)

        tx.commit().await.map_err(|e| e.to_string())?;
    }

    Ok(())
}

/// 全データのベクトルを再生成する
pub async fn rebuild_vector_data<F, Fut>(
    pool: &SqlitePool,
    dimension: usize,
    embed_fn: F,
) -> Result<(), String> 
where 
    F: Fn(String) -> Fut,
    Fut: std::future::Future<Output = Result<Vec<f32>, String>>,
{
    // 強制的に再作成
    {
        let mut tx = pool.begin().await.map_err(|e| e.to_string())?;
        sqlx::query("DROP TABLE IF EXISTS vec_items")
            .execute(&mut *tx)
            .await
            .map_err(|e| e.to_string())?;
        let create_sql = format!(
            "CREATE VIRTUAL TABLE vec_items USING vec0(id INTEGER PRIMARY KEY, embedding FLOAT[{}])",
            dimension
        );
        sqlx::query(&create_sql)
            .execute(&mut *tx)
            .await
            .map_err(|e| e.to_string())?;
        tx.commit().await.map_err(|e| e.to_string())?;
    }

    let rows = sqlx::query("SELECT id, content FROM items")
        .fetch_all(pool)
        .await
        .map_err(|e| e.to_string())?;

    for row in rows {
        let id: i64 = row.get("id");
        let content: String = row.get("content");
        let emb = embed_fn(content).await?;
        
        sqlx::query("INSERT INTO vec_items (id, embedding) VALUES (?, ?)")
            .bind(id)
            .bind(serde_json::to_string(&emb).unwrap_or_default())
            .execute(pool)
            .await
            .map_err(|e| e.to_string())?;
    }

    Ok(())
}

/// itemsテーブルにあってvec_itemsテーブルにないデータを同期する(ヒーリング)
pub async fn sync_vectors<F, Fut>(
    pool: &SqlitePool,
    embed_fn: F,
) -> Result<usize, String>
where
    F: Fn(String) -> Fut,
    Fut: std::future::Future<Output = Result<Vec<f32>, String>>,
{
    // vec_itemsに存在しないIDを抽出
    let rows = sqlx::query(
        "SELECT i.id, i.content 
         FROM items i 
         LEFT JOIN vec_items v ON i.id = v.id 
         WHERE v.id IS NULL"
    )
    .fetch_all(pool)
    .await
    .map_err(|e| e.to_string())?;

    let count = rows.len();
    if count == 0 {
        return Ok(0);
    }

    log::info!("Healing {} missing vectors...", count);

    for row in rows {
        let id: i64 = row.get("id");
        let content: String = row.get("content");
        
        match embed_fn(content).await {
            Ok(emb) => {
                sqlx::query("INSERT INTO vec_items (id, embedding) VALUES (?, ?)")
                    .bind(id)
                    .bind(serde_json::to_string(&emb).unwrap_or_default())
                    .execute(pool)
                    .await
                    .map_err(|e| e.to_string())?;
            }
            Err(e) => {
                log::error!("Failed to generate embedding for item {}: {}", id, e);
            }
        }
    }

    Ok(count)
}

// 下位互換性(既存の依存箇所が多いため、一時的に定義を残すか、順次書き換える)
pub async fn init_pool(
    db_path: &str,
    extension_path: impl Into<Cow<'static, str>>,
) -> Result<SqlitePool, sqlx::Error> {
    let opts = SqliteConnectOptions::from_str(&format!("sqlite://{}?mode=rwc", db_path))?
        .extension(extension_path);

    SqlitePoolOptions::new().connect_with(opts).await
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::fs;
    use tempfile::tempdir;

    #[tokio::test]
    async fn test_dimension_change_and_rebuild() {
        let dir = tempdir().unwrap();
        let db_path = dir.path().join("test_telos.db");
        
        // DLLのパス取得(ビルドディレクトリにあることが前提)
        // テスト環境では $OUT_DIR や $CARGO_MANIFEST_DIR 基準で探す
        let manifest_dir = env!("CARGO_MANIFEST_DIR");
        let ext_path = Path::new(manifest_dir).join("../node_modules/sqlite-vec-windows-x64/vec0.dll");

        if !ext_path.exists() {
            println!("Skipping test: vec0.dll not found at {:?}", ext_path);
            return;
        }

        // 1. 最初は次元数 384 で初期化
        let pool = initialize_database(&db_path, &ext_path, 384).await.unwrap();
        
        // テーブル定義の確認
        let row: (String,) = sqlx::query_as("SELECT sql FROM sqlite_master WHERE name='vec_items'")
            .fetch_one(&pool).await.unwrap();
        assert!(row.0.contains("FLOAT[384]"));
        
        // データ挿入
        sqlx::query("INSERT INTO items (content) VALUES ('test content')")
            .execute(&pool).await.unwrap();
        
        pool.close().await;

        // 2. 次は次元数 768 で再初期化(不整合検知 -> 再作成)
        let pool2 = initialize_database(&db_path, &ext_path, 768).await.unwrap();
        let row2: (String,) = sqlx::query_as("SELECT sql FROM sqlite_master WHERE name='vec_items'")
            .fetch_one(&pool2).await.unwrap();
        assert!(row2.0.contains("FLOAT[768]"));

        // items テーブルは維持されていることを確認
        let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM items")
            .fetch_one(&pool2).await.unwrap();
        assert_eq!(count.0, 1);
        
        pool2.close().await;
    }
}