Skip to content

Commit 3adad2b

Browse files
GODRIVER-3009 Fix concurrent panic in struct codec. (#1477) (#1489)
Co-authored-by: Qingyang Hu <[email protected]>
1 parent a8fa12a commit 3adad2b

File tree

4 files changed

+69
-6
lines changed

4 files changed

+69
-6
lines changed

bson/bsoncodec/registry.go

+3-6
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,9 @@ func (r *Registry) RegisterTypeMapEntry(bt bsontype.Type, rt reflect.Type) {
388388
// If no encoder is found, an error of type ErrNoEncoder is returned. LookupEncoder is safe for
389389
// concurrent use by multiple goroutines after all codecs and encoders are registered.
390390
func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) {
391+
if valueType == nil {
392+
return nil, ErrNoEncoder{Type: valueType}
393+
}
391394
enc, found := r.lookupTypeEncoder(valueType)
392395
if found {
393396
if enc == nil {
@@ -400,15 +403,10 @@ func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) {
400403
if found {
401404
return r.typeEncoders.LoadOrStore(valueType, enc), nil
402405
}
403-
if valueType == nil {
404-
r.storeTypeEncoder(valueType, nil)
405-
return nil, ErrNoEncoder{Type: valueType}
406-
}
407406

408407
if v, ok := r.kindEncoders.Load(valueType.Kind()); ok {
409408
return r.storeTypeEncoder(valueType, v), nil
410409
}
411-
r.storeTypeEncoder(valueType, nil)
412410
return nil, ErrNoEncoder{Type: valueType}
413411
}
414412

@@ -474,7 +472,6 @@ func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) {
474472
if v, ok := r.kindDecoders.Load(valueType.Kind()); ok {
475473
return r.storeTypeDecoder(valueType, v), nil
476474
}
477-
r.storeTypeDecoder(valueType, nil)
478475
return nil, ErrNoDecoder{Type: valueType}
479476
}
480477

bson/bsoncodec/registry_test.go

+30
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,36 @@ func TestRegistry(t *testing.T) {
792792
})
793793
})
794794
}
795+
t.Run("nil type", func(t *testing.T) {
796+
t.Parallel()
797+
798+
t.Run("Encoder", func(t *testing.T) {
799+
t.Parallel()
800+
801+
wanterr := ErrNoEncoder{Type: reflect.TypeOf(nil)}
802+
803+
gotcodec, goterr := reg.LookupEncoder(nil)
804+
if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
805+
t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr)
806+
}
807+
if !cmp.Equal(gotcodec, nil, allowunexported, cmp.Comparer(comparepc)) {
808+
t.Errorf("codecs did not match: got %#v, want nil", gotcodec)
809+
}
810+
})
811+
t.Run("Decoder", func(t *testing.T) {
812+
t.Parallel()
813+
814+
wanterr := ErrNilType
815+
816+
gotcodec, goterr := reg.LookupDecoder(nil)
817+
if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
818+
t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr)
819+
}
820+
if !cmp.Equal(gotcodec, nil, allowunexported, cmp.Comparer(comparepc)) {
821+
t.Errorf("codecs did not match: got %v: want nil", gotcodec)
822+
}
823+
})
824+
})
795825
// lookup a type whose pointer implements an interface and expect that the registered hook is
796826
// returned
797827
t.Run("interface implementation with hook (pointer)", func(t *testing.T) {

bson/marshal_test.go

+17
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"errors"
1212
"fmt"
1313
"reflect"
14+
"sync"
1415
"testing"
1516
"time"
1617

@@ -380,3 +381,19 @@ func TestMarshalExtJSONIndent(t *testing.T) {
380381
})
381382
}
382383
}
384+
385+
func TestMarshalConcurrently(t *testing.T) {
386+
t.Parallel()
387+
388+
const size = 10_000
389+
390+
wg := sync.WaitGroup{}
391+
wg.Add(size)
392+
for i := 0; i < size; i++ {
393+
go func() {
394+
defer wg.Done()
395+
_, _ = Marshal(struct{ LastError error }{})
396+
}()
397+
}
398+
wg.Wait()
399+
}

bson/unmarshal_test.go

+19
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package bson
99
import (
1010
"math/rand"
1111
"reflect"
12+
"sync"
1213
"testing"
1314

1415
"go.mongodb.org/mongo-driver/bson/bsoncodec"
@@ -773,3 +774,21 @@ func TestUnmarshalByteSlicesUseDistinctArrays(t *testing.T) {
773774
})
774775
}
775776
}
777+
778+
func TestUnmarshalConcurrently(t *testing.T) {
779+
t.Parallel()
780+
781+
const size = 10_000
782+
783+
data := []byte{16, 0, 0, 0, 10, 108, 97, 115, 116, 101, 114, 114, 111, 114, 0, 0}
784+
wg := sync.WaitGroup{}
785+
wg.Add(size)
786+
for i := 0; i < size; i++ {
787+
go func() {
788+
defer wg.Done()
789+
var res struct{ LastError error }
790+
_ = Unmarshal(data, &res)
791+
}()
792+
}
793+
wg.Wait()
794+
}

0 commit comments

Comments
 (0)