Skip to content

feat(app): revised configure_torch_cuda_allocator() & testing strategy #7733

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 5, 2025
26 changes: 18 additions & 8 deletions invokeai/app/util/torch_cuda_allocator.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
import logging
import os
import sys


def configure_torch_cuda_allocator(pytorch_cuda_alloc_conf: str, logger: logging.Logger | None = None):
def configure_torch_cuda_allocator(pytorch_cuda_alloc_conf: str, logger: logging.Logger):
"""Configure the PyTorch CUDA memory allocator. See
https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf for supported
configurations.
"""

# Raise if the PYTORCH_CUDA_ALLOC_CONF environment variable is already set.
if "torch" in sys.modules:
raise RuntimeError("configure_torch_cuda_allocator() must be called before importing torch.")

# Log a warning if the PYTORCH_CUDA_ALLOC_CONF environment variable is already set.
prev_cuda_alloc_conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", None)
if prev_cuda_alloc_conf is not None:
raise RuntimeError(
f"Attempted to configure the PyTorch CUDA memory allocator, but PYTORCH_CUDA_ALLOC_CONF is already set to "
f"'{prev_cuda_alloc_conf}'."
)
if prev_cuda_alloc_conf == pytorch_cuda_alloc_conf:
logger.info(
f"PYTORCH_CUDA_ALLOC_CONF is already set to '{pytorch_cuda_alloc_conf}'. Skipping configuration."
)
return
else:
logger.warning(
f"Attempted to configure the PyTorch CUDA memory allocator with '{pytorch_cuda_alloc_conf}', but PYTORCH_CUDA_ALLOC_CONF is already set to "
f"'{prev_cuda_alloc_conf}'. Skipping configuration."
)
return

# Configure the PyTorch CUDA memory allocator.
# NOTE: It is important that this happens before torch is imported.
Expand All @@ -38,5 +49,4 @@ def configure_torch_cuda_allocator(pytorch_cuda_alloc_conf: str, logger: logging
"not imported before calling configure_torch_cuda_allocator()."
)

if logger is not None:
logger.info(f"PyTorch CUDA memory allocator: {torch.cuda.get_allocator_backend()}")
logger.info(f"PyTorch CUDA memory allocator: {torch.cuda.get_allocator_backend()}")
126 changes: 120 additions & 6 deletions tests/app/util/test_torch_cuda_allocator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,127 @@
import pytest
import torch

from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
from tests.dangerously_run_function_in_subprocess import dangerously_run_function_in_subprocess

# These tests are a bit fiddly, because the depend on the import behaviour of torch. They use subprocesses to isolate
# the import behaviour of torch, and then check that the function behaves as expected. We have to hack in some logging
# to check that the tested function is behaving as expected.


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA device.")
def test_configure_torch_cuda_allocator_configures_backend():
"""Test that configure_torch_cuda_allocator() raises a RuntimeError if the configured backend does not match the
expected backend."""

def test_func():
import os

# Unset the environment variable if it is set so that we can test setting it
try:
del os.environ["PYTORCH_CUDA_ALLOC_CONF"]
except KeyError:
pass

from unittest.mock import MagicMock

from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator

mock_logger = MagicMock()

# Set the PyTorch CUDA memory allocator to cudaMallocAsync
configure_torch_cuda_allocator("backend:cudaMallocAsync", logger=mock_logger)

# Verify that the PyTorch CUDA memory allocator was configured correctly
import torch

assert torch.cuda.get_allocator_backend() == "cudaMallocAsync"

# Verify that the logger was called with the correct message
mock_logger.info.assert_called_once()
args, _kwargs = mock_logger.info.call_args
logged_message = args[0]
print(logged_message)

stdout, _stderr, returncode = dangerously_run_function_in_subprocess(test_func)
assert returncode == 0
assert "PyTorch CUDA memory allocator: cudaMallocAsync" in stdout


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA device.")
def test_configure_torch_cuda_allocator_raises_if_torch_is_already_imported():
"""Test that configure_torch_cuda_allocator() raises a RuntimeError if torch is already imported."""
import torch # noqa: F401
def test_configure_torch_cuda_allocator_raises_if_torch_already_imported():
"""Test that configure_torch_cuda_allocator() raises a RuntimeError if torch was already imported."""

def test_func():
from unittest.mock import MagicMock

# Import torch before calling configure_torch_cuda_allocator()
import torch # noqa: F401

from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator

try:
configure_torch_cuda_allocator("backend:cudaMallocAsync", logger=MagicMock())
except RuntimeError as e:
print(e)

stdout, _stderr, returncode = dangerously_run_function_in_subprocess(test_func)
assert returncode == 0
assert "configure_torch_cuda_allocator() must be called before importing torch." in stdout


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA device.")
def test_configure_torch_cuda_allocator_warns_if_env_var_is_set_differently():
"""Test that configure_torch_cuda_allocator() logs at WARNING level if PYTORCH_CUDA_ALLOC_CONF is set and doesn't
match the requested configuration."""

def test_func():
import os

# Explicitly set the environment variable
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:native"

from unittest.mock import MagicMock

from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator

mock_logger = MagicMock()

# Set the PyTorch CUDA memory allocator a different configuration
configure_torch_cuda_allocator("backend:cudaMallocAsync", logger=mock_logger)

# Verify that the logger was called with the correct message
mock_logger.warning.assert_called_once()
args, _kwargs = mock_logger.warning.call_args
logged_message = args[0]
print(logged_message)

stdout, _stderr, returncode = dangerously_run_function_in_subprocess(test_func)
assert returncode == 0
assert "Attempted to configure the PyTorch CUDA memory allocator with 'backend:cudaMallocAsync'" in stdout


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA device.")
def test_configure_torch_cuda_allocator_logs_if_env_var_is_already_set_correctly():
"""Test that configure_torch_cuda_allocator() logs at INFO level if PYTORCH_CUDA_ALLOC_CONF is set and matches the
requested configuration."""

def test_func():
import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:native"
from unittest.mock import MagicMock

from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator

mock_logger = MagicMock()

configure_torch_cuda_allocator("backend:native", logger=mock_logger)

mock_logger.info.assert_called_once()
args, _kwargs = mock_logger.info.call_args
logged_message = args[0]
print(logged_message)

with pytest.raises(RuntimeError, match="Failed to configure the PyTorch CUDA memory allocator."):
configure_torch_cuda_allocator("backend:cudaMallocAsync")
stdout, _stderr, returncode = dangerously_run_function_in_subprocess(test_func)
assert returncode == 0
assert "PYTORCH_CUDA_ALLOC_CONF is already set to 'backend:native'" in stdout
46 changes: 46 additions & 0 deletions tests/dangerously_run_function_in_subprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import inspect
import subprocess
import sys
import textwrap
from typing import Any, Callable


def dangerously_run_function_in_subprocess(func: Callable[[], Any]) -> tuple[str, str, int]:
"""**Use with caution! This should _only_ be used with trusted code!**

Extracts a function's source and runs it in a separate subprocess. Returns stdout, stderr, and return code
from the subprocess.

This is useful for tests where an isolated environment is required.

The function to be called must not have any arguments and must not have any closures over the scope in which is was
defined.

Any modules that the function depends on must be imported inside the function.
"""

source_code = inspect.getsource(func)

# Must dedent the source code to avoid indentation errors
dedented_source_code = textwrap.dedent(source_code)

# Get the function name so we can call it in the subprocess
func_name = func.__name__

# Create a script that calls the function
script = f"""
import sys

{dedented_source_code}

if __name__ == "__main__":
{func_name}()
"""

result = subprocess.run(
[sys.executable, "-c", textwrap.dedent(script)], # Run the script in a subprocess
capture_output=True, # Capture stdout and stderr
text=True,
)

return result.stdout, result.stderr, result.returncode
57 changes: 57 additions & 0 deletions tests/test_dangerously_run_function_in_subprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from tests.dangerously_run_function_in_subprocess import dangerously_run_function_in_subprocess


def test_simple_function():
def test_func():
print("Hello, Test!")

stdout, stderr, returncode = dangerously_run_function_in_subprocess(test_func)

assert returncode == 0
assert stdout.strip() == "Hello, Test!"
assert stderr == ""


def test_function_with_error():
def test_func():
raise ValueError("This is an error")

_stdout, stderr, returncode = dangerously_run_function_in_subprocess(test_func)

assert returncode != 0 # Should fail
assert "ValueError: This is an error" in stderr


def test_function_with_imports():
def test_func():
import math

print(math.sqrt(4))

stdout, stderr, returncode = dangerously_run_function_in_subprocess(test_func)

assert returncode == 0
assert stdout.strip() == "2.0"
assert stderr == ""


def test_function_with_sys_exit():
def test_func():
import sys

sys.exit(42)

_stdout, _stderr, returncode = dangerously_run_function_in_subprocess(test_func)

assert returncode == 42 # Should return the custom exit code


def test_function_with_closure():
foo = "bar"

def test_func():
print(foo)

_stdout, _stderr, returncode = dangerously_run_function_in_subprocess(test_func)

assert returncode == 1 # Should fail because of closure