diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index b47815c9aa93b..705cc01061bb5 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -1503,24 +1503,14 @@ def async_shutdown(self) -> None: future.set_result(None) self._discovery_event_debouncer.async_shutdown() - async def async_finish_flow( + @callback + def async_flow_removed( self, flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult], - result: ConfigFlowResult, - ) -> ConfigFlowResult: - """Finish a config flow and add an entry. - - This method is called when a flow step returns FlowResultType.ABORT or - FlowResultType.CREATE_ENTRY. - """ + ) -> None: + """Handle a removed config flow.""" flow = cast(ConfigFlow, flow) - # Mark the step as done. - # We do this to avoid a circular dependency where async_finish_flow sets up a - # new entry, which needs the integration to be set up, which is waiting for - # init to be done. - self._set_pending_import_done(flow) - # Clean up issue if this is a reauth flow if flow.context["source"] == SOURCE_REAUTH: if (entry_id := flow.context.get("entry_id")) is not None and ( @@ -1529,6 +1519,18 @@ async def async_finish_flow( issue_id = f"config_entry_reauth_{entry.domain}_{entry.entry_id}" ir.async_delete_issue(self.hass, HOMEASSISTANT_DOMAIN, issue_id) + async def async_finish_flow( + self, + flow: data_entry_flow.FlowHandler[ConfigFlowContext, ConfigFlowResult], + result: ConfigFlowResult, + ) -> ConfigFlowResult: + """Finish a config flow and add an entry. + + This method is called when a flow step returns FlowResultType.ABORT or + FlowResultType.CREATE_ENTRY. + """ + flow = cast(ConfigFlow, flow) + if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY: # If there's a config entry with a matching unique ID, # update the discovery key. @@ -1567,6 +1569,12 @@ async def async_finish_flow( ) return result + # Mark the step as done. + # We do this to avoid a circular dependency where async_finish_flow sets up a + # new entry, which needs the integration to be set up, which is waiting for + # init to be done. + self._set_pending_import_done(flow) + # Avoid adding a config entry for a integration # that only supports a single config entry, but already has an entry if ( diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 511bab25a7f2d..6a288380cd063 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -207,6 +207,13 @@ async def async_create_flow( Handler key is the domain of the component that we want to set up. """ + @callback + def async_flow_removed( + self, + flow: FlowHandler[_FlowContextT, _FlowResultT, _HandlerT], + ) -> None: + """Handle a removed data entry flow.""" + @abc.abstractmethod async def async_finish_flow( self, @@ -457,6 +464,7 @@ def _async_remove_flow_progress(self, flow_id: str) -> None: """Remove a flow from in progress.""" if (flow := self._progress.pop(flow_id, None)) is None: raise UnknownFlow + self.async_flow_removed(flow) self._async_remove_flow_from_index(flow) flow.async_cancel_progress_task() try: @@ -485,6 +493,10 @@ async def _async_handle_step( description_placeholders=err.description_placeholders, ) + if flow.flow_id not in self._progress: + # The flow was removed during the step + raise UnknownFlow + # Setup the flow handler's preview if needed if result.get("preview") is not None: await self._async_setup_preview(flow) diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index 8f1591cec3bd5..5c2e2aea21509 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -1395,9 +1395,7 @@ async def test_reauth_issue_flow_aborted( issue = await _test_reauth_issue(hass, manager, issue_registry) manager.flow.async_abort(issue.data["flow_id"]) - # This can be considered a bug, we should make sure the issue is always - # removed when the reauth flow is aborted. - assert len(issue_registry.issues) == 1 + assert len(issue_registry.issues) == 0 async def _test_reauth_issue( diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index 994d37dcd65f3..bcc40251bad31 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -243,6 +243,23 @@ async def async_step_init(self, user_input=None): assert len(manager.mock_created_entries) == 0 +async def test_abort_calls_async_flow_removed(manager: MockFlowManager) -> None: + """Test abort calling the async_flow_removed FlowManager method.""" + + @manager.mock_reg_handler("test") + class TestFlow(data_entry_flow.FlowHandler): + async def async_step_init(self, user_input=None): + return self.async_abort(reason="reason") + + manager.async_flow_removed = Mock() + await manager.async_init("test") + + manager.async_flow_removed.assert_called_once() + + assert len(manager.async_progress()) == 0 + assert len(manager.mock_created_entries) == 0 + + async def test_abort_calls_async_remove_with_exception( manager: MockFlowManager, caplog: pytest.LogCaptureFixture ) -> None: @@ -288,13 +305,7 @@ async def async_step_init(self, user_input=None): async def test_create_aborted_flow(manager: MockFlowManager) -> None: - """Test return create_entry from aborted flow. - - Note: The entry is created even if the flow is already aborted, then the - flow raises an UnknownFlow exception. This behavior is not logical, and - we should consider changing it to not create the entry if the flow is - aborted. - """ + """Test return create_entry from aborted flow.""" @manager.mock_reg_handler("test") class TestFlow(data_entry_flow.FlowHandler): @@ -308,14 +319,25 @@ async def async_step_init(self, user_input=None): await manager.async_init("test") assert len(manager.async_progress()) == 0 - # The entry is created even if the flow is aborted - assert len(manager.mock_created_entries) == 1 + # No entry should be created if the flow is aborted + assert len(manager.mock_created_entries) == 0 - entry = manager.mock_created_entries[0] - assert entry["handler"] == "test" - assert entry["title"] == "Test Title" - assert entry["data"] == "Test Data" - assert entry["source"] is None + +async def test_create_calls_async_flow_removed(manager: MockFlowManager) -> None: + """Test create calling the async_flow_removed FlowManager method.""" + + @manager.mock_reg_handler("test") + class TestFlow(data_entry_flow.FlowHandler): + async def async_step_init(self, user_input=None): + return self.async_create_entry(title="Test Title", data="Test Data") + + manager.async_flow_removed = Mock() + await manager.async_init("test") + + manager.async_flow_removed.assert_called_once() + + assert len(manager.async_progress()) == 0 + assert len(manager.mock_created_entries) == 1 async def test_discovery_init_flow(manager: MockFlowManager) -> None: @@ -930,12 +952,34 @@ async def test_configure_raises_unknown_flow_if_not_in_progress( await manager.async_configure("wrong_flow_id") -async def test_abort_raises_unknown_flow_if_not_in_progress( +async def test_manager_abort_raises_unknown_flow_if_not_in_progress( manager: MockFlowManager, ) -> None: """Test abort raises UnknownFlow if the flow is not in progress.""" with pytest.raises(data_entry_flow.UnknownFlow): - await manager.async_abort("wrong_flow_id") + manager.async_abort("wrong_flow_id") + + +async def test_manager_abort_calls_async_flow_removed(manager: MockFlowManager) -> None: + """Test abort calling the async_flow_removed FlowManager method.""" + + @manager.mock_reg_handler("test") + class TestFlow(data_entry_flow.FlowHandler): + async def async_step_init(self, user_input=None): + return self.async_show_form(step_id="init") + + manager.async_flow_removed = Mock() + result = await manager.async_init("test") + assert result["type"] == data_entry_flow.FlowResultType.FORM + assert result["step_id"] == "init" + + manager.async_flow_removed.assert_not_called() + + manager.async_abort(result["flow_id"]) + manager.async_flow_removed.assert_called_once() + + assert len(manager.async_progress()) == 0 + assert len(manager.mock_created_entries) == 0 @pytest.mark.parametrize(