Skip to content

Commit df0047e

Browse files
committed
[TP-1786] Create subtests for each channel in client tests
1 parent e2666eb commit df0047e

File tree

1 file changed

+37
-25
lines changed

1 file changed

+37
-25
lines changed

tests/common.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import pytest
23
import time
34
from datetime import datetime
45
from typing import List, Tuple
@@ -65,41 +66,58 @@ def metadata(pat: bool = False) -> Tuple[Tuple[str, str], Tuple[str, str]]:
6566
)
6667

6768

68-
def grpc_channel(func):
69+
def run_tests_using_channels(func, channel_keys):
6970
"""
70-
A decorator that runs the test using the gRPC channel.
71+
A decorator that runs the test using given channels.
7172
:param func: The test function.
73+
:param channel_keys: A list of channel keys to use for the test. A subtest is created for each channel.
7274
:return: A function wrapper.
7375
"""
7476

75-
def func_wrapper():
77+
@pytest.mark.parametrize('channel_key', channel_keys)
78+
def func_wrapper(channel_key):
79+
return func(get_channel(channel_key))
80+
81+
return func_wrapper
82+
83+
84+
def get_channel(channel_key):
85+
if channel_key == "grpc":
7686
if os.getenv("CLARIFAI_GRPC_INSECURE", "False").lower() in ("true", "1", "t"):
77-
channel = ClarifaiChannel.get_insecure_grpc_channel(port=443)
87+
return ClarifaiChannel.get_insecure_grpc_channel(port=443)
7888
else:
79-
channel = ClarifaiChannel.get_grpc_channel()
80-
func(channel)
89+
return ClarifaiChannel.get_grpc_channel()
8190

82-
return func_wrapper
91+
if channel_key == "json":
92+
return ClarifaiChannel.get_json_channel()
8393

94+
if channel_key == "asyncio":
95+
if os.getenv("CLARIFAI_GRPC_INSECURE", "False").lower() in ("true", "1", "t"):
96+
return ClarifaiChannel.get_aio_insecure_grpc_channel(port=443)
97+
else:
98+
return ClarifaiChannel.get_aio_grpc_channel()
8499

85-
def both_channels(func):
100+
raise ValueError(f"Unknown channel {channel_key}")
101+
102+
103+
def grpc_channel(func):
86104
"""
87-
A decorator that runs the test first using the gRPC channel and then using the JSON channel.
105+
A decorator that runs the test using the gRPC channel.
88106
:param func: The test function.
89107
:return: A function wrapper.
90108
"""
91109

92-
def func_wrapper():
93-
if os.getenv("CLARIFAI_GRPC_INSECURE", "False").lower() in ("true", "1", "t"):
94-
channel = ClarifaiChannel.get_insecure_grpc_channel(port=443)
95-
else:
96-
channel = ClarifaiChannel.get_grpc_channel()
97-
func(channel)
110+
return run_tests_using_channels(func, ["grpc"])
98111

99-
channel = ClarifaiChannel.get_json_channel()
100-
func(channel)
101112

102-
return func_wrapper
113+
def both_channels(func):
114+
"""
115+
A decorator that runs the test first using the gRPC channel and then using the JSON channel.
116+
:param func: The test function.
117+
:return: A function wrapper.
118+
"""
119+
120+
return run_tests_using_channels(func, ["grpc", "json"])
103121

104122

105123
def asyncio_channel(func):
@@ -109,13 +127,7 @@ def asyncio_channel(func):
109127
:return: A function wrapper.
110128
"""
111129

112-
async def func_wrapper():
113-
channel = ClarifaiChannel.get_aio_grpc_channel()
114-
if os.getenv("CLARIFAI_GRPC_INSECURE", "False").lower() in ("true", "1", "t"):
115-
channel = ClarifaiChannel.get_aio_insecure_grpc_channel(port=443)
116-
await func(channel)
117-
118-
return func_wrapper
130+
return run_tests_using_channels(func, ["asyncio"])
119131

120132

121133
def wait_for_inputs_upload(stub, metadata, input_ids):

0 commit comments

Comments
 (0)