Skip to content

Commit c7ea287

Browse files
committed
fix(updater): infinite loop when fetch url does not depend on version
1 parent 75ae298 commit c7ea287

File tree

2 files changed

+63
-8
lines changed

2 files changed

+63
-8
lines changed

src/ops2deb/updater.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,16 @@ class BaseUpdateStrategy:
1919
def __init__(self, client: httpx.AsyncClient):
2020
self.client = client
2121

22-
async def _try_version(self, blueprint: Blueprint, version: str) -> bool:
23-
if not (url := blueprint.render_fetch_url(version=version)):
22+
async def try_version(self, blueprint: Blueprint, version: str) -> bool:
23+
# No need to waste an HTTP call when called with the current blueprint version
24+
if blueprint.version == version:
25+
return True
26+
27+
# Fetch url does not depend on blueprint version or blueprint has no fetch
28+
url = blueprint.render_fetch_url(version=version)
29+
if url == blueprint.render_fetch_url() or url is None:
2430
return False
31+
2532
logger.debug(f"{self.__class__.__name__} - {blueprint.name} - Trying {url}")
2633
try:
2734
response = await self.client.head(url)
@@ -53,7 +60,7 @@ async def _try_a_few_patches(
5360
) -> Version | None:
5461
for i in range(0, 3):
5562
version = version.bump_patch()
56-
if await self._try_version(blueprint, str(version)) is True:
63+
if await self.try_version(blueprint, str(version)) is True:
5764
return version
5865
return None
5966

@@ -64,7 +71,7 @@ async def _try_versions(
6471
version_part: str,
6572
) -> Version:
6673
bumped_version = getattr(version, f"bump_{version_part}")()
67-
if await self._try_version(blueprint, str(bumped_version)) is False:
74+
if await self.try_version(blueprint, str(bumped_version)) is False:
6875
if version_part != "patch":
6976
if (
7077
result := await self._try_a_few_patches(blueprint, bumped_version)
@@ -145,7 +152,7 @@ async def __call__(self, blueprint: Blueprint) -> str:
145152
version = tag_name if not tag_name.startswith("v") else tag_name[1:]
146153
if Version.isvalid(version) and Version.isvalid(blueprint.version):
147154
version = str(max(Version.parse(version), Version.parse(blueprint.version)))
148-
if await self._try_version(blueprint, version) is False:
155+
if await self.try_version(blueprint, version) is False:
149156
raise Ops2debUpdaterError(
150157
f"Failed to determine latest release URL (latest tag is {tag_name})"
151158
)

tests/test_updater.py

+51-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
from typing import Optional
22

3+
import httpx
34
import pytest
45
from fastapi import FastAPI, HTTPException
56
from httpx import AsyncClient
67
from starlette.responses import JSONResponse
78

89
from ops2deb.exceptions import Ops2debUpdaterError
910
from ops2deb.logger import enable_debug
10-
from ops2deb.updater import GenericUpdateStrategy, GithubUpdateStrategy
11+
from ops2deb.updater import (
12+
BaseUpdateStrategy,
13+
GenericUpdateStrategy,
14+
GithubUpdateStrategy,
15+
)
1116

1217
enable_debug(True)
1318

@@ -47,6 +52,23 @@ def github_release_api():
4752
return _github_app_factory
4853

4954

55+
async def test_try_version__returns_false_when_fetch_url_does_not_depend_on_the_blueprint_version( # noqa: E501
56+
blueprint_factory,
57+
):
58+
# Given
59+
blueprint = blueprint_factory(
60+
version="1.0.0",
61+
fetch="http://test/releases/1.0.0/some-app.tar.gz",
62+
)
63+
update_strategy = BaseUpdateStrategy(httpx.AsyncClient())
64+
65+
# When
66+
result = await update_strategy.try_version(blueprint, "2.0.0")
67+
68+
# Then
69+
assert result is False
70+
71+
5072
@pytest.mark.parametrize(
5173
"versions,expected_result",
5274
[
@@ -60,16 +82,42 @@ def github_release_api():
6082
(["1.0.0", "1.0.1", "1.0.2", "1.1.0", "1.1.1"], "1.1.1"),
6183
],
6284
)
63-
async def test_generic_update_strategy_should_find_expected_blueprint_release(
85+
async def test_generic_update_strategy_finds_latest_release_version(
6486
blueprint_factory, app_factory, versions, expected_result
6587
):
88+
# Given
6689
blueprint = blueprint_factory(
90+
version="1.0.0",
6791
fetch="http://test/releases/{{version}}/some-app.tar.gz",
6892
)
6993
app = app_factory(versions)
94+
95+
# When
96+
async with AsyncClient(app=app) as client:
97+
update_strategy = GenericUpdateStrategy(client)
98+
latest_version = await update_strategy(blueprint)
99+
100+
# Then
101+
assert latest_version == expected_result
102+
103+
104+
async def test_generic_update_strategy_finds_latest_release_version_when_version_has_prerelease_part( # noqa: E501
105+
blueprint_factory, app_factory
106+
):
107+
# Given
108+
blueprint = blueprint_factory(
109+
version="1.0.0-pre",
110+
fetch="http://test/releases/{{version}}/some-app.tar.gz",
111+
)
112+
app = app_factory(["2.0.0"])
113+
114+
# When
70115
async with AsyncClient(app=app) as client:
71116
update_strategy = GenericUpdateStrategy(client)
72-
assert await update_strategy(blueprint) == expected_result
117+
latest_version = await update_strategy(blueprint)
118+
119+
# Then
120+
assert latest_version == "2.0.0"
73121

74122

75123
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)