Skip to content

feat: impl list-embedding-models endpoint && fix error && refactor #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file removed completions
Empty file.
15 changes: 10 additions & 5 deletions src/api_key.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Reference: https://docs.x.ai/api/endpoints#api-key

use crate::error::check_for_model_error;
use crate::traits::ApiKeyFetcher;
use crate::{error::XaiError, traits::ClientConfig};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -46,12 +47,16 @@ where
.await?;

if response.status().is_success() {
let api_key_info = response.json::<ApiKeyInfo>().await?;
Ok(api_key_info)
let chat_completion = response.json::<ApiKeyInfo>().await?;
Ok(chat_completion)
} else {
Err(XaiError::Http(
response.error_for_status().unwrap_err().to_string(),
))
let error_body = response.text().await.unwrap_or_else(|_| "".to_string());

if let Some(model_error) = check_for_model_error(&error_body) {
return Err(model_error);
}

Err(XaiError::Http(error_body))
}
}
}
79 changes: 73 additions & 6 deletions src/chat_compl.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,44 @@
//! Reference: https://docs.x.ai/api/endpoints#chat-completions

use crate::error::check_for_model_error;
use crate::error::XaiError;
use crate::traits::ChatCompletionsFetcher;
use crate::traits::ClientConfig;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<Message>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<u32, f32>>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
Expand All @@ -36,7 +51,9 @@ pub struct ChatCompletionResponse {
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}

Expand Down Expand Up @@ -76,12 +93,13 @@ where
presence_penalty: None,
n: None,
stop: None,
stream: None,
stream: false,
logprobs: None,
top_p: None,
top_logprobs: None,
seed: None,
user: None,
logit_bias: None,
},
}
}
Expand All @@ -96,6 +114,16 @@ where
self
}

pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
self.request.frequency_penalty = Some(frequency_penalty);
self
}

pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
self.request.presence_penalty = Some(presence_penalty);
self
}

pub fn n(mut self, n: u32) -> Self {
self.request.n = Some(n);
self
Expand All @@ -106,6 +134,41 @@ where
self
}

pub fn stream(mut self, stream: bool) -> Self {
self.request.stream = stream;
self
}

pub fn logprobs(mut self, logprobs: bool) -> Self {
self.request.logprobs = Some(logprobs);
self
}

pub fn top_p(mut self, top_p: f32) -> Self {
self.request.top_p = Some(top_p);
self
}

pub fn top_logprobs(mut self, top_logprobs: u32) -> Self {
self.request.top_logprobs = Some(top_logprobs);
self
}

pub fn seed(mut self, seed: u32) -> Self {
self.request.seed = Some(seed);
self
}

pub fn user(mut self, user: String) -> Self {
self.request.user = Some(user);
self
}

pub fn logit_bias(mut self, logit_bias: HashMap<u32, f32>) -> Self {
self.request.logit_bias = Some(logit_bias);
self
}

pub fn build(self) -> Result<ChatCompletionRequest, XaiError> {
Ok(self.request)
}
Expand All @@ -130,9 +193,13 @@ where
let chat_completion = response.json::<ChatCompletionResponse>().await?;
Ok(chat_completion)
} else {
Err(XaiError::Http(
response.error_for_status().unwrap_err().to_string(),
))
let error_body = response.text().await.unwrap_or_else(|_| "".to_string());

if let Some(model_error) = check_for_model_error(&error_body) {
return Err(model_error);
}

Err(XaiError::Http(error_body))
}
}
}
3 changes: 2 additions & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ impl ClientConfig for XaiClient {
let builder = self
.http_client
.request(method, &url)
.header("Authorization", format!("Bearer {}", api_key));
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json");
Ok(builder)
}
}
Expand Down
43 changes: 35 additions & 8 deletions src/completions.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,45 @@
//! Reference: https://docs.x.ai/api/endpoints#completions

use crate::error::check_for_model_error;
use crate::error::XaiError;
use crate::traits::{ClientConfig, CompletionsFetcher};
use reqwest::Method;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionsRequest {
pub model: String,
pub prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub echo: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
pub logit_bias: Option<std::collections::HashMap<String, i32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, i32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}

Expand Down Expand Up @@ -101,7 +118,7 @@ where
self
}

pub fn logit_bias(mut self, logit_bias: std::collections::HashMap<String, i32>) -> Self {
pub fn logit_bias(mut self, logit_bias: HashMap<String, i32>) -> Self {
self.request.logit_bias = Some(logit_bias);
self
}
Expand Down Expand Up @@ -162,6 +179,12 @@ where
}

pub fn build(self) -> Result<CompletionsRequest, XaiError> {
if self.request.model.trim().is_empty() {
return Err(XaiError::Validation("Model is required".to_string()));
}
if self.request.prompt.trim().is_empty() {
return Err(XaiError::Validation("Prompt is required".to_string()));
}
Ok(self.request)
}
}
Expand All @@ -176,18 +199,22 @@ where
) -> Result<CompletionsResponse, XaiError> {
let response = self
.client
.request(Method::POST, "/v1/completions")?
.request(Method::POST, "completions")?
.json(&request)
.send()
.await?;

if response.status().is_success() {
let completions = response.json::<CompletionsResponse>().await?;
Ok(completions)
let chat_completion = response.json::<CompletionsResponse>().await?;
Ok(chat_completion)
} else {
Err(XaiError::Http(
response.error_for_status().unwrap_err().to_string(),
))
let error_body = response.text().await.unwrap_or_else(|_| "".to_string());

if let Some(model_error) = check_for_model_error(&error_body) {
return Err(model_error);
}

Err(XaiError::Http(error_body))
}
}
}
19 changes: 13 additions & 6 deletions src/embedding.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
//! Reference: https://docs.x.ai/api/endpoints#create-embeddings

use crate::error::check_for_model_error;
use crate::error::XaiError;
use crate::traits::{ClientConfig, EmbeddingFetcher};
use reqwest::Method;
Expand Down Expand Up @@ -66,18 +69,22 @@ where
) -> Result<EmbeddingResponse, XaiError> {
let response = self
.client
.request(Method::POST, "/v1/embeddings")?
.request(Method::POST, "embeddings")?
.json(&request)
.send()
.await?;

if response.status().is_success() {
let embedding_response = response.json::<EmbeddingResponse>().await?;
Ok(embedding_response)
let chat_completion = response.json::<EmbeddingResponse>().await?;
Ok(chat_completion)
} else {
Err(XaiError::Http(
response.error_for_status().unwrap_err().to_string(),
))
let error_body = response.text().await.unwrap_or_else(|_| "".to_string());

if let Some(model_error) = check_for_model_error(&error_body) {
return Err(model_error);
}

Err(XaiError::Http(error_body))
}
}
}
43 changes: 43 additions & 0 deletions src/embedding_mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//! Reference: https://docs.x.ai/api/endpoints#list-embedding-models

use crate::error::XaiError;
use crate::traits::{ClientConfig, EmbeddingModelsFetcher};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingModelsResponse {
pub models: Vec<EmbeddingModel>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingModel {
pub created: u64,
pub id: String,
pub input_modalities: Vec<String>,
pub object: String,
pub owned_by: String,
pub prompt_image_token_price: u64,
pub prompt_text_token_price: u64,
pub version: String,
}

impl<T> EmbeddingModelsFetcher for T
where
T: ClientConfig + Send + Sync,
{
async fn list_embedding_models(&self) -> Result<EmbeddingModelsResponse, XaiError> {
let response = self
.request(reqwest::Method::GET, "embedding-models")?
.send()
.await?;

if response.status().is_success() {
let models_response = response.json::<EmbeddingModelsResponse>().await?;
Ok(models_response)
} else {
Err(XaiError::Http(
response.error_for_status().unwrap_err().to_string(),
))
}
}
}
Loading