Skip to content

Commit 269ba66

Browse files
authored
GODRIVER-1765 Add MarshalJSON/UnmarshalJSON functions for bson.D. (mongodb#1594)
1 parent 214a035 commit 269ba66

File tree

2 files changed

+316
-0
lines changed

2 files changed

+316
-0
lines changed

bson/bson_test.go

+172
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,178 @@ func TestD(t *testing.T) {
297297
})
298298
}
299299

300+
func TestD_MarshalJSON(t *testing.T) {
301+
t.Parallel()
302+
303+
testcases := []struct {
304+
name string
305+
test D
306+
expected interface{}
307+
}{
308+
{
309+
"nil",
310+
nil,
311+
nil,
312+
},
313+
{
314+
"empty",
315+
D{},
316+
struct{}{},
317+
},
318+
{
319+
"non-empty",
320+
D{
321+
{"a", 42},
322+
{"b", true},
323+
{"c", "answer"},
324+
{"d", nil},
325+
{"e", 2.71828},
326+
{"f", A{42, true, "answer", nil, 2.71828}},
327+
{"g", D{{"foo", "bar"}}},
328+
},
329+
struct {
330+
A int `json:"a"`
331+
B bool `json:"b"`
332+
C string `json:"c"`
333+
D interface{} `json:"d"`
334+
E float32 `json:"e"`
335+
F []interface{} `json:"f"`
336+
G map[string]interface{} `json:"g"`
337+
}{
338+
A: 42,
339+
B: true,
340+
C: "answer",
341+
D: nil,
342+
E: 2.71828,
343+
F: []interface{}{42, true, "answer", nil, 2.71828},
344+
G: map[string]interface{}{"foo": "bar"},
345+
},
346+
},
347+
}
348+
for _, tc := range testcases {
349+
tc := tc
350+
t.Run("json.Marshal "+tc.name, func(t *testing.T) {
351+
t.Parallel()
352+
353+
got, err := json.Marshal(tc.test)
354+
assert.NoError(t, err)
355+
want, _ := json.Marshal(tc.expected)
356+
assert.Equal(t, want, got)
357+
})
358+
}
359+
for _, tc := range testcases {
360+
tc := tc
361+
t.Run("json.MarshalIndent "+tc.name, func(t *testing.T) {
362+
t.Parallel()
363+
364+
got, err := json.MarshalIndent(tc.test, "<prefix>", "<indent>")
365+
assert.NoError(t, err)
366+
want, _ := json.MarshalIndent(tc.expected, "<prefix>", "<indent>")
367+
assert.Equal(t, want, got)
368+
})
369+
}
370+
}
371+
372+
func TestD_UnmarshalJSON(t *testing.T) {
373+
t.Parallel()
374+
375+
t.Run("success", func(t *testing.T) {
376+
t.Parallel()
377+
378+
for _, tc := range []struct {
379+
name string
380+
test []byte
381+
expected D
382+
}{
383+
{
384+
"nil",
385+
[]byte(`null`),
386+
nil,
387+
},
388+
{
389+
"empty",
390+
[]byte(`{}`),
391+
D{},
392+
},
393+
{
394+
"non-empty",
395+
[]byte(`{"hello":"world","pi":3.142,"boolean":true,"nothing":null,"list":["hello world",3.142,false,null,{"Lorem":"ipsum"}],"document":{"foo":"bar"}}`),
396+
D{
397+
{"hello", "world"},
398+
{"pi", 3.142},
399+
{"boolean", true},
400+
{"nothing", nil},
401+
{"list", []interface{}{"hello world", 3.142, false, nil, D{{"Lorem", "ipsum"}}}},
402+
{"document", D{{"foo", "bar"}}},
403+
},
404+
},
405+
} {
406+
tc := tc
407+
t.Run(tc.name, func(t *testing.T) {
408+
t.Parallel()
409+
410+
var got D
411+
err := json.Unmarshal(tc.test, &got)
412+
assert.NoError(t, err)
413+
assert.Equal(t, tc.expected, got)
414+
})
415+
}
416+
})
417+
418+
t.Run("failure", func(t *testing.T) {
419+
t.Parallel()
420+
421+
for _, tc := range []struct {
422+
name string
423+
test string
424+
}{
425+
{
426+
"illegal",
427+
`nil`,
428+
},
429+
{
430+
"invalid",
431+
`{"pi": 3.142ipsum}`,
432+
},
433+
{
434+
"malformatted",
435+
`{"pi", 3.142}`,
436+
},
437+
{
438+
"truncated",
439+
`{"pi": 3.142`,
440+
},
441+
{
442+
"array type",
443+
`["pi", 3.142]`,
444+
},
445+
{
446+
"boolean type",
447+
`true`,
448+
},
449+
} {
450+
tc := tc
451+
t.Run(tc.name, func(t *testing.T) {
452+
t.Parallel()
453+
454+
var a map[string]interface{}
455+
want := json.Unmarshal([]byte(tc.test), &a)
456+
var b D
457+
got := json.Unmarshal([]byte(tc.test), &b)
458+
switch w := want.(type) {
459+
case *json.UnmarshalTypeError:
460+
w.Type = reflect.TypeOf(b)
461+
require.IsType(t, want, got)
462+
g := got.(*json.UnmarshalTypeError)
463+
assert.Equal(t, w, g)
464+
default:
465+
assert.Equal(t, want, got)
466+
}
467+
})
468+
}
469+
})
470+
}
471+
300472
type stringerString string
301473

302474
func (ss stringerString) String() string {

bson/primitive.go

+144
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"bytes"
1414
"encoding/json"
1515
"fmt"
16+
"reflect"
1617
"time"
1718
)
1819

@@ -216,6 +217,55 @@ func (d D) Map() M {
216217
return m
217218
}
218219

220+
// MarshalJSON encodes D into JSON.
221+
func (d D) MarshalJSON() ([]byte, error) {
222+
if d == nil {
223+
return json.Marshal(nil)
224+
}
225+
var err error
226+
var buf bytes.Buffer
227+
buf.Write([]byte("{"))
228+
enc := json.NewEncoder(&buf)
229+
for i, e := range d {
230+
err = enc.Encode(e.Key)
231+
if err != nil {
232+
return nil, err
233+
}
234+
buf.Write([]byte(":"))
235+
err = enc.Encode(e.Value)
236+
if err != nil {
237+
return nil, err
238+
}
239+
if i < len(d)-1 {
240+
buf.Write([]byte(","))
241+
}
242+
}
243+
buf.Write([]byte("}"))
244+
return json.RawMessage(buf.Bytes()).MarshalJSON()
245+
}
246+
247+
// UnmarshalJSON decodes D from JSON.
248+
func (d *D) UnmarshalJSON(b []byte) error {
249+
dec := json.NewDecoder(bytes.NewReader(b))
250+
t, err := dec.Token()
251+
if err != nil {
252+
return err
253+
}
254+
if t == nil {
255+
*d = nil
256+
return nil
257+
}
258+
if v, ok := t.(json.Delim); !ok || v != '{' {
259+
return &json.UnmarshalTypeError{
260+
Value: tokenString(t),
261+
Type: reflect.TypeOf(D(nil)),
262+
Offset: dec.InputOffset(),
263+
}
264+
}
265+
*d, err = jsonDecodeD(dec)
266+
return err
267+
}
268+
219269
// E represents a BSON element for a D. It is usually used inside a D.
220270
type E struct {
221271
Key string
@@ -237,3 +287,97 @@ type M map[string]interface{}
237287
//
238288
// bson.A{"bar", "world", 3.14159, bson.D{{"qux", 12345}}}
239289
type A []interface{}
290+
291+
func jsonDecodeD(dec *json.Decoder) (D, error) {
292+
res := D{}
293+
for {
294+
var e E
295+
296+
t, err := dec.Token()
297+
if err != nil {
298+
return nil, err
299+
}
300+
key, ok := t.(string)
301+
if !ok {
302+
break
303+
}
304+
e.Key = key
305+
306+
t, err = dec.Token()
307+
if err != nil {
308+
return nil, err
309+
}
310+
switch v := t.(type) {
311+
case json.Delim:
312+
switch v {
313+
case '[':
314+
e.Value, err = jsonDecodeSlice(dec)
315+
if err != nil {
316+
return nil, err
317+
}
318+
case '{':
319+
e.Value, err = jsonDecodeD(dec)
320+
if err != nil {
321+
return nil, err
322+
}
323+
}
324+
default:
325+
e.Value = t
326+
}
327+
328+
res = append(res, e)
329+
}
330+
return res, nil
331+
}
332+
333+
func jsonDecodeSlice(dec *json.Decoder) ([]interface{}, error) {
334+
var res []interface{}
335+
done := false
336+
for !done {
337+
t, err := dec.Token()
338+
if err != nil {
339+
return nil, err
340+
}
341+
switch v := t.(type) {
342+
case json.Delim:
343+
switch v {
344+
case '[':
345+
a, err := jsonDecodeSlice(dec)
346+
if err != nil {
347+
return nil, err
348+
}
349+
res = append(res, a)
350+
case '{':
351+
d, err := jsonDecodeD(dec)
352+
if err != nil {
353+
return nil, err
354+
}
355+
res = append(res, d)
356+
default:
357+
done = true
358+
}
359+
default:
360+
res = append(res, t)
361+
}
362+
}
363+
return res, nil
364+
}
365+
366+
func tokenString(t json.Token) string {
367+
switch v := t.(type) {
368+
case json.Delim:
369+
switch v {
370+
case '{':
371+
return "object"
372+
case '[':
373+
return "array"
374+
}
375+
case bool:
376+
return "bool"
377+
case float64:
378+
return "number"
379+
case json.Number, string:
380+
return "string"
381+
}
382+
return "unknown"
383+
}

0 commit comments

Comments
 (0)