Skip to content

Commit fa7f11b

Browse files
zac-liJoanFM
andauthored
feat: support param in sm batch (#6229)
Co-authored-by: Joan Martinez <[email protected]>
1 parent ef3ad20 commit fa7f11b

File tree

9 files changed

+188
-66
lines changed

9 files changed

+188
-66
lines changed

.github/workflows/build-old-docs.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ jobs:
9292
mv ./docs/_build/dirhtml ./${{ matrix.version }}
9393
zip -r /tmp/build.zip ./${{ matrix.version }}/*
9494
- name: Upload built html
95-
uses: actions/upload-artifact@v3
95+
uses: actions/upload-artifact@v4
9696
with:
9797
name: ${{ matrix.version }}
9898
path: /tmp/build.zip
@@ -106,7 +106,7 @@ jobs:
106106
with:
107107
fetch-depth: 1
108108
ref: ${{ inputs.pages_branch }}
109-
- uses: actions/download-artifact@v3
109+
- uses: actions/download-artifact@v4
110110
with:
111111
path: /tmp/artifacts
112112
- name: Clear old builds

.github/workflows/cd.yml

+4-3
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ jobs:
245245
runs-on: ubuntu-latest
246246
steps:
247247
- name: Check out repository
248-
uses: actions/checkout@v2
248+
uses: actions/checkout@v2.5.0
249249
with:
250250
fetch-depth: 200
251251

@@ -347,7 +347,7 @@ jobs:
347347
goarch: arm64
348348
steps:
349349
- name: Check out repository
350-
uses: actions/checkout@v2
350+
uses: actions/checkout@v2.5.0
351351

352352
- name: Set up Python ${{ matrix.python }}
353353
uses: actions/setup-python@v2
@@ -399,8 +399,9 @@ jobs:
399399
python -m cibuildwheel --output-dir dist
400400
401401
- name: Upload wheels as artifacts
402-
uses: actions/upload-artifact@v3
402+
uses: actions/upload-artifact@v4
403403
with:
404+
name: artifacts-${{ strategy.job-index }}
404405
path: dist/*.whl
405406

406407
# comment for now, do it manually if needed

.github/workflows/ci.yml

+3-2
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ jobs:
651651

652652
steps:
653653
- name: Check out repository
654-
uses: actions/checkout@v2
654+
uses: actions/checkout@v2.5.0
655655

656656
- name: Set up Python ${{ matrix.python }}
657657
uses: actions/setup-python@v2
@@ -688,8 +688,9 @@ jobs:
688688
python -m cibuildwheel --output-dir dist
689689
690690
- name: Upload wheels as artifacts
691-
uses: actions/upload-artifact@v3
691+
uses: actions/upload-artifact@v4
692692
with:
693+
name: artifacts-${{ strategy.job-index }}
693694
path: dist/*.whl
694695

695696
core-test:

.github/workflows/force-release.yml

+6-4
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ jobs:
8686

8787
steps:
8888
- name: Check out repository
89-
uses: actions/checkout@v2
89+
uses: actions/checkout@v2.5.0
9090
with:
9191
fetch-depth: 200
9292

@@ -124,8 +124,9 @@ jobs:
124124
python -m cibuildwheel --output-dir dist
125125
126126
- name: Upload wheels as artifacts
127-
uses: actions/upload-artifact@v3
127+
uses: actions/upload-artifact@v4
128128
with:
129+
name: artifacts-${{ strategy.job-index }}
129130
path: dist/*.whl
130131

131132
regular-release:
@@ -141,9 +142,10 @@ jobs:
141142
with:
142143
python-version: "3.10"
143144
# https://github.com/actions/checkout#fetch-all-tags
144-
- uses: actions/download-artifact@v3
145+
- uses: actions/download-artifact@v4
145146
with:
146-
name: artifact
147+
pattern: artifacts-*
148+
merge-multiple: true
147149
path: dist
148150
- run: |
149151
git fetch --depth=200

jina/serve/runtimes/worker/http_csp_app.py

+100-52
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
1+
from typing import (TYPE_CHECKING, Callable, Dict, List, Literal, Optional,
2+
Union)
23

34
from jina._docarray import docarray_v2
45
from jina.importer import ImportExtensions
@@ -74,7 +75,7 @@ def add_post_route(
7475
input_model,
7576
output_model,
7677
input_doc_list_model=None,
77-
output_doc_list_model=None,
78+
parameters_model=None,
7879
):
7980
import json
8081
from typing import List, Type, Union
@@ -150,54 +151,85 @@ async def post(request: Request):
150151
csv_body = bytes_body.decode('utf-8')
151152
if not is_valid_csv(csv_body):
152153
raise HTTPException(
153-
status_code=400,
154+
status_code=http_status.HTTP_400_BAD_REQUEST,
154155
detail='Invalid CSV input. Please check your input.',
155156
)
156157

157-
def construct_model_from_line(
158-
model: Type[BaseModel], line: List[str]
159-
) -> BaseModel:
158+
159+
def construct_model_from_line(model: Type[BaseModel], line: List[str]) -> BaseModel:
160+
origin = get_origin(model)
161+
# If the model is of type Optional[X], unwrap it to get X
162+
if origin is Union:
163+
# If the model is of type Optional[X], unwrap it to get X
164+
args = get_args(model)
165+
if type(None) in args:
166+
model = args[0]
167+
160168
parsed_fields = {}
161169
model_fields = model.__fields__
162170

163-
for field_str, (field_name, field_info) in zip(
164-
line, model_fields.items()
165-
):
166-
field_type = field_info.outer_type_
167-
168-
# Handle Union types by attempting to arse each potential type
169-
if get_origin(field_type) is Union:
170-
for possible_type in get_args(field_type):
171-
if possible_type is str:
172-
parsed_fields[field_name] = field_str
173-
break
174-
else:
171+
for idx, (field_name, field_info) in enumerate(model_fields.items()):
172+
field_type = field_info.type_
173+
field_str = line[idx] # Corresponding value from the row
174+
175+
try:
176+
# Handle Literal types (e.g., Optional[Literal["value1", "value2"]])
177+
origin = get_origin(field_type)
178+
if origin is Literal:
179+
literal_values = get_args(field_type)
180+
if field_str not in literal_values:
181+
raise HTTPException(
182+
status_code=http_status.HTTP_400_BAD_REQUEST,
183+
detail=f"Invalid value '{field_str}' for field '{field_name}'. Expected one of: {literal_values}"
184+
)
185+
parsed_fields[field_name] = field_str
186+
187+
# Handle Union types (e.g., Optional[int, str])
188+
elif origin is Union:
189+
for possible_type in get_args(field_type):
175190
try:
176-
parsed_fields[field_name] = parse_obj_as(
177-
possible_type, json.loads(field_str)
178-
)
191+
parsed_fields[field_name] = parse_obj_as(possible_type, field_str)
179192
break
180-
except (json.JSONDecodeError, ValidationError):
193+
except (ValueError, TypeError, ValidationError):
181194
continue
182-
# Handle list of nested models
183-
elif get_origin(field_type) is list:
184-
list_item_type = get_args(field_type)[0]
185-
if field_str:
186-
parsed_list = json.loads(field_str)
187-
if issubclass(list_item_type, BaseModel):
188-
parsed_fields[field_name] = parse_obj_as(
189-
List[list_item_type], parsed_list
190-
)
191-
else:
192-
parsed_fields[field_name] = parsed_list
193-
# General parsing attempt for other types
194-
else:
195-
if field_str:
196-
try:
197-
parsed_fields[field_name] = field_info.type_(field_str)
198-
except (ValueError, TypeError):
199-
# Fallback to parse_obj_as when type is more complex, e., AnyUrl or ImageBytes
200-
parsed_fields[field_name] = parse_obj_as(field_info.type_, field_str)
195+
196+
# Handle list of nested models (e.g., List[Item])
197+
elif get_origin(field_type) is list:
198+
list_item_type = get_args(field_type)[0]
199+
if field_str:
200+
parsed_list = json.loads(field_str)
201+
if issubclass(list_item_type, BaseModel):
202+
parsed_fields[field_name] = parse_obj_as(List[list_item_type], parsed_list)
203+
else:
204+
parsed_fields[field_name] = parsed_list
205+
206+
# Handle other general types
207+
else:
208+
if field_str:
209+
if field_type == bool:
210+
# Special case: handle "false" and "true" as booleans
211+
if field_str.lower() == "false":
212+
parsed_fields[field_name] = False
213+
elif field_str.lower() == "true":
214+
parsed_fields[field_name] = True
215+
else:
216+
raise HTTPException(
217+
status_code=http_status.HTTP_400_BAD_REQUEST,
218+
detail=f"Invalid value '{field_str}' for boolean field '{field_name}'. Expected 'true' or 'false'."
219+
)
220+
else:
221+
# General case: try converting to the target type
222+
try:
223+
parsed_fields[field_name] = field_type(field_str)
224+
except (ValueError, TypeError):
225+
# Fallback to parse_obj_as when type is more complex, e., AnyUrl or ImageBytes
226+
parsed_fields[field_name] = parse_obj_as(field_type, field_str)
227+
228+
except Exception as e:
229+
raise HTTPException(
230+
status_code=http_status.HTTP_400_BAD_REQUEST,
231+
detail=f"Error parsing value '{field_str}' for field '{field_name}': {str(e)}"
232+
)
201233

202234
return model(**parsed_fields)
203235

@@ -209,25 +241,41 @@ def construct_model_from_line(
209241
# We also expect the csv file to have no quotes and use the escape char '\'
210242
field_names = [f for f in input_doc_list_model.__fields__]
211243
data = []
244+
parameters = None
245+
first_row = True
212246
for line in csv.reader(
213247
StringIO(csv_body),
214248
delimiter=',',
215249
quoting=csv.QUOTE_NONE,
216250
escapechar='\\',
217251
):
218-
if len(line) != len(field_names):
219-
raise HTTPException(
220-
status_code=400,
221-
detail=f'Invalid CSV format. Line {line} doesn\'t match '
222-
f'the expected field order {field_names}.',
223-
)
224-
data.append(construct_model_from_line(input_doc_list_model, line))
225-
226-
return await process(input_model(data=data))
252+
if first_row:
253+
first_row = False
254+
if len(line) > 1 and line[1] == 'params_row': # Check if it's a parameters row by examining the 2nd text in the first line
255+
parameters = construct_model_from_line(parameters_model, line[2:])
256+
else:
257+
if len(line) != len(field_names):
258+
raise HTTPException(
259+
status_code=http_status.HTTP_400_BAD_REQUEST,
260+
detail=f'Invalid CSV format. Line {line} doesn\'t match '
261+
f'the expected field order {field_names}.',
262+
)
263+
data.append(construct_model_from_line(input_doc_list_model, line))
264+
else:
265+
# Treat it as normal data row
266+
if len(line) != len(field_names):
267+
raise HTTPException(
268+
status_code=http_status.HTTP_400_BAD_REQUEST,
269+
detail=f'Invalid CSV format. Line {line} doesn\'t match '
270+
f'the expected field order {field_names}.',
271+
)
272+
data.append(construct_model_from_line(input_doc_list_model, line))
273+
274+
return await process(input_model(data=data, parameters=parameters))
227275

228276
else:
229277
raise HTTPException(
230-
status_code=400,
278+
status_code=http_status.HTTP_400_BAD_REQUEST,
231279
detail=f'Invalid content-type: {content_type}. '
232280
f'Please use either application/json or text/csv.',
233281
)
@@ -273,7 +321,7 @@ def construct_model_from_line(
273321
input_model=endpoint_input_model,
274322
output_model=endpoint_output_model,
275323
input_doc_list_model=input_doc_model,
276-
output_doc_list_model=output_doc_model,
324+
parameters_model=parameters_model,
277325
)
278326

279327
from jina.serve.runtimes.gateway.health_model import JinaHealthModel

tests/integration/docarray_v2/csp/SampleExecutor/executor.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from docarray import BaseDoc, DocList
33
from docarray.typing import NdArray
44
from pydantic import Field, BaseModel
5+
from typing import Optional, Literal
56

67
from jina import Executor, requests
78

@@ -20,7 +21,17 @@ class Config(BaseDoc.Config):
2021

2122

2223
class Parameters(BaseModel):
23-
emb_dim: int
24+
task: Optional[
25+
Literal[
26+
"retrieval.query",
27+
"retrieval.passage",
28+
"text-matching",
29+
"classification",
30+
"separation",
31+
]
32+
] = None
33+
late_chunking: bool = False
34+
dimensions: Optional[int] = None
2435

2536

2637

@@ -46,7 +57,7 @@ def bar(self, docs: DocList[TextDoc], parameters: Parameters, **kwargs) -> DocLi
4657
EmbeddingResponseModel(
4758
id=doc.id,
4859
text=doc.text,
49-
embeddings=np.random.random((1, parameters.emb_dim)),
60+
embeddings=np.random.random((1, parameters.dimensions)),
5061
)
5162
)
5263
return DocList[EmbeddingResponseModel](ret)

tests/integration/docarray_v2/csp/test_sagemaker_embedding.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_provider_sagemaker_pod_inference_parameters():
8989
'data': [
9090
{'text': 'hello world'},
9191
],
92-
'parameters': {'emb_dim': emb_dim}
92+
'parameters': {'dimensions': emb_dim}
9393
},
9494
)
9595
assert resp.status_code == 200
@@ -148,6 +148,50 @@ def test_provider_sagemaker_pod_batch_transform_valid(filename):
148148
assert len(d["embeddings"][0]) == 64
149149

150150

151+
def test_provider_sagemaker_pod_batch_transform_with_params_valid():
152+
args, _ = set_pod_parser().parse_known_args(
153+
[
154+
'--uses',
155+
os.path.join(os.path.dirname(__file__), "SampleExecutor", "config.yml"),
156+
'--provider',
157+
'sagemaker',
158+
"--provider-endpoint",
159+
"encode_parameter",
160+
'serve', # This is added by sagemaker
161+
]
162+
)
163+
with Pod(args):
164+
texts = []
165+
with open(os.path.join(os.path.dirname(__file__), "valid_input_3.csv"), "r") as f:
166+
csv_data = f.read()
167+
168+
csv_reader = csv.reader(io.StringIO(csv_data), delimiter=",", quoting=csv.QUOTE_NONE, escapechar="\\")
169+
170+
# Before comparison, remove the parameters row
171+
next(csv_reader)
172+
173+
for line in csv_reader:
174+
texts.append(line[1])
175+
176+
resp = requests.post(
177+
f"http://localhost:{sagemaker_port}/invocations",
178+
headers={
179+
"accept": "application/json",
180+
"content-type": "text/csv",
181+
},
182+
data=csv_data,
183+
)
184+
assert resp.status_code == 200
185+
resp_json = resp.json()
186+
assert len(resp_json["data"]) == 10
187+
for idx, d in enumerate(resp_json["data"]):
188+
assert d["text"] == texts[idx]
189+
assert len(d["embeddings"][0]) == 2
190+
191+
assert resp_json["parameters"]["late_chunking"] == False
192+
assert resp_json["parameters"]["task"] == "retrieval.query"
193+
194+
151195
def test_provider_sagemaker_pod_batch_transform_invalid():
152196
args, _ = set_pod_parser().parse_known_args(
153197
[

0 commit comments

Comments
 (0)