1
1
import os
2
+ import pytest
2
3
import time
3
4
from datetime import datetime
4
5
from typing import List , Tuple
@@ -65,41 +66,58 @@ def metadata(pat: bool = False) -> Tuple[Tuple[str, str], Tuple[str, str]]:
65
66
)
66
67
67
68
68
- def grpc_channel (func ):
69
+ def run_tests_using_channels (func , channel_keys ):
69
70
"""
70
- A decorator that runs the test using the gRPC channel .
71
+ A decorator that runs the test using given channels .
71
72
:param func: The test function.
73
+ :param channel_keys: A list of channel keys to use for the test. A subtest is created for each channel.
72
74
:return: A function wrapper.
73
75
"""
74
76
75
- def func_wrapper ():
77
+ @pytest .mark .parametrize ('channel_key' , channel_keys )
78
+ def func_wrapper (channel_key ):
79
+ return func (get_channel (channel_key ))
80
+
81
+ return func_wrapper
82
+
83
+
84
+ def get_channel (channel_key ):
85
+ if channel_key == "grpc" :
76
86
if os .getenv ("CLARIFAI_GRPC_INSECURE" , "False" ).lower () in ("true" , "1" , "t" ):
77
- channel = ClarifaiChannel .get_insecure_grpc_channel (port = 443 )
87
+ return ClarifaiChannel .get_insecure_grpc_channel (port = 443 )
78
88
else :
79
- channel = ClarifaiChannel .get_grpc_channel ()
80
- func (channel )
89
+ return ClarifaiChannel .get_grpc_channel ()
81
90
82
- return func_wrapper
91
+ if channel_key == "json" :
92
+ return ClarifaiChannel .get_json_channel ()
83
93
94
+ if channel_key == "asyncio" :
95
+ if os .getenv ("CLARIFAI_GRPC_INSECURE" , "False" ).lower () in ("true" , "1" , "t" ):
96
+ return ClarifaiChannel .get_aio_insecure_grpc_channel (port = 443 )
97
+ else :
98
+ return ClarifaiChannel .get_aio_grpc_channel ()
84
99
85
- def both_channels (func ):
100
+ raise ValueError (f"Unknown channel { channel_key } " )
101
+
102
+
103
+ def grpc_channel (func ):
86
104
"""
87
- A decorator that runs the test first using the gRPC channel and then using the JSON channel.
105
+ A decorator that runs the test using the gRPC channel.
88
106
:param func: The test function.
89
107
:return: A function wrapper.
90
108
"""
91
109
92
- def func_wrapper ():
93
- if os .getenv ("CLARIFAI_GRPC_INSECURE" , "False" ).lower () in ("true" , "1" , "t" ):
94
- channel = ClarifaiChannel .get_insecure_grpc_channel (port = 443 )
95
- else :
96
- channel = ClarifaiChannel .get_grpc_channel ()
97
- func (channel )
110
+ return run_tests_using_channels (func , ["grpc" ])
98
111
99
- channel = ClarifaiChannel .get_json_channel ()
100
- func (channel )
101
112
102
- return func_wrapper
113
+ def both_channels (func ):
114
+ """
115
+ A decorator that runs the test first using the gRPC channel and then using the JSON channel.
116
+ :param func: The test function.
117
+ :return: A function wrapper.
118
+ """
119
+
120
+ return run_tests_using_channels (func , ["grpc" , "json" ])
103
121
104
122
105
123
def asyncio_channel (func ):
@@ -109,13 +127,7 @@ def asyncio_channel(func):
109
127
:return: A function wrapper.
110
128
"""
111
129
112
- async def func_wrapper ():
113
- channel = ClarifaiChannel .get_aio_grpc_channel ()
114
- if os .getenv ("CLARIFAI_GRPC_INSECURE" , "False" ).lower () in ("true" , "1" , "t" ):
115
- channel = ClarifaiChannel .get_aio_insecure_grpc_channel (port = 443 )
116
- await func (channel )
117
-
118
- return func_wrapper
130
+ return run_tests_using_channels (func , ["asyncio" ])
119
131
120
132
121
133
def wait_for_inputs_upload (stub , metadata , input_ids ):
0 commit comments