Skip to content

Commit 202a545

Browse files
committed
Fix up
1 parent 62b0ebc commit 202a545

File tree

1 file changed

+28
-5
lines changed

1 file changed

+28
-5
lines changed

qa/L0_device_memory_tracker/test.py

+28-5
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,39 @@
2727

2828
import unittest
2929
import time
30+
from functools import partial
3031

31-
import tritonclient.http as tritonclient
32+
import tritonclient.http as httpclient
33+
import tritonclient.grpc as grpcclient
3234

3335
import nvidia_smi
3436

3537

38+
class UnifiedClientProxy:
39+
40+
def __init__(self, client):
41+
self.client_ = client
42+
43+
def __getattr__(self, attr):
44+
forward_attr = getattr(self.client_, attr)
45+
if type(self.client_) == grpcclient.InferenceServerClient:
46+
if attr == "get_model_config":
47+
return lambda *args, **kwargs: forward_attr(
48+
*args, **kwargs, as_json=True)["config"]
49+
elif attr == "get_inference_statistics":
50+
return partial(forward_attr, as_json=True)
51+
return forward_attr
52+
53+
3654
class MemoryUsageTest(unittest.TestCase):
3755

3856
def setUp(self):
3957
nvidia_smi.nvmlInit()
4058
self.gpu_handle_ = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
41-
self.client_ = tritonclient.InferenceServerClient(url="localhost:8000")
59+
self.http_client_ = httpclient.InferenceServerClient(
60+
url="localhost:8000")
61+
self.grpc_client_ = grpcclient.InferenceServerClient(
62+
url="localhost:8001")
4263

4364
def tearDown(self):
4465
nvidia_smi.nvmlShutdown()
@@ -55,7 +76,7 @@ def verify_recorded_usage(self, model_stat):
5576
recorded_gpu_usage = 0
5677
for usage in model_stat["memory_usage"]:
5778
if usage["type"] == "GPU":
58-
recorded_gpu_usage += usage["byte_size"]
79+
recorded_gpu_usage += int(usage["byte_size"])
5980
# unload and verify recorded usage
6081
before_total_usage = self.report_used_gpu_memory()
6182
self.client_.unload_model(model_stat["name"])
@@ -71,13 +92,15 @@ def verify_recorded_usage(self, model_stat):
7192
.format(model_stat["name"], usage_delta * 0.9, usage_delta * 1.1,
7293
recorded_gpu_usage))
7394

74-
def test_onnx(self):
95+
def test_onnx_http(self):
96+
self.client_ = UnifiedClientProxy(self.http_client_)
7597
model_stats = self.client_.get_inference_statistics()["model_stats"]
7698
for model_stat in model_stats:
7799
if self.is_testing_backend(model_stat["name"], "onnxruntime"):
78100
self.verify_recorded_usage(model_stat)
79101

80-
def test_plan(self):
102+
def test_plan_grpc(self):
103+
self.client_ = UnifiedClientProxy(self.grpc_client_)
81104
model_stats = self.client_.get_inference_statistics()["model_stats"]
82105
for model_stat in model_stats:
83106
if self.is_testing_backend(model_stat["name"], "tensorrt"):

0 commit comments

Comments
 (0)