Newer
Older
TelosDB / src-backend / src / db.rs
use sea_orm::{DatabaseConnection, SqlxSqliteConnector, ConnectionTrait, Statement, DatabaseBackend};
use sea_orm::sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use std::str::FromStr;
use std::time::Duration;
use std::path::Path;

pub async fn init_db(db_path: &str, extension_path: &str, vec_dim: usize) -> anyhow::Result<DatabaseConnection> {
    let db_url = format!("sqlite:{}?mode=rwc", db_path);

    // Strip \\?\ prefix that causes issues with sqlite3_load_extension
    let clean_ext_path = extension_path
        .strip_prefix("\\\\?\\")
        .unwrap_or(extension_path)
        .to_string();
    
    let static_ext_path: &'static str = Box::leak(clean_ext_path.into_boxed_str());
    log::info!("Checking SQLite extension at: {}", static_ext_path);

    // Check if the extension file exists
    let ext_path = Path::new(static_ext_path);
    if !ext_path.exists() {
        log::error!("❌ Extension file not found: {}", static_ext_path);
        return Err(anyhow::anyhow!("SQLite extension file not found: {}", static_ext_path));
    }

    // Also check for sqlite3.dll in the same directory (common dependency)
    if let Some(parent) = ext_path.parent() {
        let sqlite3_dll = parent.join("sqlite3.dll");
        if !sqlite3_dll.exists() {
            log::warn!("⚠️ sqlite3.dll not found in the same directory as vector.dll. This might cause load failure.");
        } else {
            log::info!("✅ sqlite3.dll found alongside extension.");
        }
    }

    log::info!("Loading SQLite extension: {}", static_ext_path);

    let options = SqliteConnectOptions::from_str(&db_url)?
        .create_if_missing(true)
        // Note: We use the full path. If it fails, common causes are missing dependencies (sqlite3.dll, etc.)
        .extension(static_ext_path);

    let pool = SqlitePoolOptions::new()
        .max_connections(10)
        .min_connections(5)
        .acquire_timeout(Duration::from_secs(8))
        .max_lifetime(Duration::from_secs(8))
        .connect_with(options)
        .await?;

    let db = SqlxSqliteConnector::from_sqlx_sqlite_pool(pool);
    log::info!("Database connection established. Extension should be loaded via connect options.");

    // PRAGMA settings
    let _ = db.execute(Statement::from_string(
        DatabaseBackend::Sqlite,
        "PRAGMA journal_mode = WAL;".to_string(),
    )).await;

    // Schema Initialization
    // 1. Standard items table
    db.execute(Statement::from_string(
        DatabaseBackend::Sqlite,
        "CREATE TABLE IF NOT EXISTS items (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            content TEXT NOT NULL,
            document_name TEXT,
            created_at TEXT DEFAULT (datetime('now', 'localtime')),
            updated_at TEXT DEFAULT (datetime('now', 'localtime'))
        );".to_string(),
    )).await?;

    // 2. Add embedding column to items (ALTER TABLE does not support IF NOT EXISTS in many sqlite versions)
    // We attempt it and ignore errors if column exists
    let _ = db.execute(Statement::from_string(
        DatabaseBackend::Sqlite,
        "ALTER TABLE items ADD COLUMN embedding BLOB;".to_string(),
    )).await;

    // 3. Initialize vector engine for the column
    match db.execute(Statement::from_string(
        DatabaseBackend::Sqlite,
        format!(
            "SELECT vector_init('items', 'embedding', 'type=FLOAT32,dimension={}');",
            vec_dim
        ),
    )).await {
        Ok(_) => log::info!("vector_init succeeded for items.embedding (dim={})", vec_dim),
        Err(e) => log::warn!("vector_init failed (may already be initialized): {}", e),
    }

    // triggers
    db.execute(Statement::from_string(
        DatabaseBackend::Sqlite,
        "CREATE TRIGGER IF NOT EXISTS update_items_updated_at
         AFTER UPDATE ON items
         FOR EACH ROW
         BEGIN
             UPDATE items SET updated_at = datetime('now', 'localtime') WHERE id = OLD.id;
         END;".to_string(),
    )).await?;

    Ok(db)
}

#[cfg(test)]
mod tests {
    use super::*;
    use sea_orm::{ConnectOptions, Database};

    #[tokio::test]
    async fn test_init_db_basic() {
        // NOTE: We test with a dummy extension path and expect it to fail if it's not found,
        // or we test just the SeaORM part if we could separation.
        // For now, let's just ensure we can connect to an in-memory sqlite via SeaORM.
        let mut opt = ConnectOptions::new("sqlite::memory:");
        let db = Database::connect(opt).await;
        assert!(db.is_ok());
    }
}