Skip to content

Commit fd1ca35

Browse files
authored
GODRIVER-2117 - Check clientSession is not nil inside executeTestRunnerOperation (#1457)
1 parent d52c9e1 commit fd1ca35

File tree

1 file changed

+31
-14
lines changed

1 file changed

+31
-14
lines changed

mongo/integration/unified_spec_test.go

+31-14
Original file line numberDiff line numberDiff line change
@@ -462,46 +462,64 @@ func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation,
462462

463463
var fp mtest.FailPoint
464464
if err := bson.Unmarshal(fpDoc.Document(), &fp); err != nil {
465-
return fmt.Errorf("Unmarshal error: %v", err)
465+
return fmt.Errorf("Unmarshal error: %w", err)
466466
}
467467

468+
if clientSession == nil {
469+
return errors.New("expected valid session, got nil")
470+
}
468471
targetHost := clientSession.PinnedServer.Addr.String()
469472
opts := options.Client().ApplyURI(mtest.ClusterURI()).SetHosts([]string{targetHost})
470473
integtest.AddTestServerAPIVersion(opts)
471474
client, err := mongo.Connect(context.Background(), opts)
472475
if err != nil {
473-
return fmt.Errorf("Connect error for targeted client: %v", err)
476+
return fmt.Errorf("Connect error for targeted client: %w", err)
474477
}
475478
defer func() { _ = client.Disconnect(context.Background()) }()
476479

477480
if err = client.Database("admin").RunCommand(context.Background(), fp).Err(); err != nil {
478-
return fmt.Errorf("error setting targeted fail point: %v", err)
481+
return fmt.Errorf("error setting targeted fail point: %w", err)
479482
}
480483
mt.TrackFailPoint(fp.ConfigureFailPoint)
481484
case "configureFailPoint":
482485
fp, err := op.Arguments.LookupErr("failPoint")
483-
assert.Nil(mt, err, "failPoint not found in arguments")
486+
if err != nil {
487+
return fmt.Errorf("unable to find 'failPoint' in arguments: %w", err)
488+
}
484489
mt.SetFailPointFromDocument(fp.Document())
485490
case "assertSessionTransactionState":
486491
stateVal, err := op.Arguments.LookupErr("state")
487-
assert.Nil(mt, err, "state not found in arguments")
492+
if err != nil {
493+
return fmt.Errorf("unable to find 'state' in arguments: %w", err)
494+
}
488495
expectedState, ok := stateVal.StringValueOK()
489-
assert.True(mt, ok, "state argument is not a string")
496+
if !ok {
497+
return errors.New("expected 'state' argument to be string")
498+
}
490499

491-
assert.NotNil(mt, clientSession, "expected valid session, got nil")
500+
if clientSession == nil {
501+
return errors.New("expected valid session, got nil")
502+
}
492503
actualState := clientSession.TransactionState.String()
493504

494505
// actualState should match expectedState, but "in progress" is the same as
495506
// "in_progress".
496507
stateMatch := actualState == expectedState ||
497508
actualState == "in progress" && expectedState == "in_progress"
498-
assert.True(mt, stateMatch, "expected transaction state %v, got %v",
499-
expectedState, actualState)
509+
if !stateMatch {
510+
return fmt.Errorf("expected transaction state %v, got %v", expectedState, actualState)
511+
}
500512
case "assertSessionPinned":
513+
if clientSession == nil {
514+
return errors.New("expected valid session, got nil")
515+
}
501516
if clientSession.PinnedServer == nil {
502517
return errors.New("expected pinned server, got nil")
503518
}
504519
case "assertSessionUnpinned":
520+
if clientSession == nil {
521+
return errors.New("expected valid session, got nil")
522+
}
505523
// We don't use a combined helper for assertSessionPinned and assertSessionUnpinned because the unpinned
506524
// case provides the pinned server address in the error msg for debugging.
507525
if clientSession.PinnedServer != nil {
@@ -544,7 +562,7 @@ func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation,
544562
case "waitForThread":
545563
waitForThread(mt, testCase, op)
546564
default:
547-
mt.Fatalf("unrecognized testRunner operation %v", op.Name)
565+
return fmt.Errorf("unrecognized testRunner operation %v", op.Name)
548566
}
549567

550568
return nil
@@ -571,7 +589,7 @@ func indexExists(dbName, collName, indexName string) (bool, error) {
571589
iv := mtest.GlobalClient().Database(dbName).Collection(collName).Indexes()
572590
cursor, err := iv.List(context.Background())
573591
if err != nil {
574-
return false, fmt.Errorf("IndexView.List error: %v", err)
592+
return false, fmt.Errorf("IndexView.List error: %w", err)
575593
}
576594
defer cursor.Close(context.Background())
577595

@@ -606,7 +624,7 @@ func collectionExists(dbName, collName string) (bool, error) {
606624
// Use global client because listCollections cannot be executed inside a transaction.
607625
collections, err := mtest.GlobalClient().Database(dbName).ListCollectionNames(context.Background(), filter)
608626
if err != nil {
609-
return false, fmt.Errorf("ListCollectionNames error: %v", err)
627+
return false, fmt.Errorf("ListCollectionNames error: %w", err)
610628
}
611629

612630
return len(collections) > 0, nil
@@ -636,9 +654,8 @@ func executeSessionOperation(mt *mtest.T, op *operation, sess mongo.Session) err
636654
case "withTransaction":
637655
return executeWithTransaction(mt, sess, op.Arguments)
638656
default:
639-
mt.Fatalf("unrecognized session operation: %v", op.Name)
657+
return fmt.Errorf("unrecognized session operation: %v", op.Name)
640658
}
641-
return nil
642659
}
643660

644661
func executeCollectionOperation(mt *mtest.T, op *operation, sess mongo.Session) error {

0 commit comments

Comments
 (0)