|
1 | 1 | """The entrypoint of the FastAPI application."""
|
2 | 2 |
|
3 | 3 | import logging
|
| 4 | +from contextlib import asynccontextmanager |
4 | 5 | from pathlib import Path
|
5 |
| -from typing import Any, Callable, Dict, Final |
| 6 | +from typing import Any, AsyncIterator, Callable, Dict, Final |
6 | 7 |
|
7 | 8 | import fastapimsal
|
8 | 9 | import secure
|
|
34 | 35 |
|
35 | 36 | templates = Jinja2Templates(directory=Path("rctab/templates"))
|
36 | 37 |
|
| 38 | + |
| 39 | +@asynccontextmanager |
| 40 | +async def lifespan(_: FastAPI) -> AsyncIterator[None]: |
| 41 | + """Handle setup and teardown.""" |
| 42 | + await database.connect() |
| 43 | + settings = get_settings() |
| 44 | + logging.basicConfig(level=settings.log_level) |
| 45 | + set_log_handler() |
| 46 | + if not settings.ignore_whitelist: |
| 47 | + logger = logging.getLogger(__name__) |
| 48 | + logger.warning( |
| 49 | + "Starting server with subscription whitelist: %s", settings.whitelist |
| 50 | + ) |
| 51 | + |
| 52 | + yield |
| 53 | + |
| 54 | + logger = logging.getLogger(__name__) |
| 55 | + logger.warning("Shutting down server...") |
| 56 | + |
| 57 | + logger.info("Disconnecting from database") |
| 58 | + await database.disconnect() |
| 59 | + |
| 60 | + |
37 | 61 | app = FastAPI(
|
38 | 62 | title="RCTab API",
|
39 | 63 | description="API for RCTab",
|
40 | 64 | version="0.1.0",
|
41 | 65 | docs_url=None,
|
42 | 66 | redoc_url=None,
|
43 | 67 | openapi_url=None,
|
| 68 | + lifespan=lifespan, |
44 | 69 | )
|
45 | 70 |
|
46 | 71 | server = secure.Server().set("Secure")
|
@@ -76,30 +101,6 @@ async def set_secure_headers(request: Any, call_next: Callable[[Any], Any]) -> A
|
76 | 101 | )
|
77 | 102 |
|
78 | 103 |
|
79 |
| -@app.on_event("startup") |
80 |
| -async def startup() -> None: |
81 |
| - """Start the server up.""" |
82 |
| - await database.connect() |
83 |
| - settings = get_settings() |
84 |
| - logging.basicConfig(level=settings.log_level) |
85 |
| - set_log_handler() |
86 |
| - if not settings.ignore_whitelist: |
87 |
| - logger = logging.getLogger(__name__) |
88 |
| - logger.warning( |
89 |
| - "Starting server with subscription whitelist: %s", settings.whitelist |
90 |
| - ) |
91 |
| - |
92 |
| - |
93 |
| -@app.on_event("shutdown") |
94 |
| -async def shutdown() -> None: |
95 |
| - """Shut the server down.""" |
96 |
| - logger = logging.getLogger(__name__) |
97 |
| - logger.warning("Shutting down server...") |
98 |
| - |
99 |
| - logger.info("Disconnecting from database") |
100 |
| - await database.disconnect() |
101 |
| - |
102 |
| - |
103 | 104 | @app.exception_handler(UniqueViolationError)
|
104 | 105 | async def unicorn_exception_handler(
|
105 | 106 | _: Request, exc: UniqueViolationError
|
|
0 commit comments