Skip to content

Commit 6d7bd95

Browse files
author
Joan Martinez
committed
fix: fix sagemaker csp
1 parent a2b1281 commit 6d7bd95

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

jina/serve/runtimes/worker/http_csp_app.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
if docarray_v2:
1212
from docarray import BaseDoc, DocList
13+
from docarray.utils._internal._typing import safe_issubclass
14+
else:
15+
safe_issubclass = issubclass
1316

1417

1518
def get_fastapi_app(
@@ -157,7 +160,6 @@ async def post(request: Request):
157160
detail='Invalid CSV input. Please check your input.',
158161
)
159162

160-
161163
def construct_model_from_line(model: Type[BaseModel], line: List[str]) -> BaseModel:
162164
origin = get_origin(model)
163165
# If the model is of type Optional[X], unwrap it to get X
@@ -171,7 +173,7 @@ def construct_model_from_line(model: Type[BaseModel], line: List[str]) -> BaseMo
171173
model_fields = model.__fields__
172174

173175
for idx, (field_name, field_info) in enumerate(model_fields.items()):
174-
field_type = field_info.type_
176+
field_type = field_info.outer_type_
175177
field_str = line[idx] # Corresponding value from the row
176178

177179
try:
@@ -196,15 +198,22 @@ def construct_model_from_line(model: Type[BaseModel], line: List[str]) -> BaseMo
196198
continue
197199

198200
# Handle list of nested models (e.g., List[Item])
199-
elif get_origin(field_type) is list:
201+
elif origin is list:
200202
list_item_type = get_args(field_type)[0]
201203
if field_str:
202204
parsed_list = json.loads(field_str)
203-
if issubclass(list_item_type, BaseModel):
205+
if safe_issubclass(list_item_type, BaseModel): # TODO: use safe issubclass
204206
parsed_fields[field_name] = parse_obj_as(List[list_item_type], parsed_list)
205207
else:
206208
parsed_fields[field_name] = parsed_list
207-
209+
elif safe_issubclass(field_type, DocList):
210+
list_item_type = field_type.doc_type
211+
if field_str:
212+
parsed_list = json.loads(field_str)
213+
if safe_issubclass(list_item_type, BaseDoc): # TODO: use safe issubclass
214+
parsed_fields[field_name] = parse_obj_as(DocList[list_item_type], parsed_list)
215+
else:
216+
parsed_fields[field_name] = parsed_list
208217
# Handle other general types
209218
else:
210219
if field_str:
@@ -222,7 +231,7 @@ def construct_model_from_line(model: Type[BaseModel], line: List[str]) -> BaseMo
222231
else:
223232
# General case: try converting to the target type
224233
try:
225-
parsed_fields[field_name] = field_type(field_str)
234+
parsed_fields[field_name] = DocList[field_type](field_str)
226235
except (ValueError, TypeError):
227236
# Fallback to parse_obj_as when type is more complex, e., AnyUrl or ImageBytes
228237
parsed_fields[field_name] = parse_obj_as(field_type, field_str)
@@ -253,14 +262,15 @@ def construct_model_from_line(model: Type[BaseModel], line: List[str]) -> BaseMo
253262
):
254263
if first_row:
255264
first_row = False
256-
if len(line) > 1 and line[1] == 'params_row': # Check if it's a parameters row by examining the 2nd text in the first line
265+
if len(line) > 1 and line[
266+
1] == 'params_row': # Check if it's a parameters row by examining the 2nd text in the first line
257267
parameters = construct_model_from_line(parameters_model, line[2:])
258268
else:
259269
if len(line) != len(field_names):
260270
raise HTTPException(
261271
status_code=http_status.HTTP_400_BAD_REQUEST,
262272
detail=f'Invalid CSV format. Line {line} doesn\'t match '
263-
f'the expected field order {field_names}.',
273+
f'the expected field order {field_names}.',
264274
)
265275
data.append(construct_model_from_line(input_doc_list_model, line))
266276
else:
@@ -269,7 +279,7 @@ def construct_model_from_line(model: Type[BaseModel], line: List[str]) -> BaseMo
269279
raise HTTPException(
270280
status_code=http_status.HTTP_400_BAD_REQUEST,
271281
detail=f'Invalid CSV format. Line {line} doesn\'t match '
272-
f'the expected field order {field_names}.',
282+
f'the expected field order {field_names}.',
273283
)
274284
data.append(construct_model_from_line(input_doc_list_model, line))
275285

tests/unit/orchestrate/flow/flow-construct/test_flow_yaml_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,14 @@ def test_dump_load_build(monkeypatch):
109109
# validate gateway args (set during build)
110110
assert f['gateway'].args.port == f2['gateway'].args.port
111111

112-
@pytest.mark.skipif('GITHUB_WORKFLOW' in os.env, reason='no specific port test in CI')
112+
@pytest.mark.skipif('GITHUB_WORKFLOW' in os.environ, reason='no specific port test in CI')
113113
def test_load_flow_with_port():
114114
f = Flow.load_config('yaml/test-flow-port.yml')
115115
with f:
116116
assert f.port == 12345
117117

118118

119-
@pytest.mark.skipif('GITHUB_WORKFLOW' in os.env, reason='no specific port test in CI')
119+
@pytest.mark.skipif('GITHUB_WORKFLOW' in os.environ, reason='no specific port test in CI')
120120
def test_load_flow_from_cli():
121121
a = set_flow_parser().parse_args(['--uses', 'yaml/test-flow-port.yml'])
122122
f = Flow.load_config(a.uses)

0 commit comments

Comments
 (0)