1
1
import { LLM , type BaseLLMParams } from "@langchain/core/language_models/llms" ;
2
2
import { getEnvironmentVariable } from "@langchain/core/utils/env" ;
3
+ import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager" ;
4
+ import { GenerationChunk } from "@langchain/core/outputs" ;
5
+
6
+ import type ReplicateInstance from "replicate" ;
3
7
4
8
/**
5
9
* Interface defining the structure of the input data for the Replicate
@@ -88,13 +92,85 @@ export class Replicate extends LLM implements ReplicateInput {
88
92
prompt : string ,
89
93
options : this[ "ParsedCallOptions" ]
90
94
) : Promise < string > {
95
+ const replicate = await this . _prepareReplicate ( ) ;
96
+ const input = await this . _getReplicateInput ( replicate , prompt ) ;
97
+
98
+ const output = await this . caller . callWithOptions (
99
+ { signal : options . signal } ,
100
+ ( ) =>
101
+ replicate . run ( this . model , {
102
+ input,
103
+ } )
104
+ ) ;
105
+
106
+ if ( typeof output === "string" ) {
107
+ return output ;
108
+ } else if ( Array . isArray ( output ) ) {
109
+ return output . join ( "" ) ;
110
+ } else {
111
+ // Note this is a little odd, but the output format is not consistent
112
+ // across models, so it makes some amount of sense.
113
+ return String ( output ) ;
114
+ }
115
+ }
116
+
117
+ async * _streamResponseChunks (
118
+ prompt : string ,
119
+ options : this[ "ParsedCallOptions" ] ,
120
+ runManager ?: CallbackManagerForLLMRun
121
+ ) : AsyncGenerator < GenerationChunk > {
122
+ const replicate = await this . _prepareReplicate ( ) ;
123
+ const input = await this . _getReplicateInput ( replicate , prompt ) ;
124
+
125
+ const stream = await this . caller . callWithOptions (
126
+ { signal : options ?. signal } ,
127
+ async ( ) =>
128
+ replicate . stream ( this . model , {
129
+ input,
130
+ } )
131
+ ) ;
132
+ for await ( const chunk of stream ) {
133
+ if ( chunk . event === "output" ) {
134
+ yield new GenerationChunk ( { text : chunk . data , generationInfo : chunk } ) ;
135
+ await runManager ?. handleLLMNewToken ( chunk . data ?? "" ) ;
136
+ }
137
+
138
+ // stream is done
139
+ if ( chunk . event === "done" )
140
+ yield new GenerationChunk ( {
141
+ text : "" ,
142
+ generationInfo : { finished : true } ,
143
+ } ) ;
144
+ }
145
+ }
146
+
147
+ /** @ignore */
148
+ static async imports ( ) : Promise < {
149
+ Replicate : typeof ReplicateInstance ;
150
+ } > {
151
+ try {
152
+ const { default : Replicate } = await import ( "replicate" ) ;
153
+ return { Replicate } ;
154
+ } catch ( e ) {
155
+ throw new Error (
156
+ "Please install replicate as a dependency with, e.g. `yarn add replicate`"
157
+ ) ;
158
+ }
159
+ }
160
+
161
+ private async _prepareReplicate ( ) : Promise < ReplicateInstance > {
91
162
const imports = await Replicate . imports ( ) ;
92
163
93
- const replicate = new imports . Replicate ( {
164
+ return new imports . Replicate ( {
94
165
userAgent : "langchain" ,
95
166
auth : this . apiKey ,
96
167
} ) ;
168
+ }
97
169
170
+ private async _getReplicateInput (
171
+ replicate : ReplicateInstance ,
172
+ prompt : string
173
+ ) {
98
174
if ( this . promptKey === undefined ) {
99
175
const [ modelString , versionString ] = this . model . split ( ":" ) ;
100
176
const version = await replicate . models . versions . get (
@@ -119,40 +195,11 @@ export class Replicate extends LLM implements ReplicateInput {
119
195
this . promptKey = sortedInputProperties [ 0 ] [ 0 ] ?? "prompt" ;
120
196
}
121
197
}
122
- const output = await this . caller . callWithOptions (
123
- { signal : options . signal } ,
124
- ( ) =>
125
- replicate . run ( this . model , {
126
- input : {
127
- // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
128
- [ this . promptKey ! ] : prompt ,
129
- ...this . input ,
130
- } ,
131
- } )
132
- ) ;
133
-
134
- if ( typeof output === "string" ) {
135
- return output ;
136
- } else if ( Array . isArray ( output ) ) {
137
- return output . join ( "" ) ;
138
- } else {
139
- // Note this is a little odd, but the output format is not consistent
140
- // across models, so it makes some amount of sense.
141
- return String ( output ) ;
142
- }
143
- }
144
198
145
- /** @ignore */
146
- static async imports ( ) : Promise < {
147
- Replicate : typeof import ( "replicate" ) . default ;
148
- } > {
149
- try {
150
- const { default : Replicate } = await import ( "replicate" ) ;
151
- return { Replicate } ;
152
- } catch ( e ) {
153
- throw new Error (
154
- "Please install replicate as a dependency with, e.g. `yarn add replicate`"
155
- ) ;
156
- }
199
+ return {
200
+ // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
201
+ [ this . promptKey ! ] : prompt ,
202
+ ...this . input ,
203
+ } ;
157
204
}
158
205
}
0 commit comments