@@ -87,13 +87,13 @@ torch::Tensor inclusive_sum_cub(
87
87
#if CUB_SUPPORTS_SCAN_BY_KEY()
88
88
if (backward) {
89
89
inclusive_sum_by_key (
90
- thrust::make_reverse_iterator (indices.data_ptr <long >() + n_edges),
90
+ thrust::make_reverse_iterator (indices.data_ptr <int64_t >() + n_edges),
91
91
thrust::make_reverse_iterator (inputs.data_ptr <float >() + n_edges),
92
92
thrust::make_reverse_iterator (outputs.data_ptr <float >() + n_edges),
93
93
n_edges);
94
94
} else {
95
95
inclusive_sum_by_key (
96
- indices.data_ptr <long >(),
96
+ indices.data_ptr <int64_t >(),
97
97
inputs.data_ptr <float >(),
98
98
outputs.data_ptr <float >(),
99
99
n_edges);
@@ -129,13 +129,13 @@ torch::Tensor exclusive_sum_cub(
129
129
#if CUB_SUPPORTS_SCAN_BY_KEY()
130
130
if (backward) {
131
131
exclusive_sum_by_key (
132
- thrust::make_reverse_iterator (indices.data_ptr <long >() + n_edges),
132
+ thrust::make_reverse_iterator (indices.data_ptr <int64_t >() + n_edges),
133
133
thrust::make_reverse_iterator (inputs.data_ptr <float >() + n_edges),
134
134
thrust::make_reverse_iterator (outputs.data_ptr <float >() + n_edges),
135
135
n_edges);
136
136
} else {
137
137
exclusive_sum_by_key (
138
- indices.data_ptr <long >(),
138
+ indices.data_ptr <int64_t >(),
139
139
inputs.data_ptr <float >(),
140
140
outputs.data_ptr <float >(),
141
141
n_edges);
@@ -169,7 +169,7 @@ torch::Tensor inclusive_prod_cub_forward(
169
169
170
170
#if CUB_SUPPORTS_SCAN_BY_KEY()
171
171
inclusive_prod_by_key (
172
- indices.data_ptr <long >(),
172
+ indices.data_ptr <int64_t >(),
173
173
inputs.data_ptr <float >(),
174
174
outputs.data_ptr <float >(),
175
175
n_edges);
@@ -203,7 +203,7 @@ torch::Tensor inclusive_prod_cub_backward(
203
203
}
204
204
#if CUB_SUPPORTS_SCAN_BY_KEY()
205
205
inclusive_sum_by_key (
206
- thrust::make_reverse_iterator (indices.data_ptr <long >() + n_edges),
206
+ thrust::make_reverse_iterator (indices.data_ptr <int64_t >() + n_edges),
207
207
thrust::make_reverse_iterator ((grad_outputs * outputs).data_ptr <float >() + n_edges),
208
208
thrust::make_reverse_iterator (grad_inputs.data_ptr <float >() + n_edges),
209
209
n_edges);
@@ -237,7 +237,7 @@ torch::Tensor exclusive_prod_cub_forward(
237
237
}
238
238
#if CUB_SUPPORTS_SCAN_BY_KEY()
239
239
exclusive_prod_by_key (
240
- indices.data_ptr <long >(),
240
+ indices.data_ptr <int64_t >(),
241
241
inputs.data_ptr <float >(),
242
242
outputs.data_ptr <float >(),
243
243
n_edges);
@@ -272,7 +272,7 @@ torch::Tensor exclusive_prod_cub_backward(
272
272
273
273
#if CUB_SUPPORTS_SCAN_BY_KEY()
274
274
exclusive_sum_by_key (
275
- thrust::make_reverse_iterator (indices.data_ptr <long >() + n_edges),
275
+ thrust::make_reverse_iterator (indices.data_ptr <int64_t >() + n_edges),
276
276
thrust::make_reverse_iterator ((grad_outputs * outputs).data_ptr <float >() + n_edges),
277
277
thrust::make_reverse_iterator (grad_inputs.data_ptr <float >() + n_edges),
278
278
n_edges);
0 commit comments