Skip to content

Abort reauth flows on config entry reload #140931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 10, 2025
Merged
6 changes: 2 additions & 4 deletions homeassistant/components/tplink/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ async def async_step_pick_device(
)

async def _async_reload_requires_auth_entries(self) -> None:
"""Reload any in progress config flow that now have credentials."""
"""Reload all config entries after auth update."""
_config_entries = self.hass.config_entries

if self.source == SOURCE_REAUTH:
Expand All @@ -579,11 +579,9 @@ async def _async_reload_requires_auth_entries(self) -> None:
context = flow["context"]
if context.get("source") != SOURCE_REAUTH:
continue
entry_id: str = context["entry_id"]
entry_id = context["entry_id"]
if entry := _config_entries.async_get_entry(entry_id):
await _config_entries.async_reload(entry.entry_id)
if entry.state is ConfigEntryState.LOADED:
_config_entries.flow.async_abort(flow["flow_id"])

@callback
def _async_create_or_update_entry_from_device(
Expand Down
28 changes: 17 additions & 11 deletions homeassistant/config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1513,10 +1513,9 @@ def async_flow_removed(

# Clean up issue if this is a reauth flow
if flow.context["source"] == SOURCE_REAUTH:
if (entry_id := flow.context.get("entry_id")) is not None and (
entry := self.config_entries.async_get_entry(entry_id)
) is not None:
issue_id = f"config_entry_reauth_{entry.domain}_{entry.entry_id}"
if (entry_id := flow.context.get("entry_id")) is not None:
# The config entry's domain is flow.handler
issue_id = f"config_entry_reauth_{flow.handler}_{entry_id}"
ir.async_delete_issue(self.hass, HOMEASSISTANT_DOMAIN, issue_id)

async def async_finish_flow(
Expand Down Expand Up @@ -2097,13 +2096,7 @@ def _async_clean_up(self, entry: ConfigEntry) -> None:
# If the configuration entry is removed during reauth, it should
# abort any reauth flow that is active for the removed entry and
# linked issues.
for progress_flow in self.hass.config_entries.flow.async_progress_by_handler(
entry.domain, match_context={"entry_id": entry_id, "source": SOURCE_REAUTH}
):
if "flow_id" in progress_flow:
self.hass.config_entries.flow.async_abort(progress_flow["flow_id"])
issue_id = f"config_entry_reauth_{entry.domain}_{entry.entry_id}"
ir.async_delete_issue(self.hass, HOMEASSISTANT_DOMAIN, issue_id)
_abort_reauth_flows(self.hass, entry.domain, entry_id)

self._async_dispatch(ConfigEntryChange.REMOVED, entry)

Expand Down Expand Up @@ -2235,6 +2228,9 @@ async def async_reload(self, entry_id: str) -> bool:
# attempts.
entry.async_cancel_retry_setup()

# Abort any in-progress reauth flow and linked issues
_abort_reauth_flows(self.hass, entry.domain, entry_id)

if entry.domain not in self.hass.config.components:
# If the component is not loaded, just load it as
# the config entry will be loaded as well. We need
Expand Down Expand Up @@ -3766,3 +3762,13 @@ async def _async_get_flow_handler(
return handler

raise data_entry_flow.UnknownHandler


@callback
def _abort_reauth_flows(hass: HomeAssistant, domain: str, entry_id: str) -> None:
"""Abort reauth flows for an entry."""
for progress_flow in hass.config_entries.flow.async_progress_by_handler(
domain, match_context={"entry_id": entry_id, "source": SOURCE_REAUTH}
):
if "flow_id" in progress_flow:
hass.config_entries.flow.async_abort(progress_flow["flow_id"])
9 changes: 6 additions & 3 deletions homeassistant/data_entry_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,8 +494,11 @@ async def _async_handle_step(
)

if flow.flow_id not in self._progress:
# The flow was removed during the step
raise UnknownFlow
# The flow was removed during the step, raise UnknownFlow
# unless the result is an abort
if result["type"] != FlowResultType.ABORT:
raise UnknownFlow
return result

# Setup the flow handler's preview if needed
if result.get("preview") is not None:
Expand Down Expand Up @@ -547,7 +550,7 @@ def schedule_configure(_: asyncio.Task) -> None:
flow.cur_step = result
return result

# Abort and Success results both finish the flow
# Abort and Success results both finish the flow.
self._async_remove_flow_progress(flow.flow_id)

return result
Expand Down
36 changes: 35 additions & 1 deletion tests/test_config_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ async def test_remove_entry_cancels_reauth(
manager: config_entries.ConfigEntries,
issue_registry: ir.IssueRegistry,
) -> None:
"""Tests that removing a config entry, also aborts existing reauth flows."""
"""Tests that removing a config entry also aborts existing reauth flows."""
entry = MockConfigEntry(title="test_title", domain="test")

mock_setup_entry = AsyncMock(side_effect=ConfigEntryAuthFailed())
Expand All @@ -722,6 +722,40 @@ async def test_remove_entry_cancels_reauth(
assert not issue_registry.async_get_issue(HOMEASSISTANT_DOMAIN, issue_id)


async def test_reload_entry_cancels_reauth(
hass: HomeAssistant,
manager: config_entries.ConfigEntries,
issue_registry: ir.IssueRegistry,
) -> None:
"""Tests that reloading a config entry also aborts existing reauth flows."""
entry = MockConfigEntry(title="test_title", domain="test")

mock_setup_entry = AsyncMock(side_effect=ConfigEntryAuthFailed())
mock_integration(hass, MockModule("test", async_setup_entry=mock_setup_entry))
mock_platform(hass, "test.config_flow", None)

entry.add_to_hass(hass)
await manager.async_setup(entry.entry_id)
await hass.async_block_till_done()

flows = hass.config_entries.flow.async_progress_by_handler("test")
assert len(flows) == 1
assert flows[0]["context"]["entry_id"] == entry.entry_id
assert flows[0]["context"]["source"] == config_entries.SOURCE_REAUTH
assert entry.state is config_entries.ConfigEntryState.SETUP_ERROR

issue_id = f"config_entry_reauth_test_{entry.entry_id}"
assert issue_registry.async_get_issue(HOMEASSISTANT_DOMAIN, issue_id)

mock_setup_entry.return_value = True
mock_setup_entry.side_effect = None
await manager.async_reload(entry.entry_id)

flows = hass.config_entries.flow.async_progress_by_handler("test")
assert len(flows) == 0
assert not issue_registry.async_get_issue(HOMEASSISTANT_DOMAIN, issue_id)


async def test_remove_entry_handles_callback_error(
hass: HomeAssistant, manager: config_entries.ConfigEntries
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_data_entry_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ async def async_step_init(self, user_input=None):
manager.async_abort(self.flow_id)
return self.async_abort(reason="blah")

with pytest.raises(data_entry_flow.UnknownFlow):
await manager.async_init("test")
form = await manager.async_init("test")
assert form["reason"] == "blah"
assert len(manager.async_progress()) == 0
assert len(manager.mock_created_entries) == 0

Expand Down
Loading