diff --git a/bson/bson_test.go b/bson/bson_test.go index 6b8c0cd0b..695f9029d 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -36,6 +36,7 @@ import ( "reflect" "testing" "time" + "strings" "github.com/globalsign/mgo/bson" . "gopkg.in/check.v1" @@ -381,8 +382,54 @@ func (s *S) Test64bitInt(c *C) { // -------------------------------------------------------------------------- // Generic two-way struct marshaling tests. +type prefixPtr string +type prefixVal string + +func (t *prefixPtr) GetBSON() (interface{}, error) { + if t == nil { + return nil, nil + } + return "foo-" + string(*t), nil +} + +func (t *prefixPtr) SetBSON(raw bson.Raw) error { + var s string + if raw.Kind == 0x0A { + return bson.ErrSetZero + } + if err := raw.Unmarshal(&s); err != nil { + return err + } + if !strings.HasPrefix(s, "foo-") { + return errors.New("Prefix not found: " + s) + } + *t = prefixPtr(s[4:]) + return nil +} + +func (t prefixVal) GetBSON() (interface{}, error) { + return "foo-" + string(t), nil +} + +func (t *prefixVal) SetBSON(raw bson.Raw) error { + var s string + if raw.Kind == 0x0A { + return bson.ErrSetZero + } + if err := raw.Unmarshal(&s); err != nil { + return err + } + if !strings.HasPrefix(s, "foo-") { + return errors.New("Prefix not found: " + s) + } + *t = prefixVal(s[4:]) + return nil +} + var bytevar = byte(8) var byteptr = &bytevar +var prefixptr = prefixPtr("bar") +var prefixval = prefixVal("bar") var structItems = []testItemType{ {&struct{ Ptr *byte }{nil}, @@ -419,6 +466,24 @@ var structItems = []testItemType{ // Byte arrays. {&struct{ V [2]byte }{[2]byte{'y', 'o'}}, "\x05v\x00\x02\x00\x00\x00\x00yo"}, + + {&struct{ V prefixPtr }{prefixPtr("buzz")}, + "\x02v\x00\x09\x00\x00\x00foo-buzz\x00"}, + + {&struct{ V *prefixPtr }{&prefixptr}, + "\x02v\x00\x08\x00\x00\x00foo-bar\x00"}, + + {&struct{ V *prefixPtr }{nil}, + "\x0Av\x00"}, + + {&struct{ V prefixVal }{prefixVal("buzz")}, + "\x02v\x00\x09\x00\x00\x00foo-buzz\x00"}, + + {&struct{ V *prefixVal }{&prefixval}, + "\x02v\x00\x08\x00\x00\x00foo-bar\x00"}, + + {&struct{ V *prefixVal }{nil}, + "\x0Av\x00"}, } func (s *S) TestMarshalStructItems(c *C) { diff --git a/bson/decode.go b/bson/decode.go index 3b9e2856d..3e257f846 100644 --- a/bson/decode.go +++ b/bson/decode.go @@ -87,18 +87,20 @@ func setterStyle(outt reflect.Type) int { setterMutex.RLock() style := setterStyles[outt] setterMutex.RUnlock() - if style == setterUnknown { - setterMutex.Lock() - defer setterMutex.Unlock() - if outt.Implements(setterIface) { - setterStyles[outt] = setterType - } else if reflect.PtrTo(outt).Implements(setterIface) { - setterStyles[outt] = setterAddr - } else { - setterStyles[outt] = setterNone - } - style = setterStyles[outt] + if style != setterUnknown { + return style + } + + setterMutex.Lock() + defer setterMutex.Unlock() + if outt.Implements(setterIface) { + style = setterType + } else if reflect.PtrTo(outt).Implements(setterIface) { + style = setterAddr + } else { + style = setterNone } + setterStyles[outt] = style return style } diff --git a/bson/encode.go b/bson/encode.go index 75e503b57..61f388fa1 100644 --- a/bson/encode.go +++ b/bson/encode.go @@ -35,6 +35,7 @@ import ( "reflect" "sort" "strconv" + "sync" "time" ) @@ -60,13 +61,28 @@ var ( const itoaCacheSize = 32 +const ( + getterUnknown = iota + getterNone + getterTypeVal + getterTypePtr + getterAddr +) + var itoaCache []string +var getterStyles map[reflect.Type]int +var getterIface reflect.Type +var getterMutex sync.RWMutex + func init() { itoaCache = make([]string, itoaCacheSize) for i := 0; i != itoaCacheSize; i++ { itoaCache[i] = strconv.Itoa(i) } + var iface Getter + getterIface = reflect.TypeOf(&iface).Elem() + getterStyles = make(map[reflect.Type]int) } func itoa(i int) string { @@ -76,6 +92,52 @@ func itoa(i int) string { return strconv.Itoa(i) } +func getterStyle(outt reflect.Type) int { + getterMutex.RLock() + style := getterStyles[outt] + getterMutex.RUnlock() + if style != getterUnknown { + return style + } + + getterMutex.Lock() + defer getterMutex.Unlock() + if outt.Implements(getterIface) { + vt := outt + for vt.Kind() == reflect.Ptr { + vt = vt.Elem() + } + if vt.Implements(getterIface) { + style = getterTypeVal + } else { + style = getterTypePtr + } + } else if reflect.PtrTo(outt).Implements(getterIface) { + style = getterAddr + } else { + style = getterNone + } + getterStyles[outt] = style + return style +} + +func getGetter(outt reflect.Type, out reflect.Value) Getter { + style := getterStyle(outt) + if style == getterNone { + return nil + } + if style == getterAddr { + if !out.CanAddr() { + return nil + } + return out.Addr().Interface().(Getter) + } + if style == getterTypeVal && out.Kind() == reflect.Ptr && out.IsNil() { + return nil + } + return out.Interface().(Getter) +} + // -------------------------------------------------------------------------- // Marshaling of the document value itself. @@ -253,7 +315,7 @@ func (e *encoder) addElem(name string, v reflect.Value, minSize bool) { return } - if getter, ok := v.Interface().(Getter); ok { + if getter := getGetter(v.Type(), v); getter != nil { getv, err := getter.GetBSON() if err != nil { panic(err)