Newer
Older
TelosDB / src / backend / src / lib.rs
pub mod db;
pub mod entities;
pub mod llama;
pub mod mcp;

use crate::llama::LlamaClient;
use dotenvy::dotenv;
use sea_orm::DatabaseConnection;
use std::env;
use std::path::PathBuf;
use std::sync::Arc;
use tauri::menu::{Menu, MenuItem};
use tauri::tray::{TrayIconBuilder, TrayIconEvent};
use tauri::Manager;

use std::sync::atomic::AtomicUsize;

pub struct AppState {
    pub db: DatabaseConnection,
    pub llama: Arc<LlamaClient>,
    pub mcp_tx: tokio::sync::broadcast::Sender<serde_json::Value>,
    pub connection_count: Arc<AtomicUsize>,
    pub app_handle: tauri::AppHandle,
}

#[tauri::command]
fn get_mcp_info() -> Result<serde_json::Value, String> {
    let mut candidates = vec![PathBuf::from("mcp.json")];
    if let Ok(exe_path) = env::current_exe() {
        if let Some(exe_dir) = exe_path.parent() {
            let mut p = exe_dir.to_path_buf();
            for _ in 0..5 {
                candidates.push(p.join("mcp.json"));
                if !p.pop() { break; }
            }
        }
    }
    let mut found_path = None;
    for candidate in candidates {
        if candidate.exists() {
            found_path = Some(candidate);
            break;
        }
    }
    let mcp_path = found_path.ok_or_else(|| "mcp.json not found".to_string())?;
    let content = std::fs::read_to_string(&mcp_path).map_err(|e| e.to_string())?;
    let mcp_data: serde_json::Value = serde_json::from_str(&content).map_err(|e| e.to_string())?;
    Ok(mcp_data)
}

#[tauri::command]
async fn get_db_stats(state: tauri::State<'_, Arc<AppState>>) -> Result<serde_json::Value, String> {
    use crate::entities::items;
    use sea_orm::{EntityTrait, PaginatorTrait};
    let count = items::Entity::find().count(&state.db).await.map_err(|e| e.to_string())?;
    Ok(serde_json::json!({ "itemCount": count }))
}


#[tauri::command]
async fn get_sidecar_status(state: tauri::State<'_, Arc<AppState>>) -> Result<bool, String> {
    Ok(state.llama.check_health().await)
}

fn get_config(app_handle: &tauri::AppHandle) -> serde_json::Value {
    let mut config_paths = vec![];
    if let Ok(app_data) = app_handle.path().app_config_dir() { config_paths.push(app_data.join("config.json")); }
    if let Ok(exe_path) = env::current_exe() {
        if let Some(exe_dir) = exe_path.parent() {
            let mut p = exe_dir.to_path_buf();
            for _ in 0..5 {
                config_paths.push(p.join("config.json"));
                if !p.pop() { break; }
            }
        }
    }
    if let Ok(cwd) = env::current_dir() { config_paths.push(cwd.join("config.json")); }
    if let Ok(res_dir) = app_handle.path().resource_dir() { config_paths.push(res_dir.join("config.json")); }

    for path in config_paths {
        if path.exists() {
            if let Ok(content) = std::fs::read_to_string(path) {
                if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&content) { return parsed; }
            }
        }
    }
    serde_json::json!({
        "database": { "path": "data/vector.db" },
        "model": { "path": "models/embeddinggemma-300m-q4_0.gguf" },
        "llama_server": { "port": 8080 }
    })
}

fn resolve_db_path(app_handle: &tauri::AppHandle, config: &serde_json::Value) -> String {
    if let Ok(p) = env::var("DB_PATH") { return p; }
    if let Some(p) = config.get("database").and_then(|d| d.get("path")).and_then(|p| p.as_str()) {
        let mut candidate = PathBuf::from(p);
        if candidate.is_relative() {
            if cfg!(debug_assertions) {
                if let Ok(exe_path) = env::current_exe() {
                    if let Some(exe_dir) = exe_path.parent() {
                        let mut pr = exe_dir.to_path_buf();
                        for _ in 0..4 {
                            if pr.join("config.json").exists() { candidate = pr.join(p); break; }
                            if !pr.pop() { break; }
                        }
                    }
                }
            } else {
                let mut p_base = app_handle.path().app_data_dir().expect("App data dir not found");
                p_base.push("data");
                let _ = std::fs::create_dir_all(&p_base);
                p_base.push("vector.db");
                candidate = p_base;
            }
        }
        return candidate.to_string_lossy().to_string();
    }
    
    // Default fallback
    if cfg!(debug_assertions) { "data/vector.db".to_string() } else {
        let mut p = app_handle.path().app_data_dir().expect("App data dir not found");
        p.push("data");
        let _ = std::fs::create_dir_all(&p);
        p.push("vector.db");
        p.to_string_lossy().to_string()
    }
}

fn resolve_extension_path(app_handle: &tauri::AppHandle) -> String {
    let mut candidates = vec![];

    if !cfg!(debug_assertions) {
        if let Ok(res_dir) = app_handle.path().resource_dir() {
            candidates.push(res_dir.join("vec0.dll"));
            // Also look in resources root in case it's placed there by bundle.resources
            candidates.push(res_dir.join("bin").join("vec0.dll"));
        }
    }

    let exe_dir = env::current_exe().map(|p| p.parent().unwrap().to_path_buf()).unwrap_or_else(|_| env::current_dir().unwrap());
    candidates.push(exe_dir.join("vec0.dll"));
    candidates.push(exe_dir.join("../node_modules/sqlite-vec-windows-x64/vec0.dll"));
    candidates.push(exe_dir.join("../../node_modules/sqlite-vec-windows-x64/vec0.dll"));
    candidates.push(exe_dir.join("../../../bin/vec0.dll"));
    
    for cand in &candidates {
        if cand.exists() { 
            if let Ok(canon) = cand.canonicalize() {
                let s = canon.to_string_lossy().to_string();
                if s.starts_with(r"\\?\") {
                    return s[4..].to_string();
                }
                return s;
            }
            return cand.to_string_lossy().to_string(); 
        }
    }
    "vec0.dll".to_string()
}

fn spawn_llama_server(app_handle: &tauri::AppHandle, config: &serde_json::Value) {
    let model_path = if let Ok(p) = env::var("LLAMA_CPP_MODEL_PATH") { PathBuf::from(p) } else if let Some(p) = config.get("model").and_then(|m| m.get("path")).and_then(|p| p.as_str()) {
        let mut candidate = PathBuf::from(p);
        if candidate.is_relative() {
            if cfg!(debug_assertions) {
                if let Ok(exe_path) = env::current_exe() {
                    if let Some(exe_dir) = exe_path.parent() {
                        let mut pr = exe_dir.to_path_buf();
                        for _ in 0..4 {
                            if pr.join("config.json").exists() { candidate = pr.join(p); break; }
                            if !pr.pop() { break; }
                        }
                    }
                }
            } else {
                if let Ok(res_dir) = app_handle.path().resource_dir() {
                    candidate = res_dir.join(p);
                }
            }
        }
        candidate
    } else {
        let mut found = None;
        if cfg!(debug_assertions) {
            if let Ok(exe_path) = env::current_exe() {
                if let Some(exe_dir) = exe_path.parent() {
                    let mut pr = exe_dir.to_path_buf();
                    for _ in 0..4 {
                        if pr.join("models").exists() { found = Some(pr.join("models").join("embeddinggemma-300m-q4_0.gguf")); break; }
                        if !pr.pop() { break; }
                    }
                }
            }
        } else {
            if let Ok(res_dir) = app_handle.path().resource_dir() {
                found = Some(res_dir.join("models").join("embeddinggemma-300m-q4_0.gguf"));
            }
        }
        found.unwrap_or_else(|| PathBuf::from("models/embeddinggemma-300m-q4_0.gguf"))
    };

    let sidecar_exe = if cfg!(debug_assertions) {
        let mut p = env::current_dir().unwrap();
        if p.ends_with(format!("src{}backend", std::path::MAIN_SEPARATOR)) { p.pop(); p.pop(); }
        p.join("bin").join("llama-server-x86_64-pc-windows-msvc.exe")
    } else {
        // In Tauri 2, sidecar is managed via tauri-plugin-shell-sidecar.
        // But if you're spawning manually, you need the absolute path in Resources.
        let res_dir = app_handle.path().resource_dir().expect("Failed to get resource dir");
        // Binaries are typically in resources/bin or root resources
        let candidate_a = res_dir.join("bin").join("llama-server-x86_64-pc-windows-msvc.exe");
        let candidate_b = res_dir.join("llama-server-x86_64-pc-windows-msvc.exe");
        if candidate_a.exists() { candidate_a } else { candidate_b }
    };

    let mut cmd = std::process::Command::new(&sidecar_exe);
    cmd.args(&["--model", &model_path.to_string_lossy(), "--port", "8080", "--embedding", "--host", "127.0.0.1", "-c", "8192", "-b", "8192", "-ub", "8192", "--parallel", "1"]);
    
    // Add bin dir to PATH so DLLs in the same folder are found
    if let Some(bin_dir) = sidecar_exe.parent() {
        let path = env::var("PATH").unwrap_or_default();
        cmd.env("PATH", format!("{};{}", bin_dir.display(), path));
        cmd.current_dir(bin_dir);
    }
    
    // Also check resource_dir for DLLs
    if let Ok(res_dir) = app_handle.path().resource_dir() {
        if let Some(old_path) = cmd.get_envs().find(|(k, _)| k == "PATH").and_then(|(_, v)| v) {
            cmd.env("PATH", format!("{};{}", res_dir.display(), old_path.to_string_lossy()));
        }
    }

    println!("DEBUG: Spawning llama-server: {:?}", cmd);
    match cmd.spawn() {
        Ok(child) => { 
            let pid = child.id(); 
            println!("llama-server started (PID: {})", pid); 
            std::thread::spawn(move || { 
                match child.wait_with_output() {
                    Ok(out) => println!("llama-server exited (OK): {:?}", out.status),
                    Err(e) => eprintln!("llama-server error during wait: {}", e),
                }
            }); 
        }
        Err(e) => eprintln!("CRITICAL: Failed to spawn llama-server: {}. Exe path: {:?}", e, sidecar_exe),
    }
}

#[cfg_attr(mobile, tauri::mobile_entry_point)]
pub fn run() {
    tauri::Builder::default()
        .invoke_handler(tauri::generate_handler![get_mcp_info, get_db_stats, get_sidecar_status])
        .on_window_event(|window, event| {
            if let tauri::WindowEvent::CloseRequested { api, .. } = event {
                api.prevent_close();
                let _ = window.hide();
            }
        })
        .setup(|app| {
            dotenv().ok();
            let app_handle = app.handle().clone();
            let _ = app.handle().plugin(tauri_plugin_shell::init());

            let config = get_config(&app_handle);
            spawn_llama_server(&app_handle, &config);

            tauri::async_runtime::block_on(async move {
                let db_path = resolve_db_path(&app_handle, &config);
                let ext_path = resolve_extension_path(&app_handle);
                let vec_dim = env::var("VEC_DIM").unwrap_or_else(|_| "768".to_string()).parse::<usize>().unwrap_or(768);
                let conn = db::init_db(&db_path, &ext_path, vec_dim).await.expect("Failed to init db");

                let state = Arc::new(AppState {
                    db: conn,
                    llama: Arc::new(LlamaClient::new(
                        env::var("LLAMA_CPP_BASE_URL").unwrap_or_else(|_| "http://localhost:8080".to_string()),
                        env::var("LLAMA_CPP_EMBEDDING_MODEL").unwrap_or_else(|_| "nomic-embed-text".to_string()),
                        env::var("LLAMA_CPP_MODEL").unwrap_or_else(|_| "mistral".to_string()),
                    )),
                    mcp_tx: tokio::sync::broadcast::channel(100).0,
                    connection_count: Arc::new(AtomicUsize::new(0)),
                    app_handle: app_handle.clone(),
                });
                app_handle.manage(state.clone());
                let port = env::var("MCP_PORT").unwrap_or_else(|_| "3000".to_string()).parse::<u16>().unwrap_or(3000);
                tokio::spawn(async move { mcp::start_mcp_server(state, port).await; });
            });

            let mut log_builder = tauri_plugin_log::Builder::default()
                .targets([
                    tauri_plugin_log::Target::new(tauri_plugin_log::TargetKind::Stdout),
                    tauri_plugin_log::Target::new(tauri_plugin_log::TargetKind::LogDir { file_name: None }),
                ])
                .max_file_size(10 * 1024 * 1024)
                .level(log::LevelFilter::Info);
            if cfg!(debug_assertions) {
                log_builder = log_builder.target(tauri_plugin_log::Target::new(tauri_plugin_log::TargetKind::Folder { path: PathBuf::from("logs"), file_name: None }));
            }
            let _ = app.handle().plugin(log_builder.build());

            let quit_i = MenuItem::with_id(app, "quit", "Quit", true, None::<&str>)?;
            let show_i = MenuItem::with_id(app, "show", "Show", true, None::<&str>)?;
            let menu = Menu::with_items(app, &[&show_i, &quit_i])?;
            let _tray = TrayIconBuilder::new()
                .icon(app.default_window_icon().unwrap().clone())
                .menu(&menu)
                .on_menu_event(|app, event| match event.id.as_ref() {
                    "quit" => app.exit(0),
                    "show" => if let Some(w) = app.get_webview_window("main") { let _ = w.show(); let _ = w.set_focus(); }
                    _ => {}
                })
                .on_tray_icon_event(|tray, event| if let TrayIconEvent::DoubleClick { .. } = event {
                    let app = tray.app_handle();
                    if let Some(w) = app.get_webview_window("main") { let _ = w.show(); let _ = w.set_focus(); }
                })
                .build(app)?;
            Ok(())
        })
        .run(tauri::generate_context!())
        .expect("error while running tauri application");
}