Newer
Older
TelosDB / src / backend / src / lib.rs
@楽曲作りまくりおじさん 楽曲作りまくりおじさん 9 hours ago 15 KB refactor(journals): mask absolute paths and update environment docs
pub mod db;
pub mod entities;
pub mod llama;
pub mod mcp;

use crate::llama::LlamaClient;
use dotenvy::dotenv;
use sea_orm::DatabaseConnection;
use std::env;
use std::path::PathBuf;
use std::sync::Arc;
use tauri::menu::{Menu, MenuItem};
use tauri::tray::{TrayIconBuilder, TrayIconEvent};
use tauri::Manager;

pub struct AppState {
    pub db: DatabaseConnection,
    pub llama: Arc<LlamaClient>,
    pub llama_server: Arc<tokio::sync::Mutex<Option<std::process::Child>>>,
}

fn find_resource_path(app_handle: &tauri::AppHandle, folder: &str, file: &str) -> Option<PathBuf> {
    let mut candidates = vec![PathBuf::from(file)];
    
    if let Ok(res_dir) = app_handle.path().resource_dir() {
        candidates.push(res_dir.join(folder).join(file));
    }

    if let Ok(exe_path) = env::current_exe() {
        if let Some(exe_dir) = exe_path.parent() {
            let mut p = exe_dir.to_path_buf();
            for _ in 0..5 {
                candidates.push(p.join(folder).join(file));
                if !p.pop() { break; }
            }
        }
    }
    
    candidates.into_iter().find(|p| p.exists())
}

#[tauri::command]
fn get_mcp_info(app_handle: tauri::AppHandle) -> Result<serde_json::Value, String> {
    let mcp_path = find_resource_path(&app_handle, "build_assets", "mcp.json")
        .ok_or_else(|| "mcp.json not found".to_string())?;
        
    let content = std::fs::read_to_string(&mcp_path).map_err(|e| e.to_string())?;
    let mcp_data: serde_json::Value = serde_json::from_str(&content).map_err(|e| e.to_string())?;
    Ok(mcp_data)
}

#[tauri::command]
async fn get_db_stats(state: tauri::State<'_, Arc<AppState>>) -> Result<serde_json::Value, String> {
    use crate::entities::items;
    use sea_orm::{EntityTrait, PaginatorTrait};
    let count = items::Entity::find().count(&state.db).await.map_err(|e| e.to_string())?;
    Ok(serde_json::json!({ "itemCount": count }))
}

#[tauri::command]
async fn get_sidecar_status(state: tauri::State<'_, Arc<AppState>>) -> Result<bool, String> {
    Ok(state.llama.check_health().await)
}

#[tauri::command]
async fn get_table_list(state: tauri::State<'_, Arc<AppState>>) -> Result<Vec<String>, String> {
    use sea_orm::{ConnectionTrait, Statement};
    let backend = state.db.get_database_backend();
    let res = state.db.query_all(Statement::from_string(backend, "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")).await.map_err(|e| e.to_string())?;
    
    let tables = res.into_iter().map(|row| row.try_get::<String>("", "name").unwrap_or_default()).collect();
    Ok(tables)
}

#[tauri::command]
async fn get_table_data(
    state: tauri::State<'_, Arc<AppState>>,
    table_name: String,
    limit: u64,
    offset: u64,
) -> Result<serde_json::Value, String> {
    use sea_orm::{ConnectionTrait, Statement};
    let backend = state.db.get_database_backend();
    
    let count_sql = format!("SELECT COUNT(*) as total FROM \"{}\"", table_name.replace("\"", "\"\""));
    let count_res = state.db.query_one(Statement::from_string(backend, &count_sql)).await.map_err(|e| e.to_string())?;
    let total: i64 = count_res.map(|r| r.try_get::<i64>("", "total").unwrap_or(0)).unwrap_or(0);

    let data_sql = format!("SELECT * FROM \"{}\" LIMIT {} OFFSET {}", table_name.replace("\"", "\"\""), limit, offset);
    let data_res = state.db.query_all(Statement::from_string(backend, &data_sql)).await.map_err(|e| e.to_string())?;
    
    let mut items = Vec::new();
    for row in data_res {
        items.push(database_row_to_json(row).await);
    }

    Ok(serde_json::json!({
        "data": items,
        "total": total
    }))
}

#[tauri::command]
async fn get_table_schema(
    state: tauri::State<'_, Arc<AppState>>,
    table_name: String,
) -> Result<serde_json::Value, String> {
    use sea_orm::{ConnectionTrait, Statement};
    let backend = state.db.get_database_backend();
    let sql = format!("PRAGMA table_info(\"{}\")", table_name.replace("\"", "\"\""));
    let res = state.db.query_all(Statement::from_string(backend, &sql)).await.map_err(|e| e.to_string())?;
    
    let mut schema = Vec::new();
    for row in res {
        schema.push(database_row_to_json(row).await);
    }
    Ok(serde_json::Value::Array(schema))
}

async fn database_row_to_json(row: sea_orm::QueryResult) -> serde_json::Value {
    let mut map = serde_json::Map::new();
    for col in row.column_names() {
        let col_name = col.as_str();
        if let Ok(val) = row.try_get::<String>("", col_name) {
            map.insert(col.to_string(), serde_json::Value::String(val));
        } else if let Ok(val) = row.try_get::<i32>("", col_name) {
            map.insert(col.to_string(), serde_json::json!(val));
        } else if let Ok(val) = row.try_get::<i64>("", col_name) {
            map.insert(col.to_string(), serde_json::json!(val));
        } else if let Ok(val) = row.try_get::<f64>("", col_name) {
            map.insert(col.to_string(), serde_json::json!(val));
        } else {
            map.insert(col.to_string(), serde_json::Value::Null);
        }
    }
    serde_json::Value::Object(map)
}

#[tauri::command]
async fn vector_search_text(
    state: tauri::State<'_, Arc<AppState>>,
    text: String,
    limit: i32,
) -> Result<serde_json::Value, String> {
    let embedding = state.llama.get_embedding(&text).await.map_err(|e: anyhow::Error| e.to_string())?;

    use sea_orm::{ConnectionTrait, Statement};
    let backend = state.db.get_database_backend();
    
    let vec_json = serde_json::to_string(&embedding).unwrap();
    let sql = format!(
        "SELECT i.*, v.distance \
         FROM items i \
         JOIN vec_items v ON i.id = v.id \
         WHERE v.embedding MATCH '{}' \
         AND k = {} \
         ORDER BY v.distance ASC",
        vec_json, limit
    );
    
    let res = state.db.query_all(Statement::from_string(backend, &sql)).await.map_err(|e| e.to_string())?;
    
    let mut results = Vec::new();
    for row in res {
        results.push(database_row_to_json(row).await);
    }
    
    Ok(serde_json::Value::Array(results))
}

fn get_config(app_handle: &tauri::AppHandle) -> serde_json::Value {
    let mut config_paths = vec![];
    if let Ok(app_data) = app_handle.path().app_config_dir() { config_paths.push(app_data.join("config.json")); }
    if let Ok(exe_path) = env::current_exe() {
        if let Some(exe_dir) = exe_path.parent() {
            let mut p = exe_dir.to_path_buf();
            for _ in 0..5 {
                config_paths.push(p.join("config.json"));
                if !p.pop() { break; }
            }
        }
    }
    if let Ok(cwd) = env::current_dir() { config_paths.push(cwd.join("config.json")); }
    if let Ok(res_dir) = app_handle.path().resource_dir() {
        config_paths.push(res_dir.join("build_assets").join("config.json"));
    }

    if let Some(path) = find_resource_path(app_handle, "build_assets", "config.json") {
        if let Ok(content) = std::fs::read_to_string(path) {
            if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&content) { return parsed; }
        }
    }
    serde_json::json!({
        "database": { "path": "data/vector.db" },
        "model": { "path": "models/embeddinggemma-300m-q4_0.gguf" },
        "llama_server": { "port": 8080 }
    })
}

fn resolve_db_path(app_handle: &tauri::AppHandle, config: &serde_json::Value) -> String {
    if let Ok(p) = env::var("DB_PATH") { return p; }
    if let Some(p) = config.get("database").and_then(|d| d.get("path")).and_then(|p| p.as_str()) {
        let mut candidate = PathBuf::from(p);
        if candidate.is_relative() {
            if cfg!(debug_assertions) {
                if let Ok(exe_path) = env::current_exe() {
                    if let Some(exe_dir) = exe_path.parent() {
                        let mut pr = exe_dir.to_path_buf();
                        for _ in 0..4 {
                            if pr.join("config.json").exists() { candidate = pr.join(p); break; }
                            if !pr.pop() { break; }
                        }
                    }
                }
            } else {
                let mut p_base = app_handle.path().app_data_dir().expect("App data dir not found");
                p_base.push("data");
                let _ = std::fs::create_dir_all(&p_base);
                p_base.push("vector.db");
                candidate = p_base;
            }
        }
        return candidate.to_string_lossy().to_string();
    }
    
    // Default fallback
    if cfg!(debug_assertions) { "data/vector.db".to_string() } else {
        let mut p = app_handle.path().app_data_dir().expect("App data dir not found");
        p.push("data");
        let _ = std::fs::create_dir_all(&p);
        p.push("vector.db");
        p.to_string_lossy().to_string()
    }
}

fn resolve_extension_path(app_handle: &tauri::AppHandle) -> String {
    let mut candidates = vec![];

    if !cfg!(debug_assertions) {
        if let Ok(res_dir) = app_handle.path().resource_dir() {
            candidates.push(res_dir.join("build_assets").join("vec0.dll"));
        }
    }

    let exe_dir = env::current_exe().map(|p| p.parent().unwrap().to_path_buf()).unwrap_or_else(|_| env::current_dir().unwrap());
    candidates.push(exe_dir.join("vec0.dll"));
    candidates.push(exe_dir.join("../node_modules/sqlite-vec-windows-x64/vec0.dll"));
    candidates.push(exe_dir.join("../../node_modules/sqlite-vec-windows-x64/vec0.dll"));
    candidates.push(exe_dir.join("../../../bin/vec0.dll"));
    
    for cand in &candidates {
        if cand.exists() { 
            if let Ok(canon) = cand.canonicalize() {
                let s = canon.to_string_lossy().to_string();
                if let Some(stripped) = s.strip_prefix(r"\\?\") {
                    return stripped.to_string();
                }
                return s;
            }
            return cand.to_string_lossy().to_string(); 
        }
    }
    "vec0.dll".to_string()
}

pub fn cleanup_orphaned_sidecars<R: tauri::Runtime>(_app_handle: &tauri::AppHandle<R>) {
    let base_dir = match env::current_exe() {
        Ok(exe_path) => exe_path.parent().map(|p| p.to_path_buf()).unwrap_or_default(),
        Err(_) => return,
    };
    
    let base_dir_str = base_dir.to_string_lossy().replace("\\", "\\\\");
    log::info!("Cleaning up orphaned sidecars in: {}", base_dir_str);

    // Use PowerShell to surgically kill llama-server processes within our directory
    // This avoids killing servers from other applications/projects.
    let script = format!(
        "Get-Process | Where-Object {{ ($_.Name -like '*llama-server*') -and ($_.Path -like '*{}*') }} | Stop-Process -Force",
        base_dir_str
    );

    let mut cmd = std::process::Command::new("powershell");
    cmd.args(["-NoProfile", "-Command", &script]);
    
    #[cfg(windows)]
    {
        use std::os::windows::process::CommandExt;
        const CREATE_NO_WINDOW: u32 = 0x08000000;
        cmd.creation_flags(CREATE_NO_WINDOW);
    }

    match cmd.output() {
        Ok(_) => log::info!("Orphaned sidecar cleanup completed."),
        Err(e) => log::error!("Failed to run cleanup script: {}", e),
    }
}

fn setup_logging(app: &mut tauri::App) {
    let mut log_builder = tauri_plugin_log::Builder::default()
        .targets([
            tauri_plugin_log::Target::new(tauri_plugin_log::TargetKind::Stdout),
            tauri_plugin_log::Target::new(tauri_plugin_log::TargetKind::LogDir { file_name: None }),
        ])
        .max_file_size(10 * 1024 * 1024)
        .level(log::LevelFilter::Info);
    
    if cfg!(debug_assertions) {
        log_builder = log_builder.target(tauri_plugin_log::Target::new(tauri_plugin_log::TargetKind::Folder { path: std::path::PathBuf::from("logs"), file_name: None }));
    }
    let _ = app.handle().plugin(log_builder.build());
}

async fn initialize_app(app_handle: tauri::AppHandle) -> Result<(), String> {
    dotenv().ok();
    cleanup_orphaned_sidecars(&app_handle);
    let config = get_config(&app_handle);

    let db_path = resolve_db_path(&app_handle, &config);
    let ext_path = resolve_extension_path(&app_handle);
    log::info!("DB Path: {}, Ext Path: {}", db_path, ext_path);
    
    let conn = db::init_db(&db_path, &ext_path).await.expect("Failed to init db");

    let state = Arc::new(AppState {
        db: conn,
        llama: Arc::new(LlamaClient::new(
            env::var("LLAMA_CPP_BASE_URL").unwrap_or_else(|_| "http://localhost:8080".to_string()),
            env::var("LLAMA_CPP_EMBEDDING_MODEL").unwrap_or_else(|_| "nomic-embed-text".to_string()),
            env::var("LLAMA_CPP_MODEL").unwrap_or_else(|_| "mistral".to_string()),
        )),
        llama_server: Arc::new(tokio::sync::Mutex::new(None)),
    });
    app_handle.manage(state.clone());
    
    let port_str = env::var("MCP_PORT").unwrap_or_else(|_| "4242".to_string());
    log::info!("MCP_PORT from env: {}", port_str);
    let port = port_str.parse::<u16>().unwrap_or(4242);
    
    tokio::spawn(async move { mcp::start_mcp_server(state, port).await; });

    // Spawn llama server in background
    let llama_handle = app_handle.clone();
    tokio::spawn(async move {
        if let Err(e) = llama::spawn_server(&llama_handle).await {
            eprintln!("Failed to spawn llama server: {}", e);
        }
    });

    Ok(())
}

fn setup_tray(app: &tauri::App) -> Result<(), Box<dyn std::error::Error>> {
    let quit_i = MenuItem::with_id(app, "quit", "Quit", true, None::<&str>)?;
    let show_i = MenuItem::with_id(app, "show", "Show", true, None::<&str>)?;
    let menu = Menu::with_items(app, &[&show_i, &quit_i])?;
    let _tray = TrayIconBuilder::new()
        .icon(app.default_window_icon().unwrap().clone())
        .menu(&menu)
        .on_menu_event(|app, event| match event.id.as_ref() {
            "quit" => app.exit(0),
            "show" => if let Some(w) = app.get_webview_window("main") { let _ = w.show(); let _ = w.set_focus(); }
            _ => {}
        })
        .on_tray_icon_event(|tray, event| if let TrayIconEvent::DoubleClick { .. } = event {
            let app = tray.app_handle();
            if let Some(w) = app.get_webview_window("main") { let _ = w.show(); let _ = w.set_focus(); }
        })
        .build(app)?;
    Ok(())
}

#[cfg_attr(mobile, tauri::mobile_entry_point)]
pub fn run() {
    tauri::Builder::default()
        .plugin(tauri_plugin_shell::init())
        .on_window_event(|window, event| {
            if let tauri::WindowEvent::CloseRequested { api, .. } = event {
                api.prevent_close();
                let _ = window.hide();
            }
        })
        .setup(|app| {
            dotenvy::dotenv().ok();
            setup_logging(app);
            log::info!("Application starting (Tauri 2)...");

            let app_handle = app.handle().clone();
            tauri::async_runtime::block_on(async move {
                initialize_app(app_handle).await.expect("Failed to initialize app");
            });

            setup_tray(app).expect("Failed to setup tray");
            log::info!("Tauri setup completed");
            Ok(())
        })
        .invoke_handler(tauri::generate_handler![
            get_mcp_info, 
            get_db_stats, 
            get_sidecar_status,
            get_table_list,
            get_table_data,
            get_table_schema,
            vector_search_text
        ])
        .run(tauri::generate_context!())
        .expect("error while running tauri application");
}