Skip to content

added helpful failure massages for tests #288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ makedocs(;
"ChainRulesTestUtils" => "index.md",
"API" => "api.md",
],
strict=true,
checkdocs=:exports,
# doctest=:fix
)

const repo = "github.com/JuliaDiff/ChainRulesTestUtils.jl.git"
Expand Down
32 changes: 20 additions & 12 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
```@meta
DocTestFilters = [r"[0-9\.]+s",r"isapprox\(.*\)"]
```
# ChainRulesTestUtils

[![CI](https://github.com/JuliaDiff/ChainRulesTestUtils.jl/workflows/CI/badge.svg?branch=main)](https://github.com/JuliaDiff/ChainRulesTestUtils.jl/actions?query=workflow%3ACI)
Expand Down Expand Up @@ -38,12 +41,12 @@ end
# output

```
and `rrule`
and `rrule` which contains a mistake in the first cotangent
```jldoctest ex
function ChainRulesCore.rrule(::typeof(two2three), x1, x2)
y = two2three(x1, x2)
function two2three_pullback(Ȳ)
return (NoTangent(), 2.0*Ȳ[2], 3.0*Ȳ[3])
return (NoTangent(), 2.1*Ȳ[2], 3.0*Ȳ[3])
end
return y, two2three_pullback
end
Expand All @@ -65,23 +68,28 @@ Keep this in mind when testing discontinuous rules for functions like [ReLU](htt
julia> using ChainRulesTestUtils;

julia> test_frule(two2three, 3.33, -7.77);
Test Summary: | Pass Total
test_frule: two2three on Float64,Float64 | 6 6
Test Summary: | Pass Total Time
test_frule: two2three on Float64,Float64 | 6 6 2.7s

```

### Testing the `rrule`

[`test_rrule`](@ref) takes in the function `f`, and primal inputsr `x`.
[`test_rrule`](@ref) takes in the function `f`, and primal inputs `x`.
The call will test the `rrule` for function `f` at the point `x`, and similarly to `frule` some rules should be tested at multiple points in the domain.

```jldoctest ex
julia> test_rrule(two2three, 3.33, -7.77);
Test Summary: | Pass Total
test_rrule: two2three on Float64,Float64 | 9 9

test_rrule: two2three on Float64,Float64: Test Failed at /home/lior/.julia/dev/ChainRulesTestUtils/src/check_result.jl:24
Expression: isapprox(actual, expected; kwargs...)
Problem: cotangent for input 2, Float64
Evaluated: isapprox(-4.032, -3.840000000001641; rtol = 1.0e-9, atol = 1.0e-9)
[...]
```

The output of the test indicates to us the cause of the failure under "Problem:" with the expected (`rrule` derived) and actual finite difference results.
The Problem lies with the cotangent corresponding to input 2 of `rrule`, which is the first cotangent as expected.

## Scalar example

For functions with a single argument and a single output, such as e.g. ReLU,
Expand All @@ -105,13 +113,13 @@ with the `frule` and `rrule` defined with the help of `@scalar_rule` macro
call.
```jldoctest ex
julia> test_scalar(relu, 0.5);
Test Summary: | Pass Total
test_scalar: relu at 0.5 | 11 11
Test Summary: | Pass Total Time
test_scalar: relu at 0.5 | 12 12 1.2s


julia> test_scalar(relu, -0.5);
Test Summary: | Pass Total
test_scalar: relu at -0.5 | 11 11
Test Summary: | Pass Total Time
test_scalar: relu at -0.5 | 12 12 0.0s

```

Expand Down
23 changes: 14 additions & 9 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ function test_rrule(
res === nothing && throw(MethodError(rrule_f, Tuple{Core.Typeof.(primals)...}))
y_ad, pullback = res
y = call(primals...)
test_approx(y_ad, y; isapprox_kwargs...) # make sure primal is correct
test_approx(y_ad, y, "Failed primal value check"; isapprox_kwargs...) # make sure primal is correct

ȳ = output_tangent isa Auto ? rand_tangent(y) : output_tangent

Expand All @@ -231,7 +231,8 @@ function test_rrule(
# Correctness testing via finite differencing.
is_ignored = isa.(accum_cotangents, NoTangent)
fd_cotangents = _make_j′vp_call(fdm, call, ȳ, primals, is_ignored)
foreach(accum_cotangents, ad_cotangents, fd_cotangents) do args...
msgs = ntuple(i->"cotangent for input $i, $(summary(fd_cotangents[i]))", length(fd_cotangents))
foreach(accum_cotangents, ad_cotangents, fd_cotangents, msgs) do args...
_test_cotangent(args...; check_inferred=check_inferred, isapprox_kwargs...)
end

Expand Down Expand Up @@ -282,14 +283,16 @@ function _is_inferrable(f, args...; kwargs...)
end

"""
_test_cotangent(accum_cotangent, ad_cotangent, fd_cotangent; kwargs...)
_test_cotangent(accum_cotangent, ad_cotangent, fd_cotangent[, msg]; kwargs...)

Check if the cotangent `ad_cotangent` from `rrule` is consistent with `accum_tangent` and
approximately equal to the cotangent `fd_cotangent` obtained with finite differencing.

If `accum_cotangent` is `NoTangent()`, i.e., the argument was marked as non-differentiable,
`ad_cotangent` and `fd_cotangent` should be `NoTangent()` as well.

If a msg string is given, it is emmited on test failure.

# Keyword arguments
- If `check_inferred=true` (the default) and `ad_cotangent` is a thunk, then it is checked if
its content can be inferred.
Expand All @@ -298,22 +301,23 @@ If `accum_cotangent` is `NoTangent()`, i.e., the argument was marked as non-diff
function _test_cotangent(
accum_cotangent,
ad_cotangent,
fd_cotangent;
fd_cotangent,
msg="";
check_inferred=true,
kwargs...,
)
ad_cotangent isa AbstractThunk && check_inferred && _test_inferred(unthunk, ad_cotangent)

# The main test of the actual derivative being correct:
test_approx(ad_cotangent, fd_cotangent; kwargs...)
test_approx(ad_cotangent, fd_cotangent, msg; kwargs...)
_test_add!!_behaviour(accum_cotangent, ad_cotangent; kwargs...)
end

# we marked the argument as non-differentiable
function _test_cotangent(::NoTangent, ad_cotangent, ::NoTangent; kwargs...)
function _test_cotangent(::NoTangent, ad_cotangent, ::NoTangent, msg=""; kwargs...)
@test ad_cotangent isa NoTangent
end
function _test_cotangent(::NoTangent, ::ZeroTangent, ::NoTangent; kwargs...)
function _test_cotangent(::NoTangent, ::ZeroTangent, ::NoTangent, msg=""; kwargs...)
error(
"The pullback in the rrule should use NoTangent()" *
" rather than ZeroTangent() for non-perturbable arguments."
Expand All @@ -322,7 +326,8 @@ end
function _test_cotangent(
::NoTangent,
ad_cotangent::ChainRulesCore.NotImplemented,
::NoTangent;
::NoTangent,
msg="";
kwargs...,
)
# this situation can occur if a cotangent is not implemented and
Expand All @@ -332,6 +337,6 @@ function _test_cotangent(
# https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/217
@test_broken ad_cotangent isa NoTangent
end
function _test_cotangent(::NoTangent, ad_cotangent, fd_cotangent; kwargs...)
function _test_cotangent(::NoTangent, ad_cotangent, fd_cotangent, msg=""; kwargs...)
error("cotangent obtained with finite differencing has to be NoTangent()")
end