Skip to content

Commit c31ea03

Browse files
authored
Support comparing byte slice (#1202)
* support comparing byte slice Signed-off-by: Ryan Leung <[email protected]> * address the comment Signed-off-by: Ryan Leung <[email protected]>
1 parent 48391ba commit c31ea03

File tree

2 files changed

+151
-1
lines changed

2 files changed

+151
-1
lines changed

assert/assertion_compare.go

+23-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package assert
22

33
import (
4+
"bytes"
45
"fmt"
56
"reflect"
67
"time"
@@ -32,7 +33,8 @@ var (
3233

3334
stringType = reflect.TypeOf("")
3435

35-
timeType = reflect.TypeOf(time.Time{})
36+
timeType = reflect.TypeOf(time.Time{})
37+
bytesType = reflect.TypeOf([]byte{})
3638
)
3739

3840
func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
@@ -323,6 +325,26 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
323325

324326
return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64)
325327
}
328+
case reflect.Slice:
329+
{
330+
// We only care about the []byte type.
331+
if !canConvert(obj1Value, bytesType) {
332+
break
333+
}
334+
335+
// []byte can be compared!
336+
bytesObj1, ok := obj1.([]byte)
337+
if !ok {
338+
bytesObj1 = obj1Value.Convert(bytesType).Interface().([]byte)
339+
340+
}
341+
bytesObj2, ok := obj2.([]byte)
342+
if !ok {
343+
bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte)
344+
}
345+
346+
return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true
347+
}
326348
}
327349

328350
return compareEqual, false

assert/assertion_compare_go1.17_test.go

+128
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,24 @@
88
package assert
99

1010
import (
11+
"bytes"
1112
"reflect"
1213
"testing"
1314
"time"
1415
)
1516

1617
func TestCompare17(t *testing.T) {
1718
type customTime time.Time
19+
type customBytes []byte
1820
for _, currCase := range []struct {
1921
less interface{}
2022
greater interface{}
2123
cType string
2224
}{
2325
{less: time.Now(), greater: time.Now().Add(time.Hour), cType: "time.Time"},
2426
{less: customTime(time.Now()), greater: customTime(time.Now().Add(time.Hour)), cType: "time.Time"},
27+
{less: []byte{1, 1}, greater: []byte{1, 2}, cType: "[]byte"},
28+
{less: customBytes([]byte{1, 1}), greater: customBytes([]byte{1, 2}), cType: "[]byte"},
2529
} {
2630
resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind())
2731
if !isComparable {
@@ -52,3 +56,127 @@ func TestCompare17(t *testing.T) {
5256
}
5357
}
5458
}
59+
60+
func TestGreater17(t *testing.T) {
61+
mockT := new(testing.T)
62+
63+
if !Greater(mockT, 2, 1) {
64+
t.Error("Greater should return true")
65+
}
66+
67+
if Greater(mockT, 1, 1) {
68+
t.Error("Greater should return false")
69+
}
70+
71+
if Greater(mockT, 1, 2) {
72+
t.Error("Greater should return false")
73+
}
74+
75+
// Check error report
76+
for _, currCase := range []struct {
77+
less interface{}
78+
greater interface{}
79+
msg string
80+
}{
81+
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 1]" is not greater than "[1 2]"`},
82+
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 00:00:00 +0000 UTC" is not greater than "0001-01-01 01:00:00 +0000 UTC"`},
83+
} {
84+
out := &outputT{buf: bytes.NewBuffer(nil)}
85+
False(t, Greater(out, currCase.less, currCase.greater))
86+
Contains(t, out.buf.String(), currCase.msg)
87+
Contains(t, out.helpers, "github.com/stretchr/testify/assert.Greater")
88+
}
89+
}
90+
91+
func TestGreaterOrEqual17(t *testing.T) {
92+
mockT := new(testing.T)
93+
94+
if !GreaterOrEqual(mockT, 2, 1) {
95+
t.Error("GreaterOrEqual should return true")
96+
}
97+
98+
if !GreaterOrEqual(mockT, 1, 1) {
99+
t.Error("GreaterOrEqual should return true")
100+
}
101+
102+
if GreaterOrEqual(mockT, 1, 2) {
103+
t.Error("GreaterOrEqual should return false")
104+
}
105+
106+
// Check error report
107+
for _, currCase := range []struct {
108+
less interface{}
109+
greater interface{}
110+
msg string
111+
}{
112+
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 1]" is not greater than or equal to "[1 2]"`},
113+
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 00:00:00 +0000 UTC" is not greater than or equal to "0001-01-01 01:00:00 +0000 UTC"`},
114+
} {
115+
out := &outputT{buf: bytes.NewBuffer(nil)}
116+
False(t, GreaterOrEqual(out, currCase.less, currCase.greater))
117+
Contains(t, out.buf.String(), currCase.msg)
118+
Contains(t, out.helpers, "github.com/stretchr/testify/assert.GreaterOrEqual")
119+
}
120+
}
121+
122+
func TestLess17(t *testing.T) {
123+
mockT := new(testing.T)
124+
125+
if !Less(mockT, 1, 2) {
126+
t.Error("Less should return true")
127+
}
128+
129+
if Less(mockT, 1, 1) {
130+
t.Error("Less should return false")
131+
}
132+
133+
if Less(mockT, 2, 1) {
134+
t.Error("Less should return false")
135+
}
136+
137+
// Check error report
138+
for _, currCase := range []struct {
139+
less interface{}
140+
greater interface{}
141+
msg string
142+
}{
143+
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 2]" is not less than "[1 1]"`},
144+
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 01:00:00 +0000 UTC" is not less than "0001-01-01 00:00:00 +0000 UTC"`},
145+
} {
146+
out := &outputT{buf: bytes.NewBuffer(nil)}
147+
False(t, Less(out, currCase.greater, currCase.less))
148+
Contains(t, out.buf.String(), currCase.msg)
149+
Contains(t, out.helpers, "github.com/stretchr/testify/assert.Less")
150+
}
151+
}
152+
153+
func TestLessOrEqual17(t *testing.T) {
154+
mockT := new(testing.T)
155+
156+
if !LessOrEqual(mockT, 1, 2) {
157+
t.Error("LessOrEqual should return true")
158+
}
159+
160+
if !LessOrEqual(mockT, 1, 1) {
161+
t.Error("LessOrEqual should return true")
162+
}
163+
164+
if LessOrEqual(mockT, 2, 1) {
165+
t.Error("LessOrEqual should return false")
166+
}
167+
168+
// Check error report
169+
for _, currCase := range []struct {
170+
less interface{}
171+
greater interface{}
172+
msg string
173+
}{
174+
{less: []byte{1, 1}, greater: []byte{1, 2}, msg: `"[1 2]" is not less than or equal to "[1 1]"`},
175+
{less: time.Time{}, greater: time.Time{}.Add(time.Hour), msg: `"0001-01-01 01:00:00 +0000 UTC" is not less than or equal to "0001-01-01 00:00:00 +0000 UTC"`},
176+
} {
177+
out := &outputT{buf: bytes.NewBuffer(nil)}
178+
False(t, LessOrEqual(out, currCase.greater, currCase.less))
179+
Contains(t, out.buf.String(), currCase.msg)
180+
Contains(t, out.helpers, "github.com/stretchr/testify/assert.LessOrEqual")
181+
}
182+
}

0 commit comments

Comments
 (0)