1
- # Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1
+ # Copyright 2021-2023 , NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
#
3
3
# Redistribution and use in source and binary forms, with or without
4
4
# modification, are permitted provided that the following conditions
@@ -56,6 +56,23 @@ def test_pytorch_dlpack(self):
56
56
self .assertTrue (
57
57
pytorch_tensor .type () == pytorch_tensor_dlpack .type ())
58
58
59
+ # Now let's check that upgraded DLPack implementation also
60
+ # works as expected, i.e. from_dlpack should work with
61
+ # external pytorch tensor directly
62
+
63
+ pb_tensor_upgraded = pb_utils .Tensor .from_dlpack ('test_tensor' ,
64
+ pytorch_tensor )
65
+ self .assertTrue (
66
+ np .all (pb_tensor_upgraded .as_numpy () == pytorch_tensor .numpy ()))
67
+
68
+ # Here we check that `pb_tensor` as a producer, properly
69
+ # invokes `__dlpack__` and `__dlpack_device__`
70
+ pytorch_tensor_dlpack = from_dlpack (pb_tensor_upgraded )
71
+ self .assertTrue (torch .all (pytorch_tensor_dlpack == pytorch_tensor ))
72
+
73
+ self .assertTrue (
74
+ pytorch_tensor .type () == pytorch_tensor_dlpack .type ())
75
+
59
76
def test_non_contiguous_error (self ):
60
77
pytorch_tensor = torch .rand ([20 , 30 ], dtype = torch .float16 )
61
78
@@ -83,6 +100,8 @@ def test_dlpack_string_tensor(self):
83
100
84
101
def test_dlpack_gpu_tensors (self ):
85
102
# Test different dtypes
103
+ # PyTorch does not support DLPack bool type yet:
104
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/DLConvertor.cpp
86
105
pytorch_dtypes = [
87
106
torch .float16 , torch .float32 , torch .float64 , torch .int8 ,
88
107
torch .int16 , torch .int32 , torch .int64 , torch .uint8
@@ -100,27 +119,50 @@ def test_dlpack_gpu_tensors(self):
100
119
# the same
101
120
pytorch_tensor_dlpack = from_dlpack (pb_tensor .to_dlpack ())
102
121
self .assertTrue (torch .all (pytorch_tensor_dlpack == pytorch_tensor ))
103
-
104
- # DLPack does not properly support bool type:
105
- # https://github.com/google/jax/issues/4719
106
- if pytorch_dtype != torch .bool :
107
- self .assertTrue (
108
- pytorch_tensor .type () == pytorch_tensor_dlpack .type ())
109
- else :
110
- self .assertFalse (
111
- pytorch_tensor .type () == pytorch_tensor_dlpack .type ())
122
+ self .assertTrue (
123
+ pytorch_tensor .type () == pytorch_tensor_dlpack .type ())
124
+
125
+ # Now we make sure that updated DLPack implementation works
126
+ # with GPU as well
127
+ pb_tensor = pb_utils .Tensor .from_dlpack ('test_tensor' ,
128
+ pytorch_tensor )
129
+ pytorch_tensor_dlpack = from_dlpack (pb_tensor )
130
+ self .assertTrue (torch .all (pytorch_tensor_dlpack == pytorch_tensor ))
131
+ self .assertTrue (
132
+ pytorch_tensor .type () == pytorch_tensor_dlpack .type ())
133
+
112
134
113
135
def test_dlpack_gpu_numpy (self ):
114
136
# DLPack tesnors that are in GPU cannot be converted to NumPy
115
137
pytorch_tensor = torch .rand ([100 ], dtype = torch .float16 ,
116
138
device = 'cuda' ) * 100
117
139
pb_tensor = pb_utils .Tensor .from_dlpack ('tensor' ,
118
140
to_dlpack (pytorch_tensor ))
141
+ # Make sure that `__dlpack_device__` works as expected
142
+ self .assertTrue (pb_tensor .__dlpack_device__ () == pytorch_tensor .__dlpack_device__ ())
143
+
119
144
with self .assertRaises (Exception ) as e :
120
145
pb_tensor .as_numpy ()
121
146
self .assertTrue (
122
147
str (e .exception ) ==
123
148
'Tensor is stored in GPU and cannot be converted to NumPy.' )
149
+
150
+ def test_dlpack_cpu_numpy (self ):
151
+ # Check compatibiity of PbTensor DLPack implementation
152
+ # with numpy
153
+ pytorch_tensor = torch .rand ([100 ], dtype = torch .float16 ,
154
+ device = 'cpu' ) * 100
155
+ pb_tensor = pb_utils .Tensor .from_dlpack ('tensor' , pytorch_tensor )
156
+ numpy_tensor_dlpack = np .from_dlpack (pb_tensor )
157
+ self .assertTrue (np .all (numpy_tensor_dlpack == pytorch_tensor .numpy ()))
158
+ # Make sure that `__dlpack_device__` works as expected
159
+ self .assertTrue (pb_tensor .__dlpack_device__ () == pytorch_tensor .__dlpack_device__ ())
160
+
161
+ def test_pdtensor_bool_internal_support (self ):
162
+ bool_array = np .asarray ([False , True ])
163
+ bool_tensor = pb_utils .Tensor ('tensor' , bool_array )
164
+ bool_tensor_dlpack = pb_utils .Tensor .from_dlpack ('tensor' , bool_tensor )
165
+ self .assertTrue (np .all (bool_array == bool_tensor_dlpack .as_numpy ()))
124
166
125
167
126
168
class TritonPythonModel :
0 commit comments