Skip to content

Commit f6fc011

Browse files
arcadiaphyUbuntu
authored andcommitted
Add matrix determinant operator in linalg (apache#15007)
* add backbone * cpu forward det * refactor for gpu forward det * fix * register gpu det forward * add gpu det backward * register gpu det backward * fix * add logdet slogdet backward * stop grad for zero det * fix * fix * reduce grad transfer * fix docs * update comments * fix docs * fix lint * add test * update docs * add operator * update test * trigger CI * remove slash * update comments and docs * update det helper function * update operator check * remove logdet * add no grad when det = 0 * update comments and docs * remove remaining logdet
1 parent 90e3d3f commit f6fc011

File tree

9 files changed

+679
-137
lines changed

9 files changed

+679
-137
lines changed

docs/api/python/symbol/linalg.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ In the rest of this document, we list routines provided by the `symbol.linalg` p
6060
extracttrian
6161
maketrian
6262
inverse
63+
det
64+
slogdet
6365
```
6466

6567
## API Reference

python/mxnet/contrib/amp/lists/symbol.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,8 @@
433433
'_linalg_maketrian',
434434
'_linalg_extracttrian',
435435
'_linalg_inverse',
436+
'_linalg_det',
437+
'_linalg_slogdet',
436438
'linalg_syrk',
437439
'linalg_potrf',
438440
'linalg_potri',
@@ -446,6 +448,8 @@
446448
'linalg_maketrian',
447449
'linalg_extracttrian',
448450
'linalg_inverse',
451+
'linalg_det',
452+
'linalg_slogdet',
449453
'_NDArray',
450454
'_Native',
451455
'_contrib_count_sketch',

src/operator/linalg.h

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -195,50 +195,68 @@ int linalg_syevd_workspace_query(const Tensor<xpu, 2, DType>& A,
195195

196196
// CPU/GPU-versions of LAPACK function "getrf". Please refer to the
197197
// LAPACK documentation for further details.
198-
// Note that this is A = getrf(A), so A is input and output parameter.
199198

199+
// Note:
200+
// - A is input and output parameter (overwritten by LU)
201+
// - Param check_singular is only useful in cpu version. If check_singular is false,
202+
// don't throw error when A is non-invertible matrix.
200203
template<typename xpu, typename DType>
201204
void linalg_getrf(const Tensor<xpu, 2, DType>& A,
202-
const Tensor<xpu, 1, DType>& work,
205+
const Tensor<xpu, 1, int>& pivot,
206+
bool check_singular,
203207
Stream<xpu> *s = 0);
204208

205209
template<typename xpu, typename DType>
206210
void linalg_batch_getrf(const Tensor<xpu, 3, DType>& A,
207-
const Tensor<xpu, 1, DType>& work,
211+
const Tensor<xpu, 2, int>& pivot,
212+
bool check_singular,
208213
Stream<xpu> *s = 0);
209214

210215
//////////////////////////////// GETRI ////////////////////////////////////////////
211216

212217
// CPU/GPU-versions of LAPACK function "getri". Please refer to the
213218
// LAPACK documentation for further details.
214-
// Note that this is A = getri(A), so A is input and output parameter.
215219

220+
// Note:
221+
// - pivot and LU is the output of getrf(A)
222+
// - LU is also the output parameter (overwritten by inverse(A))
216223
template<typename xpu, typename DType>
217-
void linalg_getri(const Tensor<xpu, 2, DType>& A,
224+
void linalg_getri(const Tensor<xpu, 2, DType>& LU,
225+
const Tensor<xpu, 1, int>& pivot, \
218226
const Tensor<xpu, 1, DType>& work,
219227
Stream<xpu> *s = 0);
220228

229+
// Note that this function only implements GPU version with "getriBatched" in cuBLAS.
230+
// Unlike lapack routines in cpu, it is computed out-of-place, so the final matrix
231+
// inverse is stored in A.
221232
template<typename xpu, typename DType>
222233
void linalg_batch_getri(const Tensor<xpu, 3, DType>& A,
223-
const Tensor<xpu, 3, DType>& B,
224-
const Tensor<xpu, 1, DType>& work,
234+
const Tensor<xpu, 3, DType>& LU,
235+
const Tensor<xpu, 2, int>& pivot,
225236
Stream<xpu> *s = 0);
226237

227-
// This function determines the amount of workspace needed for linalg_getri to operate
228-
// on a batch of matrices which is returned as number of elements of type DType.
229-
template<typename xpu, typename DType>
230-
int linalg_getri_workspace_query(const Tensor<xpu, 3, DType>& A,
231-
Stream<xpu> *s = 0);
232-
233238
//////////////////////////////// INVERSE ////////////////////////////////////////////
234239

235-
// CPU/GPU-versions of matrix inversion combining LAPACK function "getrf" and "getri"
240+
// CPU/GPU-versions of matrix inverse combining LAPACK function "getrf" and "getri"
236241
// Note that A = inverse(B)
237242
template<typename xpu, typename DType>
238243
void linalg_batch_inverse(const Tensor<xpu, 3, DType>& A,
239244
const Tensor<xpu, 3, DType>& B,
240-
const Tensor<xpu, 1, DType>& work,
241-
Stream<xpu> *s = 0);
245+
const mxnet::OpContext& ctx);
246+
247+
//////////////////////////////// DET ////////////////////////////////////////////
248+
249+
// CPU/GPU-versions of helper functions used in matrix determinant operators
250+
251+
// Helper function in determinant backward computation: compute matrix inverse
252+
// from LU and pivot using temp workspace, the result is stored back to LU
253+
template<typename xpu, typename DType>
254+
void linalg_batch_det_backward_helper(const Tensor<xpu, 3, DType>& LU,
255+
const Tensor<xpu, 2, int>& pivot,
256+
const Tensor<xpu, 1, DType>& det,
257+
const Tensor<xpu, 3, DType>& temp,
258+
const DType zero_det,
259+
const mxnet::OpContext& ctx);
242260

243261
#include "linalg_impl.h"
244262

0 commit comments

Comments
 (0)