Skip to content

Commit d3510d1

Browse files
hans00HenryHengZJ
andauthored
Support cache system instructs for Google GenAI (#4148)
* Support cache system instructs for Google GenAI * format code * Update FlowiseGoogleAICacheManager.ts --------- Co-authored-by: Henry Heng <[email protected]>
1 parent 654bd48 commit d3510d1

File tree

5 files changed

+160
-5
lines changed

5 files changed

+160
-5
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import type { CachedContentBase, CachedContent, Content } from '@google/generative-ai'
2+
import { GoogleAICacheManager as GoogleAICacheManagerBase } from '@google/generative-ai/server'
3+
import hash from 'object-hash'
4+
5+
type CacheContentOptions = Omit<CachedContentBase, 'contents'> & { contents?: Content[] }
6+
7+
export class GoogleAICacheManager extends GoogleAICacheManagerBase {
8+
private ttlSeconds: number
9+
private cachedContents: Map<string, CachedContent> = new Map()
10+
11+
setTtlSeconds(ttlSeconds: number) {
12+
this.ttlSeconds = ttlSeconds
13+
}
14+
15+
async lookup(options: CacheContentOptions): Promise<CachedContent | undefined> {
16+
const { model, tools, contents } = options
17+
if (!contents?.length) {
18+
return undefined
19+
}
20+
const hashKey = hash({
21+
model,
22+
tools,
23+
contents
24+
})
25+
if (this.cachedContents.has(hashKey)) {
26+
return this.cachedContents.get(hashKey)
27+
}
28+
const { cachedContents } = await this.list()
29+
const cachedContent = (cachedContents ?? []).find((cache) => cache.displayName === hashKey)
30+
if (cachedContent) {
31+
this.cachedContents.set(hashKey, cachedContent)
32+
return cachedContent
33+
}
34+
const res = await this.create({
35+
...(options as CachedContentBase),
36+
displayName: hashKey,
37+
ttlSeconds: this.ttlSeconds
38+
})
39+
this.cachedContents.set(hashKey, res)
40+
return res
41+
}
42+
}
43+
44+
export default GoogleAICacheManager
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import { getBaseClasses, getCredentialData, getCredentialParam, ICommonObject, INode, INodeData, INodeParams } from '../../../src'
2+
import FlowiseGoogleAICacheManager from './FlowiseGoogleAICacheManager'
3+
4+
class GoogleGenerativeAIContextCache implements INode {
5+
label: string
6+
name: string
7+
version: number
8+
description: string
9+
type: string
10+
icon: string
11+
category: string
12+
baseClasses: string[]
13+
inputs: INodeParams[]
14+
credential: INodeParams
15+
16+
constructor() {
17+
this.label = 'Google GenAI Context Cache'
18+
this.name = 'googleGenerativeAIContextCache'
19+
this.version = 1.0
20+
this.type = 'GoogleAICacheManager'
21+
this.description = 'Large context cache for Google Gemini large language models'
22+
this.icon = 'GoogleGemini.svg'
23+
this.category = 'Cache'
24+
this.baseClasses = [this.type, ...getBaseClasses(FlowiseGoogleAICacheManager)]
25+
this.inputs = [
26+
{
27+
label: 'TTL',
28+
name: 'ttl',
29+
type: 'number',
30+
default: 60 * 60 * 24 * 30
31+
}
32+
]
33+
this.credential = {
34+
label: 'Connect Credential',
35+
name: 'credential',
36+
type: 'credential',
37+
credentialNames: ['googleGenerativeAI'],
38+
optional: false,
39+
description: 'Google Generative AI credential.'
40+
}
41+
}
42+
43+
async init(nodeData: INodeData, _: string, options: ICommonObject): Promise<any> {
44+
const ttl = nodeData.inputs?.ttl as number
45+
const credentialData = await getCredentialData(nodeData.credential ?? '', options)
46+
const apiKey = getCredentialParam('googleGenerativeAPIKey', credentialData, nodeData)
47+
const manager = new FlowiseGoogleAICacheManager(apiKey)
48+
manager.setTtlSeconds(ttl)
49+
return manager
50+
}
51+
}
52+
53+
module.exports = { nodeClass: GoogleGenerativeAIContextCache }

packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/ChatGoogleGenerativeAI.ts

+9
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { ICommonObject, IMultiModalOption, INode, INodeData, INodeOptionsValue,
55
import { convertMultiOptionsToStringArray, getBaseClasses, getCredentialData, getCredentialParam } from '../../../src/utils'
66
import { getModels, MODEL_TYPE } from '../../../src/modelLoader'
77
import { ChatGoogleGenerativeAI, GoogleGenerativeAIChatInput } from './FlowiseChatGoogleGenerativeAI'
8+
import type FlowiseGoogleAICacheManager from '../../cache/GoogleGenerativeAIContextCache/FlowiseGoogleAICacheManager'
89

910
class GoogleGenerativeAI_ChatModels implements INode {
1011
label: string
@@ -42,6 +43,12 @@ class GoogleGenerativeAI_ChatModels implements INode {
4243
type: 'BaseCache',
4344
optional: true
4445
},
46+
{
47+
label: 'Context Cache',
48+
name: 'contextCache',
49+
type: 'GoogleAICacheManager',
50+
optional: true
51+
},
4552
{
4653
label: 'Model Name',
4754
name: 'modelName',
@@ -188,6 +195,7 @@ class GoogleGenerativeAI_ChatModels implements INode {
188195
const harmCategory = nodeData.inputs?.harmCategory as string
189196
const harmBlockThreshold = nodeData.inputs?.harmBlockThreshold as string
190197
const cache = nodeData.inputs?.cache as BaseCache
198+
const contextCache = nodeData.inputs?.contextCache as FlowiseGoogleAICacheManager
191199
const streaming = nodeData.inputs?.streaming as boolean
192200

193201
const allowImageUploads = nodeData.inputs?.allowImageUploads as boolean
@@ -225,6 +233,7 @@ class GoogleGenerativeAI_ChatModels implements INode {
225233

226234
const model = new ChatGoogleGenerativeAI(nodeData.id, obj)
227235
model.setMultiModalOption(multiModalOption)
236+
if (contextCache) model.setContextCache(contextCache)
228237

229238
return model
230239
}

packages/components/nodes/chatmodels/ChatGoogleGenerativeAI/FlowiseChatGoogleGenerativeAI.ts

+20-5
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import { StructuredToolInterface } from '@langchain/core/tools'
2525
import { isStructuredTool } from '@langchain/core/utils/function_calling'
2626
import { zodToJsonSchema } from 'zod-to-json-schema'
2727
import { BaseLanguageModelCallOptions } from '@langchain/core/language_models/base'
28+
import type FlowiseGoogleAICacheManager from '../../cache/GoogleGenerativeAIContextCache/FlowiseGoogleAICacheManager'
2829

2930
const DEFAULT_IMAGE_MAX_TOKEN = 8192
3031
const DEFAULT_IMAGE_MODEL = 'gemini-1.5-flash-latest'
@@ -86,6 +87,8 @@ class LangchainChatGoogleGenerativeAI
8687

8788
private client: GenerativeModel
8889

90+
private contextCache?: FlowiseGoogleAICacheManager
91+
8992
get _isMultimodalModel() {
9093
return this.modelName.includes('vision') || this.modelName.startsWith('gemini-1.5')
9194
}
@@ -147,7 +150,7 @@ class LangchainChatGoogleGenerativeAI
147150
this.getClient()
148151
}
149152

150-
getClient(tools?: Tool[]) {
153+
async getClient(prompt?: Content[], tools?: Tool[]) {
151154
this.client = new GenerativeAI(this.apiKey ?? '').getGenerativeModel({
152155
model: this.modelName,
153156
tools,
@@ -161,6 +164,14 @@ class LangchainChatGoogleGenerativeAI
161164
topK: this.topK
162165
}
163166
})
167+
if (this.contextCache) {
168+
const cachedContent = await this.contextCache.lookup({
169+
contents: prompt ? [{ ...prompt[0], parts: prompt[0].parts.slice(0, 1) }] : [],
170+
model: this.modelName,
171+
tools
172+
})
173+
this.client.cachedContent = cachedContent as any
174+
}
164175
}
165176

166177
_combineLLMOutput() {
@@ -209,6 +220,10 @@ class LangchainChatGoogleGenerativeAI
209220
}
210221
}
211222

223+
setContextCache(contextCache: FlowiseGoogleAICacheManager): void {
224+
this.contextCache = contextCache
225+
}
226+
212227
async getNumTokens(prompt: BaseMessage[]) {
213228
const contents = convertBaseMessagesToContent(prompt, this._isMultimodalModel)
214229
const { totalTokens } = await this.client.countTokens({ contents })
@@ -226,9 +241,9 @@ class LangchainChatGoogleGenerativeAI
226241
this.convertFunctionResponse(prompt)
227242

228243
if (tools.length > 0) {
229-
this.getClient(tools as Tool[])
244+
await this.getClient(prompt, tools as Tool[])
230245
} else {
231-
this.getClient()
246+
await this.getClient(prompt)
232247
}
233248
const res = await this.caller.callWithOptions({ signal: options?.signal }, async () => {
234249
let output
@@ -296,9 +311,9 @@ class LangchainChatGoogleGenerativeAI
296311

297312
const tools = options.tools ?? []
298313
if (tools.length > 0) {
299-
this.getClient(tools as Tool[])
314+
await this.getClient(prompt, tools as Tool[])
300315
} else {
301-
this.getClient()
316+
await this.getClient(prompt)
302317
}
303318

304319
const stream = await this.caller.callWithOptions({ signal: options?.signal }, async () => {

0 commit comments

Comments
 (0)