Skip to content

Add timeout parameter. #621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 52 additions & 5 deletions cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from io import StringIO
from multiprocessing import cpu_count
from pathlib import Path
import threading
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union

import ujson as json
Expand Down Expand Up @@ -593,6 +594,7 @@ def optimize(
show_console: bool = False,
refresh: Optional[int] = None,
time_fmt: str = "%Y%m%d%H%M%S",
timeout: Optional[float] = None,
) -> CmdStanMLE:
"""
Run the specified CmdStan optimize algorithm to produce a
Expand Down Expand Up @@ -692,6 +694,8 @@ def optimize(
:meth:`~datetime.datetime.strftime` to decide the file names for
output CSVs. Defaults to "%Y%m%d%H%M%S"

:param timeout: Duration at which optimization times out in seconds.

:return: CmdStanMLE object
"""
optimize_args = OptimizeArgs(
Expand Down Expand Up @@ -723,7 +727,13 @@ def optimize(
)
dummy_chain_id = 0
runset = RunSet(args=args, chains=1, time_fmt=time_fmt)
self._run_cmdstan(runset, dummy_chain_id, show_console=show_console)
self._run_cmdstan(
runset,
dummy_chain_id,
show_console=show_console,
timeout=timeout,
)
runset.raise_for_timeouts()

if not runset._check_retcodes():
msg = 'Error during optimization: {}'.format(runset.get_err_msgs())
Expand Down Expand Up @@ -767,6 +777,7 @@ def sample(
show_console: bool = False,
refresh: Optional[int] = None,
time_fmt: str = "%Y%m%d%H%M%S",
timeout: Optional[float] = None,
*,
force_one_process_per_chain: Optional[bool] = None,
) -> CmdStanMCMC:
Expand Down Expand Up @@ -964,6 +975,8 @@ def sample(
model was compiled with STAN_THREADS=True, and utilize the
parallel chain functionality if those conditions are met.

:param timeout: Duration at which sampling times out in seconds.

:return: CmdStanMCMC object
"""
if fixed_param is None:
Expand Down Expand Up @@ -1139,6 +1152,7 @@ def sample(
show_progress=show_progress,
show_console=show_console,
progress_hook=progress_hook,
timeout=timeout,
)
if show_progress and progress_hook is not None:
progress_hook("Done", -1) # -1 == all chains finished
Expand All @@ -1154,6 +1168,8 @@ def sample(
sys.stdout.write('\n')
get_logger().info('CmdStan done processing.')

runset.raise_for_timeouts()

get_logger().debug('runset\n%s', repr(runset))

# hack needed to parse CSV files if model has no params
Expand Down Expand Up @@ -1209,6 +1225,7 @@ def generate_quantities(
show_console: bool = False,
refresh: Optional[int] = None,
time_fmt: str = "%Y%m%d%H%M%S",
timeout: Optional[float] = None,
) -> CmdStanGQ:
"""
Run CmdStan's generate_quantities method which runs the generated
Expand Down Expand Up @@ -1267,6 +1284,8 @@ def generate_quantities(
:meth:`~datetime.datetime.strftime` to decide the file names for
output CSVs. Defaults to "%Y%m%d%H%M%S"

:param timeout: Duration at which generation times out in seconds.

:return: CmdStanGQ object
"""
if isinstance(mcmc_sample, CmdStanMCMC):
Expand Down Expand Up @@ -1329,8 +1348,10 @@ def generate_quantities(
runset,
i,
show_console=show_console,
timeout=timeout,
)

runset.raise_for_timeouts()
errors = runset.get_err_msgs()
if errors:
msg = (
Expand Down Expand Up @@ -1366,6 +1387,7 @@ def variational(
show_console: bool = False,
refresh: Optional[int] = None,
time_fmt: str = "%Y%m%d%H%M%S",
timeout: Optional[float] = None,
) -> CmdStanVB:
"""
Run CmdStan's variational inference algorithm to approximate
Expand Down Expand Up @@ -1458,6 +1480,9 @@ def variational(
:meth:`~datetime.datetime.strftime` to decide the file names for
output CSVs. Defaults to "%Y%m%d%H%M%S"

:param timeout: Duration at which variational Bayesian inference times
out in seconds.

:return: CmdStanVB object
"""
variational_args = VariationalArgs(
Expand Down Expand Up @@ -1491,7 +1516,13 @@ def variational(

dummy_chain_id = 0
runset = RunSet(args=args, chains=1, time_fmt=time_fmt)
self._run_cmdstan(runset, dummy_chain_id, show_console=show_console)
self._run_cmdstan(
runset,
dummy_chain_id,
show_console=show_console,
timeout=timeout,
)
runset.raise_for_timeouts()

# treat failure to converge as failure
transcript_file = runset.stdout_files[dummy_chain_id]
Expand Down Expand Up @@ -1527,9 +1558,8 @@ def variational(
'current value is {}.'.format(grad_samples)
)
else:
msg = (
'Variational algorithm failed.\n '
'Console output:\n{}'.format(contents)
msg = 'Error during variational inference: {}'.format(
runset.get_err_msgs()
)
raise RuntimeError(msg)
# pylint: disable=invalid-name
Expand All @@ -1543,6 +1573,7 @@ def _run_cmdstan(
show_progress: bool = False,
show_console: bool = False,
progress_hook: Optional[Callable[[str, int], None]] = None,
timeout: Optional[float] = None,
) -> None:
"""
Helper function which encapsulates call to CmdStan.
Expand Down Expand Up @@ -1579,6 +1610,20 @@ def _run_cmdstan(
env=os.environ,
universal_newlines=True,
)
if timeout:

def _timer_target() -> None:
# Abort if the process has already terminated.
if proc.poll() is not None:
return
proc.terminate()
runset._set_timeout_flag(idx, True)

timer = threading.Timer(timeout, _timer_target)
timer.daemon = True
timer.start()
else:
timer = None
while proc.poll() is None:
if proc.stdout is not None:
line = proc.stdout.readline()
Expand All @@ -1592,6 +1637,8 @@ def _run_cmdstan(
stdout, _ = proc.communicate()
retcode = proc.returncode
runset._set_retcode(idx, retcode)
if timer:
timer.cancel()

if stdout:
fd_out.write(stdout)
Expand Down
12 changes: 12 additions & 0 deletions cmdstanpy/stanfit/runset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
else:
self._num_procs = 1
self._retcodes = [-1 for _ in range(self._num_procs)]
self._timeout_flags = [False for _ in range(self._num_procs)]
if chain_ids is None:
chain_ids = [i + 1 for i in range(chains)]
self._chain_ids = chain_ids
Expand Down Expand Up @@ -230,6 +231,10 @@ def _set_retcode(self, idx: int, val: int) -> None:
"""Set retcode at process[idx] to val."""
self._retcodes[idx] = val

def _set_timeout_flag(self, idx: int, val: bool) -> None:
"""Set timeout_flag at process[idx] to val."""
self._timeout_flags[idx] = val

def get_err_msgs(self) -> str:
"""Checks console messages for each CmdStan run."""
msgs = []
Expand Down Expand Up @@ -294,3 +299,10 @@ def save_csvfiles(self, dir: Optional[str] = None) -> None:
raise ValueError(
'Cannot save to file: {}'.format(to_path)
) from e

def raise_for_timeouts(self) -> None:
if any(self._timeout_flags):
raise TimeoutError(
f"{sum(self._timeout_flags)} of {self.num_procs} processes "
"timed out"
)
21 changes: 21 additions & 0 deletions test/data/timeout.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
data {
// Indicator for endless looping.
int loop;
}

transformed data {
// Maybe loop forever so the model times out.
real y = 1;
while(loop && y) {
y += 1;
}
}

parameters {
real x;
}

model {
// A nice model so we can get a fit for the `generated_quantities` call.
x ~ normal(0, 1);
}
9 changes: 9 additions & 0 deletions test/test_generate_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,15 @@ def test_attrs(self):
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
dummy = fit.c

def test_timeout(self):
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
timeout_model = CmdStanModel(stan_file=stan)
fit = timeout_model.sample(data={'loop': 0}, chains=1, iter_sampling=10)
with self.assertRaises(TimeoutError):
timeout_model.generate_quantities(
timeout=0.1, mcmc_sample=fit, data={'loop': 1}
)


if __name__ == '__main__':
unittest.main()
6 changes: 6 additions & 0 deletions test/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,12 @@ def test_attrs(self):
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
dummy = fit.c

def test_timeout(self):
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
timeout_model = CmdStanModel(stan_file=stan)
with self.assertRaises(TimeoutError):
timeout_model.optimize(data={'loop': 1}, timeout=0.1)


if __name__ == '__main__':
unittest.main()
6 changes: 6 additions & 0 deletions test/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,6 +1913,12 @@ def test_diagnostics(self):
self.assertEqual(fit.max_treedepths, None)
self.assertEqual(fit.divergences, None)

def test_timeout(self):
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
timeout_model = CmdStanModel(stan_file=stan)
with self.assertRaises(TimeoutError):
timeout_model.sample(timeout=0.1, chains=1, data={'loop': 1})


if __name__ == '__main__':
unittest.main()
6 changes: 6 additions & 0 deletions test/test_variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ def test_attrs(self):
with self.assertRaisesRegex(AttributeError, 'Unknown variable name:'):
dummy = fit.c

def test_timeout(self):
stan = os.path.join(DATAFILES_PATH, 'timeout.stan')
timeout_model = CmdStanModel(stan_file=stan)
with self.assertRaises(TimeoutError):
timeout_model.variational(timeout=0.1, data={'loop': 1})


if __name__ == '__main__':
unittest.main()