Skip to content

Commit a02885c

Browse files
authored
feat: impl imagen (#16)
1 parent aed35a7 commit a02885c

File tree

15 files changed

+197
-4
lines changed

15 files changed

+197
-4
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ gems info
100100
gems list
101101
```
102102

103+
### Generate an Image
104+
105+
```sh
106+
gems imagen -t "Hi, can you create a 3d rendered image of a pig with wings and a top hat flying over a happy futuristic scifi city with lots of greenery?"
107+
```
108+
103109
## 🎨 Options
104110

105111
| Option | Description |
@@ -115,6 +121,7 @@ gems list
115121
| `generate` | Generate creative content. |
116122
| `vision` | Analyze an image and generate content from text. |
117123
| `stream` | Stream the generation of content. |
124+
| `imagen` | Generate an image. |
118125
| `count` | Count the number of tokens in a text. |
119126
| `embed` | Embed content into a specified model. |
120127
| `batch` | Batch embed multiple contents. |

src/chat.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ impl Chats {
3232
let request_body = GeminiRequest {
3333
model: params.model.to_string(),
3434
contents: vec![content],
35+
config: None,
3536
};
3637

3738
let req = self

src/cli.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ pub enum Command {
101101
Batch(Batch),
102102
Info(Info),
103103
List(List),
104+
Imagen(Imagen),
104105
}
105106

106107
#[cfg(feature = "cli")]
@@ -161,3 +162,11 @@ pub struct Vision {
161162
#[arg(short, long, default_value_t = String::from("What is this picture?"))]
162163
pub text: String,
163164
}
165+
166+
#[cfg(feature = "cli")]
167+
#[derive(Args, Debug, Clone)]
168+
pub struct Imagen {
169+
/// The text to generate image from.
170+
#[arg(short, long, default_value_t = String::from("Hi, step bro... I need help generating a happy, humble, bumble Rustacean. he's stuck in the shower and won't compile."))]
171+
pub text: String,
172+
}

src/client.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::chat::Chats;
22
use crate::embed::Embeddings;
3+
use crate::imagen::Images;
34
use crate::models::Model;
45
use crate::models::Models;
56
use crate::stream::Streaming;
@@ -108,6 +109,12 @@ impl CTrait for Client {
108109
client: self.clone(),
109110
}
110111
}
112+
113+
fn images(&self) -> Images {
114+
Images {
115+
client: self.clone(),
116+
}
117+
}
111118
}
112119

113120
#[derive(Default)]

src/imagen.rs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
use crate::client::Client;
2+
use crate::messages::Message;
3+
use crate::models::Model;
4+
use crate::requests::GenerationConfig;
5+
use crate::requests::{Content, GeminiRequest};
6+
use crate::responses::ImagenResponse;
7+
use crate::traits::CTrait;
8+
use crate::utils::extract_image_or_text;
9+
use anyhow::{anyhow, Result};
10+
use derive_builder::Builder;
11+
use reqwest::Method;
12+
13+
#[derive(Clone)]
14+
pub struct Images {
15+
pub client: Client,
16+
}
17+
18+
#[derive(Builder, Clone)]
19+
#[builder(setter(into))]
20+
pub struct ImageGen {
21+
pub model: Model,
22+
pub input: Message,
23+
}
24+
25+
impl Images {
26+
pub async fn generate(&self, params: ImageGen) -> Result<Vec<u8>> {
27+
let content = Content {
28+
parts: vec![params.input.to_part()],
29+
};
30+
let request_body = GeminiRequest {
31+
model: params.model.to_string(),
32+
contents: vec![content],
33+
config: Some(GenerationConfig {
34+
response_modalities: vec!["Text".into(), "Image".into()],
35+
}),
36+
};
37+
38+
let req = self
39+
.client
40+
.request(Method::POST, "generateContent")?
41+
.json(&request_body);
42+
43+
let res = req.send().await?;
44+
let json: ImagenResponse = res.json().await?;
45+
46+
let parts = json
47+
.candidates
48+
.ok_or_else(|| anyhow!("Missing candidates"))?
49+
.first()
50+
.ok_or_else(|| anyhow!("No candidate response"))?
51+
.content
52+
.parts
53+
.clone();
54+
55+
extract_image_or_text(&parts)
56+
}
57+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
pub mod chat;
55
pub mod client;
66
pub mod embed;
7+
pub mod imagen;
78
pub mod messages;
89
pub mod models;
910
pub mod requests;

src/main.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ async fn main() -> Result<()> {
1414
use gems::cli::{Cli, Command};
1515
use gems::embed::BatchEmbeddingBuilder;
1616
use gems::embed::EmbeddingBuilder;
17+
use gems::imagen::ImageGenBuilder;
1718
use gems::messages::Content;
1819
use gems::messages::Message;
1920
use gems::models::ModBuilder;
@@ -167,6 +168,22 @@ async fn main() -> Result<()> {
167168
let models = gemini_client.models().list().await?;
168169
models.print();
169170
}
171+
Command::Imagen(cmd) => {
172+
gemini_client.set_model(Model::FlashExpImage);
173+
174+
let params = ImageGenBuilder::default()
175+
.input(Message::User {
176+
content: Content::Text(cmd.text),
177+
name: None,
178+
})
179+
.model(Model::FlashExpImage)
180+
.build()
181+
.unwrap();
182+
183+
let image_data = gemini_client.images().generate(params).await?;
184+
185+
std::fs::write("output.png", &image_data)?;
186+
}
170187
}
171188
}
172189
Ok(())

src/models.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub enum Model {
2020
Imagen3,
2121
Veo2,
2222
Flash20Live,
23+
FlashExpImage,
2324
}
2425

2526
#[allow(clippy::to_string_trait_impl)]
@@ -36,6 +37,7 @@ impl ToString for Model {
3637
Model::Imagen3 => "imagen-3.0-generate-002",
3738
Model::Veo2 => "veo-2.0-generate-001",
3839
Model::Flash20Live => "gemini-2.0-flash-live-001",
40+
Model::FlashExpImage => "gemini-2.0-flash-exp-image-generation",
3941
}
4042
.to_string()
4143
}
@@ -56,6 +58,7 @@ impl FromStr for Model {
5658
"imagen-3.0-generate-002" => Ok(Model::Imagen3),
5759
"veo-2.0-generate-001" => Ok(Model::Veo2),
5860
"gemini-2.0-flash-live-001" => Ok(Model::Flash20Live),
61+
"gemini-2.0-flash-exp-image-generation" => Ok(Model::Flash20Live),
5962
_ => Err(anyhow!("Unknown model: {}", s)),
6063
}
6164
}

src/requests.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ pub struct GeminiRequest {
88

99
/// List of content items for generation.
1010
pub contents: Vec<Content>,
11+
12+
#[serde(rename = "generationConfig", skip_serializing_if = "Option::is_none")]
13+
pub config: Option<GenerationConfig>,
1114
}
15+
1216
/// Request structure for content embedding.
1317
#[derive(Debug, Serialize, Deserialize)]
1418
pub struct GeminiEmbedRequest {
@@ -32,7 +36,7 @@ pub struct Content {
3236
}
3337

3438
/// Define an enum to represent different types of parts in the content.
35-
#[derive(Debug, Serialize, Deserialize)]
39+
#[derive(Debug, Serialize, Deserialize, Clone)]
3640
#[serde(untagged)]
3741
pub enum Part {
3842
/// Represents a text part in the content.
@@ -64,10 +68,16 @@ pub struct Candidate {
6468
}
6569

6670
/// Structure representing the image part of the Gemini request.
67-
#[derive(Debug, Serialize, Deserialize)]
71+
#[derive(Debug, Serialize, Deserialize, Clone)]
6872
pub struct ImageContent {
6973
/// The MIME type of the image.
7074
pub mime_type: String,
7175
/// The actual image data in a base64-encoded string.
7276
pub data: String,
7377
}
78+
79+
#[derive(Debug, Serialize, Deserialize)]
80+
pub struct GenerationConfig {
81+
#[serde(rename = "responseModalities")]
82+
pub response_modalities: Vec<String>,
83+
}

src/responses.rs

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::requests::Candidate;
1+
use crate::requests::Candidate as ReqCandidate;
22
use serde::{Deserialize, Serialize};
33

44
/// Response structure for content embedding.
@@ -120,5 +120,62 @@ impl ModelsResponse {
120120
#[derive(Debug, Deserialize)]
121121
pub struct GeminiResponse {
122122
/// List of generated candidates.
123+
pub candidates: Option<Vec<ReqCandidate>>,
124+
}
125+
126+
#[derive(Debug, Serialize, Deserialize)]
127+
#[serde(rename_all = "camelCase")]
128+
pub struct ImagenResponse {
123129
pub candidates: Option<Vec<Candidate>>,
130+
pub usage_metadata: Option<UsageMetadata>,
131+
pub model_version: Option<String>,
132+
}
133+
134+
#[derive(Debug, Serialize, Deserialize)]
135+
#[serde(rename_all = "camelCase")]
136+
pub struct Candidate {
137+
pub content: Content,
138+
pub finish_reason: Option<String>,
139+
pub index: Option<i32>,
140+
}
141+
142+
#[derive(Debug, Serialize, Deserialize)]
143+
#[serde(rename_all = "camelCase")]
144+
pub struct Content {
145+
pub parts: Vec<Part>,
146+
pub role: Option<String>,
147+
}
148+
149+
#[derive(Debug, Serialize, Deserialize, Clone)]
150+
#[serde(untagged)]
151+
pub enum Part {
152+
Text {
153+
text: String,
154+
},
155+
Image {
156+
#[serde(rename = "inlineData")]
157+
inline_data: ImageContent,
158+
},
159+
}
160+
161+
#[derive(Debug, Serialize, Deserialize, Clone)]
162+
#[serde(rename_all = "camelCase")]
163+
pub struct ImageContent {
164+
pub mime_type: String,
165+
pub data: String,
166+
}
167+
168+
#[derive(Debug, Serialize, Deserialize)]
169+
#[serde(rename_all = "camelCase")]
170+
pub struct UsageMetadata {
171+
pub prompt_token_count: Option<i32>,
172+
pub total_token_count: Option<i32>,
173+
pub prompt_tokens_details: Option<Vec<PromptTokenDetail>>,
174+
}
175+
176+
#[derive(Debug, Serialize, Deserialize)]
177+
#[serde(rename_all = "camelCase")]
178+
pub struct PromptTokenDetail {
179+
pub modality: Option<String>,
180+
pub token_count: Option<i32>,
124181
}

src/stream.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ impl Streaming {
2828
contents: vec![Content {
2929
parts: vec![params.input.to_part()],
3030
}],
31+
config: None,
3132
};
3233

3334
let req = self

src/tokens.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ impl Tokens {
2828
contents: vec![Content {
2929
parts: vec![params.input.to_part()],
3030
}],
31+
config: None,
3132
};
3233

3334
let req = self

src/traits.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::chat::Chats;
22
use crate::embed::Embeddings;
3+
use crate::imagen::Images;
34
use crate::models::Model;
45
use crate::models::Models;
56
use crate::stream::Streaming;
@@ -21,4 +22,5 @@ pub trait CTrait {
2122
fn vision(&self) -> Visions;
2223
fn stream(&self) -> Streaming;
2324
fn models(&self) -> Models;
25+
fn images(&self) -> Images;
2426
}

src/utils.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use anyhow::Result;
1+
use crate::responses::Part;
2+
use anyhow::{anyhow, Result};
23
use base64::{engine::general_purpose::STANDARD, Engine as _};
34
use std::fs::File;
45
use std::io::Read;
@@ -97,3 +98,21 @@ pub fn load_and_encode_image(file_path: &str) -> Result<String> {
9798
let base64_encoded = STANDARD.encode(&buffer);
9899
Ok(base64_encoded)
99100
}
101+
102+
pub fn extract_image_or_text(parts: &[Part]) -> Result<Vec<u8>> {
103+
if let Some(base64_data) = parts.iter().find_map(|part| match part {
104+
Part::Image { inline_data } => Some(inline_data.data.clone()),
105+
_ => None,
106+
}) {
107+
let image_bytes = STANDARD.decode(&base64_data)?;
108+
Ok(image_bytes)
109+
}
110+
else if let Some(text) = parts.iter().find_map(|part| match part {
111+
Part::Text { text } => Some(text.clone()),
112+
_ => None,
113+
}) {
114+
Err(anyhow!("Expected image but got only text: {}", text))
115+
} else {
116+
Err(anyhow!("No image or text found in response"))
117+
}
118+
}

src/vision.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ impl Visions {
4747
contents: vec![crate::requests::Content {
4848
parts: vec![input_part, image_part],
4949
}],
50+
config: None,
5051
};
5152

5253
let req = self

0 commit comments

Comments
 (0)