diff --git a/Project.toml b/Project.toml index d29563c..8055167 100644 --- a/Project.toml +++ b/Project.toml @@ -1,19 +1,30 @@ name = "SkewLinearAlgebra" uuid = "5c889d49-8c60-4500-9d10-5d3a22e2f4b9" authors = ["smataigne and contributors"] -version = "1.0" +version = "1.0.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + +[extensions] +SkewLinearAlgebraChainRulesCoreExt = "ChainRulesCore" + [compat] -julia = "1.6" LinearAlgebra = "1.6" +ChainRulesCore = "1" +julia = "1.6" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" [targets] -test = ["Test","Random", "SparseArrays"] +test = ["Test", "LinearAlgebra", "Random", "SparseArrays", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences"] diff --git a/ext/SkewLinearAlgebraChainRulesCoreExt.jl b/ext/SkewLinearAlgebraChainRulesCoreExt.jl new file mode 100644 index 0000000..0445ace --- /dev/null +++ b/ext/SkewLinearAlgebraChainRulesCoreExt.jl @@ -0,0 +1,30 @@ +module SkewLinearAlgebraChainRulesCoreExt + +using LinearAlgebra +using SkewLinearAlgebra +using ChainRulesCore + + +function ChainRulesCore.rrule(::Type{SkewHermitian}, val) + y = SkewHermitian(val) + function Foo_pb(ΔFoo) + if isa(ΔFoo, SkewHermitian) + return NoTangent(), unthunk(ΔFoo).data + else + return (NoTangent(), unthunk(ΔFoo)) + end + end + return y, Foo_pb +end + +function ChainRulesCore.rrule(::typeof(pfaffian), A::SkewHermitian) + Ω = pfaffian(A) + pfaffian_pullback(ΔΩ) = NoTangent(), SkewHermitian(rmul!(inv(A)', dot(Ω, ΔΩ))) #potentially need the 0.5 here ! + return Ω, pfaffian_pullback +end + +function ChainRulesCore.ProjectTo{<:SkewHermitian}(A::AbstractArray) + return skewhermitian(A) +end + +end diff --git a/test/chainrulestests.jl b/test/chainrulestests.jl new file mode 100644 index 0000000..cafb40f --- /dev/null +++ b/test/chainrulestests.jl @@ -0,0 +1,41 @@ +using ChainRulesTestUtils +using Random +using SkewLinearAlgebra +using LinearAlgebra +using FiniteDifferences + +ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::SkewHermitian) = skewhermitian!(rand(rng, eltype(x), size(x)...)) + +# Required to make finite differences behave correctly +function FiniteDifferences.to_vec(x::SkewHermitian) + m = size(x, 1) + v = Vector{eltype(x)}(undef, m * (m - 1) ÷ 2) + k = 1 + for i in 2:m, j in 1:i-1 + @inbounds v[k] = x[i, j] + k += 1 + end + + function from_vec(v) + y = zero(x) + k = 1 + for i in 2:m, j in 1:i-1 + @inbounds y[i, j] = v[k] + @inbounds y[j, i] = -v[k] + k += 1 + end + return y + end + return v, from_vec +end + +@testset "automatic differentiation" begin + m = 10 + inds = [1,2] + A = skewhermitian(rand(m, m)) + + test_rrule(SkewHermitian, A) # test constructor + + test_rrule(pfaffian, A) # test pfaffian + test_rrule(pfaffian, SkewHermitian(A[inds, inds])) # test pfaffian of submatrix +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index fa97c7b..f2e78c1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -519,3 +519,5 @@ end @test E.vectors*Diagonal(E.values)*E.vectors' ≈ B end end + +include("chainrulestests.jl") \ No newline at end of file