Skip to content

Abort reauth flows on config entry reload #140931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 10, 2025
Merged
6 changes: 2 additions & 4 deletions homeassistant/components/tplink/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ async def async_step_pick_device(
)

async def _async_reload_requires_auth_entries(self) -> None:
"""Reload any in progress config flow that now have credentials."""
"""Reload all config entries after auth update."""
_config_entries = self.hass.config_entries

if self.source == SOURCE_REAUTH:
Expand All @@ -579,11 +579,9 @@ async def _async_reload_requires_auth_entries(self) -> None:
context = flow["context"]
if context.get("source") != SOURCE_REAUTH:
continue
entry_id: str = context["entry_id"]
entry_id = context["entry_id"]
if entry := _config_entries.async_get_entry(entry_id):
await _config_entries.async_reload(entry.entry_id)
if entry.state is ConfigEntryState.LOADED:
_config_entries.flow.async_abort(flow["flow_id"])

@callback
def _async_create_or_update_entry_from_device(
Expand Down
33 changes: 24 additions & 9 deletions homeassistant/config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,8 +1538,7 @@ async def async_finish_flow(
if (entry_id := flow.context.get("entry_id")) is not None and (
entry := self.config_entries.async_get_entry(entry_id)
) is not None:
issue_id = f"config_entry_reauth_{entry.domain}_{entry.entry_id}"
ir.async_delete_issue(self.hass, HOMEASSISTANT_DOMAIN, issue_id)
_remove_reauth_issue(self.hass, entry.domain, entry_id)

if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY:
# If there's an ignored config entry with a matching unique ID,
Expand Down Expand Up @@ -2128,13 +2127,8 @@ def _async_clean_up(self, entry: ConfigEntry) -> None:
# If the configuration entry is removed during reauth, it should
# abort any reauth flow that is active for the removed entry and
# linked issues.
for progress_flow in self.hass.config_entries.flow.async_progress_by_handler(
entry.domain, match_context={"entry_id": entry_id, "source": SOURCE_REAUTH}
):
if "flow_id" in progress_flow:
self.hass.config_entries.flow.async_abort(progress_flow["flow_id"])
issue_id = f"config_entry_reauth_{entry.domain}_{entry.entry_id}"
ir.async_delete_issue(self.hass, HOMEASSISTANT_DOMAIN, issue_id)
_abort_reauth_flows(self.hass, entry.domain, entry_id)
_remove_reauth_issue(self.hass, entry.domain, entry_id)

self._async_dispatch(ConfigEntryChange.REMOVED, entry)

Expand Down Expand Up @@ -2266,6 +2260,10 @@ async def async_reload(self, entry_id: str) -> bool:
# attempts.
entry.async_cancel_retry_setup()

# Abort any in-progress reauth flow and linked issues
_abort_reauth_flows(self.hass, entry.domain, entry_id)
_remove_reauth_issue(self.hass, entry.domain, entry_id)

if entry.domain not in self.hass.config.components:
# If the component is not loaded, just load it as
# the config entry will be loaded as well. We need
Expand Down Expand Up @@ -3827,3 +3825,20 @@ async def _async_get_flow_handler(
return handler

raise data_entry_flow.UnknownHandler


@callback
def _abort_reauth_flows(hass: HomeAssistant, domain: str, entry_id: str) -> None:
"""Abort reauth flows for an entry."""
for progress_flow in hass.config_entries.flow.async_progress_by_handler(
domain, match_context={"entry_id": entry_id, "source": SOURCE_REAUTH}
):
if "flow_id" in progress_flow:
hass.config_entries.flow.async_abort(progress_flow["flow_id"])


@callback
def _remove_reauth_issue(hass: HomeAssistant, domain: str, entry_id: str) -> None:
"""Remove reauth issue."""
issue_id = f"config_entry_reauth_{domain}_{entry_id}"
ir.async_delete_issue(hass, HOMEASSISTANT_DOMAIN, issue_id)
9 changes: 7 additions & 2 deletions homeassistant/data_entry_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,8 +547,13 @@ def schedule_configure(_: asyncio.Task) -> None:
flow.cur_step = result
return result

# Abort and Success results both finish the flow
self._async_remove_flow_progress(flow.flow_id)
# Abort and Success results both finish the flow.
# Suppress UnknownFlow in case the flow is already aborted
try:
self._async_remove_flow_progress(flow.flow_id)
except UnknownFlow:
if result["type"] != FlowResultType.ABORT:
raise

return result

Expand Down
36 changes: 35 additions & 1 deletion tests/test_config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ async def test_remove_entry_cancels_reauth(
manager: config_entries.ConfigEntries,
issue_registry: ir.IssueRegistry,
) -> None:
"""Tests that removing a config entry, also aborts existing reauth flows."""
"""Tests that removing a config entry also aborts existing reauth flows."""
entry = MockConfigEntry(title="test_title", domain="test")

mock_setup_entry = AsyncMock(side_effect=ConfigEntryAuthFailed())
Expand All @@ -723,6 +723,40 @@ async def test_remove_entry_cancels_reauth(
assert not issue_registry.async_get_issue(HOMEASSISTANT_DOMAIN, issue_id)


async def test_reload_entry_cancels_reauth(
hass: HomeAssistant,
manager: config_entries.ConfigEntries,
issue_registry: ir.IssueRegistry,
) -> None:
"""Tests that reloading a config entry also aborts existing reauth flows."""
entry = MockConfigEntry(title="test_title", domain="test")

mock_setup_entry = AsyncMock(side_effect=ConfigEntryAuthFailed())
mock_integration(hass, MockModule("test", async_setup_entry=mock_setup_entry))
mock_platform(hass, "test.config_flow", None)

entry.add_to_hass(hass)
await manager.async_setup(entry.entry_id)
await hass.async_block_till_done()

flows = hass.config_entries.flow.async_progress_by_handler("test")
assert len(flows) == 1
assert flows[0]["context"]["entry_id"] == entry.entry_id
assert flows[0]["context"]["source"] == config_entries.SOURCE_REAUTH
assert entry.state is config_entries.ConfigEntryState.SETUP_ERROR

issue_id = f"config_entry_reauth_test_{entry.entry_id}"
assert issue_registry.async_get_issue(HOMEASSISTANT_DOMAIN, issue_id)

mock_setup_entry.return_value = True
mock_setup_entry.side_effect = None
await manager.async_reload(entry.entry_id)

flows = hass.config_entries.flow.async_progress_by_handler("test")
assert len(flows) == 0
assert not issue_registry.async_get_issue(HOMEASSISTANT_DOMAIN, issue_id)


async def test_remove_entry_handles_callback_error(
hass: HomeAssistant, manager: config_entries.ConfigEntries
) -> None:
Expand Down
40 changes: 40 additions & 0 deletions tests/test_data_entry_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,21 @@ async def async_step_init(self, user_input=None):
assert len(manager.mock_created_entries) == 0


async def test_abort_aborted_flow(manager: MockFlowManager) -> None:
"""Test return abort from aborted flow."""

@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
async def async_step_init(self, user_input=None):
manager.async_abort(self.flow_id)
return self.async_abort(reason="blah")

form = await manager.async_init("test")
assert form["reason"] == "blah"
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 0


async def test_abort_calls_async_remove(manager: MockFlowManager) -> None:
"""Test abort calling the async_remove FlowHandler method."""

Expand Down Expand Up @@ -217,6 +232,31 @@ async def async_step_init(self, user_input=None):
assert entry["source"] is None


async def test_create_aborted_flow(manager: MockFlowManager) -> None:
"""Test return create_entry from aborted flow."""

@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 5

async def async_step_init(self, user_input=None):
manager.async_abort(self.flow_id)
return self.async_create_entry(title="Test Title", data="Test Data")

with pytest.raises(data_entry_flow.UnknownFlow):
await manager.async_init("test")
assert len(manager.async_progress()) == 0

# The entry is created even if the flow is aborted
assert len(manager.mock_created_entries) == 1

entry = manager.mock_created_entries[0]
assert entry["handler"] == "test"
assert entry["title"] == "Test Title"
assert entry["data"] == "Test Data"
assert entry["source"] is None
Copy link
Contributor Author

@emontnemery emontnemery Apr 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior is quite illogical; an entry is created without checking if the flow still exists, but there's then an error.
The behavior is not changed by this PR, but I think we should look into changing it in a follow-up.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you please add a comment to this test explaining that.



async def test_discovery_init_flow(manager: MockFlowManager) -> None:
"""Test a flow initialized by discovery."""

Expand Down
Loading