Skip to content

Commit cba152e

Browse files
authored
Prevent too large requests (#3048)
* prevent too large requests * reject file upload attempts * lint
1 parent 9a6044b commit cba152e

File tree

3 files changed

+95
-2
lines changed

3 files changed

+95
-2
lines changed

src/zenml/config/server_config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY,
3030
DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE,
3131
DEFAULT_ZENML_SERVER_MAX_DEVICE_AUTH_ATTEMPTS,
32+
DEFAULT_ZENML_SERVER_MAX_REQUEST_BODY_SIZE_IN_BYTES,
3233
DEFAULT_ZENML_SERVER_NAME,
3334
DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW,
3435
DEFAULT_ZENML_SERVER_SECURE_HEADERS_CACHE,
@@ -231,6 +232,8 @@ class ServerConfiguration(BaseModel):
231232
auto_activate: Whether to automatically activate the server and create a
232233
default admin user account with an empty password during the initial
233234
deployment.
235+
max_request_body_size_in_bytes: The maximum size of the request body in
236+
bytes. If not specified, the default value of 256 Kb will be used.
234237
"""
235238

236239
deployment_type: ServerDeploymentType = ServerDeploymentType.OTHER
@@ -319,6 +322,10 @@ class ServerConfiguration(BaseModel):
319322

320323
thread_pool_size: int = DEFAULT_ZENML_SERVER_THREAD_POOL_SIZE
321324

325+
max_request_body_size_in_bytes: int = (
326+
DEFAULT_ZENML_SERVER_MAX_REQUEST_BODY_SIZE_IN_BYTES
327+
)
328+
322329
_deployment_id: Optional[UUID] = None
323330

324331
@model_validator(mode="before")

src/zenml/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
317317
DEFAULT_ZENML_SERVER_SECURE_HEADERS_REPORT_TO = "default"
318318
DEFAULT_ZENML_SERVER_USE_LEGACY_DASHBOARD = False
319319
DEFAULT_ZENML_SERVER_REPORT_USER_ACTIVITY_TO_DB_SECONDS = 30
320+
DEFAULT_ZENML_SERVER_MAX_REQUEST_BODY_SIZE_IN_BYTES = 256 * 1024 * 1024
320321

321322
# Configurations to decide which resources report their usage and check for
322323
# entitlement in the case of a cloud deployment. Expected Format is this:

src/zenml/zen_server/zen_server_api.py

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,21 @@
2424
from asyncio.log import logger
2525
from datetime import datetime, timedelta, timezone
2626
from genericpath import isfile
27-
from typing import Any, List
27+
from typing import Any, List, Set
2828

2929
from anyio import to_thread
3030
from fastapi import FastAPI, HTTPException, Request
3131
from fastapi.exceptions import RequestValidationError
3232
from fastapi.responses import ORJSONResponse
3333
from fastapi.staticfiles import StaticFiles
3434
from fastapi.templating import Jinja2Templates
35+
from starlette.middleware.base import (
36+
BaseHTTPMiddleware,
37+
RequestResponseEndpoint,
38+
)
3539
from starlette.middleware.cors import CORSMiddleware
36-
from starlette.responses import FileResponse
40+
from starlette.responses import FileResponse, JSONResponse, Response
41+
from starlette.types import ASGIApp
3742

3843
import zenml
3944
from zenml.analytics import source_context
@@ -143,6 +148,79 @@ def validation_exception_handler(
143148
return ORJSONResponse(error_detail(exc, ValueError), status_code=422)
144149

145150

151+
class RequestBodyLimit(BaseHTTPMiddleware):
152+
"""Limits the size of the request body."""
153+
154+
def __init__(self, app: ASGIApp, max_bytes: int) -> None:
155+
"""Limits the size of the request body.
156+
157+
Args:
158+
app: The FastAPI app.
159+
max_bytes: The maximum size of the request body.
160+
"""
161+
super().__init__(app)
162+
self.max_bytes = max_bytes
163+
164+
async def dispatch(
165+
self, request: Request, call_next: RequestResponseEndpoint
166+
) -> Response:
167+
"""Limits the size of the request body.
168+
169+
Args:
170+
request: The incoming request.
171+
call_next: The next function to be called.
172+
173+
Returns:
174+
The response to the request.
175+
"""
176+
if content_length := request.headers.get("content-length"):
177+
if int(content_length) > self.max_bytes:
178+
return Response(status_code=413) # Request Entity Too Large
179+
return await call_next(request)
180+
181+
182+
class RestrictFileUploadsMiddleware(BaseHTTPMiddleware):
183+
"""Restrict file uploads to certain paths."""
184+
185+
def __init__(self, app: FastAPI, allowed_paths: Set[str]):
186+
"""Restrict file uploads to certain paths.
187+
188+
Args:
189+
app: The FastAPI app.
190+
allowed_paths: The allowed paths.
191+
"""
192+
super().__init__(app)
193+
self.allowed_paths = allowed_paths
194+
195+
async def dispatch(
196+
self, request: Request, call_next: RequestResponseEndpoint
197+
) -> Response:
198+
"""Restrict file uploads to certain paths.
199+
200+
Args:
201+
request: The incoming request.
202+
call_next: The next function to be called.
203+
204+
Returns:
205+
The response to the request.
206+
"""
207+
if request.method == "POST":
208+
content_type = request.headers.get("content-type", "")
209+
if (
210+
"multipart/form-data" in content_type
211+
and request.url.path not in self.allowed_paths
212+
):
213+
return JSONResponse(
214+
status_code=403,
215+
content={
216+
"detail": "File uploads are not allowed on this endpoint."
217+
},
218+
)
219+
return await call_next(request)
220+
221+
222+
ALLOWED_FOR_FILE_UPLOAD: Set[str] = set()
223+
146224
app.add_middleware(
147225
CORSMiddleware,
148226
allow_origins=server_config().cors_allow_origins,
@@ -151,6 +229,13 @@ def validation_exception_handler(
151229
allow_headers=["*"],
152230
)
153231

232+
app.add_middleware(
233+
RequestBodyLimit, max_bytes=server_config().max_request_body_size_in_bytes
234+
)
235+
app.add_middleware(
236+
RestrictFileUploadsMiddleware, allowed_paths=ALLOWED_FOR_FILE_UPLOAD
237+
)
238+
154239

155240
@app.middleware("http")
156241
async def set_secure_headers(request: Request, call_next: Any) -> Any:

0 commit comments

Comments
 (0)