Skip to content

refactor: otaproxy: implement resource limit for requests handling and cache r/w, refactor cache_streaming with r/w thread pools #575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dependencies = [
"pydantic-settings<3,>=2.3",
"pyyaml<7,>=6.0.1",
"requests==2.32.4",
"simple-sqlite3-orm<0.11,>=0.10",
"simple-sqlite3-orm~=0.12",
"typing-extensions>=4.6.3",
"urllib3>=2.2.2,<2.5",
"uvicorn[standard]>=0.30,<0.35",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pydantic<3,>=2.10
pydantic-settings<3,>=2.3
pyyaml<7,>=6.0.1
requests==2.32.4
simple-sqlite3-orm<0.11,>=0.10
simple-sqlite3-orm~=0.12
typing-extensions>=4.6.3
urllib3>=2.2.2,<2.5
uvicorn[standard]>=0.30,<0.35
Expand Down
6 changes: 2 additions & 4 deletions src/ota_metadata/legacy2/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

"""


from __future__ import annotations

import contextlib
Expand Down Expand Up @@ -439,7 +438,6 @@ def connect_rstable(self) -> sqlite3.Connection:


class ResourceMeta:

def __init__(
self,
*,
Expand Down Expand Up @@ -472,7 +470,7 @@ def resources_count(self) -> int:
)

try:
_query = _orm.orm_execute(_sql_stmt)
_query = _orm.orm_execute(_sql_stmt, row_factory=sqlite3.Row)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default, ORM will apply a custom row_factory to try to convert the raw result into cache db entry, but here we do SELECT count(*) here, the result is not an entry of db.
Although the custom row_factory will detect whether the raw result if actually an entry or not, if not, return the result as it, but it is better to use sqlite3.Row as row_factory if we know we are not selecting db entry in the first place.

# NOTE: return value of fetchone will be a tuple, and here
# the first and only value of the tuple is the total nums of entries.
assert _query # should be something like ((<int>,),)
Expand All @@ -496,7 +494,7 @@ def resources_size_sum(self) -> int:
)

try:
_query = _orm.orm_execute(_sql_stmt)
_query = _orm.orm_execute(_sql_stmt, row_factory=sqlite3.Row)
# NOTE: return value of fetchone will be a tuple, and here
# the first and only value of the tuple is the total nums of entries.
assert _query # should be something like ((<int>,),)
Expand Down
11 changes: 10 additions & 1 deletion src/ota_proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ def run_otaproxy(
enable_https: bool,
external_cache_mnt_point: str | None = None,
):
import asyncio

import anyio
import uvicorn
import uvloop

from . import App, OTACache

Expand All @@ -70,4 +73,10 @@ def run_otaproxy(
http="h11",
)
_server = uvicorn.Server(_config)
anyio.run(_server.serve, backend="asyncio", backend_options={"use_uvloop": True})

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
anyio.run(
_server.serve,
backend="asyncio",
backend_options={"loop_factory": uvloop.new_event_loop},
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The anyio recommended way to setup uvloop.

)
5 changes: 4 additions & 1 deletion src/ota_proxy/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from . import run_otaproxy
from .config import config as cfg

logger = logging.getLogger(__name__)
logger = logging.getLogger("ota_proxy")

if __name__ == "__main__":
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -74,6 +74,9 @@
)
args = parser.parse_args()

# suppress logging from third-party deps
logging.basicConfig(level=logging.CRITICAL)
logger.setLevel(logging.INFO)
Comment on lines +77 to +79
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For standalone starting otaproxy, configure logging. Note that here we filter out third-party deps' logging, and set the ota_proxy logger level to INFO.

logger.info(f"launch ota_proxy at {args.host}:{args.port}")
run_otaproxy(
host=args.host,
Expand Down
182 changes: 94 additions & 88 deletions src/ota_proxy/cache_control_header.py
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleanup and refactor of cache header parsing/exporting. test_cache_control_headers.py ensures that the new implementation still has the same behavior as previous.

Original file line number Diff line number Diff line change
Expand Up @@ -12,42 +12,94 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from dataclasses import dataclass, fields
from typing import ClassVar, Dict, List
import logging
from dataclasses import dataclass
from io import StringIO
from typing import ClassVar, Optional, TypedDict

from typing_extensions import Self
from typing_extensions import Unpack

from otaclient_common._typing import copy_callable_typehint_to_method
from otaclient_common._logging import get_burst_suppressed_logger

_FIELDS = "_fields"
logger = logging.getLogger(__name__)
# NOTE: for request_error, only allow max 6 lines of logging per 30 seconds
burst_suppressed_logger = get_burst_suppressed_logger(f"{__name__}.header_parse_error")

VALID_DIRECTORIES = set(
["no_cache", "retry_caching", "file_sha256", "file_compression_alg"]
)
HEADER_LOWERCASE = "ota-file-cache-control"
HEADER_DIR_SEPARATOR = ","

@dataclass
class _HeaderDef:
# ------ Header definition ------ #
# NOTE: according to RFC7230, the header name is case-insensitive,
# so for convenience during code implementation, we always use lower-case
# header name.
HEADER_LOWERCASE: ClassVar[str] = "ota-file-cache-control"
HEADER_DIR_SEPARATOR: ClassVar[str] = ","

# ------ Directives definition ------ #
no_cache: bool = False
retry_caching: bool = False
class OTAFileCacheDirTypedDict(TypedDict, total=False):
no_cache: bool
retry_caching: bool
# added in revision 2:
file_sha256: str = ""
file_compression_alg: str = ""
file_sha256: Optional[str]
file_compression_alg: Optional[str]


def parse_header(_input: str) -> OTAFileCacheControl:
if not _input:
return OTAFileCacheControl()

_res = OTAFileCacheControl()
for c in _input.strip().split(HEADER_DIR_SEPARATOR):
k, *v = c.strip().split("=", maxsplit=1)
if k not in VALID_DIRECTORIES:
burst_suppressed_logger.warning(f"get unknown directory, ignore: {c}")
continue
setattr(_res, k, v[0] if v else True)
return _res


def _parse_header_asdict(_input: str) -> OTAFileCacheDirTypedDict:
if not _input:
return {}

_res: OTAFileCacheDirTypedDict = {}
for c in _input.strip().split(HEADER_DIR_SEPARATOR):
k, *v = c.strip().split("=", maxsplit=1)
if k not in VALID_DIRECTORIES:
burst_suppressed_logger.warning(f"get unknown directory, ignore: {c}")
continue
_res[k] = v[0] if v else True
return _res


def export_kwargs_as_header_string(**kwargs: Unpack[OTAFileCacheDirTypedDict]) -> str:
"""Directly export header str from a list of directive pairs."""
if not kwargs:
return ""

with StringIO() as buffer:
for k, v in kwargs.items():
if k not in VALID_DIRECTORIES:
burst_suppressed_logger.warning(f"get unknown directory, ignore: {k}")
continue
if not v:
continue

buffer.write(k if isinstance(v, bool) and v else f"{k}={v}")
buffer.write(HEADER_DIR_SEPARATOR)
return buffer.getvalue().strip(HEADER_DIR_SEPARATOR)


def __init_subclass__(cls) -> None:
_fields = {}
for f in fields(cls):
_fields[f.name] = f.type
setattr(cls, _FIELDS, _fields)
def update_header_str(_input: str, **kwargs: Unpack[OTAFileCacheDirTypedDict]) -> str:
"""Update input header string with input directive pairs."""
if not kwargs:
return _input

_res = _parse_header_asdict(_input)
_res.update(kwargs)
return export_kwargs_as_header_string(**_res)


@dataclass
class OTAFileCacheControl(_HeaderDef):
class OTAFileCacheControl:
"""Custom header for ota file caching control policies.

format:
Expand All @@ -62,68 +114,22 @@ class OTAFileCacheControl(_HeaderDef):
file_compression_alg: the compression alg used for the OTA file
"""

@classmethod
def parse_header(cls, _input: str) -> Self:
_fields: Dict[str, type] = getattr(cls, _FIELDS)
_parsed_directives = {}
for _raw_directive in _input.split(cls.HEADER_DIR_SEPARATOR):
if not (_parsed := _raw_directive.strip().split("=", maxsplit=1)):
continue

key = _parsed[0].strip()
if not (_field_type := _fields.get(key)):
continue

if _field_type is bool:
_parsed_directives[key] = True
elif len(_parsed) == 2 and (value := _parsed[1].strip()):
_parsed_directives[key] = value
return cls(**_parsed_directives)

@classmethod
@copy_callable_typehint_to_method(_HeaderDef)
def export_kwargs_as_header(cls, **kwargs) -> str:
"""Directly export header str from a list of directive pairs."""
_fields: Dict[str, type] = getattr(cls, _FIELDS)
_directives: List[str] = []
for key, value in kwargs.items():
if key not in _fields:
continue

if isinstance(value, bool) and value:
_directives.append(key)
elif value: # str field
_directives.append(f"{key}={value}")
return cls.HEADER_DIR_SEPARATOR.join(_directives)

@classmethod
def update_header_str(cls, _input: str, **kwargs) -> str:
"""Update input header string with input directive pairs.

Current used directives:
1. no_cache
2. retry_caching
3. file_sha256
4. file_compression_alg
"""
_fields: Dict[str, type] = getattr(cls, _FIELDS)
_parsed_directives = {}
for _raw_directive in _input.split(cls.HEADER_DIR_SEPARATOR):
if not (_parsed := _raw_directive.strip().split("=", maxsplit=1)):
continue
key = _parsed[0].strip()
if key not in _fields:
continue
_parsed_directives[key] = _raw_directive

for _key, value in kwargs.items():
if not (_field_type := _fields.get(_key)):
continue
# ------ Header definition ------ #
# NOTE: according to RFC7230, the header name is case-insensitive,
# so for convenience during code implementation, we always use lower-case
# header name.
HEADER_LOWERCASE: ClassVar[str] = HEADER_LOWERCASE
HEADER_DIR_SEPARATOR: ClassVar[str] = HEADER_DIR_SEPARATOR

if _field_type is bool and value:
_parsed_directives[_key] = _key
elif value:
_parsed_directives[_key] = f"{_key}={value}"
else: # remove False or empty directives
_parsed_directives.pop(_key, None)
return cls.HEADER_DIR_SEPARATOR.join(_parsed_directives.values())
# ------ Directives definition ------ #
no_cache: bool = False
retry_caching: bool = False
# added in revision 2:
file_sha256: Optional[str] = None
file_compression_alg: Optional[str] = None

# TODO: (20250618): to not change the callers of these methods,
# currently just register these methods under OTAFileCacheControl class.
parse_header = staticmethod(parse_header)
export_kwargs_as_header = staticmethod(export_kwargs_as_header_string)
update_header_str = staticmethod(update_header_str)
Loading
Loading