Skip to content

Commit 5558b13

Browse files
authored
Merge pull request #22 from zoj613/numpify
2 parents 4b70fa9 + 556833c commit 5558b13

16 files changed

+432
-501
lines changed

Makefile

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#Copyright (c) 2020, Zolisa Bleki
22
#SPDX-License-Identifier: BSD-3-Clause */
3-
.PHONY: clean pkg test wheels cythonize lib
3+
.PHONY: clean pkg test wheels cythonize lib install
44

55
NAME := htnorm
66
CC := gcc
@@ -57,6 +57,9 @@ test:
5757
cythonize:
5858
cythonize pyhtnorm/*.pyx
5959

60+
install: clean cythonize
61+
poetry install
62+
6063
sdist: cythonize
6164
poetry build -f sdist
6265

README.md

+25-25
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,19 @@ int main (void)
7777

7878
## Python API
7979

80+
### Dependencies
81+
- NumPy >= 1.17
82+
8083
A high level python interface to the library is also provided. Linux users can
8184
install it using wheels via pip (thus not needing to worry about availability of C libraries),
8285
```bash
83-
pip install pyhtnorm
86+
pip install -U pyhtnorm
8487
```
8588
Wheels are not provided for MacOS. To install via pip, one can run the following commands:
8689
```bash
8790
#set the path to LAPACK shared library
8891
export LIBS_DIR=<some directory>
89-
pip install pyhtnorm
92+
pip install -U pyhtnorm
9093
```
9194
Alternatively, one can install it from source. This requires an installation of [poetry][7] and the following shell commands:
9295

@@ -99,37 +102,34 @@ $ export PYTHONPATH=$PWD:$PYTHONPATH
99102
```
100103

101104
Below is an example of how to use htnorm in python to sample from a multivariate
102-
gaussian truncated on the hyperplane ![sumzero](https://latex.codecogs.com/svg.latex?%5Cmathbf%7B1%7D%5ET%5Cmathbf%7Bx%7D%20%3D%200) (i.e. making sure the sampled values sum to zero)
105+
gaussian truncated on the hyperplane ![sumzero](https://latex.codecogs.com/svg.latex?%5Cmathbf%7B1%7D%5ET%5Cmathbf%7Bx%7D%20%3D%200) (i.e. making sure the sampled values sum to zero). The python
106+
API is such that the code can be easily integrated into other existing libraries.
107+
Since `v1.0.0`, it supports passing a `numpy.random.Generator` instance.
108+
Thus, one can reuse the same generator without having to declare one specifically for `htnorm`.
103109

104110
```python
105-
from pyhtnorm import HTNGenerator
111+
from pyhtnorm import hyperplane_truncated_mvnorm
106112
import numpy as np
107113

108-
rng = HTNGenerator()
114+
rng = np.random.default_rng()
109115

110116
# generate example input
111-
k1 = 1000
112-
k2 = 1
113-
npy_rng = np.random.default_rng()
114-
temp = npy_rng.random((k1, k1))
115-
cov = temp @ temp.T + np.diag(npy_rng.random(k1))
117+
k1, k2 = 1000, 1
118+
temp = rng.random((k1, k1))
119+
cov = temp @ temp.T
116120
G = np.ones((k2, k1))
117121
r = np.zeros(k2)
118-
mean = npy_rng.random(k1)
122+
mean = rng.random(k1)
119123

120-
samples = rng.hyperplane_truncated_mvnorm(mean, cov, G, r)
121-
print(sum(samples)) # verify if sampled values sum to zero
124+
# passing `random_state` is optional. If the argument is not used, a fresh
125+
# random generator state is instantiated internally using system entropy.
126+
o = hyperplane_truncated_mvnorm(mean, cov, G, r, random_state=rng)
127+
print(o.sum()) # verify if sampled values sum to zero
122128
# alternatively one can pass an array to store the results in
123-
out = np.empty(k1)
124-
rng.hyperplane_truncated_mvnorm(mean, cov, G, r, out=out)
125-
print(out.sum()) # verify
129+
hyperplane_truncated_mvnorm(mean, cov, G, r, out=o)
126130
```
127131

128-
For more details about the parameters of the `HTNGenerator` and its methods,
129-
see the docstrings via python's `help` function.
130-
131-
The python API also exposes the `HTNGenerator` class as a Cython extension type
132-
that can be "cimported" in a cython script.
132+
For more information about the function's arguments, refer to its docstring.
133133

134134
A pure numpy implementation is demonstrated in this [example script][9].
135135

@@ -183,9 +183,9 @@ see the [LICENSE][6] file.
183183
[1]: https://projecteuclid.org/euclid.ba/1488337478
184184
[2]: https://www.pcg-random.org/
185185
[3]: https://en.wikipedia.org/wiki/Xoroshiro128%2B
186-
[4]: https://github.com/zoj613/htnorm/blob/main/include/htnorm.h
187-
[5]: https://github.com/zoj613/htnorm/blob/main/include/rng.h
188-
[6]: https://github.com/zoj613/htnorm/blob/main/LICENSE
186+
[4]: ./include/htnorm.h
187+
[5]: ./include/rng.h
188+
[6]: ./LICENSE
189189
[7]: https://python-poetry.org/docs/pyproject/
190190
[8]: https://www.sciencedirect.com/science/article/abs/pii/S1877584517301600
191-
[9]: https://github.com/zoj613/htnorm/blob/main/examples/numpy_implementation.py
191+
[9]: ./examples/numpy_implementation.py

build-wheels.sh

+15-10
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,32 @@ bin_arr=(
1010
/opt/python/cp36-cp36m/bin
1111
/opt/python/cp37-cp37m/bin
1212
/opt/python/cp38-cp38/bin
13+
/opt/python/cp39-cp39/bin
1314
)
15+
1416
# add python to image's path
1517
export PATH=/opt/python/cp38-cp38/bin/:$PATH
16-
# download && install poetry
17-
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python
18-
18+
# download install script
19+
curl -#sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py > get-poetry.py
20+
# install using local archive
21+
python get-poetry.py -y --file poetry-1.1.4-linux.tar.gz
1922
# install openblas
2023
yum install -y openblas-devel
2124

2225
function build_poetry_wheels
2326
{
24-
# build wheels for 3.6-3.8 with poetry
27+
# build wheels for 3.6-3.9 with poetry
2528
for BIN in "${bin_arr[@]}"; do
2629
rm -Rf build/*
30+
# install build deps
31+
"${BIN}/python" ${HOME}/.poetry/bin/poetry run pip install numpy
2732
BUILD_WHEELS=1 "${BIN}/python" ${HOME}/.poetry/bin/poetry build -f wheel
28-
done
29-
30-
# add C libraries to wheels
31-
for whl in dist/*.whl; do
32-
auditwheel repair "$whl" --plat $1
33-
rm "$whl"
33+
auditwheel repair dist/*.whl --plat $1
34+
whl="$(basename dist/*.whl)"
35+
"${BIN}/python" -m pip install wheelhouse/"$whl"
36+
# test if installed wheel imports correctly
37+
"${BIN}/python" -c "from pyhtnorm import *"
38+
rm dist/*.whl
3439
done
3540
}
3641

build.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from distutils.core import Extension
22
import os
33

4+
import numpy as np
5+
46

57
source_files = [
68
"pyhtnorm/_htnorm.c",
@@ -28,11 +30,12 @@
2830

2931
extensions = [
3032
Extension(
31-
"pyhtnorm._htnorm",
33+
"_htnorm",
3234
source_files,
33-
include_dirs=['./include'],
35+
include_dirs=[np.get_include(), './include'],
3436
library_dirs=library_dirs,
3537
libraries=libraries,
38+
define_macros=[('NPY_NO_DEPRECATED_API', 0)],
3639
extra_compile_args=['-std=c99']
3740
)
3841
]

examples/numpy_implementation.py

+38-36
Original file line numberDiff line numberDiff line change
@@ -10,40 +10,42 @@
1010
import numpy as np
1111

1212

13-
class Generator(np.random.Generator):
14-
def hyperplane_truncated_mvnorm(self, mean, cov, G, r, diag=False):
15-
if diag:
16-
cov_diag = np.diag(cov)
17-
y = mean + cov_diag * self.standard_normal(mean.shape[0])
18-
covg = cov_diag[:, None] * G.T
19-
else:
20-
y = self.multivariate_normal(mean, cov, method='cholesky')
21-
covg = cov @ G.T
22-
gcovg = G @ covg
23-
alpha = np.linalg.solve(gcovg, r - G @ y)
24-
return y + covg @ alpha
13+
def hyperplane_truncated_mvnorm(mean, cov, G, r, diag=False, random_state=None):
14+
rng = np.random.default_rng(random_state)
15+
if diag:
16+
cov_diag = np.diag(cov)
17+
y = mean + cov_diag * rng.standard_normal(mean.shape[0])
18+
covg = cov_diag[:, None] * G.T
19+
else:
20+
y = rng.multivariate_normal(mean, cov, method='cholesky')
21+
covg = cov @ G.T
22+
gcovg = G @ covg
23+
alpha = np.linalg.solve(gcovg, r - G @ y)
24+
return y + covg @ alpha
2525

26-
def structured_precision_mvnorm(self, mean, a, phi, omega, a_type=0, o_type=0):
27-
if a_type:
28-
Ainv = 1 / np.diag(a)
29-
y1 = self.standard_normal(a.shape[0]) / np.sqrt(Ainv)
30-
ainv_phi = Ainv[:, None] * phi.T
31-
else:
32-
Ainv = np.linalg.inv(a)
33-
y1 = self.multivariate_normal(
34-
np.zeros(a.shape[0]), Ainv, method='cholesky'
35-
)
36-
ainv_phi = Ainv @ phi.T
37-
if o_type:
38-
omegainv = 1 / np.diag(omega)
39-
y2 = self.standard_normal(omega.shape[0]) / np.sqrt(omegainv)
40-
alpha = np.linalg.solve(
41-
np.diag(omegainv) + phi @ ainv_phi, phi @ y1 + y2
42-
)
43-
else:
44-
omegainv = np.linalg.inv(omega)
45-
y2 = self.multivariate_normal(
46-
np.zeros(omega.shape[0]), omegainv, method='cholesky'
47-
)
48-
alpha = np.linalg.solve(omegainv + phi @ ainv_phi, phi @ y1 + y2)
49-
return mean + y1 - ainv_phi @ alpha
26+
27+
def structured_precision_mvnorm(mean, a, phi, omega, a_type=0, o_type=0, random_state=None):
28+
rng = np.random.default_rng(random_state)
29+
if a_type:
30+
Ainv = 1 / np.diag(a)
31+
y1 = rng.standard_normal(a.shape[0]) / np.sqrt(Ainv)
32+
ainv_phi = Ainv[:, None] * phi.T
33+
else:
34+
Ainv = np.linalg.inv(a)
35+
y1 = rng.multivariate_normal(
36+
np.zeros(a.shape[0]), Ainv, method='cholesky'
37+
)
38+
ainv_phi = Ainv @ phi.T
39+
if o_type:
40+
omegainv = 1 / np.diag(omega)
41+
y2 = rng.standard_normal(omega.shape[0]) / np.sqrt(omegainv)
42+
alpha = np.linalg.solve(
43+
np.diag(omegainv) + phi @ ainv_phi, phi @ y1 + y2
44+
)
45+
else:
46+
omegainv = np.linalg.inv(omega)
47+
y2 = rng.multivariate_normal(
48+
np.zeros(omega.shape[0]), omegainv, method='cholesky'
49+
)
50+
alpha = np.linalg.solve(omegainv + phi @ ainv_phi, phi @ y1 + y2)
51+
return mean + y1 - ainv_phi @ alpha

include/rng.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
* their own bitgenerator, The minimum requirement is to provide a function
1616
* that generate 64bit unsigned integers and doubles in the range (0, 1).
1717
*/
18-
typedef struct bitgen {
18+
typedef struct {
1919
// the bsse bitgenerator
2020
void* base;
2121
// a function pointer that takes `base` as an input an returns a positive intgger
22-
uint64_t (*next_int)(void* base);
22+
uint64_t (*next_uint64)(void* base);
2323
// a function pointer that takes `base` as an input an returns a double in the range (0, 1)
2424
double (*next_double)(void* base);
2525
} rng_t;

0 commit comments

Comments
 (0)