Skip to content

Commit ff81b58

Browse files
committed
Correct BFGS update with outer product
1 parent 890ff61 commit ff81b58

File tree

2 files changed

+79
-9
lines changed

2 files changed

+79
-9
lines changed

src/pyvmcon/vmcon.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,18 @@ def _powells_gamma(gamma: np.ndarray, ksi: np.ndarray, B: np.ndarray) -> np.ndar
504504
return theta * gamma + (1 - theta) * (B @ ksi) # eqn 9
505505

506506

507+
def _revise_B(current_B: np.ndarray, ksi: np.ndarray, gamma: np.ndarray) -> np.ndarray:
508+
"""Revises B using a BFGS update.
509+
510+
Implements Equation 8 of the Crane report.
511+
"""
512+
return (
513+
current_B
514+
- ((current_B @ np.outer(ksi, ksi) @ current_B.T) / (ksi.T @ current_B @ ksi))
515+
+ (np.outer(gamma, gamma) / (ksi.T @ gamma))
516+
)
517+
518+
507519
def calculate_new_B(
508520
result: Result,
509521
new_result: Result,
@@ -513,10 +525,7 @@ def calculate_new_B(
513525
lamda_equality: np.ndarray,
514526
lamda_inequality: np.ndarray,
515527
) -> np.ndarray:
516-
"""Updates the hessian approximation matrix.
517-
518-
Uses Equation 8 of the Crane report.
519-
"""
528+
"""Updates the hessian approximation matrix."""
520529
# xi (the symbol name) would be a bit confusing in this context,
521530
# ksi is how its pronounced in modern greek
522531
# reshape ksi to be a matrix
@@ -542,11 +551,7 @@ def calculate_new_B(
542551
logger.warning("All xi (ksi) components are 0")
543552
ksi[:] = 1e-10
544553

545-
# eqn 8
546-
B_ksi = B @ ksi
547-
B += (gamma @ gamma.T) / (ksi.T @ gamma) - ((B_ksi @ ksi.T @ B) / (ksi.T @ B_ksi))
548-
549-
return B
554+
return _revise_B(B, ksi, gamma)
550555

551556

552557
def _find_out_of_bounds_vars(higher: np.ndarray, lower: np.ndarray) -> list[str]:

tests/test_vmcon.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""Test individual units of the PyVMCON implementation."""
2+
3+
from dataclasses import dataclass
4+
5+
import numpy as np
6+
import pytest
7+
8+
from pyvmcon.vmcon import _revise_B
9+
10+
11+
@dataclass
12+
class BRevisionAsset:
13+
"""Test asset for testing B matrix revision."""
14+
15+
B: np.ndarray
16+
ksi: np.ndarray
17+
eta: np.ndarray
18+
expected_new_B: np.ndarray # noqa: N815
19+
20+
21+
@pytest.mark.parametrize(
22+
"test_asset",
23+
[
24+
BRevisionAsset(
25+
B=np.identity(2),
26+
ksi=np.array([-0.66666666666666663, -0.83333333333333348]),
27+
eta=np.array([-1.3425925925925923, -1.7129629629629632]),
28+
expected_new_B=np.array(
29+
[
30+
[1.3858727457706470, 0.50241291449459347],
31+
[0.50241291449459347, 1.6536252239598812],
32+
]
33+
),
34+
),
35+
BRevisionAsset(
36+
B=np.array(
37+
[
38+
[2.1875467668036239, 1.4714414127452644],
39+
[1.4714414127452644, 2.7501870672148332],
40+
]
41+
),
42+
ksi=np.array([-1.2385140125071988e-6, -6.1925700625482853e-7]),
43+
eta=np.array([-3.6205427119684330e-6, -3.5255433852299234e-6]),
44+
expected_new_B=np.array(
45+
[
46+
[2.1875592084316073, 1.4714730232083293],
47+
[1.4714730232083293, 2.7502368318126678],
48+
]
49+
),
50+
),
51+
],
52+
)
53+
def test_revise_B(test_asset):
54+
"""Tests the hessian update implementation.
55+
56+
Uses data from Example 1 of the NEA (Crane) to ensure PyVMCON agrees with that
57+
implementation to at least 14 decimal places.
58+
"""
59+
new_B = _revise_B(test_asset.B, test_asset.ksi, test_asset.eta)
60+
61+
# check symmetric
62+
np.testing.assert_array_almost_equal(new_B, new_B.T, decimal=14)
63+
64+
# check our revision agrees with NEA version of VMCON
65+
np.testing.assert_array_almost_equal(new_B, test_asset.expected_new_B, decimal=14)

0 commit comments

Comments
 (0)