Newer
Older
TelosDB / src-tauri / src / llama.rs
use anyhow::Result;
use reqwest::Client;
use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize)]
struct EmbeddingRequest {
    model: String,
    input: String,
}

#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
    data: Vec<EmbeddingData>,
}

#[derive(Debug, Deserialize)]
struct EmbeddingData {
    embedding: Vec<f32>,
}

#[derive(Debug, Serialize)]
struct CompletionRequest {
    model: String,
    prompt: String,
    n_predict: i32,
    temperature: f32,
    stream: bool,
}

#[derive(Debug, Deserialize)]
struct CompletionResponse {
    content: String,
}

pub struct LlamaClient {
    client: Client,
    base_url: String,
    embedding_model: String,
    completion_model: String,
}

impl LlamaClient {
    pub fn new(base_url: String, embedding_model: String, completion_model: String) -> Self {
        Self {
            client: Client::new(),
            base_url,
            embedding_model,
            completion_model,
        }
    }

    pub async fn get_embedding(&self, text: &str) -> Result<Vec<f32>> {
        let url = format!("{}/embeddings", self.base_url);
        let req = EmbeddingRequest {
            model: self.embedding_model.clone(),
            input: text.to_string(),
        };

        let res = self
            .client
            .post(&url)
            .json(&req)
            .send()
            .await?
            .json::<EmbeddingResponse>()
            .await?;

        Ok(res.data[0].embedding.clone())
    }

    pub async fn completion(
        &self,
        prompt: &str,
        n_predict: i32,
        temperature: f32,
    ) -> Result<String> {
        let url = format!("{}/completion", self.base_url);
        let req = CompletionRequest {
            model: self.completion_model.clone(),
            prompt: prompt.to_string(),
            n_predict,
            temperature,
            stream: false,
        };

        let res = self
            .client
            .post(&url)
            .json(&req)
            .send()
            .await?
            .json::<CompletionResponse>()
            .await?;

        Ok(res.content)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use mockito::Server;

    #[tokio::test]
    async fn test_get_embedding() -> Result<()> {
        let mut server = Server::new_async().await;
        let url = server.url();

        let mock = server
            .mock("POST", "/embeddings")
            .with_status(200)
            .with_header("content-type", "application/json")
            .with_body(r#"{"data":[{"embedding":[0.1, 0.2, 0.3]}]}"#)
            .create_async()
            .await;

        let client = LlamaClient::new(url, "test-model".to_string(), "test-model".to_string());
        let embedding = client.get_embedding("hello").await?;

        assert_eq!(embedding, vec![0.1, 0.2, 0.3]);
        mock.assert_async().await;
        Ok(())
    }

    #[tokio::test]
    async fn test_completion() -> Result<()> {
        let mut server = Server::new_async().await;
        let url = server.url();

        let mock = server
            .mock("POST", "/completion")
            .with_status(200)
            .with_header("content-type", "application/json")
            .with_body(r#"{"content":"Hello, world!"}"#)
            .create_async()
            .await;

        let client = LlamaClient::new(url, "test-model".to_string(), "test-model".to_string());
        let result = client.completion("hi", 10, 0.7).await?;

        assert_eq!(result, "Hello, world!");
        mock.assert_async().await;
        Ok(())
    }
}