1
- from typing import TYPE_CHECKING , Callable , Dict , List , Optional , Union
1
+ from typing import (TYPE_CHECKING , Callable , Dict , List , Literal , Optional ,
2
+ Union )
2
3
3
4
from jina ._docarray import docarray_v2
4
5
from jina .importer import ImportExtensions
@@ -74,7 +75,7 @@ def add_post_route(
74
75
input_model ,
75
76
output_model ,
76
77
input_doc_list_model = None ,
77
- output_doc_list_model = None ,
78
+ parameters_model = None ,
78
79
):
79
80
import json
80
81
from typing import List , Type , Union
@@ -150,54 +151,85 @@ async def post(request: Request):
150
151
csv_body = bytes_body .decode ('utf-8' )
151
152
if not is_valid_csv (csv_body ):
152
153
raise HTTPException (
153
- status_code = 400 ,
154
+ status_code = http_status . HTTP_400_BAD_REQUEST ,
154
155
detail = 'Invalid CSV input. Please check your input.' ,
155
156
)
156
157
157
- def construct_model_from_line (
158
- model : Type [BaseModel ], line : List [str ]
159
- ) -> BaseModel :
158
+
159
+ def construct_model_from_line (model : Type [BaseModel ], line : List [str ]) -> BaseModel :
160
+ origin = get_origin (model )
161
+ # If the model is of type Optional[X], unwrap it to get X
162
+ if origin is Union :
163
+ # If the model is of type Optional[X], unwrap it to get X
164
+ args = get_args (model )
165
+ if type (None ) in args :
166
+ model = args [0 ]
167
+
160
168
parsed_fields = {}
161
169
model_fields = model .__fields__
162
170
163
- for field_str , (field_name , field_info ) in zip (
164
- line , model_fields .items ()
165
- ):
166
- field_type = field_info .outer_type_
167
-
168
- # Handle Union types by attempting to arse each potential type
169
- if get_origin (field_type ) is Union :
170
- for possible_type in get_args (field_type ):
171
- if possible_type is str :
172
- parsed_fields [field_name ] = field_str
173
- break
174
- else :
171
+ for idx , (field_name , field_info ) in enumerate (model_fields .items ()):
172
+ field_type = field_info .type_
173
+ field_str = line [idx ] # Corresponding value from the row
174
+
175
+ try :
176
+ # Handle Literal types (e.g., Optional[Literal["value1", "value2"]])
177
+ origin = get_origin (field_type )
178
+ if origin is Literal :
179
+ literal_values = get_args (field_type )
180
+ if field_str not in literal_values :
181
+ raise HTTPException (
182
+ status_code = http_status .HTTP_400_BAD_REQUEST ,
183
+ detail = f"Invalid value '{ field_str } ' for field '{ field_name } '. Expected one of: { literal_values } "
184
+ )
185
+ parsed_fields [field_name ] = field_str
186
+
187
+ # Handle Union types (e.g., Optional[int, str])
188
+ elif origin is Union :
189
+ for possible_type in get_args (field_type ):
175
190
try :
176
- parsed_fields [field_name ] = parse_obj_as (
177
- possible_type , json .loads (field_str )
178
- )
191
+ parsed_fields [field_name ] = parse_obj_as (possible_type , field_str )
179
192
break
180
- except (json . JSONDecodeError , ValidationError ):
193
+ except (ValueError , TypeError , ValidationError ):
181
194
continue
182
- # Handle list of nested models
183
- elif get_origin (field_type ) is list :
184
- list_item_type = get_args (field_type )[0 ]
185
- if field_str :
186
- parsed_list = json .loads (field_str )
187
- if issubclass (list_item_type , BaseModel ):
188
- parsed_fields [field_name ] = parse_obj_as (
189
- List [list_item_type ], parsed_list
190
- )
191
- else :
192
- parsed_fields [field_name ] = parsed_list
193
- # General parsing attempt for other types
194
- else :
195
- if field_str :
196
- try :
197
- parsed_fields [field_name ] = field_info .type_ (field_str )
198
- except (ValueError , TypeError ):
199
- # Fallback to parse_obj_as when type is more complex, e., AnyUrl or ImageBytes
200
- parsed_fields [field_name ] = parse_obj_as (field_info .type_ , field_str )
195
+
196
+ # Handle list of nested models (e.g., List[Item])
197
+ elif get_origin (field_type ) is list :
198
+ list_item_type = get_args (field_type )[0 ]
199
+ if field_str :
200
+ parsed_list = json .loads (field_str )
201
+ if issubclass (list_item_type , BaseModel ):
202
+ parsed_fields [field_name ] = parse_obj_as (List [list_item_type ], parsed_list )
203
+ else :
204
+ parsed_fields [field_name ] = parsed_list
205
+
206
+ # Handle other general types
207
+ else :
208
+ if field_str :
209
+ if field_type == bool :
210
+ # Special case: handle "false" and "true" as booleans
211
+ if field_str .lower () == "false" :
212
+ parsed_fields [field_name ] = False
213
+ elif field_str .lower () == "true" :
214
+ parsed_fields [field_name ] = True
215
+ else :
216
+ raise HTTPException (
217
+ status_code = http_status .HTTP_400_BAD_REQUEST ,
218
+ detail = f"Invalid value '{ field_str } ' for boolean field '{ field_name } '. Expected 'true' or 'false'."
219
+ )
220
+ else :
221
+ # General case: try converting to the target type
222
+ try :
223
+ parsed_fields [field_name ] = field_type (field_str )
224
+ except (ValueError , TypeError ):
225
+ # Fallback to parse_obj_as when type is more complex, e., AnyUrl or ImageBytes
226
+ parsed_fields [field_name ] = parse_obj_as (field_type , field_str )
227
+
228
+ except Exception as e :
229
+ raise HTTPException (
230
+ status_code = http_status .HTTP_400_BAD_REQUEST ,
231
+ detail = f"Error parsing value '{ field_str } ' for field '{ field_name } ': { str (e )} "
232
+ )
201
233
202
234
return model (** parsed_fields )
203
235
@@ -209,25 +241,41 @@ def construct_model_from_line(
209
241
# We also expect the csv file to have no quotes and use the escape char '\'
210
242
field_names = [f for f in input_doc_list_model .__fields__ ]
211
243
data = []
244
+ parameters = None
245
+ first_row = True
212
246
for line in csv .reader (
213
247
StringIO (csv_body ),
214
248
delimiter = ',' ,
215
249
quoting = csv .QUOTE_NONE ,
216
250
escapechar = '\\ ' ,
217
251
):
218
- if len (line ) != len (field_names ):
219
- raise HTTPException (
220
- status_code = 400 ,
221
- detail = f'Invalid CSV format. Line { line } doesn\' t match '
222
- f'the expected field order { field_names } .' ,
223
- )
224
- data .append (construct_model_from_line (input_doc_list_model , line ))
225
-
226
- return await process (input_model (data = data ))
252
+ if first_row :
253
+ first_row = False
254
+ if len (line ) > 1 and line [1 ] == 'params_row' : # Check if it's a parameters row by examining the 2nd text in the first line
255
+ parameters = construct_model_from_line (parameters_model , line [2 :])
256
+ else :
257
+ if len (line ) != len (field_names ):
258
+ raise HTTPException (
259
+ status_code = http_status .HTTP_400_BAD_REQUEST ,
260
+ detail = f'Invalid CSV format. Line { line } doesn\' t match '
261
+ f'the expected field order { field_names } .' ,
262
+ )
263
+ data .append (construct_model_from_line (input_doc_list_model , line ))
264
+ else :
265
+ # Treat it as normal data row
266
+ if len (line ) != len (field_names ):
267
+ raise HTTPException (
268
+ status_code = http_status .HTTP_400_BAD_REQUEST ,
269
+ detail = f'Invalid CSV format. Line { line } doesn\' t match '
270
+ f'the expected field order { field_names } .' ,
271
+ )
272
+ data .append (construct_model_from_line (input_doc_list_model , line ))
273
+
274
+ return await process (input_model (data = data , parameters = parameters ))
227
275
228
276
else :
229
277
raise HTTPException (
230
- status_code = 400 ,
278
+ status_code = http_status . HTTP_400_BAD_REQUEST ,
231
279
detail = f'Invalid content-type: { content_type } . '
232
280
f'Please use either application/json or text/csv.' ,
233
281
)
@@ -273,7 +321,7 @@ def construct_model_from_line(
273
321
input_model = endpoint_input_model ,
274
322
output_model = endpoint_output_model ,
275
323
input_doc_list_model = input_doc_model ,
276
- output_doc_list_model = output_doc_model ,
324
+ parameters_model = parameters_model ,
277
325
)
278
326
279
327
from jina .serve .runtimes .gateway .health_model import JinaHealthModel
0 commit comments