Skip to content

Commit b27d06b

Browse files
authored
[dvs] Add generic polling utility (sonic-net#1233)
Signed-off-by: Danny Allen <[email protected]>
1 parent 177a6b1 commit b27d06b

File tree

2 files changed

+121
-143
lines changed

2 files changed

+121
-143
lines changed

tests/dvslib/dvs_common.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""
2+
dvs_common contains common infrastructure for writing tests for the
3+
virtual switch.
4+
"""
5+
6+
import collections
7+
import time
8+
9+
_PollingConfig = collections.namedtuple('PollingConfig', 'polling_interval timeout strict')
10+
11+
class PollingConfig(_PollingConfig):
12+
"""
13+
PollingConfig provides parameters that are used to control the behavior
14+
for polling functions.
15+
16+
Params:
17+
polling_interval (int): How often to poll, in seconds.
18+
timeout (int): The maximum amount of time to wait, in seconds.
19+
strict (bool): If the strict flag is set, reaching the timeout
20+
will cause tests to fail (e.g. assert False)
21+
"""
22+
23+
pass
24+
25+
def wait_for_result(polling_function, polling_config):
26+
"""
27+
wait_for_result will periodically run `polling_function`
28+
using the parameters described in `polling_config` and return the
29+
output of the polling function.
30+
31+
Args:
32+
polling_config (PollingConfig): The parameters to use to poll
33+
the db.
34+
polling_function (Callable[[], (bool, Any)]): The function being
35+
polled. The function takes no arguments and must return a
36+
status which indicates if the function was succesful or
37+
not, as well as some return value.
38+
39+
Returns:
40+
Any: The output of the polling function, if it is succesful,
41+
None otherwise.
42+
"""
43+
if polling_config.polling_interval == 0:
44+
iterations = 1
45+
else:
46+
iterations = int(polling_config.timeout // polling_config.polling_interval) + 1
47+
48+
for _ in range(iterations):
49+
(status, result) = polling_function()
50+
51+
if status:
52+
return result
53+
54+
time.sleep(polling_config.polling_interval)
55+
56+
if polling_config.strict:
57+
assert False
58+
59+
return None

tests/dvslib/dvs_database.py

Lines changed: 62 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,8 @@
22
dvs_database contains utilities for interacting with redis when writing
33
tests for the virtual switch.
44
"""
5-
from __future__ import print_function
6-
7-
import time
8-
import collections
9-
105
from swsscommon import swsscommon
11-
12-
13-
# PollingConfig provides parameters that are used to control polling behavior
14-
# when accessing redis:
15-
# - polling_interval: how often to check for updates in redis
16-
# - timeout: the max amount of time to wait for updates in redis
17-
# - strict: if the strict flag is set, failure to receive updates will cause
18-
# the polling method to cause tests to fail (e.g. assert False)
19-
PollingConfig = collections.namedtuple('PollingConfig', 'polling_interval timeout strict')
6+
from dvslib.dvs_common import wait_for_result, PollingConfig
207

218

229
class DVSDatabase(object):
@@ -56,27 +43,27 @@ def create_entry(self, table_name, key, entry):
5643
formatted_entry = swsscommon.FieldValuePairs(entry.items())
5744
table.set(key, formatted_entry)
5845

59-
def wait_for_entry(self, table_name, key,
60-
polling_config=DEFAULT_POLLING_CONFIG):
46+
def get_entry(self, table_name, key):
6147
"""
62-
Gets the entry stored at `key` in the specified table. This method
63-
will wait for the entry to exist.
48+
Gets the entry stored at `key` in the specified table.
6449
6550
Args:
6651
table_name (str): The name of the table where the entry is
6752
stored.
6853
key (str): The key that maps to the entry being retrieved.
69-
polling_config (PollingConfig): The parameters to use to poll
70-
the db.
7154
7255
Returns:
7356
Dict[str, str]: The entry stored at `key`. If no entry is found,
7457
then an empty Dict will be returned.
75-
7658
"""
7759

78-
access_function = self._get_entry_access_function(table_name, key, True)
79-
return self._db_poll(polling_config, access_function)
60+
table = swsscommon.Table(self.db_connection, table_name)
61+
(status, fv_pairs) = table.get(key)
62+
63+
if not status:
64+
return {}
65+
66+
return dict(fv_pairs)
8067

8168
def delete_entry(self, table_name, key):
8269
"""
@@ -91,6 +78,49 @@ def delete_entry(self, table_name, key):
9178
table = swsscommon.Table(self.db_connection, table_name)
9279
table._del(key) # pylint: disable=protected-access
9380

81+
def get_keys(self, table_name):
82+
"""
83+
Gets all of the keys stored in the specified table.
84+
85+
Args:
86+
table_name (str): The name of the table from which to fetch
87+
the keys.
88+
89+
Returns:
90+
List[str]: The keys stored in the table. If no keys are found,
91+
then an empty List will be returned.
92+
"""
93+
94+
table = swsscommon.Table(self.db_connection, table_name)
95+
keys = table.getKeys()
96+
97+
return keys if keys else []
98+
99+
def wait_for_entry(self, table_name, key,
100+
polling_config=DEFAULT_POLLING_CONFIG):
101+
"""
102+
Gets the entry stored at `key` in the specified table. This method
103+
will wait for the entry to exist.
104+
105+
Args:
106+
table_name (str): The name of the table where the entry is
107+
stored.
108+
key (str): The key that maps to the entry being retrieved.
109+
polling_config (PollingConfig): The parameters to use to poll
110+
the db.
111+
112+
Returns:
113+
Dict[str, str]: The entry stored at `key`. If no entry is found,
114+
then an empty Dict will be returned.
115+
116+
"""
117+
118+
def _access_function():
119+
fv_pairs = self.get_entry(table_name, key)
120+
return (bool(fv_pairs), fv_pairs)
121+
122+
return wait_for_result(_access_function, polling_config)
123+
94124
def wait_for_empty_entry(self,
95125
table_name,
96126
key,
@@ -109,8 +139,11 @@ def wait_for_empty_entry(self,
109139
bool: True if no entry exists at `key`, False otherwise.
110140
"""
111141

112-
access_function = self._get_entry_access_function(table_name, key, False)
113-
return not self._db_poll(polling_config, access_function)
142+
def _access_function():
143+
fv_pairs = self.get_entry(table_name, key)
144+
return (not fv_pairs, fv_pairs)
145+
146+
return wait_for_result(_access_function, polling_config)
114147

115148
def wait_for_n_keys(self,
116149
table_name,
@@ -133,122 +166,8 @@ def wait_for_n_keys(self,
133166
then an empty List will be returned.
134167
"""
135168

136-
access_function = self._get_keys_access_function(table_name, num_keys)
137-
return self._db_poll(polling_config, access_function)
138-
139-
def _get_keys_access_function(self, table_name, num_keys):
140-
"""
141-
Generates an access function to check for `num_keys` in the given
142-
table and return the list of keys if successful.
143-
144-
Args:
145-
table_name (str): The name of the table from which to fetch
146-
the keys.
147-
num_keys (int): The number of keys to check for in the table.
148-
If this is set to None, then this function will just return
149-
whatever keys are in the table.
150-
151-
Returns:
152-
Callable([[], (bool, List[str])]): A function that can be
153-
called to access the database.
154-
155-
If `num_keys` keys are found in the given table, or left
156-
unspecified, then the function will return True along with
157-
the list of keys that were found. Otherwise, the function will
158-
return False and some undefined list of keys.
159-
"""
160-
161-
table = swsscommon.Table(self.db_connection, table_name)
162-
163-
def _accessor():
164-
keys = table.getKeys()
165-
if not keys:
166-
keys = []
167-
168-
if not num_keys and num_keys != 0:
169-
status = True
170-
else:
171-
status = len(keys) == num_keys
172-
173-
return (status, keys)
174-
175-
return _accessor
176-
177-
def _get_entry_access_function(self, table_name, key, expect_entry):
178-
"""
179-
Generates an access function to check for existence of an entry
180-
at `key` and return it if successful.
181-
182-
Args:
183-
table_name (str): The name of the table from which to fetch
184-
the entry.
185-
key (str): The key that maps to the entry being retrieved.
186-
expect_entry (bool): Whether or not we expect to see an entry
187-
at `key`.
188-
189-
Returns:
190-
Callable([[], (bool, Dict[str, str])]): A function that can be
191-
called to access the database.
192-
193-
If `expect_entry` is set and an entry is found, then the
194-
function will return True along with the entry that was found.
195-
196-
If `expect_entry` is not set and no entry is found, then the
197-
function will return True along with an empty Dict.
198-
199-
In all other cases, the function will return False with some
200-
undefined Dict.
201-
"""
202-
203-
table = swsscommon.Table(self.db_connection, table_name)
204-
205-
def _accessor():
206-
(status, fv_pairs) = table.get(key)
207-
208-
status = expect_entry == status
209-
210-
if fv_pairs:
211-
entry = dict(fv_pairs)
212-
else:
213-
entry = {}
214-
215-
return (status, entry)
216-
217-
return _accessor
218-
219-
@staticmethod
220-
def _db_poll(polling_config, access_function):
221-
"""
222-
_db_poll will periodically run `access_function` on the database
223-
using the parameters described in `polling_config` and return the
224-
output of the access function.
225-
226-
Args:
227-
polling_config (PollingConfig): The parameters to use to poll
228-
the db.
229-
access_function (Callable[[], (bool, Any)]): The function used
230-
for polling the db. Note that the function must return a
231-
status which indicates if the function was succesful or
232-
not, as well as some return value.
233-
234-
Returns:
235-
Any: The output of the access function, if it is succesful,
236-
None otherwise.
237-
"""
238-
if polling_config.polling_interval == 0:
239-
iterations = 1
240-
else:
241-
iterations = int(polling_config.timeout // polling_config.polling_interval) + 1
242-
243-
for _ in range(iterations):
244-
(status, result) = access_function()
245-
246-
if status:
247-
return result
248-
249-
time.sleep(polling_config.polling_interval)
250-
251-
if polling_config.strict:
252-
assert False
169+
def _access_function():
170+
keys = self.get_keys(table_name)
171+
return (len(keys) == num_keys, keys)
253172

254-
return None
173+
return wait_for_result(_access_function, polling_config)

0 commit comments

Comments
 (0)