Skip to content

Commit 94b991c

Browse files
authored
Feat/handle filters in list groups members (#488)
* feat: clarify functions docs * fix: use proper values for test * fix: update conftest to match integrations returned values * fix: properly pass keys for paginator * feat: add support for filters for better performance * fix: lint and fmt
1 parent 3ba60f7 commit 94b991c

File tree

9 files changed

+286
-58
lines changed

9 files changed

+286
-58
lines changed

app/integrations/aws/client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,20 +94,20 @@ def execute_aws_api_call(service_name, method, paginated=False, **kwargs):
9494
ValueError: If the role_arn is not provided.
9595
"""
9696

97-
role_arn = kwargs.get("role_arn", os.environ.get("AWS_SSO_ROLE_ARN", None))
97+
role_arn = kwargs.pop("role_arn", os.environ.get("AWS_SSO_ROLE_ARN", None))
98+
keys = kwargs.pop("keys", None)
9899
if role_arn is None:
99100
raise ValueError(
100101
"role_arn must be provided either as a keyword argument or as the AWS_SSO_ROLE_ARN environment variable"
101102
)
102103
if service_name is None or method is None:
103104
raise ValueError("The AWS service name and method must be provided")
104105
client = assume_role_client(service_name, role_arn)
105-
kwargs.pop("role_arn", None)
106106
if kwargs:
107107
kwargs = convert_kwargs_to_pascal_case(kwargs)
108108
api_method = getattr(client, method)
109109
if paginated:
110-
return paginator(client, method, **kwargs)
110+
return paginator(client, method, keys, **kwargs)
111111
else:
112112
return api_method(**kwargs)
113113

app/integrations/aws/identity_store.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import logging
33
from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors
4+
from utils import filters
45

56
INSTANCE_ID = os.environ.get("AWS_SSO_INSTANCE_ID", "")
67
INSTANCE_ARN = os.environ.get("AWS_SSO_INSTANCE_ARN", "")
@@ -230,9 +231,15 @@ def list_groups_with_memberships(**kwargs):
230231
Returns:
231232
list: A list of group objects with their members.
232233
"""
233-
members_details = kwargs.get("members_details", True)
234-
kwargs.pop("members_details", None)
234+
members_details = kwargs.pop("members_details", True)
235+
groups_filters = kwargs.pop("filters", [])
235236
groups = list_groups(**kwargs)
237+
238+
if not groups:
239+
return []
240+
for filter in groups_filters:
241+
groups = filters.filter_by_condition(groups, filter)
242+
236243
for group in groups:
237244
group["GroupMemberships"] = list_group_memberships(group["GroupId"])
238245
if group["GroupMemberships"] and members_details:

app/integrations/google_workspace/google_directory.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
DEFAULT_GOOGLE_WORKSPACE_CUSTOMER_ID,
88
)
99
from integrations.utils.api import convert_string_to_camel_case
10+
from utils import filters
1011

1112

1213
@handle_google_api_errors
@@ -151,11 +152,15 @@ def list_groups_with_members(**kwargs):
151152
Returns:
152153
list: A list of group objects with members.
153154
"""
154-
members_details = kwargs.get("members_details", True)
155-
kwargs.pop("members_details", None)
155+
members_details = kwargs.pop("members_details", True)
156156
groups = list_groups(**kwargs)
157+
groups_filters = kwargs.pop("filters", [])
157158
if not groups:
158159
return []
160+
161+
for filter in groups_filters:
162+
groups = filters.filter_by_condition(groups, filter)
163+
159164
for group in range(len(groups)):
160165
members = list_group_members(groups[group]["email"])
161166
if members and members_details:

app/tests/conftest.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -115,30 +115,28 @@ def _aws_users(n=3, prefix="", domain="test.com", store_id="d-123412341234"):
115115
@pytest.fixture
116116
def aws_groups():
117117
def _aws_groups(n=3, prefix="", store_id="d-123412341234"):
118-
return {
119-
"Groups": [
120-
{
121-
"GroupId": f"{prefix}aws-group_id{i+1}",
122-
"DisplayName": f"{prefix}group-name{i+1}",
123-
"Description": f"A group to test resolving AWS-group{i+1} memberships",
124-
"IdentityStoreId": f"{store_id}",
125-
}
126-
for i in range(n)
127-
]
128-
}
118+
return [
119+
{
120+
"GroupId": f"{prefix}aws-group_id{i+1}",
121+
"DisplayName": f"{prefix}group-name{i+1}",
122+
"Description": f"A group to test resolving AWS-group{i+1} memberships",
123+
"IdentityStoreId": f"{store_id}",
124+
}
125+
for i in range(n)
126+
]
129127

130128
return _aws_groups
131129

132130

133131
@pytest.fixture
134132
def aws_groups_memberships():
135-
def _aws_groups_memberships(n=3, prefix="", store_id="d-123412341234"):
133+
def _aws_groups_memberships(n=3, prefix="", group_id=1, store_id="d-123412341234"):
136134
return {
137135
"GroupMemberships": [
138136
{
139137
"IdentityStoreId": f"{store_id}",
140138
"MembershipId": f"{prefix}membership_id_{i+1}",
141-
"GroupId": f"{prefix}aws-group_id{i+1}",
139+
"GroupId": f"{prefix}aws-group_id{group_id}",
142140
"MemberId": {
143141
"UserId": f"{prefix}user_id{i+1}",
144142
},
@@ -155,15 +153,16 @@ def aws_groups_w_users(aws_groups, aws_users, aws_groups_memberships):
155153
def _aws_groups_w_users(
156154
n_groups=1, n_users=3, prefix="", domain="test.com", store_id="d-123412341234"
157155
):
158-
groups = aws_groups(n_groups, prefix, store_id)["Groups"]
156+
groups = aws_groups(n_groups, prefix, store_id)
159157
users = aws_users(n_users, prefix, domain, store_id)
160-
memberships = aws_groups_memberships(n_groups, prefix, store_id)[
161-
"GroupMemberships"
162-
]
163-
for group, membership in zip(groups, memberships):
164-
group.update(membership)
158+
for i, group in enumerate(groups):
159+
memberships = aws_groups_memberships(n_users, prefix, i + 1, store_id)[
160+
"GroupMemberships"
161+
]
162+
group.update(memberships[0])
165163
group["GroupMemberships"] = [
166-
{**membership, "MemberId": user} for user in users
164+
{**membership, "MemberId": user}
165+
for user, membership in zip(users, memberships)
167166
]
168167
return groups
169168

app/tests/integrations/aws/test_client.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def test_execute_aws_api_call_non_paginated(
222222
):
223223
mock_client = MagicMock()
224224
mock_assume_role_client.return_value = mock_client
225-
mock_convert_kwargs_to_pascal_case.return_value = {"arg1": "value1"}
225+
mock_convert_kwargs_to_pascal_case.return_value = {"Arg1": "value1"}
226226
mock_method = MagicMock()
227227
mock_method.return_value = {"key": "value"}
228228
mock_client.some_method = mock_method
@@ -232,7 +232,7 @@ def test_execute_aws_api_call_non_paginated(
232232
)
233233

234234
mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn")
235-
mock_method.assert_called_once_with(arg1="value1")
235+
mock_method.assert_called_once_with(Arg1="value1")
236236
assert result == {"key": "value"}
237237
mock_convert_kwargs_to_pascal_case.assert_called_once_with({"arg1": "value1"})
238238
mock_paginator.assert_not_called()
@@ -247,15 +247,18 @@ def test_execute_aws_api_call_paginated(
247247
):
248248
mock_client = MagicMock()
249249
mock_assume_role_client.return_value = mock_client
250-
mock_convert_kwargs_to_pascal_case.return_value = {"arg1": "value1"}
250+
mock_convert_kwargs_to_pascal_case.return_value = {"Arg1": "value1"}
251251
mock_paginator.return_value = ["value1", "value2", "value3"]
252252

253253
result = aws_client.execute_aws_api_call(
254254
"service_name", "some_method", paginated=True, arg1="value1"
255255
)
256256

257257
mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn")
258-
mock_paginator.assert_called_once_with(mock_client, "some_method", arg1="value1")
258+
mock_paginator.assert_called_once_with(
259+
mock_client, "some_method", None, Arg1="value1"
260+
)
261+
mock_convert_kwargs_to_pascal_case.assert_called_once_with({"arg1": "value1"})
259262
assert result == ["value1", "value2", "value3"]
260263

261264

@@ -267,7 +270,7 @@ def test_execute_aws_api_call_with_role_arn(
267270
):
268271
mock_client = MagicMock()
269272
mock_assume_role_client.return_value = mock_client
270-
mock_convert_kwargs_to_pascal_case.return_value = {"arg1": "value1"}
273+
mock_convert_kwargs_to_pascal_case.return_value = {"Arg1": "value1"}
271274
mock_method = MagicMock()
272275
mock_method.return_value = {"key": "value"}
273276
mock_client.some_method = mock_method
@@ -277,9 +280,10 @@ def test_execute_aws_api_call_with_role_arn(
277280
)
278281

279282
mock_assume_role_client.assert_called_once_with("service_name", "test_role_arn")
280-
mock_method.assert_called_once_with(arg1="value1")
283+
mock_method.assert_called_once_with(Arg1="value1")
281284
assert result == {"key": "value"}
282285
mock_paginator.assert_not_called()
286+
mock_convert_kwargs_to_pascal_case.assert_called_once_with({"arg1": "value1"})
283287

284288

285289
@patch.dict(os.environ, {"AWS_SSO_ROLE_ARN": "test_role_arn"})

0 commit comments

Comments
 (0)