Skip to content

[WIP] Change tests to pytest format. #639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@ pytest
pytest-cov
pytest-order
mypy
testfixtures
tqdm
xarray
122 changes: 55 additions & 67 deletions test/__init__.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,69 @@
"""Testing utilities for CmdStanPy."""

import contextlib
import os
import sys
import unittest
import logging
import platform
import re
from typing import List, Type
from unittest import mock
from importlib import reload
from io import StringIO
import pytest


class CustomTestCase(unittest.TestCase):
# pylint: disable=invalid-name
@contextlib.contextmanager
def assertRaisesRegexNested(self, exc, msg):
"""A version of assertRaisesRegex that checks the full traceback.
mark_windows_only = pytest.mark.skipif(
platform.system() != 'Windows', reason='only runs on windows'
)
mark_not_windows = pytest.mark.skipif(
platform.system() == 'Windows', reason='does not run on windows'
)

Useful for when an exception is raised from another and you wish to
inspect the inner exception.
"""
with self.assertRaises(exc) as ctx:
yield
exception = ctx.exception
exn_string = str(ctx.exception)
while exception.__cause__ is not None:
exception = exception.__cause__
exn_string += "\n" + str(exception)
self.assertRegex(exn_string, msg)

@contextlib.contextmanager
def without_import(self, library, module):
with unittest.mock.patch.dict('sys.modules', {library: None}):
reload(module)
yield
reload(module)
# pylint: disable=invalid-name
@contextlib.contextmanager
def raises_nested(expected_exception: Type[Exception], match: str) -> None:
"""A version of assertRaisesRegex that checks the full traceback.

# recipe modified from https://stackoverflow.com/a/36491341
@contextlib.contextmanager
def replace_stdin(self, target: str):
orig = sys.stdin
sys.stdin = StringIO(target)
Useful for when an exception is raised from another and you wish to
inspect the inner exception.
"""
with pytest.raises(expected_exception) as ctx:
yield
sys.stdin = orig

# recipe from https://stackoverflow.com/a/34333710
@contextlib.contextmanager
def modified_environ(self, *remove, **update):
"""
Temporarily updates the ``os.environ`` dictionary in-place.

The ``os.environ`` dictionary is updated in-place so that
the modification is sure to work in all situations.
exception: Exception = ctx.value
lines = []
while exception:
lines.append(str(exception))
exception = exception.__cause__
text = "\n".join(lines)
assert re.search(match, text), f"pattern `{match}` does not match `{text}`"

:param remove: Environment variables to remove.
:param update: Dictionary of environment variables and values to
add/update.
"""
env = os.environ
update = update or {}
remove = remove or []

# List of environment variables being updated or removed.
stomped = (set(update.keys()) | set(remove)) & set(env.keys())
# Environment variables and values to restore on exit.
update_after = {k: env[k] for k in stomped}
# Environment variables and values to remove on exit.
remove_after = frozenset(k for k in update if k not in env)
@contextlib.contextmanager
def without_import(library, module):
with mock.patch.dict('sys.modules', {library: None}):
reload(module)
yield
reload(module)

try:
env.update(update)
for k in remove:
env.pop(k, None)
yield
finally:
env.update(update_after)
for k in remove_after:
env.pop(k)

# pylint: disable=invalid-name
def assertPathsEqual(self, path1, path2):
"""Assert paths are equal after normalization"""
self.assertTrue(os.path.samefile(path1, path2))
def check_present(
caplog: pytest.LogCaptureFixture,
*conditions: List[tuple],
clear: bool = True,
) -> None:
"""
Check that all desired records exist.
"""
for condition in conditions:
logger, level, message = condition
if isinstance(level, str):
level = getattr(logging, level)
found = any(
logger == logger_ and level == level_ and message.match(message_)
if isinstance(message, re.Pattern)
else message == message_
for logger_, level_, message_ in caplog.record_tuples
)
if not found:
raise ValueError(f"logs did not contain the record {condition}")
if clear:
caplog.clear()
3 changes: 0 additions & 3 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
DATAFILES_PATH = os.path.join(HERE, 'data')


# after we have run all tests, use git to delete the built files in data/


@pytest.fixture(scope='session', autouse=True)
def cleanup_test_files():
"""Remove compiled models and output files after test run."""
Expand Down
3 changes: 3 additions & 0 deletions test/data/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@
*.testbak
*.bak-*
!return_one.hpp
# Ignore temporary files created as part of compilation.
*.o
*.o.tmp
Loading