Skip to content

Commit f344314

Browse files
authored
Abort if a flow is removed during a step (#142138)
* Abort if a flow is removed during a step * Reorganize code * Only call _set_pending_import_done if an entry is created * Try a new approach * Add tests * Update tests
1 parent 7f4d178 commit f344314

File tree

4 files changed

+95
-33
lines changed

4 files changed

+95
-33
lines changed

homeassistant/config_entries.py

+22-14
Original file line numberDiff line numberDiff line change
@@ -1503,24 +1503,14 @@ def async_shutdown(self) -> None:
15031503
future.set_result(None)
15041504
self._discovery_event_debouncer.async_shutdown()
15051505

1506-
async def async_finish_flow(
1506+
@callback
1507+
def async_flow_removed(
15071508
self,
15081509
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
1509-
result: ConfigFlowResult,
1510-
) -> ConfigFlowResult:
1511-
"""Finish a config flow and add an entry.
1512-
1513-
This method is called when a flow step returns FlowResultType.ABORT or
1514-
FlowResultType.CREATE_ENTRY.
1515-
"""
1510+
) -> None:
1511+
"""Handle a removed config flow."""
15161512
flow = cast(ConfigFlow, flow)
15171513

1518-
# Mark the step as done.
1519-
# We do this to avoid a circular dependency where async_finish_flow sets up a
1520-
# new entry, which needs the integration to be set up, which is waiting for
1521-
# init to be done.
1522-
self._set_pending_import_done(flow)
1523-
15241514
# Clean up issue if this is a reauth flow
15251515
if flow.context["source"] == SOURCE_REAUTH:
15261516
if (entry_id := flow.context.get("entry_id")) is not None and (
@@ -1529,6 +1519,18 @@ async def async_finish_flow(
15291519
issue_id = f"config_entry_reauth_{entry.domain}_{entry.entry_id}"
15301520
ir.async_delete_issue(self.hass, HOMEASSISTANT_DOMAIN, issue_id)
15311521

1522+
async def async_finish_flow(
1523+
self,
1524+
flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult],
1525+
result: ConfigFlowResult,
1526+
) -> ConfigFlowResult:
1527+
"""Finish a config flow and add an entry.
1528+
1529+
This method is called when a flow step returns FlowResultType.ABORT or
1530+
FlowResultType.CREATE_ENTRY.
1531+
"""
1532+
flow = cast(ConfigFlow, flow)
1533+
15321534
if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY:
15331535
# If there's a config entry with a matching unique ID,
15341536
# update the discovery key.
@@ -1567,6 +1569,12 @@ async def async_finish_flow(
15671569
)
15681570
return result
15691571

1572+
# Mark the step as done.
1573+
# We do this to avoid a circular dependency where async_finish_flow sets up a
1574+
# new entry, which needs the integration to be set up, which is waiting for
1575+
# init to be done.
1576+
self._set_pending_import_done(flow)
1577+
15701578
# Avoid adding a config entry for a integration
15711579
# that only supports a single config entry, but already has an entry
15721580
if (

homeassistant/data_entry_flow.py

+12
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,13 @@ async def async_create_flow(
207207
Handler key is the domain of the component that we want to set up.
208208
"""
209209

210+
@callback
211+
def async_flow_removed(
212+
self,
213+
flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT],
214+
) -> None:
215+
"""Handle a removed data entry flow."""
216+
210217
@abc.abstractmethod
211218
async def async_finish_flow(
212219
self,
@@ -457,6 +464,7 @@ def _async_remove_flow_progress(self, flow_id: str) -> None:
457464
"""Remove a flow from in progress."""
458465
if (flow := self._progress.pop(flow_id, None)) is None:
459466
raise UnknownFlow
467+
self.async_flow_removed(flow)
460468
self._async_remove_flow_from_index(flow)
461469
flow.async_cancel_progress_task()
462470
try:
@@ -485,6 +493,10 @@ async def _async_handle_step(
485493
description_placeholders=err.description_placeholders,
486494
)
487495

496+
if flow.flow_id not in self._progress:
497+
# The flow was removed during the step
498+
raise UnknownFlow
499+
488500
# Setup the flow handler's preview if needed
489501
if result.get("preview") is not None:
490502
await self._async_setup_preview(flow)

tests/test_config_entries.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1395,9 +1395,7 @@ async def test_reauth_issue_flow_aborted(
13951395
issue = await _test_reauth_issue(hass, manager, issue_registry)
13961396

13971397
manager.flow.async_abort(issue.data["flow_id"])
1398-
# This can be considered a bug, we should make sure the issue is always
1399-
# removed when the reauth flow is aborted.
1400-
assert len(issue_registry.issues) == 1
1398+
assert len(issue_registry.issues) == 0
14011399

14021400

14031401
async def _test_reauth_issue(

tests/test_data_entry_flow.py

+60-16
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,23 @@ async def async_step_init(self, user_input=None):
243243
assert len(manager.mock_created_entries) == 0
244244

245245

246+
async def test_abort_calls_async_flow_removed(manager: MockFlowManager) -> None:
247+
"""Test abort calling the async_flow_removed FlowManager method."""
248+
249+
@manager.mock_reg_handler("test")
250+
class TestFlow(data_entry_flow.FlowHandler):
251+
async def async_step_init(self, user_input=None):
252+
return self.async_abort(reason="reason")
253+
254+
manager.async_flow_removed = Mock()
255+
await manager.async_init("test")
256+
257+
manager.async_flow_removed.assert_called_once()
258+
259+
assert len(manager.async_progress()) == 0
260+
assert len(manager.mock_created_entries) == 0
261+
262+
246263
async def test_abort_calls_async_remove_with_exception(
247264
manager: MockFlowManager, caplog: pytest.LogCaptureFixture
248265
) -> None:
@@ -288,13 +305,7 @@ async def async_step_init(self, user_input=None):
288305

289306

290307
async def test_create_aborted_flow(manager: MockFlowManager) -> None:
291-
"""Test return create_entry from aborted flow.
292-
293-
Note: The entry is created even if the flow is already aborted, then the
294-
flow raises an UnknownFlow exception. This behavior is not logical, and
295-
we should consider changing it to not create the entry if the flow is
296-
aborted.
297-
"""
308+
"""Test return create_entry from aborted flow."""
298309

299310
@manager.mock_reg_handler("test")
300311
class TestFlow(data_entry_flow.FlowHandler):
@@ -308,14 +319,25 @@ async def async_step_init(self, user_input=None):
308319
await manager.async_init("test")
309320
assert len(manager.async_progress()) == 0
310321

311-
# The entry is created even if the flow is aborted
312-
assert len(manager.mock_created_entries) == 1
322+
# No entry should be created if the flow is aborted
323+
assert len(manager.mock_created_entries) == 0
313324

314-
entry = manager.mock_created_entries[0]
315-
assert entry["handler"] == "test"
316-
assert entry["title"] == "Test Title"
317-
assert entry["data"] == "Test Data"
318-
assert entry["source"] is None
325+
326+
async def test_create_calls_async_flow_removed(manager: MockFlowManager) -> None:
327+
"""Test create calling the async_flow_removed FlowManager method."""
328+
329+
@manager.mock_reg_handler("test")
330+
class TestFlow(data_entry_flow.FlowHandler):
331+
async def async_step_init(self, user_input=None):
332+
return self.async_create_entry(title="Test Title", data="Test Data")
333+
334+
manager.async_flow_removed = Mock()
335+
await manager.async_init("test")
336+
337+
manager.async_flow_removed.assert_called_once()
338+
339+
assert len(manager.async_progress()) == 0
340+
assert len(manager.mock_created_entries) == 1
319341

320342

321343
async def test_discovery_init_flow(manager: MockFlowManager) -> None:
@@ -930,12 +952,34 @@ async def test_configure_raises_unknown_flow_if_not_in_progress(
930952
await manager.async_configure("wrong_flow_id")
931953

932954

933-
async def test_abort_raises_unknown_flow_if_not_in_progress(
955+
async def test_manager_abort_raises_unknown_flow_if_not_in_progress(
934956
manager: MockFlowManager,
935957
) -> None:
936958
"""Test abort raises UnknownFlow if the flow is not in progress."""
937959
with pytest.raises(data_entry_flow.UnknownFlow):
938-
await manager.async_abort("wrong_flow_id")
960+
manager.async_abort("wrong_flow_id")
961+
962+
963+
async def test_manager_abort_calls_async_flow_removed(manager: MockFlowManager) -> None:
964+
"""Test abort calling the async_flow_removed FlowManager method."""
965+
966+
@manager.mock_reg_handler("test")
967+
class TestFlow(data_entry_flow.FlowHandler):
968+
async def async_step_init(self, user_input=None):
969+
return self.async_show_form(step_id="init")
970+
971+
manager.async_flow_removed = Mock()
972+
result = await manager.async_init("test")
973+
assert result["type"] == data_entry_flow.FlowResultType.FORM
974+
assert result["step_id"] == "init"
975+
976+
manager.async_flow_removed.assert_not_called()
977+
978+
manager.async_abort(result["flow_id"])
979+
manager.async_flow_removed.assert_called_once()
980+
981+
assert len(manager.async_progress()) == 0
982+
assert len(manager.mock_created_entries) == 0
939983

940984

941985
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)