Skip to content

Commit b1dbf66

Browse files
author
WangGLJoseph
committed
improve coverage of test_gpt.py
1 parent 507cc10 commit b1dbf66

File tree

1 file changed

+71
-39
lines changed

1 file changed

+71
-39
lines changed

test/app/model/test_gpt.py

Lines changed: 71 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,6 @@ class DummyCompletions:
3737
def create(self, *args, **kwargs):
3838
DummyCompletions.last_kwargs = kwargs # capture the kwargs passed in
3939
return DummyResponse()
40-
41-
# Create a dummy client that raises BadRequestError.
42-
class DummyBadRequestCompletions:
43-
def create(self, *args, **kwargs):
44-
raise BadRequestError("error", code="context_length_exceeded")
45-
46-
class DummyBadRequestClientChat:
47-
completions = DummyBadRequestCompletions()
48-
49-
class DummyBadRequestClient:
50-
chat = DummyBadRequestClientChat()
5140

5241
# Dummy client chat now includes a completions attribute.
5342
class DummyClientChat:
@@ -307,21 +296,10 @@ def test_call_single_tool_branch(monkeypatch):
307296

308297
# Define a dummy error class to simulate BadRequestError with a code attribute.
309298
class DummyBadRequestError(BadRequestError):
310-
def __init__(self, message, code):
299+
def __init__(self, message):
311300
# Do not call super().__init__ to avoid unexpected keyword errors.
312-
self.code = code
313301
self.message = message
314302

315-
# Global flag to capture log_and_print invocation.
316-
log_and_print_called = False
317-
# Global flag to capture log_and_print invocation.
318-
log_and_print_called = False
319-
320-
def dummy_log_and_print(message):
321-
global log_and_print_called
322-
log_and_print_called = True
323-
print(f"dummy_log_and_print called with message: {message}")
324-
325303
class DummyThreadCost:
326304
process_cost = 0.0
327305
process_input_tokens = 0
@@ -336,22 +314,52 @@ def dummy_retry(*args, **kwargs):
336314
print("dummy_retry decorator applied")
337315
return lambda f: f
338316

317+
# Define a dummy response object with the required attributes.
318+
class DummyResponseObject:
319+
request = "dummy_request"
320+
status_code = 400 # Provide a dummy status code.
321+
headers = {"content-type": "application/json"}
322+
339323
# Create a dummy client that always raises BadRequestError.
340324
class DummyBadRequestCompletions:
341325
def create(self, *args, **kwargs):
342326
print("DummyBadRequestCompletions.create called")
343-
raise BadRequestError("error", code="context_length_exceeded")
327+
# Instantiate a BadRequestError with a dummy response object.
328+
err = BadRequestError("error", response=DummyResponseObject(), body={})
329+
err.code = "context_length_exceeded"
330+
raise err
344331

345332
class DummyBadRequestClientChat:
346333
completions = DummyBadRequestCompletions()
347334

348335
class DummyBadRequestClient:
349336
chat = DummyBadRequestClientChat()
350337

351-
def test_call_bad_request(monkeypatch):
352-
global log_and_print_called
353-
log_and_print_called = False
338+
# Create a dummy client that always raises BadRequestError, with a different 'code' message.
339+
class DummyBadRequestCompletionsOther:
340+
def create(self, *args, **kwargs):
341+
print("DummyBadRequestCompletionsOther.create called")
342+
# Instantiate a BadRequestError with a dummy response object.
343+
err = BadRequestError("error", response=DummyResponseObject(), body={})
344+
err.code = "some_other_code"
345+
raise err
346+
347+
class DummyBadRequestClientChatOther:
348+
completions = DummyBadRequestCompletionsOther()
349+
350+
class DummyBadRequestClientOther:
351+
chat = DummyBadRequestClientChatOther()
352+
353+
def extract_exception_chain(exc):
354+
"""Utility to walk the __cause__ chain and return a list of exceptions."""
355+
chain = [exc]
356+
while exc.__cause__ is not None:
357+
exc = exc.__cause__
358+
chain.append(exc)
359+
return chain
354360

361+
def test_call_bad_request(monkeypatch):
362+
# Do not patch log_and_print so that the actual lines in the except block execute.
355363
# Disable sleep functions so that no real delays occur.
356364
monkeypatch.setattr("tenacity.sleep", dummy_sleep)
357365
monkeypatch.setattr(time, "sleep", dummy_sleep)
@@ -363,10 +371,7 @@ def test_call_bad_request(monkeypatch):
363371
# Replace common.thread_cost with our dummy instance.
364372
monkeypatch.setattr(common, "thread_cost", DummyThreadCost())
365373

366-
# Patch log_and_print (imported from app.log) to record its call.
367-
monkeypatch.setattr("app.log.log_and_print", dummy_log_and_print)
368-
369-
# Create a dummy client that always raises DummyBadRequestError.
374+
# Create a dummy client that always raises BadRequestError.
370375
model = Gpt_o1()
371376
model.client = DummyBadRequestClient()
372377

@@ -375,12 +380,39 @@ def test_call_bad_request(monkeypatch):
375380
print("Calling model.call with messages:", messages)
376381
with pytest.raises(RetryError) as exc_info:
377382
model.call(messages, temperature=1.0)
378-
# Extract the exception from the final attempt.
379-
last_exception = exc_info.value.last_attempt.exception()
380-
print("Last exception caught:", last_exception)
381383

382-
# Verify that the last exception has the expected code.
383-
assert isinstance(last_exception, RetryError)
384-
# assert last_exception.code == "context_length_exceeded"
385-
# # Verify that our dummy log_and_print was invoked.
386-
# assert log_and_print_called
384+
# Extract the last exception from the RetryError chain.
385+
last_exc = exc_info.value.last_attempt.exception()
386+
print("Final exception from last attempt:", last_exc)
387+
388+
# Walk the cause chain to see if BadRequestError is present.
389+
chain = extract_exception_chain(last_exc)
390+
for i, e in enumerate(chain):
391+
print(f"Exception in chain [{i}]: type={type(e)}, message={getattr(e, 'message', str(e))}, code={getattr(e, 'code', None)}")
392+
393+
# Assert that one exception in the chain is a BadRequestError with the expected code.
394+
found = any(isinstance(e, BadRequestError) and getattr(e, "code", None) == "context_length_exceeded"
395+
for e in chain)
396+
assert found, "BadRequestError with expected code not found in exception chain."
397+
398+
# Other tests with different error codes.
399+
model.client = DummyBadRequestClientOther()
400+
messages = [{"role": "user", "content": "Hello"}]
401+
402+
print("Calling model.call with messages:", messages)
403+
with pytest.raises(RetryError) as exc_info:
404+
model.call(messages, temperature=1.0)
405+
406+
# Extract the last exception from the RetryError chain.
407+
last_exc = exc_info.value.last_attempt.exception()
408+
print("Final exception from last attempt:", last_exc)
409+
410+
# Walk the cause chain to see if BadRequestError is present.
411+
chain = extract_exception_chain(last_exc)
412+
for i, e in enumerate(chain):
413+
print(f"Exception in chain [{i}]: type={type(e)}, message={getattr(e, 'message', str(e))}, code={getattr(e, 'code', None)}")
414+
415+
# Assert that one exception in the chain is a BadRequestError with the expected code.
416+
found = any(isinstance(e, BadRequestError) and getattr(e, "code", None) == "some_other_code"
417+
for e in chain)
418+
assert found, "BadRequestError with expected code not found in exception chain."

0 commit comments

Comments
 (0)