Skip to content

Stan 2.33: Move IO munging to external package, refactors #681

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ on:
required: false
default: ''

# only run one copy per PR
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

jobs:
get-cmdstan-version:
# get the latest cmdstan version to use as part of the cache key
Expand All @@ -27,7 +32,8 @@ jobs:
if [[ "${{ github.event.inputs.cmdstan-version }}" != "" ]]; then
echo "version=${{ github.event.inputs.cmdstan-version }}" >> $GITHUB_OUTPUT
else
python -c 'import requests;print("version="+requests.get("https://api.github.com/repos/stan-dev/cmdstan/releases/latest").json()["tag_name"][1:])' >> $GITHUB_OUTPUT
echo "version=git:develop" >> $GITHUB_OUTPUT
# python -c 'import requests;print("version="+requests.get("https://api.github.com/repos/stan-dev/cmdstan/releases/latest").json()["tag_name"][1:])' >> $GITHUB_OUTPUT
fi
outputs:
version: ${{ steps.check-cmdstan.outputs.version }}
Expand All @@ -39,7 +45,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.7.1 - 3.7.16", "3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
steps:
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ repos:
- id: isort
# https://github.com/python/black#version-control-integration
- repo: https://github.com/psf/black
rev: 22.10.0
rev: 23.7.0
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
rev: 3.9.2
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.982
rev: v1.5.0
hooks:
- id: mypy
# Copied from setup.cfg
Expand Down
10 changes: 7 additions & 3 deletions cmdstanpy/compiler_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,12 @@ def add_include_path(self, path: str) -> None:
elif path not in self._stanc_options['include-paths']:
self._stanc_options['include-paths'].append(path)

def compose_stanc(self) -> List[str]:
def compose_stanc(self, filename_in_msg: Optional[str]) -> List[str]:
opts = []

if filename_in_msg is not None:
opts.append(f'--filename-in-msg={filename_in_msg}')

if self._stanc_options is not None and len(self._stanc_options) > 0:
for key, val in self._stanc_options.items():
if key == 'include-paths':
Expand All @@ -295,11 +299,11 @@ def compose_stanc(self) -> List[str]:
opts.append(f'--{key}')
return opts

def compose(self) -> List[str]:
def compose(self, filename_in_msg: Optional[str] = None) -> List[str]:
"""Format makefile options as list of strings."""
opts = [
'STANCFLAGS+=' + flag.replace(" ", "\\ ")
for flag in self.compose_stanc()
for flag in self.compose_stanc(filename_in_msg)
]
if self._cpp_options is not None and len(self._cpp_options) > 0:
for key, val in self._cpp_options.items():
Expand Down
6 changes: 3 additions & 3 deletions cmdstanpy/install_cxx_toolchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Linux: Not implemented
Optional command line arguments:
-v, --version : version, defaults to latest
-d, --dir : install directory, defaults to '~/.cmdstan(py)
-d, --dir : install directory, defaults to '~/.cmdstan
-s (--silent) : install with /VERYSILENT instead of /SILENT for RTools
-m --no-make : don't install mingw32-make (Windows RTools 4.0 only)
--progress : flag, when specified show progress bar for RTools download
Expand All @@ -27,7 +27,7 @@
from cmdstanpy.utils import pushd, validate_dir, wrap_url_progress_hook

EXTENSION = '.exe' if platform.system() == 'Windows' else ''
IS_64BITS = sys.maxsize > 2 ** 32
IS_64BITS = sys.maxsize > 2**32


def usage() -> None:
Expand Down Expand Up @@ -333,7 +333,7 @@ def parse_cmdline_args() -> Dict[str, Any]:
parser = argparse.ArgumentParser()
parser.add_argument('--version', '-v', help="version, defaults to latest")
parser.add_argument(
'--dir', '-d', help="install directory, defaults to '~/.cmdstan(py)"
'--dir', '-d', help="install directory, defaults to '~/.cmdstan"
)
parser.add_argument(
'--silent',
Expand Down
31 changes: 7 additions & 24 deletions cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Dict,
Iterable,
List,
Literal,
Mapping,
Optional,
TypeVar,
Expand Down Expand Up @@ -117,8 +118,7 @@ def __init__(
model_name: Optional[str] = None,
stan_file: OptionalPath = None,
exe_file: OptionalPath = None,
# TODO should be Literal['force'] not str
compile: Union[bool, str] = True,
compile: Union[bool, Literal['force']] = True,
stanc_options: Optional[Dict[str, Any]] = None,
cpp_options: Optional[Dict[str, Any]] = None,
user_header: OptionalPath = None,
Expand Down Expand Up @@ -300,7 +300,7 @@ def src_info(self) -> Dict[str, Any]:
cmd = (
[os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)]
# handle include-paths, allow-undefined etc
+ self._compiler_options.compose_stanc()
+ self._compiler_options.compose_stanc(None)
+ ['--info', str(self.stan_file)]
)
proc = subprocess.run(cmd, capture_output=True, text=True, check=False)
Expand Down Expand Up @@ -343,7 +343,7 @@ def format(
cmd = (
[os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)]
# handle include-paths, allow-undefined etc
+ self._compiler_options.compose_stanc()
+ self._compiler_options.compose_stanc(None)
+ [str(self.stan_file)]
)

Expand Down Expand Up @@ -528,7 +528,7 @@ def compile(
)
cmd = [make]
if self._compiler_options is not None:
cmd.extend(self._compiler_options.compose())
cmd.extend(self._compiler_options.compose(self._stan_file))
cmd.append(Path(exe_file).as_posix())

sout = io.StringIO()
Expand Down Expand Up @@ -996,10 +996,7 @@ def sample(
fixed_param = self._fixed_param

if chains is None:
if fixed_param:
chains = 1
else:
chains = 4
chains = 4
if chains < 1:
raise ValueError(
'Chains must be a positive integer value, found {}.'.format(
Expand Down Expand Up @@ -1090,8 +1087,7 @@ def sample(
one_process_per_chain = True
info_dict = self.exe_info()
stan_threads = info_dict.get('STAN_THREADS', 'false').lower()
# run multi-chain sampler unless algo is fixed_param or 1 chain
if fixed_param or (chains == 1):
if chains == 1:
force_one_process_per_chain = True

if (
Expand Down Expand Up @@ -1195,19 +1191,6 @@ def sample(
sampler_args.fixed_param = True
runset._args.method_args = sampler_args

# if there was an exe-file only initialization,
# this could happen, so throw a nice error
if (
sampler_args.fixed_param
and not one_process_per_chain
and chains > 1
):
raise RuntimeError(
"Cannot use single-process multichain parallelism"
" with algorithm fixed_param.\nTry setting argument"
" force_one_process_per_chain to True"
)

errors = runset.get_err_msgs()
if not runset._check_retcodes():
msg = (
Expand Down
96 changes: 43 additions & 53 deletions cmdstanpy/stanfit/gq.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
get_logger,
scan_generated_quantities_csv,
)
from cmdstanpy.utils.data_munging import extract_reshape

from .mcmc import CmdStanMCMC
from .metadata import InferenceMetadata
Expand Down Expand Up @@ -242,7 +241,9 @@ def draws(
]
drop_cols: List[int] = []
for dup in dups:
drop_cols.extend(self.previous_fit.metadata.stan_vars_cols[dup])
drop_cols.extend(
self.previous_fit._metadata.stan_vars[dup].columns()
)

start_idx, _ = self._draws_start(inc_warmup)
previous_draws = self._previous_draws(True)
Expand Down Expand Up @@ -324,18 +325,24 @@ def draws_pd(

self._assemble_generated_quantities()

gq_cols = []
mcmc_vars = []
gq_cols: List[str] = []
mcmc_vars: List[str] = []
if vars is not None:
for var in vars_list:
if var in self.metadata.stan_vars_cols:
for idx in self.metadata.stan_vars_cols[var]:
gq_cols.append(self.column_names[idx])
if var in self._metadata.stan_vars:
info = self._metadata.stan_vars[var]
gq_cols.extend(
self.column_names[info.start_idx : info.end_idx]
)
elif (
inc_sample
and var in self.previous_fit.metadata.stan_vars_cols
inc_sample and var in self.previous_fit._metadata.stan_vars
):
mcmc_vars.append(var)
info = self.previous_fit._metadata.stan_vars[var]
mcmc_vars.extend(
self.previous_fit.column_names[
info.start_idx : info.end_idx
]
)
else:
raise ValueError('Unknown variable: {}'.format(var))
else:
Expand Down Expand Up @@ -463,18 +470,18 @@ def draws_xr(
else:
vars_list = vars
for var in vars_list:
if var not in self.metadata.stan_vars_cols:
if var not in self._metadata.stan_vars:
if inc_sample and (
var in self.previous_fit.metadata.stan_vars_cols
var in self.previous_fit._metadata.stan_vars
):
mcmc_vars_list.append(var)
dup_vars.append(var)
else:
raise ValueError('Unknown variable: {}'.format(var))
else:
vars_list = list(self.metadata.stan_vars_cols.keys())
vars_list = list(self._metadata.stan_vars.keys())
if inc_sample:
for var in self.previous_fit.metadata.stan_vars_cols.keys():
for var in self.previous_fit._metadata.stan_vars.keys():
if var not in vars_list and var not in mcmc_vars_list:
mcmc_vars_list.append(var)
for var in dup_vars:
Expand All @@ -483,7 +490,7 @@ def draws_xr(
self._assemble_generated_quantities()

num_draws = self.previous_fit.num_draws_sampling
sample_config = self.previous_fit.metadata.cmdstan_config
sample_config = self.previous_fit._metadata.cmdstan_config
attrs: MutableMapping[Hashable, Any] = {
"stan_version": f"{sample_config['stan_version_major']}."
f"{sample_config['stan_version_minor']}."
Expand All @@ -504,23 +511,15 @@ def draws_xr(
for var in vars_list:
build_xarray_data(
data,
var,
self._metadata.stan_vars_dims[var],
self._metadata.stan_vars_cols[var],
0,
self._metadata.stan_vars[var],
self.draws(inc_warmup=inc_warmup),
self._metadata.stan_vars_types[var],
)
if inc_sample:
for var in mcmc_vars_list:
build_xarray_data(
data,
var,
self.previous_fit.metadata.stan_vars_dims[var],
self.previous_fit.metadata.stan_vars_cols[var],
0,
self.previous_fit._metadata.stan_vars[var],
self.previous_fit.draws(inc_warmup=inc_warmup),
self.previous_fit._metadata.stan_vars_types[var],
)

return xr.Dataset(data, coords=coordinates, attrs=attrs).transpose(
Expand All @@ -545,13 +544,13 @@ def stan_variable(
the next M are from chain 2, and the last M elements are from chain N.

* If the variable is a scalar variable, the return array has shape
( draws X chains, 1).
( draws * chains, 1).
* If the variable is a vector, the return array has shape
( draws X chains, len(vector))
( draws * chains, len(vector))
* If the variable is a matrix, the return array has shape
( draws X chains, size(dim 1) X size(dim 2) )
( draws * chains, size(dim 1), size(dim 2) )
* If the variable is an array with N dimensions, the return array
has shape ( draws X chains, size(dim 1) X ... X size(dim N))
has shape ( draws * chains, size(dim 1), ..., size(dim N))

For example, if the Stan program variable ``theta`` is a 3x3 matrix,
and the sample consists of 4 chains with 1000 post-warmup draws,
Expand All @@ -573,8 +572,8 @@ def stan_variable(
CmdStanMLE.stan_variable
CmdStanVB.stan_variable
"""
model_var_names = self.previous_fit.metadata.stan_vars_cols.keys()
gq_var_names = self.metadata.stan_vars_cols.keys()
model_var_names = self.previous_fit._metadata.stan_vars.keys()
gq_var_names = self._metadata.stan_vars.keys()
if not (var in model_var_names or var in gq_var_names):
raise ValueError(
f'Unknown variable name: {var}\n'
Expand All @@ -588,30 +587,21 @@ def stan_variable(
)
elif isinstance(self.previous_fit, CmdStanMLE):
return np.atleast_1d( # type: ignore
np.asarray(
self.previous_fit.stan_variable(
var, inc_iterations=inc_warmup
)
self.previous_fit.stan_variable(
var, inc_iterations=inc_warmup
)
)
else:
return np.atleast_1d( # type: ignore
np.asarray(self.previous_fit.stan_variable(var))
self.previous_fit.stan_variable(var)
)

# is gq variable
self._assemble_generated_quantities()
draw1, num_draws = self._draws_start(inc_warmup)
dims = (num_draws * self.chains,)
col_idxs = self._metadata.stan_vars_cols[var]

return extract_reshape(
dims=dims + self._metadata.stan_vars_dims[var],
col_idxs=col_idxs,
var_type=self._metadata.stan_vars_types[var],
start_row=draw1,
draws_in=self._draws,
)

draw1, _ = self._draws_start(inc_warmup)
draws = flatten_chains(self._draws[draw1:])
out: np.ndarray = self._metadata.stan_vars[var].extract_reshape(draws)
return out

def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
"""
Expand All @@ -630,8 +620,8 @@ def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:
CmdStanVB.stan_variables
"""
result = {}
sample_var_names = self.previous_fit.metadata.stan_vars_cols.keys()
gq_var_names = self.metadata.stan_vars_cols.keys()
sample_var_names = self.previous_fit._metadata.stan_vars.keys()
gq_var_names = self._metadata.stan_vars.keys()
for name in gq_var_names:
result[name] = self.stan_variable(name, inc_warmup)
for name in sample_var_names:
Expand Down Expand Up @@ -697,9 +687,9 @@ def _previous_draws(self, inc_warmup: bool) -> np.ndarray:
if inc_warmup and p_fit._save_iterations:
return p_fit.optimized_iterations_np[:, None] # type: ignore

return np.atleast_2d(p_fit.optimized_params_np,)[ # type: ignore
:, None
]
return np.atleast_2d( # type: ignore
p_fit.optimized_params_np,
)[:, None]
else: # CmdStanVB:
if inc_warmup:
return np.vstack(
Expand Down
Loading