@@ -207,6 +207,33 @@ def test_model_info(self):
207
207
self .assertIn ('theta' , model_info_include ['parameters' ])
208
208
self .assertIn ('included_files' , model_info_include )
209
209
210
+ def test_compile_with_bad_includes (self ):
211
+ # Ensure compilation fails if we break an included file.
212
+ stan_file = os .path .join (DATAFILES_PATH , "add_one_model.stan" )
213
+ exe_file = os .path .splitext (stan_file )[0 ] + EXTENSION
214
+ if os .path .isfile (exe_file ):
215
+ os .unlink (exe_file )
216
+ with tempfile .TemporaryDirectory () as include_path :
217
+ include_source = os .path .join (
218
+ DATAFILES_PATH , "include-path" , "add_one_function.stan"
219
+ )
220
+ include_target = os .path .join (include_path , "add_one_function.stan" )
221
+ shutil .copy (include_source , include_target )
222
+ model = CmdStanModel (
223
+ stan_file = stan_file ,
224
+ compile = False ,
225
+ stanc_options = {"include-paths" : [include_path ]},
226
+ )
227
+ with LogCapture (level = logging .INFO ) as log :
228
+ model .compile ()
229
+ log .check_present (
230
+ ('cmdstanpy' , 'INFO' , StringComparison ('compiling stan file' ))
231
+ )
232
+ with open (include_target , "w" ) as fd :
233
+ fd .write ("gobbledygook" )
234
+ with pytest .raises (ValueError , match = "Failed to get source info" ):
235
+ model .compile ()
236
+
210
237
def test_compile_with_includes (self ):
211
238
getmtime = os .path .getmtime
212
239
configs = [
@@ -215,6 +242,9 @@ def test_compile_with_includes(self):
215
242
]
216
243
for stan_file , include_paths in configs :
217
244
stan_file = os .path .join (DATAFILES_PATH , stan_file )
245
+ exe_file = os .path .splitext (stan_file )[0 ] + EXTENSION
246
+ if os .path .isfile (exe_file ):
247
+ os .unlink (exe_file )
218
248
include_paths = [
219
249
os .path .join (DATAFILES_PATH , path ) for path in include_paths
220
250
]
@@ -348,6 +378,10 @@ def test_model_syntax_error(self):
348
378
with self .assertRaisesRegex (ValueError , r'.*Syntax error.*' ):
349
379
CmdStanModel (stan_file = stan )
350
380
381
+ def test_model_syntax_error_without_compile (self ):
382
+ stan = os .path .join (DATAFILES_PATH , 'bad_syntax.stan' )
383
+ CmdStanModel (stan_file = stan , compile = False )
384
+
351
385
def test_repr (self ):
352
386
model = CmdStanModel (stan_file = BERN_STAN )
353
387
model_repr = repr (model )
0 commit comments