Skip to content

Commit 64971e7

Browse files
authored
Merge pull request #288 from nomadbl/add_massages
added helpful failure massages for tests
2 parents 49a0324 + f7388ec commit 64971e7

File tree

3 files changed

+35
-22
lines changed

3 files changed

+35
-22
lines changed

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ makedocs(;
1010
"ChainRulesTestUtils" => "index.md",
1111
"API" => "api.md",
1212
],
13-
strict=true,
1413
checkdocs=:exports,
14+
# doctest=:fix
1515
)
1616

1717
const repo = "github.com/JuliaDiff/ChainRulesTestUtils.jl.git"

docs/src/index.md

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
```@meta
2+
DocTestFilters = [r"[0-9\.]+s",r"isapprox\(.*\)"]
3+
```
14
# ChainRulesTestUtils
25

36
[![CI](https://github.com/JuliaDiff/ChainRulesTestUtils.jl/workflows/CI/badge.svg?branch=main)](https://github.com/JuliaDiff/ChainRulesTestUtils.jl/actions?query=workflow%3ACI)
@@ -38,12 +41,12 @@ end
3841
# output
3942
4043
```
41-
and `rrule`
44+
and `rrule` which contains a mistake in the first cotangent
4245
```jldoctest ex
4346
function ChainRulesCore.rrule(::typeof(two2three), x1, x2)
4447
y = two2three(x1, x2)
4548
function two2three_pullback(Ȳ)
46-
return (NoTangent(), 2.0*Ȳ[2], 3.0*Ȳ[3])
49+
return (NoTangent(), 2.1*Ȳ[2], 3.0*Ȳ[3])
4750
end
4851
return y, two2three_pullback
4952
end
@@ -65,23 +68,28 @@ Keep this in mind when testing discontinuous rules for functions like [ReLU](htt
6568
julia> using ChainRulesTestUtils;
6669
6770
julia> test_frule(two2three, 3.33, -7.77);
68-
Test Summary: | Pass Total
69-
test_frule: two2three on Float64,Float64 | 6 6
71+
Test Summary: | Pass Total Time
72+
test_frule: two2three on Float64,Float64 | 6 6 2.7s
7073
7174
```
7275

7376
### Testing the `rrule`
7477

75-
[`test_rrule`](@ref) takes in the function `f`, and primal inputsr `x`.
78+
[`test_rrule`](@ref) takes in the function `f`, and primal inputs `x`.
7679
The call will test the `rrule` for function `f` at the point `x`, and similarly to `frule` some rules should be tested at multiple points in the domain.
7780

7881
```jldoctest ex
7982
julia> test_rrule(two2three, 3.33, -7.77);
80-
Test Summary: | Pass Total
81-
test_rrule: two2three on Float64,Float64 | 9 9
82-
83+
test_rrule: two2three on Float64,Float64: Test Failed at /home/lior/.julia/dev/ChainRulesTestUtils/src/check_result.jl:24
84+
Expression: isapprox(actual, expected; kwargs...)
85+
Problem: cotangent for input 2, Float64
86+
Evaluated: isapprox(-4.032, -3.840000000001641; rtol = 1.0e-9, atol = 1.0e-9)
87+
[...]
8388
```
8489

90+
The output of the test indicates to us the cause of the failure under "Problem:" with the expected (`rrule` derived) and actual finite difference results.
91+
The Problem lies with the cotangent corresponding to input 2 of `rrule`, which is the first cotangent as expected.
92+
8593
## Scalar example
8694

8795
For functions with a single argument and a single output, such as e.g. ReLU,
@@ -105,13 +113,13 @@ with the `frule` and `rrule` defined with the help of `@scalar_rule` macro
105113
call.
106114
```jldoctest ex
107115
julia> test_scalar(relu, 0.5);
108-
Test Summary: | Pass Total
109-
test_scalar: relu at 0.5 | 11 11
116+
Test Summary: | Pass Total Time
117+
test_scalar: relu at 0.5 | 12 12 1.2s
110118
111119
112120
julia> test_scalar(relu, -0.5);
113-
Test Summary: | Pass Total
114-
test_scalar: relu at -0.5 | 11 11
121+
Test Summary: | Pass Total Time
122+
test_scalar: relu at -0.5 | 12 12 0.0s
115123
116124
```
117125

src/testers.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ function test_rrule(
213213
res === nothing && throw(MethodError(rrule_f, Tuple{Core.Typeof.(primals)...}))
214214
y_ad, pullback = res
215215
y = call(primals...)
216-
test_approx(y_ad, y; isapprox_kwargs...) # make sure primal is correct
216+
test_approx(y_ad, y, "Failed primal value check"; isapprox_kwargs...) # make sure primal is correct
217217

218218
ȳ = output_tangent isa Auto ? rand_tangent(y) : output_tangent
219219

@@ -231,7 +231,8 @@ function test_rrule(
231231
# Correctness testing via finite differencing.
232232
is_ignored = isa.(accum_cotangents, NoTangent)
233233
fd_cotangents = _make_j′vp_call(fdm, call, ȳ, primals, is_ignored)
234-
foreach(accum_cotangents, ad_cotangents, fd_cotangents) do args...
234+
msgs = ntuple(i->"cotangent for input $i, $(summary(fd_cotangents[i]))", length(fd_cotangents))
235+
foreach(accum_cotangents, ad_cotangents, fd_cotangents, msgs) do args...
235236
_test_cotangent(args...; check_inferred=check_inferred, isapprox_kwargs...)
236237
end
237238

@@ -282,14 +283,16 @@ function _is_inferrable(f, args...; kwargs...)
282283
end
283284

284285
"""
285-
_test_cotangent(accum_cotangent, ad_cotangent, fd_cotangent; kwargs...)
286+
_test_cotangent(accum_cotangent, ad_cotangent, fd_cotangent[, msg]; kwargs...)
286287
287288
Check if the cotangent `ad_cotangent` from `rrule` is consistent with `accum_tangent` and
288289
approximately equal to the cotangent `fd_cotangent` obtained with finite differencing.
289290
290291
If `accum_cotangent` is `NoTangent()`, i.e., the argument was marked as non-differentiable,
291292
`ad_cotangent` and `fd_cotangent` should be `NoTangent()` as well.
292293
294+
If a msg string is given, it is emmited on test failure.
295+
293296
# Keyword arguments
294297
- If `check_inferred=true` (the default) and `ad_cotangent` is a thunk, then it is checked if
295298
its content can be inferred.
@@ -298,22 +301,23 @@ If `accum_cotangent` is `NoTangent()`, i.e., the argument was marked as non-diff
298301
function _test_cotangent(
299302
accum_cotangent,
300303
ad_cotangent,
301-
fd_cotangent;
304+
fd_cotangent,
305+
msg="";
302306
check_inferred=true,
303307
kwargs...,
304308
)
305309
ad_cotangent isa AbstractThunk && check_inferred && _test_inferred(unthunk, ad_cotangent)
306310

307311
# The main test of the actual derivative being correct:
308-
test_approx(ad_cotangent, fd_cotangent; kwargs...)
312+
test_approx(ad_cotangent, fd_cotangent, msg; kwargs...)
309313
_test_add!!_behaviour(accum_cotangent, ad_cotangent; kwargs...)
310314
end
311315

312316
# we marked the argument as non-differentiable
313-
function _test_cotangent(::NoTangent, ad_cotangent, ::NoTangent; kwargs...)
317+
function _test_cotangent(::NoTangent, ad_cotangent, ::NoTangent, msg=""; kwargs...)
314318
@test ad_cotangent isa NoTangent
315319
end
316-
function _test_cotangent(::NoTangent, ::ZeroTangent, ::NoTangent; kwargs...)
320+
function _test_cotangent(::NoTangent, ::ZeroTangent, ::NoTangent, msg=""; kwargs...)
317321
error(
318322
"The pullback in the rrule should use NoTangent()" *
319323
" rather than ZeroTangent() for non-perturbable arguments."
@@ -322,7 +326,8 @@ end
322326
function _test_cotangent(
323327
::NoTangent,
324328
ad_cotangent::ChainRulesCore.NotImplemented,
325-
::NoTangent;
329+
::NoTangent,
330+
msg="";
326331
kwargs...,
327332
)
328333
# this situation can occur if a cotangent is not implemented and
@@ -332,6 +337,6 @@ function _test_cotangent(
332337
# https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/217
333338
@test_broken ad_cotangent isa NoTangent
334339
end
335-
function _test_cotangent(::NoTangent, ad_cotangent, fd_cotangent; kwargs...)
340+
function _test_cotangent(::NoTangent, ad_cotangent, fd_cotangent, msg=""; kwargs...)
336341
error("cotangent obtained with finite differencing has to be NoTangent()")
337342
end

0 commit comments

Comments
 (0)