Skip to content

Commit fe1013c

Browse files
jeasonnowjeasonnow
jeasonnow
authored andcommitted
langchain-community[patch]: #3369 Streaming support for Replicate models
1 parent 4839804 commit fe1013c

File tree

3 files changed

+108
-43
lines changed

3 files changed

+108
-43
lines changed

libs/langchain-community/package.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@
195195
"puppeteer": "^19.7.2",
196196
"redis": "^4.6.6",
197197
"release-it": "^15.10.1",
198-
"replicate": "^0.18.0",
198+
"replicate": "^0.29.4",
199199
"rollup": "^3.19.1",
200200
"sonix-speech-recognition": "^2.1.1",
201201
"srt-parser-2": "^1.2.3",
@@ -316,7 +316,7 @@
316316
"portkey-ai": "^0.1.11",
317317
"puppeteer": "^19.7.2",
318318
"redis": "*",
319-
"replicate": "^0.18.0",
319+
"replicate": "^0.29.4",
320320
"sonix-speech-recognition": "^2.1.1",
321321
"srt-parser-2": "^1.2.3",
322322
"typeorm": "^0.3.12",

libs/langchain-community/src/llms/replicate.ts

+82-35
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import { LLM, type BaseLLMParams } from "@langchain/core/language_models/llms";
22
import { getEnvironmentVariable } from "@langchain/core/utils/env";
3+
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
4+
import { GenerationChunk } from "@langchain/core/outputs";
5+
6+
import type ReplicateInstance from "replicate";
37

48
/**
59
* Interface defining the structure of the input data for the Replicate
@@ -88,13 +92,85 @@ export class Replicate extends LLM implements ReplicateInput {
8892
prompt: string,
8993
options: this["ParsedCallOptions"]
9094
): Promise<string> {
95+
const replicate = await this._prepareReplicate();
96+
const input = await this._getReplicateInput(replicate, prompt);
97+
98+
const output = await this.caller.callWithOptions(
99+
{ signal: options.signal },
100+
() =>
101+
replicate.run(this.model, {
102+
input,
103+
})
104+
);
105+
106+
if (typeof output === "string") {
107+
return output;
108+
} else if (Array.isArray(output)) {
109+
return output.join("");
110+
} else {
111+
// Note this is a little odd, but the output format is not consistent
112+
// across models, so it makes some amount of sense.
113+
return String(output);
114+
}
115+
}
116+
117+
async *_streamResponseChunks(
118+
prompt: string,
119+
options: this["ParsedCallOptions"],
120+
runManager?: CallbackManagerForLLMRun
121+
): AsyncGenerator<GenerationChunk> {
122+
const replicate = await this._prepareReplicate();
123+
const input = await this._getReplicateInput(replicate, prompt);
124+
125+
const stream = await this.caller.callWithOptions(
126+
{ signal: options?.signal },
127+
async () =>
128+
replicate.stream(this.model, {
129+
input,
130+
})
131+
);
132+
for await (const chunk of stream) {
133+
if (chunk.event === "output") {
134+
yield new GenerationChunk({ text: chunk.data, generationInfo: chunk });
135+
await runManager?.handleLLMNewToken(chunk.data ?? "");
136+
}
137+
138+
// stream is done
139+
if (chunk.event === "done")
140+
yield new GenerationChunk({
141+
text: "",
142+
generationInfo: { finished: true },
143+
});
144+
}
145+
}
146+
147+
/** @ignore */
148+
static async imports(): Promise<{
149+
Replicate: typeof ReplicateInstance;
150+
}> {
151+
try {
152+
const { default: Replicate } = await import("replicate");
153+
return { Replicate };
154+
} catch (e) {
155+
throw new Error(
156+
"Please install replicate as a dependency with, e.g. `yarn add replicate`"
157+
);
158+
}
159+
}
160+
161+
private async _prepareReplicate(): Promise<ReplicateInstance> {
91162
const imports = await Replicate.imports();
92163

93-
const replicate = new imports.Replicate({
164+
return new imports.Replicate({
94165
userAgent: "langchain",
95166
auth: this.apiKey,
96167
});
168+
}
97169

170+
private async _getReplicateInput(
171+
replicate: ReplicateInstance,
172+
prompt: string
173+
) {
98174
if (this.promptKey === undefined) {
99175
const [modelString, versionString] = this.model.split(":");
100176
const version = await replicate.models.versions.get(
@@ -119,40 +195,11 @@ export class Replicate extends LLM implements ReplicateInput {
119195
this.promptKey = sortedInputProperties[0][0] ?? "prompt";
120196
}
121197
}
122-
const output = await this.caller.callWithOptions(
123-
{ signal: options.signal },
124-
() =>
125-
replicate.run(this.model, {
126-
input: {
127-
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
128-
[this.promptKey!]: prompt,
129-
...this.input,
130-
},
131-
})
132-
);
133-
134-
if (typeof output === "string") {
135-
return output;
136-
} else if (Array.isArray(output)) {
137-
return output.join("");
138-
} else {
139-
// Note this is a little odd, but the output format is not consistent
140-
// across models, so it makes some amount of sense.
141-
return String(output);
142-
}
143-
}
144198

145-
/** @ignore */
146-
static async imports(): Promise<{
147-
Replicate: typeof import("replicate").default;
148-
}> {
149-
try {
150-
const { default: Replicate } = await import("replicate");
151-
return { Replicate };
152-
} catch (e) {
153-
throw new Error(
154-
"Please install replicate as a dependency with, e.g. `yarn add replicate`"
155-
);
156-
}
199+
return {
200+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
201+
[this.promptKey!]: prompt,
202+
...this.input,
203+
};
157204
}
158205
}

yarn.lock

+24-6
Original file line numberDiff line numberDiff line change
@@ -9154,7 +9154,7 @@ __metadata:
91549154
puppeteer: ^19.7.2
91559155
redis: ^4.6.6
91569156
release-it: ^15.10.1
9157-
replicate: ^0.18.0
9157+
replicate: ^0.29.4
91589158
rollup: ^3.19.1
91599159
sonix-speech-recognition: ^2.1.1
91609160
srt-parser-2: ^1.2.3
@@ -9277,7 +9277,7 @@ __metadata:
92779277
portkey-ai: ^0.1.11
92789278
puppeteer: ^19.7.2
92799279
redis: "*"
9280-
replicate: ^0.18.0
9280+
replicate: ^0.29.4
92819281
sonix-speech-recognition: ^2.1.1
92829282
srt-parser-2: ^1.2.3
92839283
typeorm: ^0.3.12
@@ -32771,6 +32771,19 @@ __metadata:
3277132771
languageName: node
3277232772
linkType: hard
3277332773

32774+
"readable-stream@npm:>=4.0.0":
32775+
version: 4.5.2
32776+
resolution: "readable-stream@npm:4.5.2"
32777+
dependencies:
32778+
abort-controller: ^3.0.0
32779+
buffer: ^6.0.3
32780+
events: ^3.3.0
32781+
process: ^0.11.10
32782+
string_decoder: ^1.3.0
32783+
checksum: c4030ccff010b83e4f33289c535f7830190773e274b3fcb6e2541475070bdfd69c98001c3b0cb78763fc00c8b62f514d96c2b10a8bd35d5ce45203a25fa1d33a
32784+
languageName: node
32785+
linkType: hard
32786+
3277432787
"readable-stream@npm:^2.0.0, readable-stream@npm:^2.0.1, readable-stream@npm:^2.3.0, readable-stream@npm:^2.3.5, readable-stream@npm:~2.3.6":
3277532788
version: 2.3.8
3277632789
resolution: "readable-stream@npm:2.3.8"
@@ -33191,10 +33204,15 @@ __metadata:
3319133204
languageName: node
3319233205
linkType: hard
3319333206

33194-
"replicate@npm:^0.18.0":
33195-
version: 0.18.0
33196-
resolution: "replicate@npm:0.18.0"
33197-
checksum: 547a8b386418aedf6e5be2086a63090e5a5f6cda36202a0122c4036a2af8a80efea420393e5efa4810c9cff0616a7df5adbd40fd4a0560f4aa1b4eda60a34794
33207+
"replicate@npm:^0.29.4":
33208+
version: 0.29.4
33209+
resolution: "replicate@npm:0.29.4"
33210+
dependencies:
33211+
readable-stream: ">=4.0.0"
33212+
dependenciesMeta:
33213+
readable-stream:
33214+
optional: true
33215+
checksum: 9405e19f619134a312aa77b3c04156549e4c8ba5e0711a494b99358abd0378646c22cd9bf07e6f9c8ab4a2f80b69ba22ed0a5b8ec0610684e9fa5d413e3b5729
3319833216
languageName: node
3319933217
linkType: hard
3320033218

0 commit comments

Comments
 (0)