Skip to content

Commit 5d7475a

Browse files
committed
Merge branch 'main' into monthly-upload
2 parents 24bd51c + 4836a23 commit 5d7475a

File tree

2 files changed

+29
-32
lines changed

2 files changed

+29
-32
lines changed

rctab/main.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""The entrypoint of the FastAPI application."""
22

33
import logging
4+
from contextlib import asynccontextmanager
45
from pathlib import Path
5-
from typing import Any, Callable, Dict, Final
6+
from typing import Any, AsyncIterator, Callable, Dict, Final
67

78
import fastapimsal
89
import secure
@@ -34,13 +35,37 @@
3435

3536
templates = Jinja2Templates(directory=Path("rctab/templates"))
3637

38+
39+
@asynccontextmanager
40+
async def lifespan(_: FastAPI) -> AsyncIterator[None]:
41+
"""Handle setup and teardown."""
42+
await database.connect()
43+
settings = get_settings()
44+
logging.basicConfig(level=settings.log_level)
45+
set_log_handler()
46+
if not settings.ignore_whitelist:
47+
logger = logging.getLogger(__name__)
48+
logger.warning(
49+
"Starting server with subscription whitelist: %s", settings.whitelist
50+
)
51+
52+
yield
53+
54+
logger = logging.getLogger(__name__)
55+
logger.warning("Shutting down server...")
56+
57+
logger.info("Disconnecting from database")
58+
await database.disconnect()
59+
60+
3761
app = FastAPI(
3862
title="RCTab API",
3963
description="API for RCTab",
4064
version="0.1.0",
4165
docs_url=None,
4266
redoc_url=None,
4367
openapi_url=None,
68+
lifespan=lifespan,
4469
)
4570

4671
server = secure.Server().set("Secure")
@@ -76,30 +101,6 @@ async def set_secure_headers(request: Any, call_next: Callable[[Any], Any]) -> A
76101
)
77102

78103

79-
@app.on_event("startup")
80-
async def startup() -> None:
81-
"""Start the server up."""
82-
await database.connect()
83-
settings = get_settings()
84-
logging.basicConfig(level=settings.log_level)
85-
set_log_handler()
86-
if not settings.ignore_whitelist:
87-
logger = logging.getLogger(__name__)
88-
logger.warning(
89-
"Starting server with subscription whitelist: %s", settings.whitelist
90-
)
91-
92-
93-
@app.on_event("shutdown")
94-
async def shutdown() -> None:
95-
"""Shut the server down."""
96-
logger = logging.getLogger(__name__)
97-
logger.warning("Shutting down server...")
98-
99-
logger.info("Disconnecting from database")
100-
await database.disconnect()
101-
102-
103104
@app.exception_handler(UniqueViolationError)
104105
async def unicorn_exception_handler(
105106
_: Request, exc: UniqueViolationError

tests/test_routes/test_routes.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# pylint: disable=redefined-outer-name,
22
import random
33
from datetime import date, timedelta
4-
from typing import Any, AsyncGenerator, Callable, Coroutine, Dict, Optional, Tuple
4+
from typing import Any, AsyncGenerator, Callable, Coroutine, Optional, Tuple
55
from unittest.mock import AsyncMock
66
from uuid import UUID
77

@@ -145,15 +145,11 @@ async def create_subscription(
145145

146146
def make_async_execute(
147147
connection: Engine,
148-
) -> Callable[
149-
[VarArg(Tuple[Any, ...]), KwArg(Dict[str, Any])], Coroutine[Any, Any, ResultProxy]
150-
]:
148+
) -> Callable[[VarArg(Any), KwArg(Any)], Coroutine[Any, Any, ResultProxy]]:
151149
"""We need an async function to patch database.execute() with
152150
but connection.execute() is synchronous so make a wrapper for it."""
153151

154-
async def async_execute(
155-
*args: Tuple[Any, ...], **kwargs: Dict[str, Any]
156-
) -> ResultProxy:
152+
async def async_execute(*args: Any, **kwargs: Any) -> ResultProxy:
157153
"""An async wrapper around connection.execute()."""
158154
return connection.execute(*args, **kwargs) # type: ignore
159155

0 commit comments

Comments
 (0)