Skip to content

Commit 73eb244

Browse files
committed
SetAsserter works only outside unit tests, they MUST use pkg asserter
1 parent a9dc394 commit 73eb244

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

assert/assert.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -967,8 +967,6 @@ func doNotZero[T Number](val T, a []any) {
967967
//
968968
// NOTE that since our TLS [asserterMap] we still continue to use indexing.
969969
func current() (curAsserter asserter) {
970-
// we need thread local storage, maybe we'll implement that to x.package?
971-
// study `tester` and copy ideas from it.
972970
tlsID := goid()
973971
asserterMap.Rx(func(m map[int]asserter) {
974972
aster, found := m[tlsID]
@@ -1023,9 +1021,16 @@ func SetDefault(i defInd) (old defInd) {
10231021
// to return plain error messages instead of the panic asserts, they can use
10241022
// following:
10251023
//
1026-
// assert.SetAsserter(assert.Plain)
1024+
// defer assert.SetAsserter(assert.Plain)()
10271025
func SetAsserter(i defInd) func() {
1028-
asserterMap.Set(goid(), defAsserter[i])
1026+
// get pkg lvl asserter
1027+
curAsserter := defAsserter[def]
1028+
// .. to check if we are doing unit tests
1029+
if !curAsserter.isUnitTesting() {
1030+
// .. allow TLS specific asserter. NOTE see current()
1031+
curGoRID := goid()
1032+
asserterMap.Set(curGoRID, defAsserter[i])
1033+
}
10291034
return popCurrentAsserter
10301035
}
10311036

assert/assert_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,22 @@ func ExampleCLen() {
261261
// Output: sample: assert_test.go:253: ExampleCLen.func1(): assertion failure: length: got '2', want '3'
262262
}
263263

264+
func ExampleThatNot() {
265+
sample := func() (err error) {
266+
defer err2.Handle(&err)
267+
268+
assert.ThatNot(true, "overrides if Plain asserter")
269+
return err
270+
}
271+
272+
// set asserter for this thread/goroutine only, we want plain errors
273+
defer assert.SetAsserter(assert.Plain)()
274+
275+
err := sample()
276+
fmt.Printf("%v", err)
277+
// Output: testing: run example: overrides if Plain asserter
278+
}
279+
264280
func BenchmarkMKeyExists(b *testing.B) {
265281
bs := map[int]int{0: 0, 1: 1}
266282
for n := 0; n < b.N; n++ {

0 commit comments

Comments
 (0)