use anyhow::Result;
use reqwest::Client;
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize)]
struct EmbeddingRequest {
model: String,
input: String,
}
#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingData>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingData {
embedding: Vec<f32>,
}
#[derive(Debug, Serialize)]
struct CompletionRequest {
model: String,
prompt: String,
n_predict: i32,
temperature: f32,
stream: bool,
}
#[derive(Debug, Deserialize)]
struct CompletionResponse {
content: String,
}
pub struct LlamaClient {
client: Client,
base_url: String,
embedding_model: String,
completion_model: String,
}
impl LlamaClient {
pub fn new(base_url: String, embedding_model: String, completion_model: String) -> Self {
Self {
client: Client::new(),
base_url,
embedding_model,
completion_model,
}
}
pub async fn get_embedding(&self, text: &str) -> Result<Vec<f32>> {
let url = format!("{}/embeddings", self.base_url);
let req = EmbeddingRequest {
model: self.embedding_model.clone(),
input: text.to_string(),
};
let res = self
.client
.post(&url)
.json(&req)
.send()
.await?
.json::<EmbeddingResponse>()
.await?;
Ok(res.data[0].embedding.clone())
}
pub async fn completion(
&self,
prompt: &str,
n_predict: i32,
temperature: f32,
) -> Result<String> {
let url = format!("{}/completion", self.base_url);
let req = CompletionRequest {
model: self.completion_model.clone(),
prompt: prompt.to_string(),
n_predict,
temperature,
stream: false,
};
let res = self
.client
.post(&url)
.json(&req)
.send()
.await?
.json::<CompletionResponse>()
.await?;
Ok(res.content)
}
pub async fn check_health(&self) -> bool {
let url = format!("{}/health", self.base_url);
self.client.get(&url).send().await.is_ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
use mockito::Server;
#[tokio::test]
async fn test_get_embedding() -> Result<()> {
let mut server = Server::new_async().await;
let url = server.url();
let mock = server
.mock("POST", "/embeddings")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(r#"{"data":[{"embedding":[0.1, 0.2, 0.3]}]}"#)
.create_async()
.await;
let client = LlamaClient::new(url, "test-model".to_string(), "test-model".to_string());
let embedding = client.get_embedding("hello").await?;
assert_eq!(embedding, vec![0.1, 0.2, 0.3]);
mock.assert_async().await;
Ok(())
}
#[tokio::test]
async fn test_completion() -> Result<()> {
let mut server = Server::new_async().await;
let url = server.url();
let mock = server
.mock("POST", "/completion")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(r#"{"content":"Hello, world!"}"#)
.create_async()
.await;
let client = LlamaClient::new(url, "test-model".to_string(), "test-model".to_string());
let result = client.completion("hi", 10, 0.7).await?;
assert_eq!(result, "Hello, world!");
mock.assert_async().await;
Ok(())
}
}