Skip to content

Commit a57a126

Browse files
committed
pick new otaproxy from refactor/otaproxy_main
1 parent b8fd418 commit a57a126

File tree

12 files changed

+1120
-739
lines changed

12 files changed

+1120
-739
lines changed

src/ota_proxy/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,11 @@ def run_otaproxy(
4545
enable_https: bool,
4646
external_cache_mnt_point: str | None = None,
4747
):
48+
import asyncio
49+
4850
import anyio
4951
import uvicorn
52+
import uvloop
5053

5154
from . import App, OTACache
5255

@@ -70,4 +73,10 @@ def run_otaproxy(
7073
http="h11",
7174
)
7275
_server = uvicorn.Server(_config)
73-
anyio.run(_server.serve, backend="asyncio", backend_options={"use_uvloop": True})
76+
77+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
78+
anyio.run(
79+
_server.serve,
80+
backend="asyncio",
81+
backend_options={"loop_factory": uvloop.new_event_loop},
82+
)

src/ota_proxy/__main__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from . import run_otaproxy
2222
from .config import config as cfg
2323

24-
logger = logging.getLogger(__name__)
24+
logger = logging.getLogger("ota_proxy")
2525

2626
if __name__ == "__main__":
2727
parser = argparse.ArgumentParser(
@@ -74,6 +74,9 @@
7474
)
7575
args = parser.parse_args()
7676

77+
# suppress logging from third-party deps
78+
logging.basicConfig(level=logging.CRITICAL)
79+
logger.setLevel(logging.INFO)
7780
logger.info(f"launch ota_proxy at {args.host}:{args.port}")
7881
run_otaproxy(
7982
host=args.host,

src/ota_proxy/cache_control_header.py

Lines changed: 94 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -12,42 +12,94 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
1516

16-
from dataclasses import dataclass, fields
17-
from typing import ClassVar, Dict, List
17+
import logging
18+
from dataclasses import dataclass
19+
from io import StringIO
20+
from typing import ClassVar, Optional, TypedDict
1821

19-
from typing_extensions import Self
22+
from typing_extensions import Unpack
2023

21-
from otaclient_common._typing import copy_callable_typehint_to_method
24+
from otaclient_common._logging import get_burst_suppressed_logger
2225

23-
_FIELDS = "_fields"
26+
logger = logging.getLogger(__name__)
27+
# NOTE: for request_error, only allow max 6 lines of logging per 30 seconds
28+
burst_suppressed_logger = get_burst_suppressed_logger(f"{__name__}.header_parse_error")
2429

30+
VALID_DIRECTORIES = set(
31+
["no_cache", "retry_caching", "file_sha256", "file_compression_alg"]
32+
)
33+
HEADER_LOWERCASE = "ota-file-cache-control"
34+
HEADER_DIR_SEPARATOR = ","
2535

26-
@dataclass
27-
class _HeaderDef:
28-
# ------ Header definition ------ #
29-
# NOTE: according to RFC7230, the header name is case-insensitive,
30-
# so for convenience during code implementation, we always use lower-case
31-
# header name.
32-
HEADER_LOWERCASE: ClassVar[str] = "ota-file-cache-control"
33-
HEADER_DIR_SEPARATOR: ClassVar[str] = ","
3436

35-
# ------ Directives definition ------ #
36-
no_cache: bool = False
37-
retry_caching: bool = False
37+
class OTAFileCacheDirTypedDict(TypedDict, total=False):
38+
no_cache: bool
39+
retry_caching: bool
3840
# added in revision 2:
39-
file_sha256: str = ""
40-
file_compression_alg: str = ""
41+
file_sha256: Optional[str]
42+
file_compression_alg: Optional[str]
43+
44+
45+
def parse_header(_input: str) -> OTAFileCacheControl:
46+
if not _input:
47+
return OTAFileCacheControl()
48+
49+
_res = OTAFileCacheControl()
50+
for c in _input.strip().split(HEADER_DIR_SEPARATOR):
51+
k, *v = c.strip().split("=", maxsplit=1)
52+
if k not in VALID_DIRECTORIES:
53+
burst_suppressed_logger.warning(f"get unknown directory, ignore: {c}")
54+
continue
55+
setattr(_res, k, v[0] if v else True)
56+
return _res
57+
58+
59+
def _parse_header_asdict(_input: str) -> OTAFileCacheDirTypedDict:
60+
if not _input:
61+
return {}
62+
63+
_res: OTAFileCacheDirTypedDict = {}
64+
for c in _input.strip().split(HEADER_DIR_SEPARATOR):
65+
k, *v = c.strip().split("=", maxsplit=1)
66+
if k not in VALID_DIRECTORIES:
67+
burst_suppressed_logger.warning(f"get unknown directory, ignore: {c}")
68+
continue
69+
_res[k] = v[0] if v else True
70+
return _res
71+
72+
73+
def export_kwargs_as_header_string(**kwargs: Unpack[OTAFileCacheDirTypedDict]) -> str:
74+
"""Directly export header str from a list of directive pairs."""
75+
if not kwargs:
76+
return ""
77+
78+
with StringIO() as buffer:
79+
for k, v in kwargs.items():
80+
if k not in VALID_DIRECTORIES:
81+
burst_suppressed_logger.warning(f"get unknown directory, ignore: {k}")
82+
continue
83+
if not v:
84+
continue
85+
86+
buffer.write(k if isinstance(v, bool) and v else f"{k}={v}")
87+
buffer.write(HEADER_DIR_SEPARATOR)
88+
return buffer.getvalue().strip(HEADER_DIR_SEPARATOR)
89+
4190

42-
def __init_subclass__(cls) -> None:
43-
_fields = {}
44-
for f in fields(cls):
45-
_fields[f.name] = f.type
46-
setattr(cls, _FIELDS, _fields)
91+
def update_header_str(_input: str, **kwargs: Unpack[OTAFileCacheDirTypedDict]) -> str:
92+
"""Update input header string with input directive pairs."""
93+
if not kwargs:
94+
return _input
95+
96+
_res = _parse_header_asdict(_input)
97+
_res.update(kwargs)
98+
return export_kwargs_as_header_string(**_res)
4799

48100

49101
@dataclass
50-
class OTAFileCacheControl(_HeaderDef):
102+
class OTAFileCacheControl:
51103
"""Custom header for ota file caching control policies.
52104
53105
format:
@@ -62,68 +114,22 @@ class OTAFileCacheControl(_HeaderDef):
62114
file_compression_alg: the compression alg used for the OTA file
63115
"""
64116

65-
@classmethod
66-
def parse_header(cls, _input: str) -> Self:
67-
_fields: Dict[str, type] = getattr(cls, _FIELDS)
68-
_parsed_directives = {}
69-
for _raw_directive in _input.split(cls.HEADER_DIR_SEPARATOR):
70-
if not (_parsed := _raw_directive.strip().split("=", maxsplit=1)):
71-
continue
72-
73-
key = _parsed[0].strip()
74-
if not (_field_type := _fields.get(key)):
75-
continue
76-
77-
if _field_type is bool:
78-
_parsed_directives[key] = True
79-
elif len(_parsed) == 2 and (value := _parsed[1].strip()):
80-
_parsed_directives[key] = value
81-
return cls(**_parsed_directives)
82-
83-
@classmethod
84-
@copy_callable_typehint_to_method(_HeaderDef)
85-
def export_kwargs_as_header(cls, **kwargs) -> str:
86-
"""Directly export header str from a list of directive pairs."""
87-
_fields: Dict[str, type] = getattr(cls, _FIELDS)
88-
_directives: List[str] = []
89-
for key, value in kwargs.items():
90-
if key not in _fields:
91-
continue
92-
93-
if isinstance(value, bool) and value:
94-
_directives.append(key)
95-
elif value: # str field
96-
_directives.append(f"{key}={value}")
97-
return cls.HEADER_DIR_SEPARATOR.join(_directives)
98-
99-
@classmethod
100-
def update_header_str(cls, _input: str, **kwargs) -> str:
101-
"""Update input header string with input directive pairs.
102-
103-
Current used directives:
104-
1. no_cache
105-
2. retry_caching
106-
3. file_sha256
107-
4. file_compression_alg
108-
"""
109-
_fields: Dict[str, type] = getattr(cls, _FIELDS)
110-
_parsed_directives = {}
111-
for _raw_directive in _input.split(cls.HEADER_DIR_SEPARATOR):
112-
if not (_parsed := _raw_directive.strip().split("=", maxsplit=1)):
113-
continue
114-
key = _parsed[0].strip()
115-
if key not in _fields:
116-
continue
117-
_parsed_directives[key] = _raw_directive
118-
119-
for _key, value in kwargs.items():
120-
if not (_field_type := _fields.get(_key)):
121-
continue
117+
# ------ Header definition ------ #
118+
# NOTE: according to RFC7230, the header name is case-insensitive,
119+
# so for convenience during code implementation, we always use lower-case
120+
# header name.
121+
HEADER_LOWERCASE: ClassVar[str] = HEADER_LOWERCASE
122+
HEADER_DIR_SEPARATOR: ClassVar[str] = HEADER_DIR_SEPARATOR
122123

123-
if _field_type is bool and value:
124-
_parsed_directives[_key] = _key
125-
elif value:
126-
_parsed_directives[_key] = f"{_key}={value}"
127-
else: # remove False or empty directives
128-
_parsed_directives.pop(_key, None)
129-
return cls.HEADER_DIR_SEPARATOR.join(_parsed_directives.values())
124+
# ------ Directives definition ------ #
125+
no_cache: bool = False
126+
retry_caching: bool = False
127+
# added in revision 2:
128+
file_sha256: Optional[str] = None
129+
file_compression_alg: Optional[str] = None
130+
131+
# TODO: (20250618): to not change the callers of these methods,
132+
# currently just register these methods under OTAFileCacheControl class.
133+
parse_header = staticmethod(parse_header)
134+
export_kwargs_as_header = staticmethod(export_kwargs_as_header_string)
135+
update_header_str = staticmethod(update_header_str)

0 commit comments

Comments
 (0)