Skip to content

Commit fb4ed95

Browse files
committed
refactor
1 parent 8dd6339 commit fb4ed95

File tree

1 file changed

+69
-22
lines changed

1 file changed

+69
-22
lines changed

tests/public_models/test_public_models_predicts.py

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -350,35 +350,76 @@ def test_multimodal_predict_on_public_models(channel):
350350

351351
# Helper functions for the new OpenAI compatible endpoint test
352352
def _call_openai_model(model_id):
353-
"""Attempts to call a model using OpenAI's chat completions and image generation APIs."""
353+
"""
354+
Attempts to call a model using OpenAI's chat completions and image generation APIs,
355+
with an integrated retry mechanism for transient network errors.
356+
"""
354357
client = OpenAI(api_key=API_KEY, base_url="https://api.clarifai.com/v2/ext/openai/v1")
355-
try:
356-
response = client.chat.completions.create(
357-
model=model_id,
358-
messages=[
359-
{"role": "system", "content": "You are a helpful assistant."},
360-
{"role": "user", "content": "Who are you?"},
361-
],
362-
max_tokens=50,
363-
)
364-
assert hasattr(response, "choices") and len(response.choices) > 0, "No choices in response"
365-
return response, None
366-
except Exception as e1:
358+
last_err_chat = None
359+
last_err_image = None
360+
361+
# --- Attempt 1: Chat Completions with Retry ---
362+
for attempt in range(MAX_RETRY_ATTEMPTS):
363+
try:
364+
response = client.chat.completions.create(
365+
model=model_id,
366+
messages=[
367+
{"role": "system", "content": "You are a helpful assistant."},
368+
{"role": "user", "content": "Who are you?"},
369+
],
370+
max_tokens=50,
371+
)
372+
if hasattr(response, 'choices') and response.choices:
373+
return response, None # Success
374+
else:
375+
last_err_chat = ValueError(
376+
f"Chat completions returned no choices. Response: {response}"
377+
)
378+
break # Successful call but no data, so don't retry.
379+
except (APIConnectionError, APITimeoutError, RateLimitError) as e:
380+
last_err_chat = e
381+
if attempt == MAX_RETRY_ATTEMPTS - 1:
382+
break # Last attempt failed, break to move on
383+
print(
384+
f"Retrying chat predict for '{model_id}' after error: {e}. Attempt #{attempt + 1}"
385+
)
386+
time.sleep(attempt + 1) # Exponential backoff
387+
except Exception as e:
388+
last_err_chat = e
389+
break # Non-retriable error, break immediately
390+
391+
# --- Attempt 2: Image Generation with Retry ---
392+
for attempt in range(MAX_RETRY_ATTEMPTS):
367393
try:
368394
response = client.images.generate(
369395
model=model_id,
370396
prompt="A cat and a dog sitting together in a park",
371397
)
372-
assert hasattr(response, "data") and len(response.data) > 0, (
373-
"No image data in response"
398+
if hasattr(response, 'data') and response.data:
399+
return response, None # Success
400+
else:
401+
last_err_image = ValueError(
402+
f"Image generation returned no data. Response: {response}"
403+
)
404+
break # Successful call but no data, so don't retry.
405+
except (APIConnectionError, APITimeoutError, RateLimitError) as e:
406+
last_err_image = e
407+
if attempt == MAX_RETRY_ATTEMPTS - 1:
408+
break # Last attempt failed, break
409+
print(
410+
f"Retrying image predict for '{model_id}' after error: {e}. Attempt #{attempt + 1}"
374411
)
375-
return response, None
376-
except Exception as e2:
377-
return None, f"chat.completions error: {e1}; image.generate error: {e2}"
412+
time.sleep(attempt + 1) # Exponential backoff
413+
except Exception as e:
414+
last_err_image = e
415+
break # Non-retriable error, break immediately
416+
417+
return None, f"chat.completions error: {last_err_chat}; image.generate error: {last_err_image}"
378418

379419

380420
def _list_featured_models(per_page=50):
381421
"""Lists featured models from the Clarifai platform."""
422+
# This function remains unchanged
382423
channel = ClarifaiChannel.get_grpc_channel()
383424
stub = service_pb2_grpc.V2Stub(channel)
384425
auth_metadata = (("authorization", f"Key {API_KEY}"),)
@@ -389,9 +430,16 @@ def _list_featured_models(per_page=50):
389430
return response.models
390431

391432

433+
# The test functions below remain unchanged as the retry logic
434+
# is now encapsulated within the _call_openai_model helper.
435+
436+
392437
# New integrated test
393438
def test_openai_compatible_endpoint_on_featured_models():
394439
"""Tests the OpenAI compatible endpoint with featured models."""
440+
if not API_KEY:
441+
pytest.skip("Skipping test: CLARIFAI_PAT environment variable not set.")
442+
395443
featured_models = _list_featured_models()
396444
failed_models = []
397445

@@ -411,17 +459,16 @@ def test_openai_compatible_endpoint_on_featured_models():
411459
# New integrated async test
412460
async def _call_openai_model_async(model_identifier, session):
413461
"""Async helper to call a single model."""
414-
# Note: The OpenAI library doesn't natively support asyncio for this type of call.
415-
# To make this truly async, one would typically use an async-compatible HTTP client
416-
# like `httpx` or `aiohttp`. For simplicity and to match the provided script's
417-
# library, we run the synchronous `_call_openai_model` in a thread pool executor.
418462
loop = asyncio.get_running_loop()
419463
return await loop.run_in_executor(None, _call_openai_model, model_identifier)
420464

421465

422466
@pytest.mark.asyncio
423467
async def test_openai_compatible_endpoint_on_featured_models_async():
424468
"""Tests the OpenAI compatible endpoint concurrently with featured models."""
469+
if not API_KEY:
470+
pytest.skip("Skipping test: CLARIFAI_PAT environment variable not set.")
471+
425472
featured_models = _list_featured_models()
426473
tasks = []
427474
model_identifiers = []

0 commit comments

Comments
 (0)