24
24
from asyncio .log import logger
25
25
from datetime import datetime , timedelta , timezone
26
26
from genericpath import isfile
27
- from typing import Any , List
27
+ from typing import Any , List , Set
28
28
29
29
from anyio import to_thread
30
30
from fastapi import FastAPI , HTTPException , Request
31
31
from fastapi .exceptions import RequestValidationError
32
32
from fastapi .responses import ORJSONResponse
33
33
from fastapi .staticfiles import StaticFiles
34
34
from fastapi .templating import Jinja2Templates
35
+ from starlette .middleware .base import (
36
+ BaseHTTPMiddleware ,
37
+ RequestResponseEndpoint ,
38
+ )
35
39
from starlette .middleware .cors import CORSMiddleware
36
- from starlette .responses import FileResponse
40
+ from starlette .responses import FileResponse , JSONResponse , Response
41
+ from starlette .types import ASGIApp
37
42
38
43
import zenml
39
44
from zenml .analytics import source_context
@@ -143,6 +148,79 @@ def validation_exception_handler(
143
148
return ORJSONResponse (error_detail (exc , ValueError ), status_code = 422 )
144
149
145
150
151
+ class RequestBodyLimit (BaseHTTPMiddleware ):
152
+ """Limits the size of the request body."""
153
+
154
+ def __init__ (self , app : ASGIApp , max_bytes : int ) -> None :
155
+ """Limits the size of the request body.
156
+
157
+ Args:
158
+ app: The FastAPI app.
159
+ max_bytes: The maximum size of the request body.
160
+ """
161
+ super ().__init__ (app )
162
+ self .max_bytes = max_bytes
163
+
164
+ async def dispatch (
165
+ self , request : Request , call_next : RequestResponseEndpoint
166
+ ) -> Response :
167
+ """Limits the size of the request body.
168
+
169
+ Args:
170
+ request: The incoming request.
171
+ call_next: The next function to be called.
172
+
173
+ Returns:
174
+ The response to the request.
175
+ """
176
+ if content_length := request .headers .get ("content-length" ):
177
+ if int (content_length ) > self .max_bytes :
178
+ return Response (status_code = 413 ) # Request Entity Too Large
179
+ return await call_next (request )
180
+
181
+
182
+ class RestrictFileUploadsMiddleware (BaseHTTPMiddleware ):
183
+ """Restrict file uploads to certain paths."""
184
+
185
+ def __init__ (self , app : FastAPI , allowed_paths : Set [str ]):
186
+ """Restrict file uploads to certain paths.
187
+
188
+ Args:
189
+ app: The FastAPI app.
190
+ allowed_paths: The allowed paths.
191
+ """
192
+ super ().__init__ (app )
193
+ self .allowed_paths = allowed_paths
194
+
195
+ async def dispatch (
196
+ self , request : Request , call_next : RequestResponseEndpoint
197
+ ) -> Response :
198
+ """Restrict file uploads to certain paths.
199
+
200
+ Args:
201
+ request: The incoming request.
202
+ call_next: The next function to be called.
203
+
204
+ Returns:
205
+ The response to the request.
206
+ """
207
+ if request .method == "POST" :
208
+ content_type = request .headers .get ("content-type" , "" )
209
+ if (
210
+ "multipart/form-data" in content_type
211
+ and request .url .path not in self .allowed_paths
212
+ ):
213
+ return JSONResponse (
214
+ status_code = 403 ,
215
+ content = {
216
+ "detail" : "File uploads are not allowed on this endpoint."
217
+ },
218
+ )
219
+ return await call_next (request )
220
+
221
+
222
+ ALLOWED_FOR_FILE_UPLOAD : Set [str ] = set ()
223
+
146
224
app .add_middleware (
147
225
CORSMiddleware ,
148
226
allow_origins = server_config ().cors_allow_origins ,
@@ -151,6 +229,13 @@ def validation_exception_handler(
151
229
allow_headers = ["*" ],
152
230
)
153
231
232
+ app .add_middleware (
233
+ RequestBodyLimit , max_bytes = server_config ().max_request_body_size_in_bytes
234
+ )
235
+ app .add_middleware (
236
+ RestrictFileUploadsMiddleware , allowed_paths = ALLOWED_FOR_FILE_UPLOAD
237
+ )
238
+
154
239
155
240
@app .middleware ("http" )
156
241
async def set_secure_headers (request : Request , call_next : Any ) -> Any :
0 commit comments