Skip to content

Commit 1adcc49

Browse files
authored
Merge pull request #75 from slimgroup/mix-type-rr
Mix type rr
2 parents 11a4a77 + 0d4c1db commit 1adcc49

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "InvertibleNetworks"
22
uuid = "b7115f24-5f92-4794-81e8-23b0ddb121d3"
33
authors = ["Philipp Witte <[email protected]>", "Ali Siahkoohi <[email protected]>", "Mathias Louboutin <[email protected]>", "Gabrio Rizzuti <[email protected]>", "Rafael Orozco <[email protected]>", "Felix J. herrmann <[email protected]>"]
4-
version = "2.2.3"
4+
version = "2.2.4"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

src/utils/chainrules.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ function forward_update!(state::InvertibleOperationsTape, X::AbstractArray{T,N},
7070

7171
if isa_newblock(state, X)
7272
push!(state.Y, Y)
73-
push!(state.layer_blocks, [net])
73+
push!(state.layer_blocks, Vector{Any}([net]))
7474
state.counter_block += 1
7575
state.counter_layer = 1
7676
else

src/utils/compute_utils.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function chain_lr(x::AbstractMatrix{T}, vi::Vararg{AbstractVector{T}, N}) where
2121
out = T(1) .* x
2222
tmp = cuzeros(vi[1], size(x, 1))
2323
for v=vi
24-
n = -2/norm(v)^2
24+
n = -2/dot(v, v)
2525
mul!(tmp, out, v)
2626
rmul!(tmp, n)
2727
gemm_outer!(out, tmp, v)

0 commit comments

Comments
 (0)