use sqlx::Row;
use sqlx::SqlitePool;
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use std::borrow::Cow;
use std::path::Path;
use std::str::FromStr;
/// データベースを初期化し、コネクションプールを返す。
/// 埋め込み次元が変更されている場合は vec_items テーブルを再作成する。
pub async fn initialize_database(
db_path: &Path,
extension_path: &Path,
dimension: usize,
) -> Result<SqlitePool, String> {
// ディレクトリの作成
if let Some(parent) = db_path.parent() {
if !parent.exists() {
std::fs::create_dir_all(parent).map_err(|e| e.to_string())?;
}
}
let db_path_str = db_path.to_str().ok_or("Invalid DB path")?;
let ext_path_str = extension_path.to_str().ok_or("Invalid extension path")?;
let opts = SqliteConnectOptions::from_str(&format!("sqlite://{}?mode=rwc", db_path_str))
.map_err(|e| e.to_string())?
.extension(ext_path_str.to_owned());
let pool = SqlitePoolOptions::new()
.max_connections(5)
.connect_with(opts)
.await
.map_err(|e| e.to_string())?;
// スキーマ初期化
init_schema(&pool, dimension).await?;
Ok(pool)
}
async fn init_schema(pool: &SqlitePool, dimension: usize) -> Result<(), String> {
// PRAGMA設定
sqlx::query("PRAGMA journal_mode = WAL")
.execute(pool)
.await
.map_err(|e| e.to_string())?;
// 標準テーブル作成
sqlx::query(
"CREATE TABLE IF NOT EXISTS items (
id INTEGER PRIMARY KEY AUTOINCREMENT,
content TEXT NOT NULL,
path TEXT,
created_at TEXT DEFAULT (datetime('now', 'localtime')),
updated_at TEXT DEFAULT (datetime('now', 'localtime'))
)",
)
.execute(pool)
.await
.map_err(|e| e.to_string())?;
// トリガー作成
sqlx::query(
"CREATE TRIGGER IF NOT EXISTS update_items_updated_at
AFTER UPDATE ON items
FOR EACH ROW
BEGIN
UPDATE items SET updated_at = datetime('now', 'localtime') WHERE id = OLD.id;
END",
)
.execute(pool)
.await
.map_err(|e| e.to_string())?;
// vec_items の次元数チェックと初期化
check_and_init_vector_table(pool, dimension).await?;
Ok(())
}
async fn check_and_init_vector_table(pool: &SqlitePool, dimension: usize) -> Result<(), String> {
// 現在のテーブル定義を確認
let row = sqlx::query("SELECT sql FROM sqlite_master WHERE type='table' AND name='vec_items'")
.fetch_optional(pool)
.await
.map_err(|e| e.to_string())?;
let should_recreate = if let Some(row) = row {
let sql: String = row.get(0);
// "FLOAT[640]" のような文字列が含まれているかチェック
let expected = format!("FLOAT[{}]", dimension);
if !sql.contains(&expected) {
log::info!("Dimension mismatch detected (expected {}). Rebuilding vec_items...", dimension);
true
} else {
false
}
} else {
true
};
if should_recreate {
let mut tx = pool.begin().await.map_err(|e| e.to_string())?;
sqlx::query("DROP TABLE IF EXISTS vec_items")
.execute(&mut *tx)
.await
.map_err(|e| e.to_string())?;
let create_sql = format!(
"CREATE VIRTUAL TABLE vec_items USING vec0(id INTEGER PRIMARY KEY, embedding FLOAT[{}])",
dimension
);
sqlx::query(&create_sql)
.execute(&mut *tx)
.await
.map_err(|e| e.to_string())?;
// 既存の items があればベクトルを再生成する必要があるが、
// ここでは空のテーブル作成に留める。
// (/v1/embeddings が必要になるため、再生成は別のユーティリティで行うのが安全)
tx.commit().await.map_err(|e| e.to_string())?;
}
Ok(())
}
/// 全データのベクトルを再生成する
pub async fn rebuild_vector_data<F, Fut>(
pool: &SqlitePool,
dimension: usize,
embed_fn: F,
) -> Result<(), String>
where
F: Fn(String) -> Fut,
Fut: std::future::Future<Output = Result<Vec<f32>, String>>,
{
// 強制的に再作成
{
let mut tx = pool.begin().await.map_err(|e| e.to_string())?;
sqlx::query("DROP TABLE IF EXISTS vec_items")
.execute(&mut *tx)
.await
.map_err(|e| e.to_string())?;
let create_sql = format!(
"CREATE VIRTUAL TABLE vec_items USING vec0(id INTEGER PRIMARY KEY, embedding FLOAT[{}])",
dimension
);
sqlx::query(&create_sql)
.execute(&mut *tx)
.await
.map_err(|e| e.to_string())?;
tx.commit().await.map_err(|e| e.to_string())?;
}
let rows = sqlx::query("SELECT id, content FROM items")
.fetch_all(pool)
.await
.map_err(|e| e.to_string())?;
for row in rows {
let id: i64 = row.get("id");
let content: String = row.get("content");
let emb = embed_fn(content).await?;
sqlx::query("INSERT INTO vec_items (id, embedding) VALUES (?, ?)")
.bind(id)
.bind(serde_json::to_string(&emb).unwrap_or_default())
.execute(pool)
.await
.map_err(|e| e.to_string())?;
}
Ok(())
}
/// itemsテーブルにあってvec_itemsテーブルにないデータを同期する(ヒーリング)
pub async fn sync_vectors<F, Fut>(
pool: &SqlitePool,
embed_fn: F,
) -> Result<usize, String>
where
F: Fn(String) -> Fut,
Fut: std::future::Future<Output = Result<Vec<f32>, String>>,
{
// vec_itemsに存在しないIDを抽出
let rows = sqlx::query(
"SELECT i.id, i.content
FROM items i
LEFT JOIN vec_items v ON i.id = v.id
WHERE v.id IS NULL"
)
.fetch_all(pool)
.await
.map_err(|e| e.to_string())?;
let count = rows.len();
if count == 0 {
return Ok(0);
}
log::info!("Healing {} missing vectors...", count);
for row in rows {
let id: i64 = row.get("id");
let content: String = row.get("content");
match embed_fn(content).await {
Ok(emb) => {
sqlx::query("INSERT INTO vec_items (id, embedding) VALUES (?, ?)")
.bind(id)
.bind(serde_json::to_string(&emb).unwrap_or_default())
.execute(pool)
.await
.map_err(|e| e.to_string())?;
}
Err(e) => {
log::error!("Failed to generate embedding for item {}: {}", id, e);
}
}
}
Ok(count)
}
// 下位互換性(既存の依存箇所が多いため、一時的に定義を残すか、順次書き換える)
pub async fn init_pool(
db_path: &str,
extension_path: impl Into<Cow<'static, str>>,
) -> Result<SqlitePool, sqlx::Error> {
let opts = SqliteConnectOptions::from_str(&format!("sqlite://{}?mode=rwc", db_path))?
.extension(extension_path);
SqlitePoolOptions::new().connect_with(opts).await
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
#[tokio::test]
async fn test_dimension_change_and_rebuild() {
let dir = tempdir().unwrap();
let db_path = dir.path().join("test_telos.db");
// DLLのパス取得(ビルドディレクトリにあることが前提)
// テスト環境では $OUT_DIR や $CARGO_MANIFEST_DIR 基準で探す
let manifest_dir = env!("CARGO_MANIFEST_DIR");
let ext_path = Path::new(manifest_dir).join("../node_modules/sqlite-vec-windows-x64/vec0.dll");
if !ext_path.exists() {
println!("Skipping test: vec0.dll not found at {:?}", ext_path);
return;
}
// 1. 最初は次元数 384 で初期化
let pool = initialize_database(&db_path, &ext_path, 384).await.unwrap();
// テーブル定義の確認
let row: (String,) = sqlx::query_as("SELECT sql FROM sqlite_master WHERE name='vec_items'")
.fetch_one(&pool).await.unwrap();
assert!(row.0.contains("FLOAT[384]"));
// データ挿入
sqlx::query("INSERT INTO items (content) VALUES ('test content')")
.execute(&pool).await.unwrap();
pool.close().await;
// 2. 次は次元数 768 で再初期化(不整合検知 -> 再作成)
let pool2 = initialize_database(&db_path, &ext_path, 768).await.unwrap();
let row2: (String,) = sqlx::query_as("SELECT sql FROM sqlite_master WHERE name='vec_items'")
.fetch_one(&pool2).await.unwrap();
assert!(row2.0.contains("FLOAT[768]"));
// items テーブルは維持されていることを確認
let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM items")
.fetch_one(&pool2).await.unwrap();
assert_eq!(count.0, 1);
pool2.close().await;
}
}