Skip to content

Commit 5a7632a

Browse files
tensordot and reshape bug (#37)
1 parent b734c00 commit 5a7632a

File tree

2 files changed

+116
-45
lines changed

2 files changed

+116
-45
lines changed

src/tensor/linalg.cr

+94
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,100 @@ class Tensor(T)
569569
dest
570570
end
571571

572+
# Compute tensor dot product along specified axes.
573+
#
574+
# Given two tensors, a and b, and an array_like object containing two
575+
# array_like objects, (a_axes, b_axes), sum the products of a’s and b’s
576+
# elements (components) over the axes specified by a_axes and b_axes.
577+
# The third argument can be a single non-negative integer_like scalar,
578+
# N; if it is such, then the last N dimensions of a and the first N
579+
# dimensions of b are summed over.
580+
#
581+
# Arguments
582+
# ---------
583+
# *b* : Tensor
584+
# Right hand side of dot products
585+
# *axes* : Array(Array(Int)) | Array(Int) | Int
586+
# Axes of summation
587+
#
588+
# Examples
589+
# --------
590+
# ```
591+
# a = Tensor.range(60.0).reshape(3, 4, 5)
592+
# b = Tensor.range(24.0).reshape(4, 3, 2)
593+
# puts a.tensordot(b, axes: [[1, 0], [0, 1]])
594+
#
595+
# # [[4400, 4730],
596+
# # [4532, 4874],
597+
# # [4664, 5018],
598+
# # [4796, 5162],
599+
# # [4928, 5306]]
600+
# ```
601+
def tensordot(b : Tensor(T), axes : Array(Array(Int)))
602+
axes_a, axes_b = axes
603+
na = axes_a.size
604+
nb = axes_b.size
605+
as_ = self.shape
606+
nda = self.rank
607+
bs = b.shape
608+
ndb = b.rank
609+
equal = na == nb
610+
na.times do |k|
611+
if as_[axes_a[k]] != bs[axes_b[k]]
612+
equal = false
613+
break
614+
end
615+
if axes_a[k] < 0
616+
axes_a[k] += nda
617+
end
618+
if axes_b[k] < 0
619+
axes_b[k] += ndb
620+
end
621+
end
622+
unless equal
623+
raise Num::Internal::ShapeError.new("Shape mismatch for sum")
624+
end
625+
notin = (0...nda).select do |k|
626+
!axes_a.includes?(k)
627+
end
628+
newaxes_a = notin + axes_a
629+
n2 = 1
630+
axes_a.each do |axis|
631+
n2 *= as_[axis]
632+
end
633+
newshape_a = [(notin.map { |ax| as_[ax] }).product, n2]
634+
olda = notin.map { |ax| as_[ax] }
635+
636+
notin = (0...ndb).select do |k|
637+
!axes_b.includes?(k)
638+
end
639+
newaxes_b = axes_b + notin
640+
n2 = 1
641+
axes_b.each do |axis|
642+
n2 *= bs[axis]
643+
end
644+
newshape_b = [n2, (notin.map { |ax| bs[ax] }).product]
645+
oldb = notin.map { |ax| bs[ax] }
646+
647+
at = self.transpose(newaxes_a).reshape(newshape_a)
648+
bt = b.transpose(newaxes_b).reshape(newshape_b)
649+
res = at.matmul(bt)
650+
res.reshape(olda + oldb)
651+
end
652+
653+
# :ditto:
654+
def tensordot(b : Tensor(T), axes : Int)
655+
axes_a = (-axes...0).to_a
656+
axes_b = (0...axes).to_a
657+
self.tensordot(b, [axes_a, axes_b])
658+
end
659+
660+
# :ditto:
661+
def tensordot(b : Tensor(T), axes : Array(Int))
662+
axes_a, axes_b = axes
663+
self.tensordot(b, [[axes_a], [axes_b]])
664+
end
665+
572666
# :nodoc:
573667
def is_matrix
574668
unless self.rank == 2

src/tensor/tensor.cr

+22-45
Original file line numberDiff line numberDiff line change
@@ -1217,64 +1217,41 @@ class Tensor(T)
12171217
# # [3, 4]]
12181218
# ```
12191219
def reshape(new_shape : Array(Int))
1220-
result_shape = new_shape.map &.to_i
1221-
1222-
if result_shape == @shape
1220+
newshape = new_shape.map &.to_i
1221+
if newshape == shape
12231222
return self.view
12241223
end
1225-
1226-
n = 1
1227-
c = @size
1228-
auto = -1
1229-
1230-
result_shape.each_with_index do |v, i|
1231-
if v < 0
1232-
if auto >= 0
1233-
raise Num::Internal::ValueError.new(
1234-
"Only a single dimension can be inferred"
1235-
)
1224+
newsize = 1
1225+
cur_size = size
1226+
autosize = -1
1227+
newshape.each_with_index do |val, i|
1228+
if val < 0
1229+
if autosize >= 0
1230+
raise Num::Internal::ValueError.new("Only shape dimension can be automatic")
12361231
end
1237-
auto = i
1232+
autosize = i
12381233
else
1239-
n *= v
1234+
newsize *= val
12401235
end
12411236
end
12421237

1243-
if auto >= 0
1244-
result_shape = result_shape.dup
1245-
result_shape[auto] = c // n
1246-
n *= result_shape[auto]
1238+
if autosize >= 0
1239+
newshape = newshape.dup
1240+
newshape[autosize] = cur_size // newsize
1241+
newsize *= newshape[autosize]
12471242
end
12481243

1249-
if n != c
1250-
raise Num::Internal::ShapeError.new(
1251-
"Shape #{@shape} cannot be reshaped to #{result_shape}"
1252-
)
1244+
if newsize != cur_size
1245+
raise Num::Internal::ShapeError.new "Shapes #{shape} cannot be reshaped to #{newshape}"
12531246
end
12541247

1248+
newstrides = Num::Internal.shape_to_strides(newshape, Num::RowMajor)
1249+
12551250
if @flags.contiguous?
1256-
new_strides = Num::Internal.shape_to_strides(
1257-
result_shape,
1258-
Num::RowMajor
1259-
)
1260-
t = Tensor(T).new(@buffer, result_shape, new_strides)
1261-
t.flags &= ~Num::ArrayFlags::OwnData
1262-
t
1263-
elsif @flags.fortran?
1264-
new_strides = Num::Internal.shape_to_strides(
1265-
result_shape,
1266-
Num::ColMajor
1267-
)
1268-
t = Tensor(T).new(@buffer, result_shape, new_strides)
1269-
t.flags &= ~Num::ArrayFlags::OwnData
1270-
t
1251+
self.class.new(@buffer, newshape, newstrides)
12711252
else
1272-
t = dup(Num::ColMajor)
1273-
new_strides = Num::Internal.shape_to_strides(
1274-
result_shape,
1275-
Num::ColMajor
1276-
)
1277-
t
1253+
tmp = self.dup(Num::RowMajor)
1254+
self.class.new(tmp.to_unsafe, newshape, newstrides)
12781255
end
12791256
end
12801257

0 commit comments

Comments
 (0)