@@ -56,7 +56,7 @@ isa_newblock(state::InvertibleOperationsTape, X) = (state.counter_block == 0) ||
56
56
"""
57
57
Error if mismatch between state and network
58
58
"""
59
- function check_coherence (state:: InvertibleOperationsTape , net:: Union{NeuralNetLayer,InvertibleNetwork} )
59
+ function check_coherence (state:: InvertibleOperationsTape , net:: Invertible )
60
60
if state. counter_block != 0 && state. counter_layer != 0 && state. layer_blocks[state. counter_block][state. counter_layer] != net
61
61
reset! (state)
62
62
throw (ArgumentError (" Current state does not correspond to current layer, resetting state..." ))
66
66
"""
67
67
Update state in the forward pass.
68
68
"""
69
- function forward_update! (state:: InvertibleOperationsTape , X:: AbstractArray{T,N} , Y:: AbstractArray{T,N} , logdet:: Union{Nothing,T} , net:: Union{NeuralNetLayer,InvertibleNetwork} ) where {T, N}
69
+ function forward_update! (state:: InvertibleOperationsTape , X:: AbstractArray{T,N} , Y:: AbstractArray{T,N} , logdet:: Union{Nothing,T} , net:: Invertible ) where {T, N}
70
70
71
71
if isa_newblock (state, X)
72
72
push! (state. Y, Y)
104
104
105
105
# # Chain rules for invertible networks
106
106
# General pullback function
107
- function pullback (net:: Union{NeuralNetLayer,InvertibleNetwork} , ΔY:: AbstractArray{T,N} ;
107
+ function pullback (net:: Invertible , ΔY:: AbstractArray{T,N} ;
108
108
state:: InvertibleOperationsTape = GLOBAL_STATE_INVOPS) where {T, N}
109
109
110
110
# Check state coherency
124
124
125
125
126
126
# Reverse-mode AD rule
127
- function ChainRulesCore. rrule (net:: Union{NeuralNetLayer,InvertibleNetwork} , X:: AbstractArray{T, N} ;
127
+ function ChainRulesCore. rrule (net:: Invertible , X:: AbstractArray{T, N} ;
128
128
state:: InvertibleOperationsTape = GLOBAL_STATE_INVOPS) where {T, N}
129
129
130
130
# Forward pass
0 commit comments