Skip to content

Commit 23b99a4

Browse files
authored
Add rrule for inv and test (#236)
1 parent 6578475 commit 23b99a4

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

ext/TensorKitChainRulesCoreExt/linalg.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,14 @@ function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2)
9797
return n, norm_pullback
9898
end
9999

100+
function ChainRulesCore.rrule(::typeof(inv), A::AbstractTensorMap)
101+
Ainv = inv(A)
102+
inv_pullback = let Ainv = Ainv
103+
inv_pullback(ΔAinv) = NoTangent(), -Ainv' * unthunk(ΔAinv) * Ainv'
104+
end
105+
return Ainv, inv_pullback
106+
end
107+
100108
function ChainRulesCore.rrule(::typeof(real), a::AbstractTensorMap)
101109
a_real = real(a)
102110
real_pullback(Δa) = NoTangent(), eltype(a) <: Real ? Δa : complex(unthunk(Δa))

test/ad.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
234234
E = randn(T, (V[1:i]...) (V[1:i]...))
235235
test_rrule(LinearAlgebra.tr, E)
236236
test_rrule(exp, E; check_inferred=false)
237+
test_rrule(inv, E)
237238
end
238239

239240
A = randn(T, V[1] V[2] V[3] V[4] V[5])

0 commit comments

Comments
 (0)