Skip to content

Commit c4c6bb7

Browse files
Raise errors when src_info subprocess fails. (#638)
* Raise errors when `src_info` subprocess fails. * Log errors in constructor if not compiling. * Explicitly set `check=False` for source info subprocess. * Only catch `ValueError` when calling `src_info`.
1 parent 2068ca5 commit c4c6bb7

File tree

2 files changed

+57
-28
lines changed

2 files changed

+57
-28
lines changed

cmdstanpy/model.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,14 @@ def __init__(
169169
if not cmdstan_version_before(
170170
2, 27
171171
): # unknown end of version range
172-
model_info = self.src_info()
173-
if 'parameters' in model_info:
174-
self._fixed_param |= len(model_info['parameters']) == 0
172+
try:
173+
model_info = self.src_info()
174+
if 'parameters' in model_info:
175+
self._fixed_param |= len(model_info['parameters']) == 0
176+
except ValueError as e:
177+
if compile:
178+
raise
179+
get_logger().debug(e)
175180

176181
if exe_file is not None:
177182
self._exe_file = os.path.realpath(os.path.expanduser(exe_file))
@@ -269,32 +274,22 @@ def src_info(self) -> Dict[str, Any]:
269274
If stanc is older than 2.27 or if the stan
270275
file cannot be found, returns an empty dictionary.
271276
"""
272-
result: Dict[str, Any] = {}
273-
if self.stan_file is None:
274-
return result
275-
try:
276-
cmd = (
277-
[os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)]
278-
# handle include-paths, allow-undefined etc
279-
+ self._compiler_options.compose_stanc()
280-
+ [
281-
'--info',
282-
str(self.stan_file),
283-
]
284-
)
285-
proc = subprocess.run(
286-
cmd, capture_output=True, text=True, check=True
277+
if self.stan_file is None or cmdstan_version_before(2, 27):
278+
return {}
279+
cmd = (
280+
[os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)]
281+
# handle include-paths, allow-undefined etc
282+
+ self._compiler_options.compose_stanc()
283+
+ ['--info', str(self.stan_file)]
284+
)
285+
proc = subprocess.run(cmd, capture_output=True, text=True, check=False)
286+
if proc.returncode:
287+
raise ValueError(
288+
f"Failed to get source info for Stan model "
289+
f"'{self._stan_file}'. Console:\n{proc.stderr}"
287290
)
288-
result = json.loads(proc.stdout)
289-
return result
290-
except (
291-
ValueError,
292-
RuntimeError,
293-
OSError,
294-
subprocess.CalledProcessError,
295-
) as e:
296-
get_logger().debug(e)
297-
return result
291+
result: Dict[str, Any] = json.loads(proc.stdout)
292+
return result
298293

299294
def format(
300295
self,

test/test_model.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,33 @@ def test_model_info(self):
207207
self.assertIn('theta', model_info_include['parameters'])
208208
self.assertIn('included_files', model_info_include)
209209

210+
def test_compile_with_bad_includes(self):
211+
# Ensure compilation fails if we break an included file.
212+
stan_file = os.path.join(DATAFILES_PATH, "add_one_model.stan")
213+
exe_file = os.path.splitext(stan_file)[0] + EXTENSION
214+
if os.path.isfile(exe_file):
215+
os.unlink(exe_file)
216+
with tempfile.TemporaryDirectory() as include_path:
217+
include_source = os.path.join(
218+
DATAFILES_PATH, "include-path", "add_one_function.stan"
219+
)
220+
include_target = os.path.join(include_path, "add_one_function.stan")
221+
shutil.copy(include_source, include_target)
222+
model = CmdStanModel(
223+
stan_file=stan_file,
224+
compile=False,
225+
stanc_options={"include-paths": [include_path]},
226+
)
227+
with LogCapture(level=logging.INFO) as log:
228+
model.compile()
229+
log.check_present(
230+
('cmdstanpy', 'INFO', StringComparison('compiling stan file'))
231+
)
232+
with open(include_target, "w") as fd:
233+
fd.write("gobbledygook")
234+
with pytest.raises(ValueError, match="Failed to get source info"):
235+
model.compile()
236+
210237
def test_compile_with_includes(self):
211238
getmtime = os.path.getmtime
212239
configs = [
@@ -215,6 +242,9 @@ def test_compile_with_includes(self):
215242
]
216243
for stan_file, include_paths in configs:
217244
stan_file = os.path.join(DATAFILES_PATH, stan_file)
245+
exe_file = os.path.splitext(stan_file)[0] + EXTENSION
246+
if os.path.isfile(exe_file):
247+
os.unlink(exe_file)
218248
include_paths = [
219249
os.path.join(DATAFILES_PATH, path) for path in include_paths
220250
]
@@ -348,6 +378,10 @@ def test_model_syntax_error(self):
348378
with self.assertRaisesRegex(ValueError, r'.*Syntax error.*'):
349379
CmdStanModel(stan_file=stan)
350380

381+
def test_model_syntax_error_without_compile(self):
382+
stan = os.path.join(DATAFILES_PATH, 'bad_syntax.stan')
383+
CmdStanModel(stan_file=stan, compile=False)
384+
351385
def test_repr(self):
352386
model = CmdStanModel(stan_file=BERN_STAN)
353387
model_repr = repr(model)

0 commit comments

Comments
 (0)