Skip to content

Commit 8a2856c

Browse files
committed
feat: add system prompts to req && add unit tests (#17)
1 parent 7420afa commit 8a2856c

File tree

16 files changed

+565
-15
lines changed

16 files changed

+565
-15
lines changed

.circleci/config.yml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Rust CircleCI 2.1 configuration file
2+
version: 2.1
3+
aliases:
4+
- &rust_container
5+
docker:
6+
- image: cimg/rust:1.86.0
7+
jobs:
8+
testing:
9+
<<: *rust_container
10+
steps:
11+
- checkout
12+
- run:
13+
name: Run Tests
14+
command: cargo test --all-features
15+
16+
workflows:
17+
version: 2
18+
test:
19+
jobs:
20+
- testing

.github/workflows/ci.yml

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,18 @@ name: CI
22

33
on:
44
push:
5-
branches: [ "main" ]
5+
branches: ["main"]
66
pull_request:
7-
branches: [ "main" ]
7+
branches: ["main"]
88

99
env:
1010
CARGO_TERM_COLOR: always
1111

1212
jobs:
1313
build:
14-
1514
runs-on: ubuntu-latest
1615

1716
steps:
18-
- uses: actions/checkout@v3
19-
- name: Build
20-
run: cargo build --verbose
21-
- name: Run tests
22-
run: cargo test --verbose
17+
- uses: actions/checkout@v3
18+
- name: Build
19+
run: cargo build --all-features --release

src/chat.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ use reqwest::Method;
1616
pub struct Chat {
1717
pub model: Model,
1818
pub messages: Vec<Message>,
19+
#[builder(setter(into, strip_option), default)]
20+
pub system: Option<Vec<Message>>,
1921
}
2022

2123
#[derive(Clone)]
@@ -29,9 +31,14 @@ impl Chats {
2931
parts: params.messages.iter().map(|msg| msg.to_part()).collect(),
3032
};
3133

34+
let system_instruction = params.system.as_ref().map(|messages| Content {
35+
parts: messages.iter().map(|msg| msg.to_part()).collect(),
36+
});
37+
3238
let request_body = GeminiRequest {
3339
model: params.model.to_string(),
3440
contents: vec![content],
41+
system_instruction,
3542
config: None,
3643
};
3744

src/client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use std::sync::{Arc, RwLock};
1818

1919
const GEMINI_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
2020

21-
#[derive(Clone)]
21+
#[derive(Debug, Clone)]
2222
#[allow(dead_code)]
2323
pub struct Client {
2424
http_client: Arc<HttpClient>,

src/imagen.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,24 @@ pub struct Images {
2020
pub struct ImageGen {
2121
pub model: Model,
2222
pub input: Message,
23+
#[builder(setter(into, strip_option), default)]
24+
pub system: Option<Vec<Message>>,
2325
}
2426

2527
impl Images {
2628
pub async fn generate(&self, params: ImageGen) -> Result<Vec<u8>> {
2729
let content = Content {
2830
parts: vec![params.input.to_part()],
2931
};
32+
33+
let system_instruction = params.system.as_ref().map(|messages| Content {
34+
parts: messages.iter().map(|msg| msg.to_part()).collect(),
35+
});
36+
3037
let request_body = GeminiRequest {
3138
model: params.model.to_string(),
3239
contents: vec![content],
40+
system_instruction,
3341
config: Some(GenerationConfig {
3442
response_modalities: vec!["Text".into(), "Image".into()],
3543
}),

src/messages.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::requests::Part;
22

3-
#[derive(Debug, Clone)]
3+
#[derive(Debug, Clone, PartialEq)]
44
pub enum Content {
55
Text(String),
66
}
@@ -11,7 +11,7 @@ impl Default for Content {
1111
}
1212
}
1313

14-
#[derive(Debug, Clone)]
14+
#[derive(Debug, Clone, PartialEq)]
1515
pub enum Message {
1616
User {
1717
content: Content,

src/models.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use derive_builder::Builder;
77
use reqwest::Method;
88
use std::str::FromStr;
99

10-
#[derive(Debug, Clone, Default)]
10+
#[derive(Debug, Clone, Default, PartialEq)]
1111
pub enum Model {
1212
Pro25Preview,
1313
#[default]

src/requests.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ pub struct GeminiRequest {
1111

1212
#[serde(rename = "generationConfig", skip_serializing_if = "Option::is_none")]
1313
pub config: Option<GenerationConfig>,
14+
15+
/// Optional system-level instruction.
16+
#[serde(skip_serializing_if = "Option::is_none")]
17+
pub system_instruction: Option<Content>,
1418
}
1519

1620
/// Request structure for content embedding.

src/responses.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pub struct BatchEmbedContentsResponse {
1616
}
1717

1818
/// Structure representing embedding information.
19-
#[derive(Debug, Deserialize)]
19+
#[derive(Debug, Deserialize, Default, Clone)]
2020
pub struct Embedding {
2121
/// List of values for the embedding.
2222
pub values: Vec<f64>,

src/stream.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,22 @@ pub struct Streaming {
1919
pub struct Stream {
2020
pub model: Model,
2121
pub input: Message,
22+
#[builder(setter(into, strip_option), default)]
23+
pub system: Option<Vec<Message>>,
2224
}
2325

2426
impl Streaming {
2527
pub async fn generate(&self, params: Stream) -> Result<Response> {
28+
let system_instruction = params.system.as_ref().map(|messages| Content {
29+
parts: messages.iter().map(|msg| msg.to_part()).collect(),
30+
});
31+
2632
let request_body = GeminiRequest {
2733
model: params.model.to_string(),
2834
contents: vec![Content {
2935
parts: vec![params.input.to_part()],
3036
}],
37+
system_instruction,
3138
config: None,
3239
};
3340

src/tokens.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,21 @@ pub struct Tokens {
1919
pub struct Token {
2020
model: Model,
2121
input: Message,
22+
system: Vec<Message>,
2223
}
2324

2425
impl Tokens {
2526
pub async fn count(&self, params: Token) -> Result<usize> {
27+
let system_instruction = Content {
28+
parts: params.system.iter().map(|msg| msg.to_part()).collect(),
29+
};
30+
2631
let request_body = GeminiRequest {
2732
model: params.model.to_string(),
2833
contents: vec![Content {
2934
parts: vec![params.input.to_part()],
3035
}],
36+
system_instruction: Some(system_instruction),
3137
config: None,
3238
};
3339

src/utils.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ pub fn extract_image_or_text(parts: &[Part]) -> Result<Vec<u8>> {
106106
}) {
107107
let image_bytes = STANDARD.decode(&base64_data)?;
108108
Ok(image_bytes)
109-
}
110-
else if let Some(text) = parts.iter().find_map(|part| match part {
109+
} else if let Some(text) = parts.iter().find_map(|part| match part {
111110
Part::Text { text } => Some(text.clone()),
112111
_ => None,
113112
}) {

src/vision.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::messages::Content;
22
use crate::messages::Message;
33
use crate::models::Model;
4+
use crate::requests::Content as ReqContent;
45
use crate::requests::GeminiRequest;
56
use crate::requests::ImageContent;
67
use crate::requests::Part;
@@ -22,12 +23,18 @@ pub struct Visions {
2223
pub struct Vision {
2324
pub input: Message,
2425
pub image: Message,
26+
#[builder(setter(into, strip_option), default)]
27+
pub system: Option<Vec<Message>>,
2528
}
2629

2730
impl Visions {
2831
pub async fn generate(&self, params: Vision) -> Result<String> {
2932
let input_part = params.input.to_part();
3033

34+
let system_instruction = params.system.as_ref().map(|messages| ReqContent {
35+
parts: messages.iter().map(|msg| msg.to_part()).collect(),
36+
});
37+
3138
let image_data = match &params.image {
3239
Message::Tool { content } => content.clone(),
3340
Message::User { content, .. }
@@ -47,6 +54,7 @@ impl Visions {
4754
contents: vec![crate::requests::Content {
4855
parts: vec![input_part, image_part],
4956
}],
57+
system_instruction,
5058
config: None,
5159
};
5260

0 commit comments

Comments
 (0)