@@ -20,31 +20,40 @@ import 'jasmine';
20
20
21
21
import { ChatSession , GenerativeModel , StartChatParams , VertexAI } from './index' ;
22
22
import * as StreamFunctions from './process_stream' ;
23
- import { CountTokensRequest , CountTokensResponse , GenerateContentParams , GenerateContentResult } from './types/content' ;
24
- import * as PostRequest from './util/post_request' ;
23
+ import { CountTokensRequest , GenerateContentRequest , GenerateContentResponse , GenerateContentResult , StreamGenerateContentResult } from './types/content' ;
25
24
26
25
const PROJECT = 'test_project' ;
27
26
const LOCATION = 'test_location' ;
28
- const MODEL_ID = 'test_model_id' ;
29
27
const TEST_USER_CHAT_MESSAGE =
30
28
[ { role : 'user' , parts : [ { text : 'How are you doing today?' } ] } ] ;
31
- const TEST_MODEL_RESPONSE = [ {
32
- candidates : [
33
- {
34
- index : 1 ,
35
- content :
36
- { role : 'assistant' , parts : [ { text : 'I\m doing great! How are you?' } ] } ,
37
- finish_reason : 0 ,
38
- finish_message : '' ,
39
- safety_ratings : [ { category : 0 , threshold : 0 } ] ,
40
- } ,
41
- ] ,
29
+ const TEST_CANDIDATES = [
30
+ {
31
+ index : 1 ,
32
+ content :
33
+ { role : 'assistant' , parts : [ { text : 'I\m doing great! How are you?' } ] } ,
34
+ finish_reason : 0 ,
35
+ finish_message : '' ,
36
+ safety_ratings : [ { category : 0 , threshold : 0 } ] ,
37
+ } ,
38
+ ] ;
39
+ const TEST_MODEL_RESPONSE = {
40
+ candidates : TEST_CANDIDATES ,
42
41
usage_metadata : { prompt_token_count : 0 , candidates_token_count : 0 }
43
42
44
- } ] ;
43
+ } ;
45
44
46
45
const TEST_ENDPOINT_BASE_PATH = 'test.googleapis.com' ;
47
46
47
+ /**
48
+ * Returns a generator, used to mock the streamGenerateContent response
49
+ */
50
+ export async function *
51
+ testGenerator ( ) : AsyncGenerator < GenerateContentResponse > {
52
+ yield {
53
+ candidates : TEST_CANDIDATES ,
54
+ } ;
55
+ }
56
+
48
57
describe ( 'VertexAI' , ( ) => {
49
58
let vertexai : VertexAI ;
50
59
let model : GenerativeModel ;
@@ -59,21 +68,18 @@ describe('VertexAI', () => {
59
68
expect ( vertexai ) . toBeInstanceOf ( VertexAI ) ;
60
69
} ) ;
61
70
62
- // TODO: update this test when stream and unary implementation is separated
63
71
describe ( 'generateContent' , ( ) => {
64
- it ( 'returns a GenerateContentResponse when stream=false ' , async ( ) => {
65
- const req : GenerateContentParams = {
72
+ it ( 'returns a GenerateContentResponse' , async ( ) => {
73
+ const req : GenerateContentRequest = {
66
74
contents : TEST_USER_CHAT_MESSAGE ,
67
- stream : false ,
68
75
} ;
69
76
const expectedResult : GenerateContentResult = {
70
- responses : TEST_MODEL_RESPONSE ,
77
+ response : TEST_MODEL_RESPONSE ,
71
78
} ;
72
79
spyOn ( StreamFunctions , 'processNonStream' ) . and . returnValue ( expectedResult ) ;
73
80
const resp = await model . generateContent ( req ) ;
74
81
expect ( resp ) . toEqual ( expectedResult ) ;
75
82
} ) ;
76
- // TODO: add test from stream=true here
77
83
} ) ;
78
84
79
85
describe ( 'generateContent' , ( ) => {
@@ -85,12 +91,11 @@ describe('VertexAI', () => {
85
91
model : 'gemini-pro'
86
92
} ) ;
87
93
88
- const req : GenerateContentParams = {
94
+ const req : GenerateContentRequest = {
89
95
contents : TEST_USER_CHAT_MESSAGE ,
90
- stream : false ,
91
96
} ;
92
97
const expectedResult : GenerateContentResult = {
93
- responses : TEST_MODEL_RESPONSE ,
98
+ response : TEST_MODEL_RESPONSE ,
94
99
} ;
95
100
const requestSpy = spyOn ( global , 'fetch' ) ;
96
101
spyOn ( StreamFunctions ,
@@ -110,12 +115,11 @@ describe('VertexAI', () => {
110
115
model : 'gemini-pro'
111
116
} ) ;
112
117
113
- const req : GenerateContentParams = {
118
+ const req : GenerateContentRequest = {
114
119
contents : TEST_USER_CHAT_MESSAGE ,
115
- stream : false ,
116
120
} ;
117
121
const expectedResult : GenerateContentResult = {
118
- responses : TEST_MODEL_RESPONSE ,
122
+ response : TEST_MODEL_RESPONSE ,
119
123
} ;
120
124
const requestSpy = spyOn ( global , 'fetch' ) ;
121
125
spyOn ( StreamFunctions , 'processNonStream' ) . and . returnValue ( expectedResult ) ; await
@@ -125,6 +129,21 @@ describe('VertexAI', () => {
125
129
} ) ;
126
130
} ) ;
127
131
132
+ describe ( 'streamGenerateContent' , ( ) => {
133
+ it ( 'returns a GenerateContentResponse' , async ( ) => {
134
+ const req : GenerateContentRequest = {
135
+ contents : TEST_USER_CHAT_MESSAGE ,
136
+ } ;
137
+ const expectedResult : StreamGenerateContentResult = {
138
+ response : Promise . resolve ( TEST_MODEL_RESPONSE ) ,
139
+ stream : testGenerator ( ) ,
140
+ } ;
141
+ spyOn ( StreamFunctions , 'processStream' ) . and . returnValue ( expectedResult ) ;
142
+ const resp = await model . streamGenerateContent ( req ) ;
143
+ expect ( resp ) . toEqual ( expectedResult ) ;
144
+ } ) ;
145
+ } ) ;
146
+
128
147
describe ( 'startChat' , ( ) => {
129
148
it ( 'returns a ChatSession' , ( ) => {
130
149
const req : StartChatParams = {
@@ -174,19 +193,24 @@ describe('ChatSession', () => {
174
193
expect ( chatSession . history . length ) . toEqual ( 1 ) ;
175
194
} ) ;
176
195
177
- describe ( 'sendMessage' , ( ) => {
178
- it ( 'returns a GenerateContentResponse' , async ( ) => {
179
- const req = 'How are you doing today?' ;
180
- const expectedResult : GenerateContentResult = {
181
- responses : TEST_MODEL_RESPONSE ,
182
- stream : StreamFunctions . emptyGenerator ( ) ,
183
- } ;
184
- spyOn ( StreamFunctions , 'processStream' ) . and . returnValue ( expectedResult ) ;
185
- const resp = await chatSession . sendMessage ( req ) ;
186
- expect ( resp ) . toEqual ( expectedResult ) ;
187
- expect ( chatSession . history . length ) . toEqual ( 3 ) ;
188
- } ) ;
189
-
190
- // TODO: add test cases for different content types passed to sendMessage
191
- } ) ;
196
+ // TODO: update sendMessage after generateContent and streamGenerateContent
197
+ // are working
198
+ describe (
199
+ 'sendMessage' ,
200
+ ( ) => {
201
+ // it('returns a GenerateContentResponse', async () => {
202
+ // const req = 'How are you doing today?';
203
+ // const expectedResult: GenerateContentResult = {
204
+ // responses: TEST_MODEL_RESPONSE,
205
+ // };
206
+ // spyOn(StreamFunctions,
207
+ // 'processStream').and.returnValue(expectedResult);
208
+ // const resp = await chatSession.sendMessage(req);
209
+ // expect(resp).toEqual(expectedResult);
210
+ // expect(chatSession.history.length).toEqual(3);
211
+ // });
212
+
213
+ // TODO: add test cases for different content types passed to
214
+ // sendMessage
215
+ } ) ;
192
216
} ) ;
0 commit comments