Skip to content

Commit e4a7134

Browse files
committed
update
1 parent 9d17885 commit e4a7134

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

python/mscclpp/gpu_utils_py.cpp

+10-7
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,22 @@ static nb::capsule toDlpack(GpuBuffer<char> buffer, std::string dataType) {
6161
delete self;
6262
};
6363

64-
return nb::capsule(dlManagedTensor, "dltensor", [](void* self) noexcept {
65-
nb::capsule* capsule = static_cast<nb::capsule*>(self);
66-
if (strcmp(capsule->name(), "dltensor") != 0) {
64+
PyObject* dlCapsule = PyCapsule_New(static_cast<void*>(dlManagedTensor), "dltensor", [](PyObject* capsule) {
65+
if (PyCapsule_IsValid(capsule, "used_dltensor")) {
6766
return;
6867
}
69-
DLManagedTensor* tensor = static_cast<DLManagedTensor*>(capsule->data());
70-
if (tensor == nullptr) {
68+
if (!PyCapsule_IsValid(capsule, "dltensor")) {
7169
return;
7270
}
73-
if (tensor->deleter) {
74-
tensor->deleter(tensor);
71+
DLManagedTensor* managedTensor = static_cast<DLManagedTensor*>(PyCapsule_GetPointer(capsule, "dltensor"));
72+
if (managedTensor == nullptr) {
73+
return;
74+
}
75+
if (managedTensor->deleter) {
76+
managedTensor->deleter(managedTensor);
7577
}
7678
});
79+
return nb::steal<nb::capsule>(dlCapsule);
7780
}
7881

7982
void register_gpu_utils(nb::module_& m) {

0 commit comments

Comments
 (0)