Skip to content

Commit be49d65

Browse files
committed
Fix the numerical stability issue of log det function
Updated the tests as well. The natural logarithm is used.
1 parent 8b46637 commit be49d65

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

selector/diversity.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,19 @@ def logdet(X: np.ndarray) -> float:
134134
135135
Notes
136136
-----
137-
Nakamura, T., Sakaue, S., Fujii, K., Harabuchi, Y., Maeda, S., and Iwata, S.. (2022)
138-
Selecting molecules with diverse structures and properties by maximizing
139-
submodular functions of descriptors learned with graph neural networks.
140-
Scientific Reports 12.
137+
The log-determinant function is based on the formula in [1]_. Please note that we used the
138+
natural logrithim to avoid the numerical stability issues,
139+
https://github.com/theochem/Selector/issues/229.
140+
141+
.. [1] Nakamura, T., Sakaue, S., Fujii, K., Harabuchi, Y., Maeda, S., and Iwata, S..,
142+
Selecting molecules with diverse structures and properties by maximizing
143+
submodular functions of descriptors learned with graph neural networks.
144+
Scientific Reports 12, 2022.
145+
141146
"""
142-
mid = np.dot(X, np.transpose(X))
143-
f_logdet = np.log10(np.linalg.det(mid + np.identity(len(X))))
147+
mid = np.dot(X, np.transpose(X)) + np.identity(X.shape[0])
148+
logdet_mid = np.linalg.slogdet(mid)
149+
f_logdet = logdet_mid.sign * logdet_mid.logabsdet
144150
return f_logdet
145151

146152

selector/tests/test_diversity.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,14 @@ def test_compute_diversity_invalid():
108108
def test_logdet():
109109
"""Test the log determinant function with predefined subset matrix."""
110110
sel = logdet(sample3)
111-
expected = np.log10(131)
111+
expected = np.log(131)
112112
assert_almost_equal(sel, expected)
113113

114114

115115
def test_logdet_non_square_matrix():
116116
"""Test the log determinant function with a rectangular matrix."""
117117
sel = logdet(sample4)
118-
expected = np.log10(8)
118+
expected = np.log(8)
119119
assert_almost_equal(sel, expected)
120120

121121

0 commit comments

Comments
 (0)