Skip to content

Add timeout parameter. #621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from io import StringIO
from multiprocessing import cpu_count
from pathlib import Path
import threading
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union

import ujson as json
Expand Down Expand Up @@ -593,6 +594,7 @@ def optimize(
show_console: bool = False,
refresh: Optional[int] = None,
time_fmt: str = "%Y%m%d%H%M%S",
timeout: Optional[float] = None,
) -> CmdStanMLE:
"""
Run the specified CmdStan optimize algorithm to produce a
Expand Down Expand Up @@ -692,6 +694,8 @@ def optimize(
:meth:`~datetime.datetime.strftime` to decide the file names for
output CSVs. Defaults to "%Y%m%d%H%M%S"

:param timeout: Duration at which optimization times out in seconds.

:return: CmdStanMLE object
"""
optimize_args = OptimizeArgs(
Expand Down Expand Up @@ -723,7 +727,8 @@ def optimize(
)
dummy_chain_id = 0
runset = RunSet(args=args, chains=1, time_fmt=time_fmt)
self._run_cmdstan(runset, dummy_chain_id, show_console=show_console)
self._run_cmdstan(runset, dummy_chain_id, show_console=show_console,
timeout=timeout)

if not runset._check_retcodes():
msg = 'Error during optimization: {}'.format(runset.get_err_msgs())
Expand Down Expand Up @@ -767,6 +772,7 @@ def sample(
show_console: bool = False,
refresh: Optional[int] = None,
time_fmt: str = "%Y%m%d%H%M%S",
timeout: Optional[float] = None,
*,
force_one_process_per_chain: Optional[bool] = None,
) -> CmdStanMCMC:
Expand Down Expand Up @@ -964,6 +970,8 @@ def sample(
model was compiled with STAN_THREADS=True, and utilize the
parallel chain functionality if those conditions are met.

:param timeout: Duration at which sampling times out in seconds.

:return: CmdStanMCMC object
"""
if fixed_param is None:
Expand Down Expand Up @@ -1139,6 +1147,7 @@ def sample(
show_progress=show_progress,
show_console=show_console,
progress_hook=progress_hook,
timeout=timeout,
)
if show_progress and progress_hook is not None:
progress_hook("Done", -1) # -1 == all chains finished
Expand Down Expand Up @@ -1209,6 +1218,7 @@ def generate_quantities(
show_console: bool = False,
refresh: Optional[int] = None,
time_fmt: str = "%Y%m%d%H%M%S",
timeout: Optional[float] = None,
) -> CmdStanGQ:
"""
Run CmdStan's generate_quantities method which runs the generated
Expand Down Expand Up @@ -1267,6 +1277,8 @@ def generate_quantities(
:meth:`~datetime.datetime.strftime` to decide the file names for
output CSVs. Defaults to "%Y%m%d%H%M%S"

:param timeout: Duration at which generation times out in seconds.

:return: CmdStanGQ object
"""
if isinstance(mcmc_sample, CmdStanMCMC):
Expand Down Expand Up @@ -1329,6 +1341,7 @@ def generate_quantities(
runset,
i,
show_console=show_console,
timeout=timeout,
)

errors = runset.get_err_msgs()
Expand Down Expand Up @@ -1366,6 +1379,7 @@ def variational(
show_console: bool = False,
refresh: Optional[int] = None,
time_fmt: str = "%Y%m%d%H%M%S",
timeout: Optional[float] = None,
) -> CmdStanVB:
"""
Run CmdStan's variational inference algorithm to approximate
Expand Down Expand Up @@ -1458,6 +1472,9 @@ def variational(
:meth:`~datetime.datetime.strftime` to decide the file names for
output CSVs. Defaults to "%Y%m%d%H%M%S"

:param timeout: Duration at which variational Bayesian inference times
out in seconds.

:return: CmdStanVB object
"""
variational_args = VariationalArgs(
Expand Down Expand Up @@ -1491,7 +1508,8 @@ def variational(

dummy_chain_id = 0
runset = RunSet(args=args, chains=1, time_fmt=time_fmt)
self._run_cmdstan(runset, dummy_chain_id, show_console=show_console)
self._run_cmdstan(runset, dummy_chain_id, show_console=show_console,
timeout=timeout)

# treat failure to converge as failure
transcript_file = runset.stdout_files[dummy_chain_id]
Expand Down Expand Up @@ -1527,10 +1545,8 @@ def variational(
'current value is {}.'.format(grad_samples)
)
else:
msg = (
'Variational algorithm failed.\n '
'Console output:\n{}'.format(contents)
)
msg = 'Error during variational inference: {}'.format(
runset.get_err_msgs())
raise RuntimeError(msg)
# pylint: disable=invalid-name
vb = CmdStanVB(runset)
Expand All @@ -1543,6 +1559,7 @@ def _run_cmdstan(
show_progress: bool = False,
show_console: bool = False,
progress_hook: Optional[Callable[[str, int], None]] = None,
timeout: Optional[float] = None,
) -> None:
"""
Helper function which encapsulates call to CmdStan.
Expand Down Expand Up @@ -1579,6 +1596,12 @@ def _run_cmdstan(
env=os.environ,
universal_newlines=True,
)
if timeout:
timer = threading.Timer(timeout, proc.terminate)
timer.setDaemon(True)
timer.start()
else:
timer = None
while proc.poll() is None:
if proc.stdout is not None:
line = proc.stdout.readline()
Expand All @@ -1591,6 +1614,8 @@ def _run_cmdstan(

stdout, _ = proc.communicate()
retcode = proc.returncode
if timer and retcode == -15:
retcode = 60
runset._set_retcode(idx, retcode)

if stdout:
Expand Down
2 changes: 2 additions & 0 deletions cmdstanpy/stanfit/runset.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def get_err_msgs(self) -> str:
"""Checks console messages for each CmdStan run."""
msgs = []
for i in range(self._num_procs):
if self._retcodes[i] == 60:
msgs.append("processing timed out")
if (
os.path.exists(self._stdout_files[i])
and os.stat(self._stdout_files[i]).st_size > 0
Expand Down
21 changes: 21 additions & 0 deletions test/data/timeout.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
data {
// Indicator for endless looping.
int loop;
}

transformed data {
// Maybe loop forever so the model times out.
real y = 1;
while(loop && y) {
y += 1;
}
}

parameters {
real x;
}

model {
// A nice model so we can get a fit for the `generated_quantities` call.
x ~ normal(0, 1);
}
10 changes: 10 additions & 0 deletions test/test_generate_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,16 @@ def test_attrs(self):
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
dummy = fit.c

def test_timeout(self):
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
timeout_model = CmdStanModel(stan_file=stan)
fit = timeout_model.sample(data={'loop': 0}, chains=1, iter_sampling=10)
self.assertRaisesRegex(
RuntimeError, 'processing timed out',
timeout_model.generate_quantities, timeout=0.1,
mcmc_sample=fit, data={'loop': 1},
)


if __name__ == '__main__':
unittest.main()
8 changes: 8 additions & 0 deletions test/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,14 @@ def test_attrs(self):
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
dummy = fit.c

def test_timeout(self):
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
timeout_model = CmdStanModel(stan_file=stan)
self.assertRaisesRegex(
RuntimeError, 'processing timed out', timeout_model.optimize,
data={'loop': 1}, timeout=0.1,
)


if __name__ == '__main__':
unittest.main()
8 changes: 8 additions & 0 deletions test/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,6 +1913,14 @@ def test_diagnostics(self):
self.assertEqual(fit.max_treedepths, None)
self.assertEqual(fit.divergences, None)

def test_timeout(self):
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
timeout_model = CmdStanModel(stan_file=stan)
self.assertRaisesRegex(
RuntimeError, 'processing timed out', timeout_model.sample,
timeout=0.1, chains=1, data={'loop': 1},
)


if __name__ == '__main__':
unittest.main()
8 changes: 8 additions & 0 deletions test/test_variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,14 @@ def test_attrs(self):
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
dummy = fit.c

def test_timeout(self):
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
timeout_model = CmdStanModel(stan_file=stan)
self.assertRaisesRegex(
RuntimeError, 'processing timed out', timeout_model.variational,
timeout=0.1, data={'loop': 1}, show_console=True,
)


if __name__ == '__main__':
unittest.main()