10
10
11
11
if docarray_v2 :
12
12
from docarray import BaseDoc , DocList
13
+ from docarray .utils ._internal ._typing import safe_issubclass
14
+ else :
15
+ safe_issubclass = issubclass
13
16
14
17
15
18
def get_fastapi_app (
@@ -157,7 +160,6 @@ async def post(request: Request):
157
160
detail = 'Invalid CSV input. Please check your input.' ,
158
161
)
159
162
160
-
161
163
def construct_model_from_line (model : Type [BaseModel ], line : List [str ]) -> BaseModel :
162
164
origin = get_origin (model )
163
165
# If the model is of type Optional[X], unwrap it to get X
@@ -171,7 +173,7 @@ def construct_model_from_line(model: Type[BaseModel], line: List[str]) -> BaseMo
171
173
model_fields = model .__fields__
172
174
173
175
for idx , (field_name , field_info ) in enumerate (model_fields .items ()):
174
- field_type = field_info .type_
176
+ field_type = field_info .outer_type_
175
177
field_str = line [idx ] # Corresponding value from the row
176
178
177
179
try :
@@ -196,15 +198,22 @@ def construct_model_from_line(model: Type[BaseModel], line: List[str]) -> BaseMo
196
198
continue
197
199
198
200
# Handle list of nested models (e.g., List[Item])
199
- elif get_origin ( field_type ) is list :
201
+ elif origin is list :
200
202
list_item_type = get_args (field_type )[0 ]
201
203
if field_str :
202
204
parsed_list = json .loads (field_str )
203
- if issubclass (list_item_type , BaseModel ):
205
+ if safe_issubclass (list_item_type , BaseModel ): # TODO: use safe issubclass
204
206
parsed_fields [field_name ] = parse_obj_as (List [list_item_type ], parsed_list )
205
207
else :
206
208
parsed_fields [field_name ] = parsed_list
207
-
209
+ elif safe_issubclass (field_type , DocList ):
210
+ list_item_type = field_type .doc_type
211
+ if field_str :
212
+ parsed_list = json .loads (field_str )
213
+ if safe_issubclass (list_item_type , BaseDoc ): # TODO: use safe issubclass
214
+ parsed_fields [field_name ] = parse_obj_as (DocList [list_item_type ], parsed_list )
215
+ else :
216
+ parsed_fields [field_name ] = parsed_list
208
217
# Handle other general types
209
218
else :
210
219
if field_str :
@@ -222,7 +231,7 @@ def construct_model_from_line(model: Type[BaseModel], line: List[str]) -> BaseMo
222
231
else :
223
232
# General case: try converting to the target type
224
233
try :
225
- parsed_fields [field_name ] = field_type (field_str )
234
+ parsed_fields [field_name ] = DocList [ field_type ] (field_str )
226
235
except (ValueError , TypeError ):
227
236
# Fallback to parse_obj_as when type is more complex, e., AnyUrl or ImageBytes
228
237
parsed_fields [field_name ] = parse_obj_as (field_type , field_str )
@@ -253,14 +262,15 @@ def construct_model_from_line(model: Type[BaseModel], line: List[str]) -> BaseMo
253
262
):
254
263
if first_row :
255
264
first_row = False
256
- if len (line ) > 1 and line [1 ] == 'params_row' : # Check if it's a parameters row by examining the 2nd text in the first line
265
+ if len (line ) > 1 and line [
266
+ 1 ] == 'params_row' : # Check if it's a parameters row by examining the 2nd text in the first line
257
267
parameters = construct_model_from_line (parameters_model , line [2 :])
258
268
else :
259
269
if len (line ) != len (field_names ):
260
270
raise HTTPException (
261
271
status_code = http_status .HTTP_400_BAD_REQUEST ,
262
272
detail = f'Invalid CSV format. Line { line } doesn\' t match '
263
- f'the expected field order { field_names } .' ,
273
+ f'the expected field order { field_names } .' ,
264
274
)
265
275
data .append (construct_model_from_line (input_doc_list_model , line ))
266
276
else :
@@ -269,7 +279,7 @@ def construct_model_from_line(model: Type[BaseModel], line: List[str]) -> BaseMo
269
279
raise HTTPException (
270
280
status_code = http_status .HTTP_400_BAD_REQUEST ,
271
281
detail = f'Invalid CSV format. Line { line } doesn\' t match '
272
- f'the expected field order { field_names } .' ,
282
+ f'the expected field order { field_names } .' ,
273
283
)
274
284
data .append (construct_model_from_line (input_doc_list_model , line ))
275
285
0 commit comments