Skip to content

Commit 2d5d986

Browse files
authored
[TP-1786] Create subtests for each channel in client tests (#220)
### Why * Make it clearer when a test is executed over multiple channels ### How * Add new methods `run_tests_using_channels` & `run_tests_using_channels_async` that use `@pytest.mark.parametrize` to create subtests for each test channel
1 parent e2666eb commit 2d5d986

File tree

1 file changed

+51
-24
lines changed

1 file changed

+51
-24
lines changed

tests/common.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from datetime import datetime
44
from typing import List, Tuple
55

6+
import pytest
67
from grpc._channel import _Rendezvous
78

89
from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
@@ -65,41 +66,73 @@ def metadata(pat: bool = False) -> Tuple[Tuple[str, str], Tuple[str, str]]:
6566
)
6667

6768

68-
def grpc_channel(func):
69+
def run_tests_using_channels(func, channel_keys):
6970
"""
70-
A decorator that runs the test using the gRPC channel.
71+
A decorator that runs the test using given channels.
7172
:param func: The test function.
73+
:param channel_keys: A list of channel keys to use for the test. A subtest is created for each channel.
7274
:return: A function wrapper.
7375
"""
7476

75-
def func_wrapper():
76-
if os.getenv("CLARIFAI_GRPC_INSECURE", "False").lower() in ("true", "1", "t"):
77-
channel = ClarifaiChannel.get_insecure_grpc_channel(port=443)
78-
else:
79-
channel = ClarifaiChannel.get_grpc_channel()
80-
func(channel)
77+
@pytest.mark.parametrize('channel_key', channel_keys)
78+
def func_wrapper(channel_key):
79+
return func(get_channel(channel_key))
8180

8281
return func_wrapper
8382

8483

85-
def both_channels(func):
84+
def run_tests_using_channels_async(func, channel_keys):
8685
"""
87-
A decorator that runs the test first using the gRPC channel and then using the JSON channel.
86+
A decorator that runs the test using given channels.
8887
:param func: The test function.
88+
:param channel_keys: A list of channel keys to use for the test. A subtest is created for each channel.
8989
:return: A function wrapper.
9090
"""
9191

92-
def func_wrapper():
92+
@pytest.mark.parametrize('channel_key', channel_keys)
93+
async def func_wrapper(channel_key):
94+
return await func(get_channel(channel_key))
95+
96+
return func_wrapper
97+
98+
99+
def get_channel(channel_key):
100+
if channel_key == "grpc":
93101
if os.getenv("CLARIFAI_GRPC_INSECURE", "False").lower() in ("true", "1", "t"):
94-
channel = ClarifaiChannel.get_insecure_grpc_channel(port=443)
102+
return ClarifaiChannel.get_insecure_grpc_channel(port=443)
95103
else:
96-
channel = ClarifaiChannel.get_grpc_channel()
97-
func(channel)
104+
return ClarifaiChannel.get_grpc_channel()
98105

99-
channel = ClarifaiChannel.get_json_channel()
100-
func(channel)
106+
if channel_key == "json":
107+
return ClarifaiChannel.get_json_channel()
101108

102-
return func_wrapper
109+
if channel_key == "aio_grpc":
110+
if os.getenv("CLARIFAI_GRPC_INSECURE", "False").lower() in ("true", "1", "t"):
111+
return ClarifaiChannel.get_aio_insecure_grpc_channel(port=443)
112+
else:
113+
return ClarifaiChannel.get_aio_grpc_channel()
114+
115+
raise ValueError(f"Unknown channel {channel_key}")
116+
117+
118+
def grpc_channel(func):
119+
"""
120+
A decorator that runs the test using the gRPC channel.
121+
:param func: The test function.
122+
:return: A function wrapper.
123+
"""
124+
125+
return run_tests_using_channels(func, ["grpc"])
126+
127+
128+
def both_channels(func):
129+
"""
130+
A decorator that runs the test first using the gRPC channel and then using the JSON channel.
131+
:param func: The test function.
132+
:return: A function wrapper.
133+
"""
134+
135+
return run_tests_using_channels(func, ["grpc", "json"])
103136

104137

105138
def asyncio_channel(func):
@@ -109,13 +142,7 @@ def asyncio_channel(func):
109142
:return: A function wrapper.
110143
"""
111144

112-
async def func_wrapper():
113-
channel = ClarifaiChannel.get_aio_grpc_channel()
114-
if os.getenv("CLARIFAI_GRPC_INSECURE", "False").lower() in ("true", "1", "t"):
115-
channel = ClarifaiChannel.get_aio_insecure_grpc_channel(port=443)
116-
await func(channel)
117-
118-
return func_wrapper
145+
return run_tests_using_channels_async(func, ["aio_grpc"])
119146

120147

121148
def wait_for_inputs_upload(stub, metadata, input_ids):

0 commit comments

Comments
 (0)