3
3
from datetime import datetime
4
4
from typing import List , Tuple
5
5
6
+ import pytest
6
7
from grpc ._channel import _Rendezvous
7
8
8
9
from clarifai_grpc .channel .clarifai_channel import ClarifaiChannel
@@ -65,41 +66,73 @@ def metadata(pat: bool = False) -> Tuple[Tuple[str, str], Tuple[str, str]]:
65
66
)
66
67
67
68
68
- def grpc_channel (func ):
69
+ def run_tests_using_channels (func , channel_keys ):
69
70
"""
70
- A decorator that runs the test using the gRPC channel .
71
+ A decorator that runs the test using given channels .
71
72
:param func: The test function.
73
+ :param channel_keys: A list of channel keys to use for the test. A subtest is created for each channel.
72
74
:return: A function wrapper.
73
75
"""
74
76
75
- def func_wrapper ():
76
- if os .getenv ("CLARIFAI_GRPC_INSECURE" , "False" ).lower () in ("true" , "1" , "t" ):
77
- channel = ClarifaiChannel .get_insecure_grpc_channel (port = 443 )
78
- else :
79
- channel = ClarifaiChannel .get_grpc_channel ()
80
- func (channel )
77
+ @pytest .mark .parametrize ('channel_key' , channel_keys )
78
+ def func_wrapper (channel_key ):
79
+ return func (get_channel (channel_key ))
81
80
82
81
return func_wrapper
83
82
84
83
85
- def both_channels (func ):
84
+ def run_tests_using_channels_async (func , channel_keys ):
86
85
"""
87
- A decorator that runs the test first using the gRPC channel and then using the JSON channel .
86
+ A decorator that runs the test using given channels .
88
87
:param func: The test function.
88
+ :param channel_keys: A list of channel keys to use for the test. A subtest is created for each channel.
89
89
:return: A function wrapper.
90
90
"""
91
91
92
- def func_wrapper ():
92
+ @pytest .mark .parametrize ('channel_key' , channel_keys )
93
+ async def func_wrapper (channel_key ):
94
+ return await func (get_channel (channel_key ))
95
+
96
+ return func_wrapper
97
+
98
+
99
+ def get_channel (channel_key ):
100
+ if channel_key == "grpc" :
93
101
if os .getenv ("CLARIFAI_GRPC_INSECURE" , "False" ).lower () in ("true" , "1" , "t" ):
94
- channel = ClarifaiChannel .get_insecure_grpc_channel (port = 443 )
102
+ return ClarifaiChannel .get_insecure_grpc_channel (port = 443 )
95
103
else :
96
- channel = ClarifaiChannel .get_grpc_channel ()
97
- func (channel )
104
+ return ClarifaiChannel .get_grpc_channel ()
98
105
99
- channel = ClarifaiChannel . get_json_channel ()
100
- func ( channel )
106
+ if channel_key == "json" :
107
+ return ClarifaiChannel . get_json_channel ( )
101
108
102
- return func_wrapper
109
+ if channel_key == "aio_grpc" :
110
+ if os .getenv ("CLARIFAI_GRPC_INSECURE" , "False" ).lower () in ("true" , "1" , "t" ):
111
+ return ClarifaiChannel .get_aio_insecure_grpc_channel (port = 443 )
112
+ else :
113
+ return ClarifaiChannel .get_aio_grpc_channel ()
114
+
115
+ raise ValueError (f"Unknown channel { channel_key } " )
116
+
117
+
118
+ def grpc_channel (func ):
119
+ """
120
+ A decorator that runs the test using the gRPC channel.
121
+ :param func: The test function.
122
+ :return: A function wrapper.
123
+ """
124
+
125
+ return run_tests_using_channels (func , ["grpc" ])
126
+
127
+
128
+ def both_channels (func ):
129
+ """
130
+ A decorator that runs the test first using the gRPC channel and then using the JSON channel.
131
+ :param func: The test function.
132
+ :return: A function wrapper.
133
+ """
134
+
135
+ return run_tests_using_channels (func , ["grpc" , "json" ])
103
136
104
137
105
138
def asyncio_channel (func ):
@@ -109,13 +142,7 @@ def asyncio_channel(func):
109
142
:return: A function wrapper.
110
143
"""
111
144
112
- async def func_wrapper ():
113
- channel = ClarifaiChannel .get_aio_grpc_channel ()
114
- if os .getenv ("CLARIFAI_GRPC_INSECURE" , "False" ).lower () in ("true" , "1" , "t" ):
115
- channel = ClarifaiChannel .get_aio_insecure_grpc_channel (port = 443 )
116
- await func (channel )
117
-
118
- return func_wrapper
145
+ return run_tests_using_channels_async (func , ["aio_grpc" ])
119
146
120
147
121
148
def wait_for_inputs_upload (stub , metadata , input_ids ):
0 commit comments