@@ -217,23 +217,8 @@ def __dlpack__(self, *, stream=None):
217
217
218
218
self ._sync_on_requested_stream (stream )
219
219
220
- dl_managed_tensor = Tensor ._create_managed_tensor ()
221
- dl_managed_tensor .dl_tensor .data = self .data_ptr
222
- dl_managed_tensor .dl_tensor .device = DLDevice (
223
- TRITON_MEMORY_TYPE_TO_DLPACK_DEVICE_TYPE [self .memory_type ],
224
- self .memory_type_id ,
225
- )
220
+ dl_managed_tensor = self ._create_managed_tensor ()
226
221
227
- dl_managed_tensor .dl_tensor .dtype = TRITON_TO_DLPACK_DTYPE [self .data_type ]
228
- dl_managed_tensor .dl_tensor .ndim = len (self .shape )
229
- dl_managed_tensor .dl_tensor .shape = (ctypes .c_int64 * len (self .shape ))(
230
- * self .shape
231
- )
232
- dl_managed_tensor .dl_tensor .strides = ctypes .POINTER (ctypes .c_int64 )()
233
- dl_managed_tensor .dl_tensor .byte_offset = 0
234
- dl_managed_tensor .deleter = Tensor ._managed_tensor_deleter
235
-
236
- self ._set_dlpack_manager_ctx (dl_managed_tensor )
237
222
pycapsule = ctypes .pythonapi .PyCapsule_New (
238
223
ctypes .byref (dl_managed_tensor ),
239
224
c_str_dltensor ,
@@ -600,26 +585,39 @@ def _from_numpy(obj: numpy.ndarray | numpy.generic) -> Tensor:
600
585
size = obj .itemsize * obj .size ,
601
586
owner = obj ,
602
587
)
603
-
604
588
return Tensor (data_type , shape , memory_buffer )
605
589
606
- @staticmethod
607
- def _create_managed_tensor ():
590
+ def _create_managed_tensor (self ) -> DLManagedTensor :
591
+ # Allocates space for a managed tensor object
592
+ # and fills in the fields
593
+ #
594
+ # To ensure the lifetime of the managed tensor we create a
595
+ # context object that includes a newly created shape array and a
596
+ # reference to self
597
+
608
598
size = ctypes .c_size_t (ctypes .sizeof (DLManagedTensor ))
609
599
address = ctypes .pythonapi .PyMem_RawMalloc (size )
610
- return DLManagedTensor .from_address (address )
600
+ dl_managed_tensor = DLManagedTensor .from_address (address )
601
+ dl_managed_tensor .dl_tensor .data = self .data_ptr
602
+ dl_managed_tensor .dl_tensor .device = DLDevice (
603
+ TRITON_MEMORY_TYPE_TO_DLPACK_DEVICE_TYPE [self .memory_type ],
604
+ self .memory_type_id ,
605
+ )
606
+ dl_managed_tensor .dl_tensor .dtype = TRITON_TO_DLPACK_DTYPE [self .data_type ]
607
+ dl_managed_tensor .dl_tensor .ndim = len (self .shape )
608
+ manager_ctx = _ManagerCtx (self )
609
+ dl_managed_tensor .dl_tensor .shape = manager_ctx .shape
610
+ dl_managed_tensor .dl_tensor .strides = manager_ctx .strides
611
+ dl_managed_tensor .dl_tensor .byte_offset = 0
612
+ dl_managed_tensor .deleter = Tensor ._managed_tensor_deleter
613
+ dl_managed_tensor .manager_ctx = manager_ctx .reference ()
614
+ return dl_managed_tensor
611
615
612
616
@staticmethod
613
617
@ctypes .CFUNCTYPE (None , ctypes .c_void_p )
614
618
def _managed_tensor_deleter (handle : int ) -> None :
615
619
dl_managed_tensor = DLManagedTensor .from_address (handle )
616
- tensor_obj_ptr = ctypes .cast (
617
- dl_managed_tensor .manager_ctx , ctypes .POINTER (ctypes .py_object )
618
- )
619
- tensor_obj = tensor_obj_ptr .contents
620
- ctypes .pythonapi .Py_DecRef (tensor_obj )
621
- shape_obj = ctypes .py_object (dl_managed_tensor .dl_tensor .shape )
622
- ctypes .pythonapi .Py_DecRef (shape_obj )
620
+ _ManagerCtx .release (dl_managed_tensor .manager_ctx )
623
621
ctypes .pythonapi .PyMem_RawFree (handle )
624
622
625
623
@staticmethod
@@ -639,14 +637,36 @@ def _pycapsule_deleter(handle: ctypes.c_void_p) -> None:
639
637
print (f"Exception occurred while deleting capsule: { e } " )
640
638
raise e
641
639
642
- def _set_dlpack_manager_ctx (self , dl_managed_tensor ):
643
- tensor_obj = ctypes .py_object (self )
644
- tensor_obj_ptr = ctypes .pointer (tensor_obj )
645
- dl_managed_tensor .manager_ctx = ctypes .cast (tensor_obj_ptr , ctypes .c_void_p )
646
- shape_obj = ctypes .py_object (dl_managed_tensor .dl_tensor .shape )
647
- ctypes .pythonapi .Py_IncRef (tensor_obj )
648
- ctypes .pythonapi .Py_IncRef (shape_obj )
649
-
650
640
_from_converters : ClassVar [dict [type , Callable [[Any ], Tensor ]]] = dict (
651
641
{numpy .ndarray : _from_numpy , numpy .generic : _from_numpy , list : _from_list },
652
642
)
643
+
644
+
645
+ class _ManagerCtx :
646
+ # To ensure the lifetime of the managed tensor we create a
647
+ # context object that includes a newly created shape array and a
648
+ # reference to self
649
+
650
+ def __init__ (self , tensor : Tensor ) -> None :
651
+ self ._tensor = tensor
652
+ self .shape = (ctypes .c_int64 * len (tensor .shape ))(* tensor .shape )
653
+ self .strides = ctypes .POINTER (ctypes .c_int64 )()
654
+
655
+ def reference (self ) -> ctypes .c_void_p :
656
+ py_obj = ctypes .py_object (self )
657
+ ctypes .pythonapi .Py_IncRef (py_obj )
658
+
659
+ # Note: Could not find a direct way to cast a python object
660
+ # to a c_void_p. The mechanism is to either use id(self) or
661
+ # cast as described here:
662
+ #
663
+ # https://groups.google.com/g/dev-python/c/QRRqVC7gkf4/m/zH7l1gTXBwAJ
664
+ #
665
+ # To avoid relying on the behavior of id() we use the casting mechanism
666
+
667
+ return ctypes .POINTER (ctypes .c_void_p )(py_obj )[0 ]
668
+
669
+ @staticmethod
670
+ def release (reference : ctypes .c_void_p ) -> None :
671
+ py_obj = ctypes .cast (reference , ctypes .py_object )
672
+ ctypes .pythonapi .Py_DecRef (py_obj )
0 commit comments