1
1
require " ../tensor"
2
2
3
- macro init_strided_iteration (coord , backstrides , t_shape , t_strides , t_rank )
3
+ macro init_strided_iteration (coord , backstrides , t_shape , t_strides , t_rank , t_data )
4
4
{{ coord.id }} = Pointer (Int32 ).malloc({{ t_rank }}, 0 )
5
5
{{ backstrides.id }} = Pointer (Int32 ).malloc({{ t_rank }})
6
6
{{ t_rank }}.times do |i |
7
7
{{ backstrides.id }}[i] = {{ t_strides }}[i] * ({{ t_shape }}[i] - 1 )
8
+ if {{ t_strides }}[i] < 0
9
+ {{ t_data }} += ({{ t_shape }}[i] - 1 ) * {{ t_strides }}[i].abs
10
+ end
8
11
end
9
12
end
10
13
@@ -31,7 +34,7 @@ def strided_iteration(t : Tensor)
31
34
end
32
35
else
33
36
t_shape, t_strides, t_rank = t.iter_attrs
34
- init_strided_iteration(:coord , :backstrides , t_shape, t_strides, t_rank)
37
+ init_strided_iteration(:coord , :backstrides , t_shape, t_strides, t_rank, data )
35
38
t.size.times do |i |
36
39
yield i, data
37
40
advance_strided_iteration(:coord , :backstrides , t_shape, t_strides, t_rank, data)
@@ -59,22 +62,22 @@ def dual_strided_iteration(t1 : Tensor, t2 : Tensor)
59
62
t2data += 1
60
63
end
61
64
elsif t1_contiguous
62
- init_strided_iteration(:t2_coord , :t2_backstrides , t2_shape, t2_strides, t2_rank)
65
+ init_strided_iteration(:t2_coord , :t2_backstrides , t2_shape, t2_strides, t2_rank, t2data )
63
66
n.times do |i |
64
67
yield i, t1data, t2data
65
68
t1data += 1
66
69
advance_strided_iteration(:t2_coord , :t2_backstrides , t2_shape, t2_strides, t2_rank, t2data)
67
70
end
68
71
elsif t2_contiguous
69
- init_strided_iteration(:t1_coord , :t1_backstrides , t1_shape, t1_strides, t1_rank)
72
+ init_strided_iteration(:t1_coord , :t1_backstrides , t1_shape, t1_strides, t1_rank, t1data )
70
73
n.times do |i |
71
74
yield i, t1data, t2data
72
75
advance_strided_iteration(:t1_coord , :t1_backstrides , t1_shape, t1_strides, t1_rank, t1data)
73
76
t2data += 1
74
77
end
75
78
else
76
- init_strided_iteration(:t1_coord , :t1_backstrides , t1_shape, t1_strides, t1_rank)
77
- init_strided_iteration(:t2_coord , :t2_backstrides , t2_shape, t2_strides, t2_rank)
79
+ init_strided_iteration(:t1_coord , :t1_backstrides , t1_shape, t1_strides, t1_rank, t1data )
80
+ init_strided_iteration(:t2_coord , :t2_backstrides , t2_shape, t2_strides, t2_rank, t2data )
78
81
n.times do |i |
79
82
yield i, t1data, t2data
80
83
advance_strided_iteration(:t1_coord , :t1_backstrides , t1_shape, t1_strides, t1_rank, t1data)
@@ -107,26 +110,26 @@ def tri_strided_iteration(t1 : Tensor, t2 : Tensor, t3 : Tensor)
107
110
t3data += 1
108
111
end
109
112
elsif t1_contiguous && t2_contiguous
110
- init_strided_iteration(:t3_coord , :t3_backstrides , t3_shape, t3_strides, t3_rank)
113
+ init_strided_iteration(:t3_coord , :t3_backstrides , t3_shape, t3_strides, t3_rank, t3data )
111
114
n.times do |i |
112
115
yield i, t1data, t2data, t3data
113
116
t1data += 1
114
117
t2data += 1
115
118
advance_strided_iteration(:t3_coord , :t3_backstrides , t3_shape, t3_strides, t3_rank, t3data)
116
119
end
117
120
elsif t1_contiguous
118
- init_strided_iteration(:t2_coord , :t2_backstrides , t2_shape, t2_strides, t2_rank)
119
- init_strided_iteration(:t3_coord , :t3_backstrides , t3_shape, t3_strides, t3_rank)
121
+ init_strided_iteration(:t2_coord , :t2_backstrides , t2_shape, t2_strides, t2_rank, t2data )
122
+ init_strided_iteration(:t3_coord , :t3_backstrides , t3_shape, t3_strides, t3_rank, t3data )
120
123
n.times do |i |
121
124
yield i, t1data, t2data, t3data
122
125
t1data += 1
123
126
advance_strided_iteration(:t2_coord , :t2_backstrides , t2_shape, t2_strides, t2_rank, t2data)
124
127
advance_strided_iteration(:t3_coord , :t3_backstrides , t3_shape, t3_strides, t3_rank, t3data)
125
128
end
126
129
else
127
- init_strided_iteration(:t1_coord , :t1_backstrides , t1_shape, t1_strides, t1_rank)
128
- init_strided_iteration(:t2_coord , :t2_backstrides , t2_shape, t2_strides, t2_rank)
129
- init_strided_iteration(:t3_coord , :t3_backstrides , t3_shape, t3_strides, t3_rank)
130
+ init_strided_iteration(:t1_coord , :t1_backstrides , t1_shape, t1_strides, t1_rank, t1data )
131
+ init_strided_iteration(:t2_coord , :t2_backstrides , t2_shape, t2_strides, t2_rank, t2data )
132
+ init_strided_iteration(:t3_coord , :t3_backstrides , t3_shape, t3_strides, t3_rank, t3data )
130
133
n.times do |i |
131
134
yield i, t1data, t2data, t3data
132
135
advance_strided_iteration(:t1_coord , :t1_backstrides , t1_shape, t1_strides, t1_rank, t1data)
@@ -162,7 +165,7 @@ def outer_strided_iteration(t1 : Tensor, t2 : Tensor)
162
165
t1data += 1
163
166
end
164
167
elsif t1_contiguous
165
- init_strided_iteration(:t2_coord , :t2_backstrides , t2_shape, t2_strides, t2_rank)
168
+ init_strided_iteration(:t2_coord , :t2_backstrides , t2_shape, t2_strides, t2_rank, t2data )
166
169
n.times do
167
170
m.times do
168
171
yield index, t1data, t2data
@@ -172,7 +175,7 @@ def outer_strided_iteration(t1 : Tensor, t2 : Tensor)
172
175
t1data += 1
173
176
end
174
177
elsif t2_contiguous
175
- init_strided_iteration(:t1_coord , :t1_backstrides , t1_shape, t1_strides, t1_rank)
178
+ init_strided_iteration(:t1_coord , :t1_backstrides , t1_shape, t1_strides, t1_rank, t1data )
176
179
n.times do
177
180
m.times do
178
181
yield index, t1data, t2data
@@ -182,8 +185,8 @@ def outer_strided_iteration(t1 : Tensor, t2 : Tensor)
182
185
advance_strided_iteration(:t1_coord , :t1_backstrides , t1_shape, t1_strides, t1_rank, t1data)
183
186
end
184
187
else
185
- init_strided_iteration(:t1_coord , :t1_backstrides , t1_shape, t1_strides, t1_rank)
186
- init_strided_iteration(:t2_coord , :t2_backstrides , t2_shape, t2_strides, t2_rank)
188
+ init_strided_iteration(:t1_coord , :t1_backstrides , t1_shape, t1_strides, t1_rank, t1data )
189
+ init_strided_iteration(:t2_coord , :t2_backstrides , t2_shape, t2_strides, t2_rank, t2data )
187
190
n.times do
188
191
m.times do
189
192
yield index, t1data, t2data
0 commit comments