@@ -195,50 +195,68 @@ int linalg_syevd_workspace_query(const Tensor<xpu, 2, DType>& A,
195
195
196
196
// CPU/GPU-versions of LAPACK function "getrf". Please refer to the
197
197
// LAPACK documentation for further details.
198
- // Note that this is A = getrf(A), so A is input and output parameter.
199
198
199
+ // Note:
200
+ // - A is input and output parameter (overwritten by LU)
201
+ // - Param check_singular is only useful in cpu version. If check_singular is false,
202
+ // don't throw error when A is non-invertible matrix.
200
203
template <typename xpu, typename DType>
201
204
void linalg_getrf (const Tensor<xpu, 2 , DType>& A,
202
- const Tensor<xpu, 1 , DType>& work,
205
+ const Tensor<xpu, 1 , int >& pivot,
206
+ bool check_singular,
203
207
Stream<xpu> *s = 0 );
204
208
205
209
template <typename xpu, typename DType>
206
210
void linalg_batch_getrf (const Tensor<xpu, 3 , DType>& A,
207
- const Tensor<xpu, 1 , DType>& work,
211
+ const Tensor<xpu, 2 , int >& pivot,
212
+ bool check_singular,
208
213
Stream<xpu> *s = 0 );
209
214
210
215
// ////////////////////////////// GETRI ////////////////////////////////////////////
211
216
212
217
// CPU/GPU-versions of LAPACK function "getri". Please refer to the
213
218
// LAPACK documentation for further details.
214
- // Note that this is A = getri(A), so A is input and output parameter.
215
219
220
+ // Note:
221
+ // - pivot and LU is the output of getrf(A)
222
+ // - LU is also the output parameter (overwritten by inverse(A))
216
223
template <typename xpu, typename DType>
217
- void linalg_getri (const Tensor<xpu, 2 , DType>& A,
224
+ void linalg_getri (const Tensor<xpu, 2 , DType>& LU,
225
+ const Tensor<xpu, 1 , int >& pivot, \
218
226
const Tensor<xpu, 1 , DType>& work,
219
227
Stream<xpu> *s = 0 );
220
228
229
+ // Note that this function only implements GPU version with "getriBatched" in cuBLAS.
230
+ // Unlike lapack routines in cpu, it is computed out-of-place, so the final matrix
231
+ // inverse is stored in A.
221
232
template <typename xpu, typename DType>
222
233
void linalg_batch_getri (const Tensor<xpu, 3 , DType>& A,
223
- const Tensor<xpu, 3 , DType>& B ,
224
- const Tensor<xpu, 1 , DType >& work ,
234
+ const Tensor<xpu, 3 , DType>& LU ,
235
+ const Tensor<xpu, 2 , int >& pivot ,
225
236
Stream<xpu> *s = 0 );
226
237
227
- // This function determines the amount of workspace needed for linalg_getri to operate
228
- // on a batch of matrices which is returned as number of elements of type DType.
229
- template <typename xpu, typename DType>
230
- int linalg_getri_workspace_query (const Tensor<xpu, 3 , DType>& A,
231
- Stream<xpu> *s = 0 );
232
-
233
238
// ////////////////////////////// INVERSE ////////////////////////////////////////////
234
239
235
- // CPU/GPU-versions of matrix inversion combining LAPACK function "getrf" and "getri"
240
+ // CPU/GPU-versions of matrix inverse combining LAPACK function "getrf" and "getri"
236
241
// Note that A = inverse(B)
237
242
template <typename xpu, typename DType>
238
243
void linalg_batch_inverse (const Tensor<xpu, 3 , DType>& A,
239
244
const Tensor<xpu, 3 , DType>& B,
240
- const Tensor<xpu, 1 , DType>& work,
241
- Stream<xpu> *s = 0 );
245
+ const mxnet::OpContext& ctx);
246
+
247
+ // ////////////////////////////// DET ////////////////////////////////////////////
248
+
249
+ // CPU/GPU-versions of helper functions used in matrix determinant operators
250
+
251
+ // Helper function in determinant backward computation: compute matrix inverse
252
+ // from LU and pivot using temp workspace, the result is stored back to LU
253
+ template <typename xpu, typename DType>
254
+ void linalg_batch_det_backward_helper (const Tensor<xpu, 3 , DType>& LU,
255
+ const Tensor<xpu, 2 , int >& pivot,
256
+ const Tensor<xpu, 1 , DType>& det,
257
+ const Tensor<xpu, 3 , DType>& temp,
258
+ const DType zero_det,
259
+ const mxnet::OpContext& ctx);
242
260
243
261
#include " linalg_impl.h"
244
262
0 commit comments