@@ -12,7 +12,9 @@ import {
12
12
} from '../bson' ;
13
13
import { type ProxyOptions } from '../cmap/connection' ;
14
14
import { getSocks , type SocksLib } from '../deps' ;
15
+ import { MongoOperationTimeoutError } from '../error' ;
15
16
import { type MongoClient , type MongoClientOptions } from '../mongo_client' ;
17
+ import { Timeout , type TimeoutContext , TimeoutError } from '../timeout' ;
16
18
import { BufferPool , MongoDBCollectionNamespace , promiseWithResolvers } from '../utils' ;
17
19
import { autoSelectSocketOptions , type DataKey } from './client_encryption' ;
18
20
import { MongoCryptError } from './errors' ;
@@ -173,6 +175,7 @@ export type StateMachineOptions = {
173
175
* An internal class that executes across a MongoCryptContext until either
174
176
* a finishing state or an error is reached. Do not instantiate directly.
175
177
*/
178
+ // TODO(DRIVERS-2671): clarify CSOT behavior for FLE APIs
176
179
export class StateMachine {
177
180
constructor (
178
181
private options : StateMachineOptions ,
@@ -182,7 +185,11 @@ export class StateMachine {
182
185
/**
183
186
* Executes the state machine according to the specification
184
187
*/
185
- async execute ( executor : StateMachineExecutable , context : MongoCryptContext ) : Promise < Uint8Array > {
188
+ async execute (
189
+ executor : StateMachineExecutable ,
190
+ context : MongoCryptContext ,
191
+ timeoutContext ?: TimeoutContext
192
+ ) : Promise < Uint8Array > {
186
193
const keyVaultNamespace = executor . _keyVaultNamespace ;
187
194
const keyVaultClient = executor . _keyVaultClient ;
188
195
const metaDataClient = executor . _metaDataClient ;
@@ -201,8 +208,13 @@ export class StateMachine {
201
208
'unreachable state machine state: entered MONGOCRYPT_CTX_NEED_MONGO_COLLINFO but metadata client is undefined'
202
209
) ;
203
210
}
204
- const collInfo = await this . fetchCollectionInfo ( metaDataClient , context . ns , filter ) ;
205
211
212
+ const collInfo = await this . fetchCollectionInfo (
213
+ metaDataClient ,
214
+ context . ns ,
215
+ filter ,
216
+ timeoutContext
217
+ ) ;
206
218
if ( collInfo ) {
207
219
context . addMongoOperationResponse ( collInfo ) ;
208
220
}
@@ -222,9 +234,9 @@ export class StateMachine {
222
234
// When we are using the shared library, we don't have a mongocryptd manager.
223
235
const markedCommand : Uint8Array = mongocryptdManager
224
236
? await mongocryptdManager . withRespawn (
225
- this . markCommand . bind ( this , mongocryptdClient , context . ns , command )
237
+ this . markCommand . bind ( this , mongocryptdClient , context . ns , command , timeoutContext )
226
238
)
227
- : await this . markCommand ( mongocryptdClient , context . ns , command ) ;
239
+ : await this . markCommand ( mongocryptdClient , context . ns , command , timeoutContext ) ;
228
240
229
241
context . addMongoOperationResponse ( markedCommand ) ;
230
242
context . finishMongoOperation ( ) ;
@@ -233,7 +245,12 @@ export class StateMachine {
233
245
234
246
case MONGOCRYPT_CTX_NEED_MONGO_KEYS : {
235
247
const filter = context . nextMongoOperation ( ) ;
236
- const keys = await this . fetchKeys ( keyVaultClient , keyVaultNamespace , filter ) ;
248
+ const keys = await this . fetchKeys (
249
+ keyVaultClient ,
250
+ keyVaultNamespace ,
251
+ filter ,
252
+ timeoutContext
253
+ ) ;
237
254
238
255
if ( keys . length === 0 ) {
239
256
// See docs on EMPTY_V
@@ -255,9 +272,7 @@ export class StateMachine {
255
272
}
256
273
257
274
case MONGOCRYPT_CTX_NEED_KMS : {
258
- const requests = Array . from ( this . requests ( context ) ) ;
259
- await Promise . all ( requests ) ;
260
-
275
+ await Promise . all ( this . requests ( context , timeoutContext ) ) ;
261
276
context . finishKMSRequests ( ) ;
262
277
break ;
263
278
}
@@ -299,7 +314,7 @@ export class StateMachine {
299
314
* @param kmsContext - A C++ KMS context returned from the bindings
300
315
* @returns A promise that resolves when the KMS reply has be fully parsed
301
316
*/
302
- async kmsRequest ( request : MongoCryptKMSRequest ) : Promise < void > {
317
+ async kmsRequest ( request : MongoCryptKMSRequest , timeoutContext ?: TimeoutContext ) : Promise < void > {
303
318
const parsedUrl = request . endpoint . split ( ':' ) ;
304
319
const port = parsedUrl [ 1 ] != null ? Number . parseInt ( parsedUrl [ 1 ] , 10 ) : HTTPS_PORT ;
305
320
const socketOptions = autoSelectSocketOptions ( this . options . socketOptions || { } ) ;
@@ -329,10 +344,6 @@ export class StateMachine {
329
344
}
330
345
}
331
346
332
- function ontimeout ( ) {
333
- return new MongoCryptError ( 'KMS request timed out' ) ;
334
- }
335
-
336
347
function onerror ( cause : Error ) {
337
348
return new MongoCryptError ( 'KMS request failed' , { cause } ) ;
338
349
}
@@ -364,7 +375,6 @@ export class StateMachine {
364
375
resolve : resolveOnNetSocketConnect
365
376
} = promiseWithResolvers < void > ( ) ;
366
377
netSocket
367
- . once ( 'timeout' , ( ) => rejectOnNetSocketError ( ontimeout ( ) ) )
368
378
. once ( 'error' , err => rejectOnNetSocketError ( onerror ( err ) ) )
369
379
. once ( 'close' , ( ) => rejectOnNetSocketError ( onclose ( ) ) )
370
380
. once ( 'connect' , ( ) => resolveOnNetSocketConnect ( ) ) ;
@@ -410,8 +420,8 @@ export class StateMachine {
410
420
reject : rejectOnTlsSocketError ,
411
421
resolve
412
422
} = promiseWithResolvers < void > ( ) ;
423
+
413
424
socket
414
- . once ( 'timeout' , ( ) => rejectOnTlsSocketError ( ontimeout ( ) ) )
415
425
. once ( 'error' , err => rejectOnTlsSocketError ( onerror ( err ) ) )
416
426
. once ( 'close' , ( ) => rejectOnTlsSocketError ( onclose ( ) ) )
417
427
. on ( 'data' , data => {
@@ -425,20 +435,26 @@ export class StateMachine {
425
435
resolve ( ) ;
426
436
}
427
437
} ) ;
428
- await willResolveKmsRequest ;
438
+ await ( timeoutContext ?. csotEnabled ( )
439
+ ? Promise . all ( [ willResolveKmsRequest , Timeout . expires ( timeoutContext ?. remainingTimeMS ) ] )
440
+ : willResolveKmsRequest ) ;
441
+ } catch ( error ) {
442
+ if ( error instanceof TimeoutError )
443
+ throw new MongoOperationTimeoutError ( 'KMS request timed out' ) ;
444
+ throw error ;
429
445
} finally {
430
446
// There's no need for any more activity on this socket at this point.
431
447
destroySockets ( ) ;
432
448
}
433
449
}
434
450
435
- * requests ( context : MongoCryptContext ) {
451
+ * requests ( context : MongoCryptContext , timeoutContext ?: TimeoutContext ) {
436
452
for (
437
453
let request = context . nextKMSRequest ( ) ;
438
454
request != null ;
439
455
request = context . nextKMSRequest ( )
440
456
) {
441
- yield this . kmsRequest ( request ) ;
457
+ yield this . kmsRequest ( request , timeoutContext ) ;
442
458
}
443
459
}
444
460
@@ -498,15 +514,19 @@ export class StateMachine {
498
514
async fetchCollectionInfo (
499
515
client : MongoClient ,
500
516
ns : string ,
501
- filter : Document
517
+ filter : Document ,
518
+ timeoutContext ?: TimeoutContext
502
519
) : Promise < Uint8Array | null > {
503
520
const { db } = MongoDBCollectionNamespace . fromString ( ns ) ;
504
521
505
522
const collections = await client
506
523
. db ( db )
507
524
. listCollections ( filter , {
508
525
promoteLongs : false ,
509
- promoteValues : false
526
+ promoteValues : false ,
527
+ ...( timeoutContext ?. csotEnabled ( )
528
+ ? { timeoutMS : timeoutContext ?. remainingTimeMS , timeoutMode : 'cursorLifetime' }
529
+ : { } )
510
530
} )
511
531
. toArray ( ) ;
512
532
@@ -522,12 +542,22 @@ export class StateMachine {
522
542
* @param command - The command to execute.
523
543
* @param callback - Invoked with the serialized and marked bson command, or with an error
524
544
*/
525
- async markCommand ( client : MongoClient , ns : string , command : Uint8Array ) : Promise < Uint8Array > {
526
- const options = { promoteLongs : false , promoteValues : false } ;
545
+ async markCommand (
546
+ client : MongoClient ,
547
+ ns : string ,
548
+ command : Uint8Array ,
549
+ timeoutContext ?: TimeoutContext
550
+ ) : Promise < Uint8Array > {
527
551
const { db } = MongoDBCollectionNamespace . fromString ( ns ) ;
528
- const rawCommand = deserialize ( command , options ) ;
552
+ const bsonOptions = { promoteLongs : false , promoteValues : false } ;
553
+ const rawCommand = deserialize ( command , bsonOptions ) ;
529
554
530
- const response = await client . db ( db ) . command ( rawCommand , options ) ;
555
+ const response = await client . db ( db ) . command ( rawCommand , {
556
+ ...bsonOptions ,
557
+ ...( timeoutContext ?. csotEnabled ( )
558
+ ? { timeoutMS : timeoutContext ?. remainingTimeMS }
559
+ : undefined )
560
+ } ) ;
531
561
532
562
return serialize ( response , this . bsonOptions ) ;
533
563
}
@@ -543,15 +573,21 @@ export class StateMachine {
543
573
fetchKeys (
544
574
client : MongoClient ,
545
575
keyVaultNamespace : string ,
546
- filter : Uint8Array
576
+ filter : Uint8Array ,
577
+ timeoutContext ?: TimeoutContext
547
578
) : Promise < Array < DataKey > > {
548
579
const { db : dbName , collection : collectionName } =
549
580
MongoDBCollectionNamespace . fromString ( keyVaultNamespace ) ;
550
581
551
582
return client
552
583
. db ( dbName )
553
584
. collection < DataKey > ( collectionName , { readConcern : { level : 'majority' } } )
554
- . find ( deserialize ( filter ) )
585
+ . find (
586
+ deserialize ( filter ) ,
587
+ timeoutContext ?. csotEnabled ( )
588
+ ? { timeoutMS : timeoutContext ?. remainingTimeMS , timeoutMode : 'cursorLifetime' }
589
+ : { }
590
+ )
555
591
. toArray ( ) ;
556
592
}
557
593
}
0 commit comments