@@ -577,68 +577,62 @@ class TestContextObjectWatchers(unittest.TestCase):
577
577
def context_watcher (self , which_watcher ):
578
578
wid = _testcapi .add_context_watcher (which_watcher )
579
579
try :
580
- yield wid
580
+ switches = _testcapi .get_context_switches (which_watcher )
581
+ except ValueError :
582
+ switches = None
583
+ try :
584
+ yield switches
581
585
finally :
582
586
_testcapi .clear_context_watcher (wid )
583
587
584
- def assert_event_counts (self , exp_enter_0 , exp_exit_0 ,
585
- exp_enter_1 , exp_exit_1 ):
586
- self .assertEqual (
587
- exp_enter_0 , _testcapi .get_context_watcher_num_enter_events (0 ))
588
- self .assertEqual (
589
- exp_exit_0 , _testcapi .get_context_watcher_num_exit_events (0 ))
590
- self .assertEqual (
591
- exp_enter_1 , _testcapi .get_context_watcher_num_enter_events (1 ))
592
- self .assertEqual (
593
- exp_exit_1 , _testcapi .get_context_watcher_num_exit_events (1 ))
588
+ def assert_event_counts (self , want_0 , want_1 ):
589
+ self .assertEqual (len (_testcapi .get_context_switches (0 )), want_0 )
590
+ self .assertEqual (len (_testcapi .get_context_switches (1 )), want_1 )
594
591
595
592
def test_context_object_events_dispatched (self ):
596
593
# verify that all counts are zero before any watchers are registered
597
- self .assert_event_counts (0 , 0 , 0 , 0 )
594
+ self .assert_event_counts (0 , 0 )
598
595
599
596
# verify that all counts remain zero when a context object is
600
597
# entered and exited with no watchers registered
601
598
ctx = contextvars .copy_context ()
602
- ctx .run (self .assert_event_counts , 0 , 0 , 0 , 0 )
603
- self .assert_event_counts (0 , 0 , 0 , 0 )
599
+ ctx .run (self .assert_event_counts , 0 , 0 )
600
+ self .assert_event_counts (0 , 0 )
604
601
605
602
# verify counts are as expected when first watcher is registered
606
603
with self .context_watcher (0 ):
607
- self .assert_event_counts (0 , 0 , 0 , 0 )
608
- ctx .run (self .assert_event_counts , 1 , 0 , 0 , 0 )
609
- self .assert_event_counts (1 , 1 , 0 , 0 )
604
+ self .assert_event_counts (0 , 0 )
605
+ ctx .run (self .assert_event_counts , 1 , 0 )
606
+ self .assert_event_counts (2 , 0 )
610
607
611
608
# again with second watcher registered
612
609
with self .context_watcher (1 ):
613
- self .assert_event_counts (1 , 1 , 0 , 0 )
614
- ctx .run (self .assert_event_counts , 2 , 1 , 1 , 0 )
615
- self .assert_event_counts (2 , 2 , 1 , 1 )
610
+ self .assert_event_counts (2 , 0 )
611
+ ctx .run (self .assert_event_counts , 3 , 1 )
612
+ self .assert_event_counts (4 , 2 )
616
613
617
614
# verify counts are reset and don't change after both watchers are cleared
618
- ctx .run (self .assert_event_counts , 0 , 0 , 0 , 0 )
619
- self .assert_event_counts (0 , 0 , 0 , 0 )
620
-
621
- def test_enter_error (self ):
622
- with self .context_watcher (2 ):
623
- with catch_unraisable_exception () as cm :
624
- ctx = contextvars .copy_context ()
625
- ctx .run (int , 0 )
626
- self .assertEqual (
627
- cm .unraisable .err_msg ,
628
- "Exception ignored in "
629
- f"Py_CONTEXT_EVENT_EXIT watcher callback for { ctx !r} "
630
- )
631
- self .assertEqual (str (cm .unraisable .exc_value ), "boom!" )
632
-
633
- def test_exit_error (self ):
634
- ctx = contextvars .copy_context ()
635
- def _in_context (stack ):
636
- stack .enter_context (self .context_watcher (2 ))
637
-
638
- with catch_unraisable_exception () as cm :
639
- with ExitStack () as stack :
640
- ctx .run (_in_context , stack )
641
- self .assertEqual (str (cm .unraisable .exc_value ), "boom!" )
615
+ ctx .run (self .assert_event_counts , 0 , 0 )
616
+ self .assert_event_counts (0 , 0 )
617
+
618
+ def test_callback_error (self ):
619
+ ctx_outer = contextvars .copy_context ()
620
+ ctx_inner = contextvars .copy_context ()
621
+ unraisables = []
622
+
623
+ def _in_outer ():
624
+ with self .context_watcher (2 ):
625
+ with catch_unraisable_exception () as cm :
626
+ ctx_inner .run (lambda : unraisables .append (cm .unraisable ))
627
+ unraisables .append (cm .unraisable )
628
+
629
+ ctx_outer .run (_in_outer )
630
+ self .assertEqual ([x .err_msg for x in unraisables ],
631
+ ["Exception ignored in Py_CONTEXT_SWITCHED "
632
+ f"watcher callback for { ctx !r} "
633
+ for ctx in [ctx_inner , ctx_outer ]])
634
+ self .assertEqual ([str (x .exc_value ) for x in unraisables ],
635
+ ["boom!" , "boom!" ])
642
636
643
637
def test_clear_out_of_range_watcher_id (self ):
644
638
with self .assertRaisesRegex (ValueError , r"Invalid context watcher ID -1" ):
@@ -654,5 +648,12 @@ def test_allocate_too_many_watchers(self):
654
648
with self .assertRaisesRegex (RuntimeError , r"no more context watcher IDs available" ):
655
649
_testcapi .allocate_too_many_context_watchers ()
656
650
651
+ def test_exit_base_context (self ):
652
+ ctx = contextvars .Context ()
653
+ _testcapi .clear_context_stack ()
654
+ with self .context_watcher (0 ) as switches :
655
+ ctx .run (lambda : None )
656
+ self .assertEqual (switches , [ctx , None ])
657
+
657
658
if __name__ == "__main__" :
658
659
unittest .main ()
0 commit comments