use sea_orm::{DatabaseConnection, SqlxSqliteConnector, ConnectionTrait, Statement, DatabaseBackend};
use sea_orm::sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use std::str::FromStr;
use std::time::Duration;
use std::path::Path;
pub async fn init_db(db_path: &str, extension_path: &str, vec_dim: usize) -> anyhow::Result<DatabaseConnection> {
let db_url = format!("sqlite:{}?mode=rwc", db_path);
// Strip \\?\ prefix that causes issues with sqlite3_load_extension
let clean_ext_path = extension_path
.strip_prefix("\\\\?\\")
.unwrap_or(extension_path)
.to_string();
let static_ext_path: &'static str = Box::leak(clean_ext_path.into_boxed_str());
log::info!("Checking SQLite extension at: {}", static_ext_path);
// Check if the extension file exists
let ext_path = Path::new(static_ext_path);
if !ext_path.exists() {
log::error!("❌ Extension file not found: {}", static_ext_path);
return Err(anyhow::anyhow!("SQLite extension file not found: {}", static_ext_path));
}
// Also check for sqlite3.dll in the same directory (common dependency)
if let Some(parent) = ext_path.parent() {
let sqlite3_dll = parent.join("sqlite3.dll");
if !sqlite3_dll.exists() {
log::warn!("⚠️ sqlite3.dll not found in the same directory as vector.dll. This might cause load failure.");
} else {
log::info!("✅ sqlite3.dll found alongside extension.");
}
}
log::info!("Loading SQLite extension: {}", static_ext_path);
let options = SqliteConnectOptions::from_str(&db_url)?
.create_if_missing(true)
// Note: We use the full path. If it fails, common causes are missing dependencies (sqlite3.dll, etc.)
.extension(static_ext_path);
let pool = SqlitePoolOptions::new()
.max_connections(10)
.min_connections(5)
.acquire_timeout(Duration::from_secs(8))
.max_lifetime(Duration::from_secs(8))
.connect_with(options)
.await?;
let db = SqlxSqliteConnector::from_sqlx_sqlite_pool(pool);
log::info!("Database connection established. Extension should be loaded via connect options.");
// PRAGMA settings
let _ = db.execute(Statement::from_string(
DatabaseBackend::Sqlite,
"PRAGMA journal_mode = WAL;".to_string(),
)).await;
// Schema Initialization
// 1. Standard items table
db.execute(Statement::from_string(
DatabaseBackend::Sqlite,
"CREATE TABLE IF NOT EXISTS items (
id INTEGER PRIMARY KEY AUTOINCREMENT,
content TEXT NOT NULL,
document_name TEXT,
created_at TEXT DEFAULT (datetime('now', 'localtime')),
updated_at TEXT DEFAULT (datetime('now', 'localtime'))
);".to_string(),
)).await?;
// 2. Add embedding column to items (ALTER TABLE does not support IF NOT EXISTS in many sqlite versions)
// We attempt it and ignore errors if column exists
let _ = db.execute(Statement::from_string(
DatabaseBackend::Sqlite,
"ALTER TABLE items ADD COLUMN embedding BLOB;".to_string(),
)).await;
// 3. Initialize vector engine for the column
match db.execute(Statement::from_string(
DatabaseBackend::Sqlite,
format!(
"SELECT vector_init('items', 'embedding', 'type=FLOAT32,dimension={}');",
vec_dim
),
)).await {
Ok(_) => log::info!("vector_init succeeded for items.embedding (dim={})", vec_dim),
Err(e) => log::warn!("vector_init failed (may already be initialized): {}", e),
}
// triggers
db.execute(Statement::from_string(
DatabaseBackend::Sqlite,
"CREATE TRIGGER IF NOT EXISTS update_items_updated_at
AFTER UPDATE ON items
FOR EACH ROW
BEGIN
UPDATE items SET updated_at = datetime('now', 'localtime') WHERE id = OLD.id;
END;".to_string(),
)).await?;
Ok(db)
}
#[cfg(test)]
mod tests {
use super::*;
use sea_orm::{ConnectOptions, Database};
#[tokio::test]
async fn test_init_db_basic() {
// NOTE: We test with a dummy extension path and expect it to fail if it's not found,
// or we test just the SeaORM part if we could separation.
// For now, let's just ensure we can connect to an in-memory sqlite via SeaORM.
let mut opt = ConnectOptions::new("sqlite::memory:");
let db = Database::connect(opt).await;
assert!(db.is_ok());
}
}