Newer
Older
TelosDB / src-backend / src / llama.rs
use anyhow::Result;
use reqwest::Client;
use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize)]
struct EmbeddingRequest {
    model: String,
    input: String,
}

#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
    data: Vec<EmbeddingData>,
}

#[derive(Debug, Deserialize)]
struct EmbeddingData {
    embedding: Vec<f32>,
}

#[derive(Debug, Serialize)]
struct CompletionRequest {
    model: String,
    prompt: String,
    n_predict: i32,
    temperature: f32,
    stream: bool,
}

#[derive(Debug, Deserialize)]
struct CompletionResponse {
    content: String,
}

pub struct LlamaClient {
    client: Client,
    base_url: String,
    embedding_model: String,
    completion_model: String,
}

impl LlamaClient {
    pub fn new(base_url: String, embedding_model: String, completion_model: String) -> Self {
        Self {
            client: Client::new(),
            base_url,
            embedding_model,
            completion_model,
        }
    }

    pub async fn get_embedding(&self, text: &str) -> Result<Vec<f32>> {
        // Heuristic: ~4 chars per token. 2048 tokens * 3 = 6144 characters.
        // Let's use 4000 characters as a safe chunk size.
        let chunk_size = 4000;
        
        if text.len() <= chunk_size {
            return self.get_single_embedding(text).await;
        }

        // Chunking logic
        let mut chunks = Vec::new();
        let mut start = 0;
        while start < text.len() {
            let end = (start + chunk_size).min(text.len());
            chunks.push(&text[start..end]);
            start = end;
        }

        println!("DEBUG: Text length {} exceeds chunk size {}. Splitting into {} chunks.", text.len(), chunk_size, chunks.len());

        let mut all_embeddings = Vec::new();
        for chunk in chunks {
            let emb = self.get_single_embedding(chunk).await?;
            all_embeddings.push(emb);
        }

        // Average the embeddings (Mean Pooling)
        if all_embeddings.is_empty() {
             return Err(anyhow::anyhow!("No embeddings generated for chunks"));
        }

        let dim = all_embeddings[0].len();
        let mut mean_embedding = vec![0.0f32; dim];
        
        for emb in &all_embeddings {
            for i in 0..dim {
                mean_embedding[i] += emb[i];
            }
        }

        for i in 0..dim {
            mean_embedding[i] /= all_embeddings.len() as f32;
        }

        Ok(mean_embedding)
    }

    async fn get_single_embedding(&self, text: &str) -> Result<Vec<f32>> {
        let url = format!("{}/v1/embeddings", self.base_url);
        let req = EmbeddingRequest {
            model: self.embedding_model.clone(),
            input: text.to_string(),
        };

        let res = self
            .client
            .post(&url)
            .json(&req)
            .send()
            .await?;

        if !res.status().is_success() {
            let status = res.status();
            let text = res.text().await.unwrap_or_else(|_| "Could not read error body".to_string());
            return Err(anyhow::anyhow!("llama-server error ({}): {}", status, text));
        }

        let text_res = res.text().await.map_err(|e| {
            anyhow::anyhow!("Failed to read response text: {}", e)
        })?;

        let res_json: EmbeddingResponse = serde_json::from_str(&text_res).map_err(|e| {
            anyhow::anyhow!("Failed to parse embedding JSON: {}. Body snippet: {}", e, &text_res[..text_res.len().min(200)])
        })?;

        if res_json.data.is_empty() {
            return Err(anyhow::anyhow!("llama-server returned empty data"));
        }

        Ok(res_json.data[0].embedding.clone())
    }

    pub async fn completion(
        &self,
        prompt: &str,
        n_predict: i32,
        temperature: f32,
    ) -> Result<String> {
        let url = format!("{}/completion", self.base_url);
        let req = CompletionRequest {
            model: self.completion_model.clone(),
            prompt: prompt.to_string(),
            n_predict,
            temperature,
            stream: false,
        };

        let res = self
            .client
            .post(&url)
            .json(&req)
            .send()
            .await?
            .json::<CompletionResponse>()
            .await?;

        Ok(res.content)
    }

    pub async fn check_health(&self) -> bool {
        let url = format!("{}/health", self.base_url);
        self.client.get(&url).send().await.is_ok()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use mockito::Server;

    #[tokio::test]
    async fn test_get_embedding() -> Result<()> {
        let mut server = Server::new_async().await;
        let url = server.url();

        let mock = server
            .mock("POST", "/embeddings")
            .with_status(200)
            .with_header("content-type", "application/json")
            .with_body(r#"{"data":[{"embedding":[0.1, 0.2, 0.3]}]}"#)
            .create_async()
            .await;

        let client = LlamaClient::new(url, "test-model".to_string(), "test-model".to_string());
        let embedding = client.get_embedding("hello").await?;

        assert_eq!(embedding, vec![0.1, 0.2, 0.3]);
        mock.assert_async().await;
        Ok(())
    }

    #[tokio::test]
    async fn test_completion() -> Result<()> {
        let mut server = Server::new_async().await;
        let url = server.url();

        let mock = server
            .mock("POST", "/completion")
            .with_status(200)
            .with_header("content-type", "application/json")
            .with_body(r#"{"content":"Hello, world!"}"#)
            .create_async()
            .await;

        let client = LlamaClient::new(url, "test-model".to_string(), "test-model".to_string());
        let result = client.completion("hi", 10, 0.7).await?;

        assert_eq!(result, "Hello, world!");
        mock.assert_async().await;
        Ok(())
    }
}