Skip to content

Abort if a flow is removed during a step #142138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions homeassistant/config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,24 +1503,14 @@ def async_shutdown(self) -> None:
future.set_result(None)
self._discovery_event_debouncer.async_shutdown()

async def async_finish_flow(
@callback
def async_flow_removed(
self,
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
result: ConfigFlowResult,
) -> ConfigFlowResult:
"""Finish a config flow and add an entry.
This method is called when a flow step returns FlowResultType.ABORT or
FlowResultType.CREATE_ENTRY.
"""
) -> None:
"""Handle a removed config flow."""
Comment on lines +1506 to +1511
Copy link
Contributor Author

@emontnemery emontnemery Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point of this method is to unconditionally remove reauth issues when a reauth flow is removed. We could optionally also include updating of the discovery keys, but I don't think that's needed.

flow = cast(ConfigFlow, flow)

# Mark the step as done.
# We do this to avoid a circular dependency where async_finish_flow sets up a
# new entry, which needs the integration to be set up, which is waiting for
# init to be done.
self._set_pending_import_done(flow)

# Clean up issue if this is a reauth flow
if flow.context["source"] == SOURCE_REAUTH:
if (entry_id := flow.context.get("entry_id")) is not None and (
Expand All @@ -1529,6 +1519,18 @@ async def async_finish_flow(
issue_id = f"config_entry_reauth_{entry.domain}_{entry.entry_id}"
ir.async_delete_issue(self.hass, HOMEASSISTANT_DOMAIN, issue_id)

async def async_finish_flow(
self,
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
result: ConfigFlowResult,
) -> ConfigFlowResult:
"""Finish a config flow and add an entry.
This method is called when a flow step returns FlowResultType.ABORT or
FlowResultType.CREATE_ENTRY.
"""
flow = cast(ConfigFlow, flow)

if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY:
# If there's a config entry with a matching unique ID,
# update the discovery key.
Expand Down Expand Up @@ -1567,6 +1569,12 @@ async def async_finish_flow(
)
return result

# Mark the step as done.
# We do this to avoid a circular dependency where async_finish_flow sets up a
# new entry, which needs the integration to be set up, which is waiting for
# init to be done.
self._set_pending_import_done(flow)

# Avoid adding a config entry for a integration
# that only supports a single config entry, but already has an entry
if (
Expand Down
12 changes: 12 additions & 0 deletions homeassistant/data_entry_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,13 @@ async def async_create_flow(
Handler key is the domain of the component that we want to set up.
"""

@callback
def async_flow_removed(
self,
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
) -> None:
"""Handle a removed data entry flow."""

@abc.abstractmethod
async def async_finish_flow(
self,
Expand Down Expand Up @@ -457,6 +464,7 @@ def _async_remove_flow_progress(self, flow_id: str) -> None:
"""Remove a flow from in progress."""
if (flow := self._progress.pop(flow_id, None)) is None:
raise UnknownFlow
self.async_flow_removed(flow)
self._async_remove_flow_from_index(flow)
flow.async_cancel_progress_task()
try:
Expand Down Expand Up @@ -485,6 +493,10 @@ async def _async_handle_step(
description_placeholders=err.description_placeholders,
)

if flow.flow_id not in self._progress:
# The flow was removed during the step
raise UnknownFlow

# Setup the flow handler's preview if needed
if result.get("preview") is not None:
await self._async_setup_preview(flow)
Expand Down
4 changes: 1 addition & 3 deletions tests/test_config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,9 +1395,7 @@ async def test_reauth_issue_flow_aborted(
issue = await _test_reauth_issue(hass, manager, issue_registry)

manager.flow.async_abort(issue.data["flow_id"])
# This can be considered a bug, we should make sure the issue is always
# removed when the reauth flow is aborted.
assert len(issue_registry.issues) == 1
assert len(issue_registry.issues) == 0


async def _test_reauth_issue(
Expand Down
76 changes: 60 additions & 16 deletions tests/test_data_entry_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,23 @@ async def async_step_init(self, user_input=None):
assert len(manager.mock_created_entries) == 0


async def test_abort_calls_async_flow_removed(manager: MockFlowManager) -> None:
"""Test abort calling the async_flow_removed FlowManager method."""

@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
return self.async_abort(reason="reason")

manager.async_flow_removed = Mock()
await manager.async_init("test")

manager.async_flow_removed.assert_called_once()

assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 0


async def test_abort_calls_async_remove_with_exception(
manager: MockFlowManager, caplog: pytest.LogCaptureFixture
) -> None:
Expand Down Expand Up @@ -288,13 +305,7 @@ async def async_step_init(self, user_input=None):


async def test_create_aborted_flow(manager: MockFlowManager) -> None:
"""Test return create_entry from aborted flow.
Note: The entry is created even if the flow is already aborted, then the
flow raises an UnknownFlow exception. This behavior is not logical, and
we should consider changing it to not create the entry if the flow is
aborted.
"""
"""Test return create_entry from aborted flow."""

@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
Expand All @@ -308,14 +319,25 @@ async def async_step_init(self, user_input=None):
await manager.async_init("test")
assert len(manager.async_progress()) == 0

# The entry is created even if the flow is aborted
assert len(manager.mock_created_entries) == 1
# No entry should be created if the flow is aborted
assert len(manager.mock_created_entries) == 0

entry = manager.mock_created_entries[0]
assert entry["handler"] == "test"
assert entry["title"] == "Test Title"
assert entry["data"] == "Test Data"
assert entry["source"] is None

async def test_create_calls_async_flow_removed(manager: MockFlowManager) -> None:
"""Test create calling the async_flow_removed FlowManager method."""

@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
return self.async_create_entry(title="Test Title", data="Test Data")

manager.async_flow_removed = Mock()
await manager.async_init("test")

manager.async_flow_removed.assert_called_once()

assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 1


async def test_discovery_init_flow(manager: MockFlowManager) -> None:
Expand Down Expand Up @@ -930,12 +952,34 @@ async def test_configure_raises_unknown_flow_if_not_in_progress(
await manager.async_configure("wrong_flow_id")


async def test_abort_raises_unknown_flow_if_not_in_progress(
async def test_manager_abort_raises_unknown_flow_if_not_in_progress(
manager: MockFlowManager,
) -> None:
"""Test abort raises UnknownFlow if the flow is not in progress."""
with pytest.raises(data_entry_flow.UnknownFlow):
await manager.async_abort("wrong_flow_id")
manager.async_abort("wrong_flow_id")


async def test_manager_abort_calls_async_flow_removed(manager: MockFlowManager) -> None:
"""Test abort calling the async_flow_removed FlowManager method."""

@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
return self.async_show_form(step_id="init")

manager.async_flow_removed = Mock()
result = await manager.async_init("test")
assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["step_id"] == "init"

manager.async_flow_removed.assert_not_called()

manager.async_abort(result["flow_id"])
manager.async_flow_removed.assert_called_once()

assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 0


@pytest.mark.parametrize(
Expand Down