Skip to content

Commit d3ea782

Browse files
committed
super performance! Split assert functions go get inlining
1 parent 9072412 commit d3ea782

File tree

2 files changed

+48
-29
lines changed

2 files changed

+48
-29
lines changed

assert/assert.go

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -438,8 +438,7 @@ func MNotNil[M ~map[T]U, T comparable, U any](m M, a ...any) {
438438
// assert violation message.
439439
func NotEqual[T comparable](val, want T, a ...any) {
440440
if want == val {
441-
defMsg := fmt.Sprintf(assertionMsg+": got '%v' want (!= '%v')", val, want)
442-
current().reportAssertionFault(defMsg, a)
441+
doShouldNotBeEqual(val, want, a)
443442
}
444443
}
445444

@@ -451,11 +450,20 @@ func NotEqual[T comparable](val, want T, a ...any) {
451450
// are used to override the auto-generated assert violation message.
452451
func Equal[T comparable](val, want T, a ...any) {
453452
if want != val {
454-
defMsg := fmt.Sprintf(assertionMsg+gotWantFmt, val, want)
455-
current().reportAssertionFault(defMsg, a)
453+
doShouldBeEqual(val, want, a)
456454
}
457455
}
458456

457+
func doShouldBeEqual[T comparable](val, want T, a []any) {
458+
defMsg := fmt.Sprintf(assertionMsg+gotWantFmt, val, want)
459+
current().newReportAssertionFault(defMsg, a)
460+
}
461+
462+
func doShouldNotBeEqual[T comparable](val, want T, a []any) {
463+
defMsg := fmt.Sprintf(assertionMsg+": got '%v' want (!= '%v')", val, want)
464+
current().reportAssertionFault(defMsg, a)
465+
}
466+
459467
// DeepEqual asserts that the (whatever) values are equal. If not it
460468
// panics/errors (according the current Asserter) with the auto-generated
461469
// message. You can append the generated got-want message by using optional
@@ -503,8 +511,7 @@ func Len(obj string, length int, a ...any) {
503511
l := len(obj)
504512

505513
if l != length {
506-
defMsg := fmt.Sprintf(assertionMsg+gotWantFmt, l, length)
507-
current().reportAssertionFault(defMsg, a)
514+
doShouldBeEqual(l, length, a)
508515
}
509516
}
510517

@@ -560,8 +567,7 @@ func SLen[S ~[]T, T any](obj S, length int, a ...any) {
560567
l := len(obj)
561568

562569
if l != length {
563-
defMsg := fmt.Sprintf(assertionMsg+gotWantFmt, l, length)
564-
current().reportAssertionFault(defMsg, a)
570+
doShouldBeEqual(l, length, a)
565571
}
566572
}
567573

@@ -617,8 +623,7 @@ func MLen[M ~map[T]U, T comparable, U any](obj M, length int, a ...any) {
617623
l := len(obj)
618624

619625
if l != length {
620-
defMsg := fmt.Sprintf(assertionMsg+gotWantFmt, l, length)
621-
current().reportAssertionFault(defMsg, a)
626+
doShouldBeEqual(l, length, a)
622627
}
623628
}
624629

@@ -916,11 +921,15 @@ func doZero[T Number](val T, a []any) {
916921
// are used to override the auto-generated assert violation message.
917922
func NotZero[T Number](val T, a ...any) {
918923
if val == 0 {
919-
defMsg := fmt.Sprintf(assertionMsg+": got '%v', want (!= 0)", val)
920-
current().reportAssertionFault(defMsg, a)
924+
doNotZero(val, a)
921925
}
922926
}
923927

928+
func doNotZero[T Number](val T, a []any) {
929+
defMsg := fmt.Sprintf(assertionMsg+": got '%v', want (!= 0)", val)
930+
current().newReportAssertionFault(defMsg, a)
931+
}
932+
924933
// current returns a current default asserter used for package-level
925934
// functions like assert.That().
926935
//

assert/assert_test.go

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -221,20 +221,17 @@ func ExampleGreater() {
221221
// Output: sample: assert_test.go:215: ExampleGreater.func1(): assertion violation: got '2', want <= '2'
222222
}
223223

224-
func assertZero(i int) {
225-
assert.Zero(i)
226-
}
227-
228-
func assertZeroGen(i int) {
229-
assert.Equal(i, 0)
230-
}
231-
232-
func assertMLen(b map[byte]byte, l int) {
233-
assert.MLen(b, l)
234-
}
224+
func ExampleNotZero() {
225+
sample := func(b int8) (err error) {
226+
defer err2.Handle(&err, "sample")
235227

236-
func assertEqualInt2(b int) {
237-
assert.Equal(b, 2)
228+
assert.NotZero(b)
229+
return err
230+
}
231+
var b int8
232+
err := sample(b)
233+
fmt.Printf("%v", err)
234+
// Output: sample: assert_test.go:228: ExampleNotZero.func1(): assertion violation: got '0', want (!= 0)
238235
}
239236

240237
func BenchmarkSNotNil(b *testing.B) {
@@ -267,13 +264,19 @@ func BenchmarkNotEmpty(b *testing.B) {
267264

268265
func BenchmarkZero(b *testing.B) {
269266
for n := 0; n < b.N; n++ {
270-
assertZero(0)
267+
assert.Zero(0)
268+
}
269+
}
270+
271+
func BenchmarkNotZero(b *testing.B) {
272+
for n := 0; n < b.N; n++ {
273+
assert.NotZero(n + 1)
271274
}
272275
}
273276

274277
func BenchmarkEqual(b *testing.B) {
275278
for n := 0; n < b.N; n++ {
276-
assertZeroGen(0)
279+
assert.Equal(n, n)
277280
}
278281
}
279282

@@ -292,7 +295,7 @@ func BenchmarkAsserter_TrueIfVersion(b *testing.B) {
292295
func BenchmarkMLen(b *testing.B) {
293296
d := map[byte]byte{1: 1, 2: 2}
294297
for n := 0; n < b.N; n++ {
295-
assertMLen(d, 2)
298+
assert.MLen(d, 2)
296299
}
297300
}
298301

@@ -317,10 +320,17 @@ func BenchmarkSLen_thatVersion(b *testing.B) {
317320
}
318321
}
319322

323+
func BenchmarkNotEqualInt(b *testing.B) {
324+
const d = 2
325+
for n := 0; n < b.N; n++ {
326+
assert.NotEqual(d, 3)
327+
}
328+
}
329+
320330
func BenchmarkEqualInt(b *testing.B) {
321331
const d = 2
322332
for n := 0; n < b.N; n++ {
323-
assertEqualInt2(d)
333+
assert.Equal(d, 2)
324334
}
325335
}
326336

0 commit comments

Comments
 (0)