Newer
Older
TelosDB / src-tauri / src / bin / rebuild_vecs.rs
use std::env;
use std::path::PathBuf;

#[tokio::main]
async fn main() {
    // CWD is expected to be src-tauri when run via our cargo command below.
    let cwd = env::current_dir().expect("failed to get cwd");

    let args: Vec<String> = env::args().collect();
    let db_path = args
        .get(1)
        .map(PathBuf::from)
        .unwrap_or_else(|| cwd.join("telos.db"));
    let vec0_path = args
        .get(2)
        .map(PathBuf::from)
        .unwrap_or_else(|| cwd.join("target").join("debug").join("vec0.dll"));

    println!("Using db: {:?}", db_path);
    println!("Using vec0 extension: {:?}", vec0_path);

    let db_path_str = db_path.to_string_lossy().to_string();
    let vec0_path_str = vec0_path.to_string_lossy().to_string();

    let dimension = 640; // Default for Gemma-3, or should be dynamic
    // Ensure SQLite schema / extension is initialized (creates items and vec_items if missing)
    match app_lib::db::initialize_database(&db_path, &vec0_path, dimension).await {
        Ok(_) => println!("initialize_database succeeded or schema already present."),
        Err(e) => {
            eprintln!("initialize_database failed: {:?}", e);
            std::process::exit(1);
        }
    }

    // Initialize SQLx pool
    let pool = match app_lib::db::init_pool(&db_path_str, vec0_path_str.clone()).await {
        Ok(p) => p,
        Err(e) => {
            eprintln!("Failed to init pool: {:?}", e);
            std::process::exit(1);
        }
    };

    // embed_fn: posts to local llama-server embeddings endpoint
    let client = reqwest::Client::new();
    let embed_fn = move |txt: String| -> std::pin::Pin<
        Box<dyn std::future::Future<Output = Result<Vec<f32>, String>> + Send + 'static>,
    > {
        let client = client.clone();
        let s = txt.to_string();
        Box::pin(async move {
            let payload = serde_json::json!({"input": [s], "model": "default"});
            let resp = client
                .post("http://127.0.0.1:8080/v1/embeddings")
                .json(&payload)
                .send()
                .await
                .map_err(|e| e.to_string())?;
            let body = resp.text().await.map_err(|e| e.to_string())?;
            let json: serde_json::Value = serde_json::from_str(&body).map_err(|e| e.to_string())?;
            let arr = json["data"][0]["embedding"]
                .as_array()
                .ok_or_else(|| format!("no embedding in response: {}", body))?;
            let v: Vec<f32> = arr
                .iter()
                .map(|v| v.as_f64().unwrap_or(0.0) as f32)
                .collect();
            Ok(v)
        })
    };

    println!("Starting rebuild_vector_data...");
    match app_lib::db::rebuild_vector_data(&pool, dimension, embed_fn).await {
        Ok(_) => {
            println!("rebuild_vector_data completed successfully.");
        }
        Err(e) => {
            eprintln!("rebuild_vector_data failed: {}", e);
            std::process::exit(1);
        }
    }
}