27
27
28
28
import unittest
29
29
import time
30
+ from functools import partial
30
31
31
- import tritonclient .http as tritonclient
32
+ import tritonclient .http as httpclient
33
+ import tritonclient .grpc as grpcclient
32
34
33
35
import nvidia_smi
34
36
35
37
38
+ class UnifiedClientProxy :
39
+
40
+ def __init__ (self , client ):
41
+ self .client_ = client
42
+
43
+ def __getattr__ (self , attr ):
44
+ forward_attr = getattr (self .client_ , attr )
45
+ if type (self .client_ ) == grpcclient .InferenceServerClient :
46
+ if attr == "get_model_config" :
47
+ return lambda * args , ** kwargs : forward_attr (
48
+ * args , ** kwargs , as_json = True )["config" ]
49
+ elif attr == "get_inference_statistics" :
50
+ return partial (forward_attr , as_json = True )
51
+ return forward_attr
52
+
53
+
36
54
class MemoryUsageTest (unittest .TestCase ):
37
55
38
56
def setUp (self ):
39
57
nvidia_smi .nvmlInit ()
40
58
self .gpu_handle_ = nvidia_smi .nvmlDeviceGetHandleByIndex (0 )
41
- self .client_ = tritonclient .InferenceServerClient (url = "localhost:8000" )
59
+ self .http_client_ = httpclient .InferenceServerClient (
60
+ url = "localhost:8000" )
61
+ self .grpc_client_ = grpcclient .InferenceServerClient (
62
+ url = "localhost:8001" )
42
63
43
64
def tearDown (self ):
44
65
nvidia_smi .nvmlShutdown ()
@@ -55,7 +76,7 @@ def verify_recorded_usage(self, model_stat):
55
76
recorded_gpu_usage = 0
56
77
for usage in model_stat ["memory_usage" ]:
57
78
if usage ["type" ] == "GPU" :
58
- recorded_gpu_usage += usage ["byte_size" ]
79
+ recorded_gpu_usage += int ( usage ["byte_size" ])
59
80
# unload and verify recorded usage
60
81
before_total_usage = self .report_used_gpu_memory ()
61
82
self .client_ .unload_model (model_stat ["name" ])
@@ -71,13 +92,15 @@ def verify_recorded_usage(self, model_stat):
71
92
.format (model_stat ["name" ], usage_delta * 0.9 , usage_delta * 1.1 ,
72
93
recorded_gpu_usage ))
73
94
74
- def test_onnx (self ):
95
+ def test_onnx_http (self ):
96
+ self .client_ = UnifiedClientProxy (self .http_client_ )
75
97
model_stats = self .client_ .get_inference_statistics ()["model_stats" ]
76
98
for model_stat in model_stats :
77
99
if self .is_testing_backend (model_stat ["name" ], "onnxruntime" ):
78
100
self .verify_recorded_usage (model_stat )
79
101
80
- def test_plan (self ):
102
+ def test_plan_grpc (self ):
103
+ self .client_ = UnifiedClientProxy (self .grpc_client_ )
81
104
model_stats = self .client_ .get_inference_statistics ()["model_stats" ]
82
105
for model_stat in model_stats :
83
106
if self .is_testing_backend (model_stat ["name" ], "tensorrt" ):
0 commit comments