@@ -37,17 +37,6 @@ class DummyCompletions:
37
37
def create (self , * args , ** kwargs ):
38
38
DummyCompletions .last_kwargs = kwargs # capture the kwargs passed in
39
39
return DummyResponse ()
40
-
41
- # Create a dummy client that raises BadRequestError.
42
- class DummyBadRequestCompletions :
43
- def create (self , * args , ** kwargs ):
44
- raise BadRequestError ("error" , code = "context_length_exceeded" )
45
-
46
- class DummyBadRequestClientChat :
47
- completions = DummyBadRequestCompletions ()
48
-
49
- class DummyBadRequestClient :
50
- chat = DummyBadRequestClientChat ()
51
40
52
41
# Dummy client chat now includes a completions attribute.
53
42
class DummyClientChat :
@@ -307,21 +296,10 @@ def test_call_single_tool_branch(monkeypatch):
307
296
308
297
# Define a dummy error class to simulate BadRequestError with a code attribute.
309
298
class DummyBadRequestError (BadRequestError ):
310
- def __init__ (self , message , code ):
299
+ def __init__ (self , message ):
311
300
# Do not call super().__init__ to avoid unexpected keyword errors.
312
- self .code = code
313
301
self .message = message
314
302
315
- # Global flag to capture log_and_print invocation.
316
- log_and_print_called = False
317
- # Global flag to capture log_and_print invocation.
318
- log_and_print_called = False
319
-
320
- def dummy_log_and_print (message ):
321
- global log_and_print_called
322
- log_and_print_called = True
323
- print (f"dummy_log_and_print called with message: { message } " )
324
-
325
303
class DummyThreadCost :
326
304
process_cost = 0.0
327
305
process_input_tokens = 0
@@ -336,22 +314,52 @@ def dummy_retry(*args, **kwargs):
336
314
print ("dummy_retry decorator applied" )
337
315
return lambda f : f
338
316
317
+ # Define a dummy response object with the required attributes.
318
+ class DummyResponseObject :
319
+ request = "dummy_request"
320
+ status_code = 400 # Provide a dummy status code.
321
+ headers = {"content-type" : "application/json" }
322
+
339
323
# Create a dummy client that always raises BadRequestError.
340
324
class DummyBadRequestCompletions :
341
325
def create (self , * args , ** kwargs ):
342
326
print ("DummyBadRequestCompletions.create called" )
343
- raise BadRequestError ("error" , code = "context_length_exceeded" )
327
+ # Instantiate a BadRequestError with a dummy response object.
328
+ err = BadRequestError ("error" , response = DummyResponseObject (), body = {})
329
+ err .code = "context_length_exceeded"
330
+ raise err
344
331
345
332
class DummyBadRequestClientChat :
346
333
completions = DummyBadRequestCompletions ()
347
334
348
335
class DummyBadRequestClient :
349
336
chat = DummyBadRequestClientChat ()
350
337
351
- def test_call_bad_request (monkeypatch ):
352
- global log_and_print_called
353
- log_and_print_called = False
338
+ # Create a dummy client that always raises BadRequestError, with a different 'code' message.
339
+ class DummyBadRequestCompletionsOther :
340
+ def create (self , * args , ** kwargs ):
341
+ print ("DummyBadRequestCompletionsOther.create called" )
342
+ # Instantiate a BadRequestError with a dummy response object.
343
+ err = BadRequestError ("error" , response = DummyResponseObject (), body = {})
344
+ err .code = "some_other_code"
345
+ raise err
346
+
347
+ class DummyBadRequestClientChatOther :
348
+ completions = DummyBadRequestCompletionsOther ()
349
+
350
+ class DummyBadRequestClientOther :
351
+ chat = DummyBadRequestClientChatOther ()
352
+
353
+ def extract_exception_chain (exc ):
354
+ """Utility to walk the __cause__ chain and return a list of exceptions."""
355
+ chain = [exc ]
356
+ while exc .__cause__ is not None :
357
+ exc = exc .__cause__
358
+ chain .append (exc )
359
+ return chain
354
360
361
+ def test_call_bad_request (monkeypatch ):
362
+ # Do not patch log_and_print so that the actual lines in the except block execute.
355
363
# Disable sleep functions so that no real delays occur.
356
364
monkeypatch .setattr ("tenacity.sleep" , dummy_sleep )
357
365
monkeypatch .setattr (time , "sleep" , dummy_sleep )
@@ -363,10 +371,7 @@ def test_call_bad_request(monkeypatch):
363
371
# Replace common.thread_cost with our dummy instance.
364
372
monkeypatch .setattr (common , "thread_cost" , DummyThreadCost ())
365
373
366
- # Patch log_and_print (imported from app.log) to record its call.
367
- monkeypatch .setattr ("app.log.log_and_print" , dummy_log_and_print )
368
-
369
- # Create a dummy client that always raises DummyBadRequestError.
374
+ # Create a dummy client that always raises BadRequestError.
370
375
model = Gpt_o1 ()
371
376
model .client = DummyBadRequestClient ()
372
377
@@ -375,12 +380,39 @@ def test_call_bad_request(monkeypatch):
375
380
print ("Calling model.call with messages:" , messages )
376
381
with pytest .raises (RetryError ) as exc_info :
377
382
model .call (messages , temperature = 1.0 )
378
- # Extract the exception from the final attempt.
379
- last_exception = exc_info .value .last_attempt .exception ()
380
- print ("Last exception caught:" , last_exception )
381
383
382
- # Verify that the last exception has the expected code.
383
- assert isinstance (last_exception , RetryError )
384
- # assert last_exception.code == "context_length_exceeded"
385
- # # Verify that our dummy log_and_print was invoked.
386
- # assert log_and_print_called
384
+ # Extract the last exception from the RetryError chain.
385
+ last_exc = exc_info .value .last_attempt .exception ()
386
+ print ("Final exception from last attempt:" , last_exc )
387
+
388
+ # Walk the cause chain to see if BadRequestError is present.
389
+ chain = extract_exception_chain (last_exc )
390
+ for i , e in enumerate (chain ):
391
+ print (f"Exception in chain [{ i } ]: type={ type (e )} , message={ getattr (e , 'message' , str (e ))} , code={ getattr (e , 'code' , None )} " )
392
+
393
+ # Assert that one exception in the chain is a BadRequestError with the expected code.
394
+ found = any (isinstance (e , BadRequestError ) and getattr (e , "code" , None ) == "context_length_exceeded"
395
+ for e in chain )
396
+ assert found , "BadRequestError with expected code not found in exception chain."
397
+
398
+ # Other tests with different error codes.
399
+ model .client = DummyBadRequestClientOther ()
400
+ messages = [{"role" : "user" , "content" : "Hello" }]
401
+
402
+ print ("Calling model.call with messages:" , messages )
403
+ with pytest .raises (RetryError ) as exc_info :
404
+ model .call (messages , temperature = 1.0 )
405
+
406
+ # Extract the last exception from the RetryError chain.
407
+ last_exc = exc_info .value .last_attempt .exception ()
408
+ print ("Final exception from last attempt:" , last_exc )
409
+
410
+ # Walk the cause chain to see if BadRequestError is present.
411
+ chain = extract_exception_chain (last_exc )
412
+ for i , e in enumerate (chain ):
413
+ print (f"Exception in chain [{ i } ]: type={ type (e )} , message={ getattr (e , 'message' , str (e ))} , code={ getattr (e , 'code' , None )} " )
414
+
415
+ # Assert that one exception in the chain is a BadRequestError with the expected code.
416
+ found = any (isinstance (e , BadRequestError ) and getattr (e , "code" , None ) == "some_other_code"
417
+ for e in chain )
418
+ assert found , "BadRequestError with expected code not found in exception chain."
0 commit comments