Skip to content

Commit 0b98669

Browse files
authored
feat(vertexai): Add responseModality (#17326)
* Add responseModality * review comments
1 parent e53c707 commit 0b98669

File tree

6 files changed

+105
-27
lines changed

6 files changed

+105
-27
lines changed

packages/firebase_vertexai/firebase_vertexai/example/lib/main.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void main() async {
4141

4242
var vertexInstance =
4343
FirebaseVertexAI.instanceFor(auth: FirebaseAuth.instance);
44-
final model = vertexInstance.generativeModel(model: 'gemini-1.5-flash');
44+
final model = vertexInstance.generativeModel(model: 'gemini-2.0-flash');
4545

4646
runApp(GenerativeAISample(model: model));
4747
}

packages/firebase_vertexai/firebase_vertexai/example/lib/pages/chat_page.dart

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,17 @@ class _ChatPageState extends State<ChatPage> {
9595
const SizedBox.square(
9696
dimension: 15,
9797
),
98+
if (!_loading)
99+
IconButton(
100+
onPressed: () async {
101+
await _imageResponse(_textController.text);
102+
},
103+
icon: Icon(
104+
Icons.image,
105+
color: Theme.of(context).colorScheme.primary,
106+
),
107+
tooltip: 'Image response',
108+
),
98109
if (!_loading)
99110
IconButton(
100111
onPressed: () async {
@@ -152,6 +163,58 @@ class _ChatPageState extends State<ChatPage> {
152163
}
153164
}
154165

166+
Future<void> _imageResponse(String message) async {
167+
setState(() {
168+
_loading = true;
169+
});
170+
171+
try {
172+
_messages.add(MessageData(text: message, fromUser: true));
173+
var response = await widget.model.generateContent(
174+
[Content.text(message)],
175+
generationConfig: GenerationConfig(
176+
responseModalities: [
177+
ResponseModalities.text,
178+
ResponseModalities.image,
179+
],
180+
),
181+
);
182+
var inlineDatas = response.inlineDataParts.toList();
183+
184+
if (inlineDatas.isEmpty) {
185+
_showError('No response from API.');
186+
return;
187+
} else {
188+
for (final inlineData in inlineDatas) {
189+
if (inlineData.mimeType.contains('image')) {
190+
_messages.add(
191+
MessageData(
192+
text: response.text,
193+
image: Image.memory(inlineData.bytes),
194+
fromUser: false,
195+
),
196+
);
197+
}
198+
}
199+
setState(() {
200+
_loading = false;
201+
_scrollDown();
202+
});
203+
}
204+
} catch (e) {
205+
_showError(e.toString());
206+
setState(() {
207+
_loading = false;
208+
});
209+
} finally {
210+
_textController.clear();
211+
setState(() {
212+
_loading = false;
213+
});
214+
_textFieldFocus.requestFocus();
215+
}
216+
}
217+
155218
void _showError(String message) {
156219
showDialog<void>(
157220
context: context,

packages/firebase_vertexai/firebase_vertexai/lib/firebase_vertexai.dart

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ export 'src/api.dart'
2727
HarmProbability,
2828
HarmBlockMethod,
2929
PromptFeedback,
30+
ResponseModalities,
3031
SafetyRating,
3132
SafetySetting,
3233
// TODO(cynthiajiang) remove in next breaking change.
@@ -72,7 +73,6 @@ export 'src/live_api.dart'
7273
show
7374
LiveGenerationConfig,
7475
SpeechConfig,
75-
ResponseModalities,
7676
LiveServerMessage,
7777
LiveServerContent,
7878
LiveServerToolCall,

packages/firebase_vertexai/firebase_vertexai/lib/src/api.dart

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,15 @@ final class GenerateContentResponse {
108108
Iterable<FunctionCall> get functionCalls =>
109109
candidates.firstOrNull?.content.parts.whereType<FunctionCall>() ??
110110
const [];
111+
112+
/// The inline data parts of the first candidate in [candidates], if any.
113+
///
114+
/// Returns an empty list if there are no candidates, or if the first
115+
/// candidate has no [InlineDataPart] parts. There is no error thrown if the
116+
/// prompt or response were blocked.
117+
Iterable<InlineDataPart> get inlineDataParts =>
118+
candidates.firstOrNull?.content.parts.whereType<InlineDataPart>() ??
119+
const [];
111120
}
112121

113122
/// Feedback metadata of a prompt specified in a [GenerativeModel] request.
@@ -656,6 +665,24 @@ enum HarmBlockMethod {
656665
Object toJson() => _jsonString;
657666
}
658667

668+
/// The available response modalities.
669+
enum ResponseModalities {
670+
/// Text response modality.
671+
text('TEXT'),
672+
673+
/// Image response modality.
674+
image('IMAGE'),
675+
676+
/// Audio response modality.
677+
audio('AUDIO');
678+
679+
const ResponseModalities(this._jsonString);
680+
final String _jsonString;
681+
682+
/// Convert to json format
683+
String toJson() => _jsonString;
684+
}
685+
659686
/// Configuration options for model generation and outputs.
660687
abstract class BaseGenerationConfig {
661688
// ignore: public_member_api_docs
@@ -667,6 +694,7 @@ abstract class BaseGenerationConfig {
667694
this.topK,
668695
this.presencePenalty,
669696
this.frequencyPenalty,
697+
this.responseModalities,
670698
});
671699

672700
/// Number of generated responses to return.
@@ -743,6 +771,9 @@ abstract class BaseGenerationConfig {
743771
/// for more details.
744772
final double? frequencyPenalty;
745773

774+
/// The list of desired response modalities.
775+
final List<ResponseModalities>? responseModalities;
776+
746777
// ignore: public_member_api_docs
747778
Map<String, Object?> toJson() => {
748779
if (candidateCount case final candidateCount?)
@@ -756,6 +787,9 @@ abstract class BaseGenerationConfig {
756787
'presencePenalty': presencePenalty,
757788
if (frequencyPenalty case final frequencyPenalty?)
758789
'frequencyPenalty': frequencyPenalty,
790+
if (responseModalities case final responseModalities?)
791+
'responseModalities':
792+
responseModalities.map((modality) => modality.toJson()).toList(),
759793
};
760794
}
761795

@@ -771,6 +805,7 @@ final class GenerationConfig extends BaseGenerationConfig {
771805
super.topK,
772806
super.presencePenalty,
773807
super.frequencyPenalty,
808+
super.responseModalities,
774809
this.responseMimeType,
775810
this.responseSchema,
776811
});
@@ -996,6 +1031,9 @@ SafetyRating _parseSafetyRating(Object? jsonObject) {
9961031
if (jsonObject is! Map) {
9971032
throw unhandledFormat('SafetyRating', jsonObject);
9981033
}
1034+
if (jsonObject.isEmpty) {
1035+
return SafetyRating(HarmCategory.unknown, HarmProbability.unknown);
1036+
}
9991037
return SafetyRating(HarmCategory._parseValue(jsonObject['category']),
10001038
HarmProbability._parseValue(jsonObject['probability']),
10011039
probabilityScore: jsonObject['probabilityScore'] as double?,

packages/firebase_vertexai/firebase_vertexai/lib/src/live_api.dart

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -71,30 +71,12 @@ class SpeechConfig {
7171
};
7272
}
7373

74-
/// The available response modalities.
75-
enum ResponseModalities {
76-
/// Text response modality.
77-
text('TEXT'),
78-
79-
/// Image response modality.
80-
image('IMAGE'),
81-
82-
/// Audio response modality.
83-
audio('AUDIO');
84-
85-
const ResponseModalities(this._jsonString);
86-
final String _jsonString;
87-
88-
/// Convert to json format
89-
String toJson() => _jsonString;
90-
}
91-
9274
/// Configures live generation settings.
9375
final class LiveGenerationConfig extends BaseGenerationConfig {
9476
// ignore: public_member_api_docs
9577
LiveGenerationConfig({
9678
this.speechConfig,
97-
this.responseModalities,
79+
super.responseModalities,
9880
super.candidateCount,
9981
super.maxOutputTokens,
10082
super.temperature,
@@ -107,17 +89,11 @@ final class LiveGenerationConfig extends BaseGenerationConfig {
10789
/// The speech configuration.
10890
final SpeechConfig? speechConfig;
10991

110-
/// The list of desired response modalities.
111-
final List<ResponseModalities>? responseModalities;
112-
11392
@override
11493
Map<String, Object?> toJson() => {
11594
...super.toJson(),
11695
if (speechConfig case final speechConfig?)
11796
'speechConfig': speechConfig.toJson(),
118-
if (responseModalities case final responseModalities?)
119-
'responseModalities':
120-
responseModalities.map((modality) => modality.toJson()).toList(),
12197
};
12298
}
12399

packages/firebase_vertexai/firebase_vertexai/test/live_test.dart

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414
import 'dart:typed_data';
1515

16+
import 'package:firebase_vertexai/src/api.dart';
1617
import 'package:firebase_vertexai/src/content.dart';
1718
import 'package:firebase_vertexai/src/error.dart';
1819
import 'package:firebase_vertexai/src/live_api.dart';

0 commit comments

Comments
 (0)