Skip to content

Commit 0955815

Browse files
committed
Additional allocation to convert index array to same type as x/y
1 parent 8ac3c79 commit 0955815

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,11 @@ function _pullback_via_pushforward(
168168
dy,
169169
contexts::Vararg{Context,C},
170170
) where {F,C}
171-
dx = map(x, CartesianIndices(x)) do xj, j
171+
ind = CartesianIndices(x)
172+
T = typeof(similar(x, eltype(ind)))
173+
dx = map(x, T(ind)) do xj, j
172174
t1 = pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...)
173-
convert(eltype(x), dot(only(t1), dy))
175+
dot(only(t1), dy)
174176
end
175177
return dx
176178
end
@@ -254,9 +256,11 @@ function _pullback_via_pushforward(
254256
dy,
255257
contexts::Vararg{Context,C},
256258
) where {F,C}
257-
dx = map(x, CartesianIndices(x)) do xj, j # preserve shape
259+
ind = CartesianIndices(x)
260+
T = typeof(similar(x, eltype(ind)))
261+
dx = map(x, T(ind)) do xj, j # preserve shape
258262
t1 = pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...)
259-
convert(eltype(x), dot(only(t1), dy))
263+
dot(only(t1), dy)
260264
end
261265
return dx
262266
end

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,11 @@ function _pushforward_via_pullback(
171171
dx,
172172
contexts::Vararg{Context,C},
173173
) where {F,C}
174-
dy = map(y, CartesianIndices(y)) do yi, i
174+
ind = CartesianIndices(y)
175+
T = typeof(similar(y, eltype(ind)))
176+
dy = map(y, T(ind)) do yi, i
175177
t1 = pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...)
176-
convert(eltype(y), dot(only(t1), dx))
178+
dot(only(t1), dx)
177179
end
178180
return dy
179181
end
@@ -243,9 +245,11 @@ function _pushforward_via_pullback(
243245
dx,
244246
contexts::Vararg{Context,C},
245247
) where {F,C}
246-
dy = map(y, CartesianIndices(y)) do yi, i # preserve shape
248+
ind = CartesianIndices(y)
249+
T = typeof(similar(y, eltype(ind)))
250+
dy = map(y, T(ind)) do yi, i # preserve shape
247251
t1 = pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...)
248-
convert(eltype(y), dot(only(t1), dx))
252+
dot(only(t1), dx)
249253
end
250254
return dy
251255
end

0 commit comments

Comments
 (0)