27
27
import sys
28
28
29
29
sys .path .append ("../common" )
30
-
31
30
import json
32
31
import unittest
33
-
32
+ import tritonclient .http as httpclient
33
+ import tritonclient .grpc as grpcclient
34
+ import numpy as np
34
35
import test_util as tu
36
+ import time
35
37
38
+ EXPECTED_NUM_SPANS = 10
36
39
37
40
class OpenTelemetryTest (tu .TestResultCollector ):
38
41
39
42
def setUp (self ):
40
- with open ('trace_collector.log' , 'rt' ) as f :
41
- data = f .read ()
42
-
43
+ while True :
44
+ with open ('trace_collector.log' , 'rt' ) as f :
45
+ data = f .read ()
46
+ if data .count ("resource_spans" ) != EXPECTED_NUM_SPANS :
47
+ time .sleep (5 )
48
+ continue
49
+ else :
50
+ break
51
+
43
52
data = data .split ('\n ' )
44
- full_spans = [entry for entry in data if "resource_spans" in entry ]
53
+ full_spans = [entry . split ( 'POST' )[ 0 ] for entry in data if "resource_spans" in entry ]
45
54
self .spans = []
46
55
for span in full_spans :
47
56
span = json .loads (span )
48
57
self .spans .append (
49
58
span ["resource_spans" ][0 ]['scope_spans' ][0 ]['spans' ][0 ])
50
59
51
- self .model_name = "simple"
60
+ self .simple_model_name = "simple"
61
+ self .ensemble_model_name = "ensemble_add_sub_int32_int32_int32"
52
62
self .root_span = "InferRequest"
53
63
54
64
def _check_events (self , span_name , events ):
@@ -102,7 +112,7 @@ def _check_events(self, span_name, events):
102
112
self .assertFalse (
103
113
all (entry in events for entry in compute_events ))
104
114
105
- elif span_name == self .model_name :
115
+ elif span_name == self .simple_model_name :
106
116
# Check that all request related events (and only them)
107
117
# are recorded in request span
108
118
self .assertTrue (all (entry in events for entry in request_events ))
@@ -131,14 +141,15 @@ def test_spans(self):
131
141
parsed_spans .append (span_name )
132
142
133
143
# There should be 6 spans in total:
134
- # 3 for http request and 3 for grpc request.
135
- self .assertEqual (len (self .spans ), 6 )
136
- # We should have 2 compute spans
137
- self .assertEqual (parsed_spans .count ("compute" ), 2 )
138
- # 2 request spans (named simple - same as our model name)
139
- self .assertEqual (parsed_spans .count (self .model_name ), 2 )
140
- # 2 root spans
141
- self .assertEqual (parsed_spans .count (self .root_span ), 2 )
144
+ # 3 for http request, 3 for grpc request, 4 for ensemble
145
+ self .assertEqual (len (self .spans ), 10 )
146
+ # We should have 3 compute spans
147
+ self .assertEqual (parsed_spans .count ("compute" ), 3 )
148
+ # 4 request spans (3 named simple - same as our model name, 1 ensemble)
149
+ self .assertEqual (parsed_spans .count (self .simple_model_name ), 3 )
150
+ self .assertEqual (parsed_spans .count (self .ensemble_model_name ), 1 )
151
+ # 3 root spans
152
+ self .assertEqual (parsed_spans .count (self .root_span ), 3 )
142
153
143
154
def test_nested_spans (self ):
144
155
@@ -156,9 +167,9 @@ def test_nested_spans(self):
156
167
self .spans [2 ],
157
168
"root span has a parent_span_id specified" )
158
169
159
- # Last 3 spans in `self.spans` belong to GRPC request
170
+ # Next 3 spans in `self.spans` belong to GRPC request
160
171
# Order of spans and their relationship described earlier
161
- for child , parent in zip (self .spans [3 :], self .spans [4 :]):
172
+ for child , parent in zip (self .spans [3 :6 ], self .spans [4 :6 ]):
162
173
self ._check_parent (child , parent )
163
174
164
175
# root_span should not have `parent_span_id` field
@@ -167,6 +178,48 @@ def test_nested_spans(self):
167
178
self .spans [5 ],
168
179
"root span has a parent_span_id specified" )
169
180
181
+ # Final 4 spans in `self.spans` belong to ensemble request
182
+ # Order of spans: compute span - request span - request span - root span
183
+ for child , parent in zip (self .spans [6 :10 ], self .spans [7 :10 ]):
184
+ self ._check_parent (child , parent )
185
+
186
+ # root_span should not have `parent_span_id` field
187
+ self .assertNotIn (
188
+ 'parent_span_id' ,
189
+ self .spans [9 ],
190
+ "root span has a parent_span_id specified" )
191
+
192
+ def prepare_data (client ):
193
+
194
+ inputs = []
195
+ outputs = []
196
+ input0_data = np .full (shape = (1 , 16 ), fill_value = - 1 , dtype = np .int32 )
197
+ input1_data = np .full (shape = (1 , 16 ), fill_value = - 1 , dtype = np .int32 )
198
+
199
+ inputs .append (client .InferInput ('INPUT0' , [1 , 16 ], "INT32" ))
200
+ inputs .append (client .InferInput ('INPUT1' , [1 , 16 ], "INT32" ))
201
+
202
+ # Initialize the data
203
+ inputs [0 ].set_data_from_numpy (input0_data )
204
+ inputs [1 ].set_data_from_numpy (input1_data )
205
+
206
+ return inputs
207
+
208
+ def prepare_traces ():
209
+
210
+ triton_client_http = httpclient .InferenceServerClient ("localhost:8000" ,
211
+ verbose = True )
212
+ triton_client_grpc = grpcclient .InferenceServerClient ("localhost:8001" ,
213
+ verbose = True )
214
+ inputs = prepare_data (httpclient )
215
+ triton_client_http .infer ("simple" ,inputs )
216
+
217
+ inputs = prepare_data (grpcclient )
218
+ triton_client_grpc .infer ("simple" , inputs )
219
+
220
+ inputs = prepare_data (httpclient )
221
+ triton_client_http .infer ("ensemble_add_sub_int32_int32_int32" , inputs )
222
+
170
223
171
224
if __name__ == '__main__' :
172
225
unittest .main ()
0 commit comments