Skip to content

Commit e2666eb

Browse files
authored
[TP-1781] Add test for MiniCPM-o-2_6-language model (#219)
1 parent a677c61 commit e2666eb

File tree

3 files changed

+100
-1
lines changed

3 files changed

+100
-1
lines changed

tests/common.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,23 @@ def metadata(pat: bool = False) -> Tuple[Tuple[str, str], Tuple[str, str]]:
6565
)
6666

6767

68+
def grpc_channel(func):
69+
"""
70+
A decorator that runs the test using the gRPC channel.
71+
:param func: The test function.
72+
:return: A function wrapper.
73+
"""
74+
75+
def func_wrapper():
76+
if os.getenv("CLARIFAI_GRPC_INSECURE", "False").lower() in ("true", "1", "t"):
77+
channel = ClarifaiChannel.get_insecure_grpc_channel(port=443)
78+
else:
79+
channel = ClarifaiChannel.get_grpc_channel()
80+
func(channel)
81+
82+
return func_wrapper
83+
84+
6885
def both_channels(func):
6986
"""
7087
A decorator that runs the test first using the gRPC channel and then using the JSON channel.
@@ -73,9 +90,10 @@ def both_channels(func):
7390
"""
7491

7592
def func_wrapper():
76-
channel = ClarifaiChannel.get_grpc_channel()
7793
if os.getenv("CLARIFAI_GRPC_INSECURE", "False").lower() in ("true", "1", "t"):
7894
channel = ClarifaiChannel.get_insecure_grpc_channel(port=443)
95+
else:
96+
channel = ClarifaiChannel.get_grpc_channel()
7997
func(channel)
8098

8199
channel = ClarifaiChannel.get_json_channel()
@@ -335,6 +353,25 @@ def post_model_outputs_and_maybe_allow_retries(
335353
return response
336354

337355

356+
def _generate_model_outputs(
357+
stub: service_pb2_grpc.V2Stub,
358+
request: service_pb2.PostModelOutputsRequest,
359+
metadata: Tuple,
360+
):
361+
is_model_loaded = False
362+
for i in range(1, MAX_PREDICT_ATTEMPTS + 1):
363+
response_iterator = stub.GenerateModelOutputs(request, metadata=metadata)
364+
for response in response_iterator:
365+
if not is_model_loaded and response.status.code == status_code_pb2.MODEL_LOADING:
366+
print(f"Model {request.model_id} is still loading...")
367+
time.sleep(15)
368+
break
369+
is_model_loaded = True
370+
yield response
371+
if is_model_loaded:
372+
break
373+
374+
338375
async def async_post_model_outputs_and_maybe_allow_retries(
339376
stub: service_pb2_grpc.V2Stub,
340377
request: service_pb2.PostModelOutputsRequest,

tests/public_models/public_test_helper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
TEXT_SENTIMENT_MODEL_ID = "multilingual-uncased-sentiment" # bert-based
3232
TEXT_MULTILINGUAL_MODERATION_MODEL_ID = "bdcedc0f8da58c396b7df12f634ef923"
3333

34+
TEXT_MINI_CPM_MODEL_ID = "MiniCPM-o-2_6-language"
3435

3536
TRANSLATION_TEST_DATA = {
3637
"ROMANCE": "No me apetece nada estudiar esta noche",
@@ -212,6 +213,10 @@
212213
),
213214
]
214215

216+
TEXT_LLM_MODEL_TITLE_IDS_TUPLE = [
217+
("multimodal large language model", TEXT_MINI_CPM_MODEL_ID, "miniCPM", "openbmb")
218+
]
219+
215220
# title, model_id, text, app, user
216221
TEXT_FB_TRANSLATION_MODEL_TITLE_ID_DATA_TUPLE = []
217222
TEXT_HELSINKI_TRANSLATION_MODEL_TITLE_ID_DATA_TUPLE = []

tests/public_models/test_public_models_predicts.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
GENERAL_MODEL_ID,
2020
MAIN_APP_ID,
2121
MAIN_APP_USER_ID,
22+
_generate_model_outputs,
2223
async_post_model_outputs_and_maybe_allow_retries,
2324
async_raise_on_failure,
2425
asyncio_channel,
2526
both_channels,
27+
grpc_channel,
2628
metadata,
2729
post_model_outputs_and_maybe_allow_retries,
2830
raise_on_failure,
@@ -35,6 +37,7 @@
3537
MULTIMODAL_MODEL_TITLE_AND_IDS,
3638
TEXT_FB_TRANSLATION_MODEL_TITLE_ID_DATA_TUPLE,
3739
TEXT_HELSINKI_TRANSLATION_MODEL_TITLE_ID_DATA_TUPLE,
40+
TEXT_LLM_MODEL_TITLE_IDS_TUPLE,
3841
TEXT_MODEL_TITLE_IDS_TUPLE,
3942
TRANSLATION_TEST_DATA,
4043
)
@@ -120,6 +123,60 @@ def test_text_predict_on_public_models(channel):
120123
)
121124

122125

126+
@grpc_channel
127+
def test_text_predict_on_public_llm_models(channel):
128+
stub = service_pb2_grpc.V2Stub(channel)
129+
130+
for title, model_id, app_id, user_id in TEXT_LLM_MODEL_TITLE_IDS_TUPLE:
131+
request = service_pb2.PostModelOutputsRequest(
132+
user_app_id=resources_pb2.UserAppIDSet(user_id=user_id, app_id=app_id),
133+
model_id=model_id,
134+
inputs=[
135+
resources_pb2.Input(
136+
data=resources_pb2.Data(
137+
parts=[
138+
resources_pb2.Part(
139+
id="prompt",
140+
data=resources_pb2.Data(
141+
string_value=TRANSLATION_TEST_DATA["EN"],
142+
),
143+
),
144+
resources_pb2.Part(
145+
id="max_tokens",
146+
data=resources_pb2.Data(
147+
int_value=10,
148+
),
149+
),
150+
resources_pb2.Part(
151+
id="temperature",
152+
data=resources_pb2.Data(
153+
float_value=0.7,
154+
),
155+
),
156+
resources_pb2.Part(
157+
id="top_p",
158+
data=resources_pb2.Data(
159+
float_value=0.95,
160+
),
161+
),
162+
]
163+
)
164+
)
165+
],
166+
)
167+
response_iterator = _generate_model_outputs(stub, request, metadata(pat=True))
168+
169+
responses_count = 0
170+
for response in response_iterator:
171+
responses_count += 1
172+
raise_on_failure(
173+
response,
174+
custom_message=f"Text predict failed for the {title} model (ID: {model_id}).",
175+
)
176+
177+
assert responses_count > 0
178+
179+
123180
@asyncio_channel
124181
async def test_text_predict_on_public_models_async(channel):
125182
"""Test non translation text/nlp models.

0 commit comments

Comments
 (0)