@@ -290,6 +290,7 @@ def create_trans_fn(fn_name, fn_trafo_args):
290
290
# we need to create a scope-function from our class for the given method
291
291
@functools .wraps (fn )
292
292
def wrapped_fn (self , * args , ** kwargs ):
293
+ state = self ._state .export ()
293
294
# make a scope-function to transform
294
295
def core_fn (scopes , * args , ** kwargs ):
295
296
# make a clone of self using its arguments
@@ -301,7 +302,7 @@ def core_fn(scopes, *args, **kwargs):
301
302
# we reference module_class, not self.__class__ to avoid infinite loop
302
303
cloned = module_class (parent = None , ** attrs )
303
304
cloned , args , kwargs = set_module_scopes (cloned , args , kwargs , scopes )
304
- object .__setattr__ (cloned , '_state' , self . _state .export ()) # pylint: disable=protected-access
305
+ object .__setattr__ (cloned , '_state' , state .export ()) # pylint: disable=protected-access
305
306
res = fn (cloned , * args , ** kwargs )
306
307
self ._state .reimport (cloned ._state ) # pylint: disable=protected-access
307
308
_test_transformed_return_values (res , fn_name )
@@ -343,12 +344,13 @@ def decorator_lift_transform(transform, class_fn, *trafo_args,
343
344
prewrapped_fns = [wrap_method_once (class_fn ) for class_fn in class_fns ]
344
345
@functools .wraps (prewrapped_fns [0 ])
345
346
def wrapped_fn (self , * args , ** kwargs ):
347
+ state = self ._state .export ()
346
348
# make a scope-function to transform
347
349
def core_fn (prewrapped_fn , class_fn , scopes , * args , ** kwargs ):
348
350
if not multi_scope :
349
351
scopes = [scopes ]
350
352
cloned , args , kwargs = set_module_scopes (self , args , kwargs , scopes )
351
- object .__setattr__ (cloned , '_state' , self . _state .export ()) # pylint: disable=protected-access
353
+ object .__setattr__ (cloned , '_state' , state .export ()) # pylint: disable=protected-access
352
354
res = prewrapped_fn (cloned , * args , ** kwargs )
353
355
self ._state .reimport (cloned ._state ) # pylint: disable=protected-access
354
356
_test_transformed_return_values (res , getattr (class_fn , '__name__' , None ))
0 commit comments