Skip to content

Commit a68b842

Browse files
authored
Merge pull request #5 from opensass/emb-mod
feat: impl `list-embedding-models` endpoint && fix error && refactor
2 parents 8570f7b + 2279ecd commit a68b842

15 files changed

+416
-303
lines changed

completions

Whitespace-only changes.

src/api_key.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//! Reference: https://docs.x.ai/api/endpoints#api-key
22
3+
use crate::error::check_for_model_error;
34
use crate::traits::ApiKeyFetcher;
45
use crate::{error::XaiError, traits::ClientConfig};
56
use serde::{Deserialize, Serialize};
@@ -46,12 +47,16 @@ where
4647
.await?;
4748

4849
if response.status().is_success() {
49-
let api_key_info = response.json::<ApiKeyInfo>().await?;
50-
Ok(api_key_info)
50+
let chat_completion = response.json::<ApiKeyInfo>().await?;
51+
Ok(chat_completion)
5152
} else {
52-
Err(XaiError::Http(
53-
response.error_for_status().unwrap_err().to_string(),
54-
))
53+
let error_body = response.text().await.unwrap_or_else(|_| "".to_string());
54+
55+
if let Some(model_error) = check_for_model_error(&error_body) {
56+
return Err(model_error);
57+
}
58+
59+
Err(XaiError::Http(error_body))
5560
}
5661
}
5762
}

src/chat_compl.rs

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,44 @@
11
//! Reference: https://docs.x.ai/api/endpoints#chat-completions
22
3+
use crate::error::check_for_model_error;
34
use crate::error::XaiError;
45
use crate::traits::ChatCompletionsFetcher;
56
use crate::traits::ClientConfig;
67
use serde::{Deserialize, Serialize};
8+
use std::collections::HashMap;
79

810
#[derive(Debug, Clone, Serialize, Deserialize)]
911
pub struct ChatCompletionRequest {
1012
pub model: String,
1113
pub messages: Vec<Message>,
14+
pub stream: bool,
15+
#[serde(skip_serializing_if = "Option::is_none")]
1216
pub temperature: Option<f32>,
17+
#[serde(skip_serializing_if = "Option::is_none")]
1318
pub max_tokens: Option<u32>,
19+
#[serde(skip_serializing_if = "Option::is_none")]
1420
pub frequency_penalty: Option<f32>,
21+
#[serde(skip_serializing_if = "Option::is_none")]
1522
pub presence_penalty: Option<f32>,
23+
#[serde(skip_serializing_if = "Option::is_none")]
1624
pub n: Option<u32>,
25+
#[serde(skip_serializing_if = "Option::is_none")]
1726
pub stop: Option<Vec<String>>,
18-
pub stream: Option<bool>,
27+
#[serde(skip_serializing_if = "Option::is_none")]
1928
pub logprobs: Option<bool>,
29+
#[serde(skip_serializing_if = "Option::is_none")]
2030
pub top_p: Option<f32>,
31+
#[serde(skip_serializing_if = "Option::is_none")]
2132
pub top_logprobs: Option<u32>,
33+
#[serde(skip_serializing_if = "Option::is_none")]
2234
pub seed: Option<u32>,
35+
#[serde(skip_serializing_if = "Option::is_none")]
2336
pub user: Option<String>,
37+
#[serde(skip_serializing_if = "Option::is_none")]
38+
pub logit_bias: Option<HashMap<u32, f32>>,
2439
}
2540

26-
#[derive(Debug, Clone, Serialize, Deserialize)]
41+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
2742
pub struct Message {
2843
pub role: String,
2944
pub content: String,
@@ -36,7 +51,9 @@ pub struct ChatCompletionResponse {
3651
pub created: u64,
3752
pub model: String,
3853
pub choices: Vec<Choice>,
54+
#[serde(skip_serializing_if = "Option::is_none")]
3955
pub usage: Option<Usage>,
56+
#[serde(skip_serializing_if = "Option::is_none")]
4057
pub system_fingerprint: Option<String>,
4158
}
4259

@@ -76,12 +93,13 @@ where
7693
presence_penalty: None,
7794
n: None,
7895
stop: None,
79-
stream: None,
96+
stream: false,
8097
logprobs: None,
8198
top_p: None,
8299
top_logprobs: None,
83100
seed: None,
84101
user: None,
102+
logit_bias: None,
85103
},
86104
}
87105
}
@@ -96,6 +114,16 @@ where
96114
self
97115
}
98116

117+
pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
118+
self.request.frequency_penalty = Some(frequency_penalty);
119+
self
120+
}
121+
122+
pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
123+
self.request.presence_penalty = Some(presence_penalty);
124+
self
125+
}
126+
99127
pub fn n(mut self, n: u32) -> Self {
100128
self.request.n = Some(n);
101129
self
@@ -106,6 +134,41 @@ where
106134
self
107135
}
108136

137+
pub fn stream(mut self, stream: bool) -> Self {
138+
self.request.stream = stream;
139+
self
140+
}
141+
142+
pub fn logprobs(mut self, logprobs: bool) -> Self {
143+
self.request.logprobs = Some(logprobs);
144+
self
145+
}
146+
147+
pub fn top_p(mut self, top_p: f32) -> Self {
148+
self.request.top_p = Some(top_p);
149+
self
150+
}
151+
152+
pub fn top_logprobs(mut self, top_logprobs: u32) -> Self {
153+
self.request.top_logprobs = Some(top_logprobs);
154+
self
155+
}
156+
157+
pub fn seed(mut self, seed: u32) -> Self {
158+
self.request.seed = Some(seed);
159+
self
160+
}
161+
162+
pub fn user(mut self, user: String) -> Self {
163+
self.request.user = Some(user);
164+
self
165+
}
166+
167+
pub fn logit_bias(mut self, logit_bias: HashMap<u32, f32>) -> Self {
168+
self.request.logit_bias = Some(logit_bias);
169+
self
170+
}
171+
109172
pub fn build(self) -> Result<ChatCompletionRequest, XaiError> {
110173
Ok(self.request)
111174
}
@@ -130,9 +193,13 @@ where
130193
let chat_completion = response.json::<ChatCompletionResponse>().await?;
131194
Ok(chat_completion)
132195
} else {
133-
Err(XaiError::Http(
134-
response.error_for_status().unwrap_err().to_string(),
135-
))
196+
let error_body = response.text().await.unwrap_or_else(|_| "".to_string());
197+
198+
if let Some(model_error) = check_for_model_error(&error_body) {
199+
return Err(model_error);
200+
}
201+
202+
Err(XaiError::Http(error_body))
136203
}
137204
}
138205
}

src/client.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ impl ClientConfig for XaiClient {
3434
let builder = self
3535
.http_client
3636
.request(method, &url)
37-
.header("Authorization", format!("Bearer {}", api_key));
37+
.header("Authorization", format!("Bearer {}", api_key))
38+
.header("Content-Type", "application/json");
3839
Ok(builder)
3940
}
4041
}

src/completions.rs

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,45 @@
11
//! Reference: https://docs.x.ai/api/endpoints#completions
22
3+
use crate::error::check_for_model_error;
34
use crate::error::XaiError;
45
use crate::traits::{ClientConfig, CompletionsFetcher};
56
use reqwest::Method;
67
use serde::{Deserialize, Serialize};
8+
use std::collections::HashMap;
79

810
#[derive(Debug, Clone, Serialize, Deserialize)]
911
pub struct CompletionsRequest {
1012
pub model: String,
1113
pub prompt: String,
14+
#[serde(skip_serializing_if = "Option::is_none")]
1215
pub best_of: Option<u32>,
16+
#[serde(skip_serializing_if = "Option::is_none")]
1317
pub echo: Option<bool>,
18+
#[serde(skip_serializing_if = "Option::is_none")]
1419
pub frequency_penalty: Option<f32>,
15-
pub logit_bias: Option<std::collections::HashMap<String, i32>>,
20+
#[serde(skip_serializing_if = "Option::is_none")]
21+
pub logit_bias: Option<HashMap<String, i32>>,
22+
#[serde(skip_serializing_if = "Option::is_none")]
1623
pub logprobs: Option<u32>,
24+
#[serde(skip_serializing_if = "Option::is_none")]
1725
pub max_tokens: Option<u32>,
26+
#[serde(skip_serializing_if = "Option::is_none")]
1827
pub n: Option<u32>,
28+
#[serde(skip_serializing_if = "Option::is_none")]
1929
pub presence_penalty: Option<f32>,
30+
#[serde(skip_serializing_if = "Option::is_none")]
2031
pub seed: Option<u32>,
32+
#[serde(skip_serializing_if = "Option::is_none")]
2133
pub stop: Option<Vec<String>>,
34+
#[serde(skip_serializing_if = "Option::is_none")]
2235
pub stream: Option<bool>,
36+
#[serde(skip_serializing_if = "Option::is_none")]
2337
pub suffix: Option<String>,
38+
#[serde(skip_serializing_if = "Option::is_none")]
2439
pub temperature: Option<f32>,
40+
#[serde(skip_serializing_if = "Option::is_none")]
2541
pub top_p: Option<f32>,
42+
#[serde(skip_serializing_if = "Option::is_none")]
2643
pub user: Option<String>,
2744
}
2845

@@ -101,7 +118,7 @@ where
101118
self
102119
}
103120

104-
pub fn logit_bias(mut self, logit_bias: std::collections::HashMap<String, i32>) -> Self {
121+
pub fn logit_bias(mut self, logit_bias: HashMap<String, i32>) -> Self {
105122
self.request.logit_bias = Some(logit_bias);
106123
self
107124
}
@@ -162,6 +179,12 @@ where
162179
}
163180

164181
pub fn build(self) -> Result<CompletionsRequest, XaiError> {
182+
if self.request.model.trim().is_empty() {
183+
return Err(XaiError::Validation("Model is required".to_string()));
184+
}
185+
if self.request.prompt.trim().is_empty() {
186+
return Err(XaiError::Validation("Prompt is required".to_string()));
187+
}
165188
Ok(self.request)
166189
}
167190
}
@@ -176,18 +199,22 @@ where
176199
) -> Result<CompletionsResponse, XaiError> {
177200
let response = self
178201
.client
179-
.request(Method::POST, "/v1/completions")?
202+
.request(Method::POST, "completions")?
180203
.json(&request)
181204
.send()
182205
.await?;
183206

184207
if response.status().is_success() {
185-
let completions = response.json::<CompletionsResponse>().await?;
186-
Ok(completions)
208+
let chat_completion = response.json::<CompletionsResponse>().await?;
209+
Ok(chat_completion)
187210
} else {
188-
Err(XaiError::Http(
189-
response.error_for_status().unwrap_err().to_string(),
190-
))
211+
let error_body = response.text().await.unwrap_or_else(|_| "".to_string());
212+
213+
if let Some(model_error) = check_for_model_error(&error_body) {
214+
return Err(model_error);
215+
}
216+
217+
Err(XaiError::Http(error_body))
191218
}
192219
}
193220
}

src/embedding.rs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
//! Reference: https://docs.x.ai/api/endpoints#create-embeddings
2+
3+
use crate::error::check_for_model_error;
14
use crate::error::XaiError;
25
use crate::traits::{ClientConfig, EmbeddingFetcher};
36
use reqwest::Method;
@@ -66,18 +69,22 @@ where
6669
) -> Result<EmbeddingResponse, XaiError> {
6770
let response = self
6871
.client
69-
.request(Method::POST, "/v1/embeddings")?
72+
.request(Method::POST, "embeddings")?
7073
.json(&request)
7174
.send()
7275
.await?;
7376

7477
if response.status().is_success() {
75-
let embedding_response = response.json::<EmbeddingResponse>().await?;
76-
Ok(embedding_response)
78+
let chat_completion = response.json::<EmbeddingResponse>().await?;
79+
Ok(chat_completion)
7780
} else {
78-
Err(XaiError::Http(
79-
response.error_for_status().unwrap_err().to_string(),
80-
))
81+
let error_body = response.text().await.unwrap_or_else(|_| "".to_string());
82+
83+
if let Some(model_error) = check_for_model_error(&error_body) {
84+
return Err(model_error);
85+
}
86+
87+
Err(XaiError::Http(error_body))
8188
}
8289
}
8390
}

src/embedding_mod.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
//! Reference: https://docs.x.ai/api/endpoints#list-embedding-models
2+
3+
use crate::error::XaiError;
4+
use crate::traits::{ClientConfig, EmbeddingModelsFetcher};
5+
use serde::{Deserialize, Serialize};
6+
7+
#[derive(Debug, Clone, Serialize, Deserialize)]
8+
pub struct EmbeddingModelsResponse {
9+
pub models: Vec<EmbeddingModel>,
10+
}
11+
12+
#[derive(Debug, Clone, Serialize, Deserialize)]
13+
pub struct EmbeddingModel {
14+
pub created: u64,
15+
pub id: String,
16+
pub input_modalities: Vec<String>,
17+
pub object: String,
18+
pub owned_by: String,
19+
pub prompt_image_token_price: u64,
20+
pub prompt_text_token_price: u64,
21+
pub version: String,
22+
}
23+
24+
impl<T> EmbeddingModelsFetcher for T
25+
where
26+
T: ClientConfig + Send + Sync,
27+
{
28+
async fn list_embedding_models(&self) -> Result<EmbeddingModelsResponse, XaiError> {
29+
let response = self
30+
.request(reqwest::Method::GET, "embedding-models")?
31+
.send()
32+
.await?;
33+
34+
if response.status().is_success() {
35+
let models_response = response.json::<EmbeddingModelsResponse>().await?;
36+
Ok(models_response)
37+
} else {
38+
Err(XaiError::Http(
39+
response.error_for_status().unwrap_err().to_string(),
40+
))
41+
}
42+
}
43+
}

0 commit comments

Comments
 (0)