Skip to content

Commit 9d2b864

Browse files
Fix yield iterations for Tensors with negative strides (#60)
1 parent 6ba53ea commit 9d2b864

File tree

2 files changed

+21
-18
lines changed

2 files changed

+21
-18
lines changed

shard.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
name: num
2-
version: 0.4.2
2+
version: 0.4.4
33

44
authors:
55
- Chris Zimmerman <[email protected]>
66

7-
crystal: 0.34.0
7+
crystal: 0.35.1
88

99
license: MIT
1010

src/tensor/internal/yield_iterators.cr

+19-16
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
require "../tensor"
22

3-
macro init_strided_iteration(coord, backstrides, t_shape, t_strides, t_rank)
3+
macro init_strided_iteration(coord, backstrides, t_shape, t_strides, t_rank, t_data)
44
{{ coord.id }} = Pointer(Int32).malloc({{ t_rank }}, 0)
55
{{ backstrides.id }} = Pointer(Int32).malloc({{ t_rank }})
66
{{ t_rank }}.times do |i|
77
{{ backstrides.id }}[i] = {{ t_strides }}[i] * ({{ t_shape }}[i] - 1)
8+
if {{ t_strides }}[i] < 0
9+
{{ t_data }} += ({{ t_shape }}[i] - 1) * {{ t_strides }}[i].abs
10+
end
811
end
912
end
1013

@@ -31,7 +34,7 @@ def strided_iteration(t : Tensor)
3134
end
3235
else
3336
t_shape, t_strides, t_rank = t.iter_attrs
34-
init_strided_iteration(:coord, :backstrides, t_shape, t_strides, t_rank)
37+
init_strided_iteration(:coord, :backstrides, t_shape, t_strides, t_rank, data)
3538
t.size.times do |i|
3639
yield i, data
3740
advance_strided_iteration(:coord, :backstrides, t_shape, t_strides, t_rank, data)
@@ -59,22 +62,22 @@ def dual_strided_iteration(t1 : Tensor, t2 : Tensor)
5962
t2data += 1
6063
end
6164
elsif t1_contiguous
62-
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank)
65+
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank, t2data)
6366
n.times do |i|
6467
yield i, t1data, t2data
6568
t1data += 1
6669
advance_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank, t2data)
6770
end
6871
elsif t2_contiguous
69-
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank)
72+
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
7073
n.times do |i|
7174
yield i, t1data, t2data
7275
advance_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
7376
t2data += 1
7477
end
7578
else
76-
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank)
77-
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank)
79+
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
80+
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank, t2data)
7881
n.times do |i|
7982
yield i, t1data, t2data
8083
advance_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
@@ -107,26 +110,26 @@ def tri_strided_iteration(t1 : Tensor, t2 : Tensor, t3 : Tensor)
107110
t3data += 1
108111
end
109112
elsif t1_contiguous && t2_contiguous
110-
init_strided_iteration(:t3_coord, :t3_backstrides, t3_shape, t3_strides, t3_rank)
113+
init_strided_iteration(:t3_coord, :t3_backstrides, t3_shape, t3_strides, t3_rank, t3data)
111114
n.times do |i|
112115
yield i, t1data, t2data, t3data
113116
t1data += 1
114117
t2data += 1
115118
advance_strided_iteration(:t3_coord, :t3_backstrides, t3_shape, t3_strides, t3_rank, t3data)
116119
end
117120
elsif t1_contiguous
118-
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank)
119-
init_strided_iteration(:t3_coord, :t3_backstrides, t3_shape, t3_strides, t3_rank)
121+
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank, t2data)
122+
init_strided_iteration(:t3_coord, :t3_backstrides, t3_shape, t3_strides, t3_rank, t3data)
120123
n.times do |i|
121124
yield i, t1data, t2data, t3data
122125
t1data += 1
123126
advance_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank, t2data)
124127
advance_strided_iteration(:t3_coord, :t3_backstrides, t3_shape, t3_strides, t3_rank, t3data)
125128
end
126129
else
127-
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank)
128-
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank)
129-
init_strided_iteration(:t3_coord, :t3_backstrides, t3_shape, t3_strides, t3_rank)
130+
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
131+
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank, t2data)
132+
init_strided_iteration(:t3_coord, :t3_backstrides, t3_shape, t3_strides, t3_rank, t3data)
130133
n.times do |i|
131134
yield i, t1data, t2data, t3data
132135
advance_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
@@ -162,7 +165,7 @@ def outer_strided_iteration(t1 : Tensor, t2 : Tensor)
162165
t1data += 1
163166
end
164167
elsif t1_contiguous
165-
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank)
168+
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank, t2data)
166169
n.times do
167170
m.times do
168171
yield index, t1data, t2data
@@ -172,7 +175,7 @@ def outer_strided_iteration(t1 : Tensor, t2 : Tensor)
172175
t1data += 1
173176
end
174177
elsif t2_contiguous
175-
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank)
178+
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
176179
n.times do
177180
m.times do
178181
yield index, t1data, t2data
@@ -182,8 +185,8 @@ def outer_strided_iteration(t1 : Tensor, t2 : Tensor)
182185
advance_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
183186
end
184187
else
185-
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank)
186-
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank)
188+
init_strided_iteration(:t1_coord, :t1_backstrides, t1_shape, t1_strides, t1_rank, t1data)
189+
init_strided_iteration(:t2_coord, :t2_backstrides, t2_shape, t2_strides, t2_rank, t2data)
187190
n.times do
188191
m.times do
189192
yield index, t1data, t2data

0 commit comments

Comments
 (0)