|
1 | 1 | import base64
|
2 | 2 | import re
|
3 |
| -import unittest |
4 | 3 | from unittest.mock import MagicMock, patch
|
5 | 4 |
|
| 5 | +from challenges.models import Challenge |
| 6 | +from django.test import TestCase |
6 | 7 | from jobs.aws_utils import generate_aws_eks_bearer_token
|
7 | 8 |
|
8 | 9 |
|
9 |
| -class TestGenerateAWSEksBearerToken(unittest.TestCase): |
| 10 | +class TestGenerateAWSEksBearerToken(TestCase): |
| 11 | + def setUp(self): |
| 12 | + """Set up common test data and mock objects""" |
| 13 | + self.cluster_name = "test-cluster" |
| 14 | + self.challenge = MagicMock(spec=Challenge) |
| 15 | + self.challenge.id = "challenge-id" |
| 16 | + |
| 17 | + self.aws_credentials = { |
| 18 | + "AWS_ACCESS_KEY_ID": "fake_access_key", |
| 19 | + "AWS_SECRET_ACCESS_KEY": "fake_secret_key", |
| 20 | + "AWS_REGION": "us-west-2", |
| 21 | + } |
| 22 | + |
| 23 | + self.mock_signed_url = "https://sts.us-west-2.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=test" |
10 | 24 |
|
11 | 25 | @patch("jobs.aws_utils.get_aws_credentials_for_challenge")
|
12 | 26 | @patch("jobs.aws_utils.boto3.Session")
|
13 | 27 | @patch("jobs.aws_utils.RequestSigner")
|
14 | 28 | def test_generate_aws_eks_bearer_token(
|
15 | 29 | self, MockRequestSigner, MockSession, MockGetAwsCredentials
|
16 | 30 | ):
|
17 |
| - # Mock AWS credentials |
18 |
| - MockGetAwsCredentials.return_value = { |
19 |
| - "AWS_ACCESS_KEY_ID": "fake_access_key", |
20 |
| - "AWS_SECRET_ACCESS_KEY": "fake_secret_key", |
21 |
| - "AWS_REGION": "us-west-2", |
22 |
| - } |
23 |
| - |
24 |
| - # Mock the Session and its client method |
| 31 | + MockGetAwsCredentials.return_value = self.aws_credentials |
25 | 32 | mock_session = MagicMock()
|
26 | 33 | mock_client = MagicMock()
|
27 | 34 | mock_session.client.return_value = mock_client
|
28 | 35 | mock_client.meta.service_model.service_id = "STS"
|
29 | 36 | MockSession.return_value = mock_session
|
30 | 37 |
|
31 |
| - # Mock RequestSigner and its generate_presigned_url method |
32 | 38 | mock_signer = MagicMock()
|
33 |
| - mock_signer.generate_presigned_url.return_value = "https://signed.url" |
| 39 | + mock_signer.generate_presigned_url.return_value = self.mock_signed_url |
34 | 40 | MockRequestSigner.return_value = mock_signer
|
35 | 41 |
|
36 |
| - # Define test input |
37 |
| - cluster_name = "test-cluster" |
38 |
| - |
39 |
| - class Challenge: |
40 |
| - id = "challenge-id" |
41 |
| - |
42 |
| - challenge = Challenge() |
43 |
| - |
44 |
| - # Call the function to test |
45 |
| - token = generate_aws_eks_bearer_token(cluster_name, challenge) |
| 42 | + token = generate_aws_eks_bearer_token( |
| 43 | + self.cluster_name, self.challenge |
| 44 | + ) |
46 | 45 |
|
47 |
| - # Expected results |
48 |
| - expected_signed_url = "https://signed.url" |
49 | 46 | expected_base64_url = base64.urlsafe_b64encode(
|
50 |
| - expected_signed_url.encode("utf-8") |
| 47 | + self.mock_signed_url.encode("utf-8") |
51 | 48 | ).decode("utf-8")
|
52 | 49 | expected_bearer_token = "k8s-aws-v1." + re.sub(
|
53 | 50 | r"=*", "", expected_base64_url
|
54 | 51 | )
|
55 | 52 |
|
56 |
| - # Assertions |
57 |
| - MockGetAwsCredentials.assert_called_once_with("challenge-id") |
| 53 | + MockGetAwsCredentials.assert_called_once_with(self.challenge.id) |
58 | 54 | MockSession.assert_called_once_with(
|
59 |
| - aws_access_key_id="fake_access_key", |
60 |
| - aws_secret_access_key="fake_secret_key", |
| 55 | + aws_access_key_id=self.aws_credentials["AWS_ACCESS_KEY_ID"], |
| 56 | + aws_secret_access_key=self.aws_credentials[ |
| 57 | + "AWS_SECRET_ACCESS_KEY" |
| 58 | + ], |
61 | 59 | )
|
62 | 60 | mock_session.client.assert_called_once_with(
|
63 |
| - "sts", region_name="us-west-2" |
| 61 | + "sts", region_name=self.aws_credentials["AWS_REGION"] |
64 | 62 | )
|
65 | 63 | MockRequestSigner.assert_called_once_with(
|
66 | 64 | "STS",
|
67 |
| - "us-west-2", |
| 65 | + self.aws_credentials["AWS_REGION"], |
68 | 66 | "sts",
|
69 | 67 | "v4",
|
70 | 68 | mock_session.get_credentials(),
|
71 | 69 | mock_session.events,
|
72 | 70 | )
|
| 71 | + |
| 72 | + expected_params = { |
| 73 | + "method": "GET", |
| 74 | + "url": f"https://sts.{self.aws_credentials['AWS_REGION']}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15", |
| 75 | + "body": {}, |
| 76 | + "headers": {"x-k8s-aws-id": self.cluster_name}, |
| 77 | + "context": {}, |
| 78 | + } |
| 79 | + |
73 | 80 | mock_signer.generate_presigned_url.assert_called_once_with(
|
74 |
| - { |
75 |
| - "method": "GET", |
76 |
| - "url": "https://sts.us-west-2.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15", |
77 |
| - "body": {}, |
78 |
| - "headers": {"x-k8s-aws-id": cluster_name}, |
79 |
| - "context": {}, |
80 |
| - }, |
81 |
| - region_name="us-west-2", |
| 81 | + expected_params, |
| 82 | + region_name=self.aws_credentials["AWS_REGION"], |
82 | 83 | expires_in=60,
|
83 | 84 | operation_name="",
|
84 | 85 | )
|
| 86 | + |
85 | 87 | self.assertEqual(token, expected_bearer_token)
|
| 88 | + |
| 89 | + @patch("jobs.aws_utils.get_aws_credentials_for_challenge") |
| 90 | + @patch("jobs.aws_utils.boto3.Session") |
| 91 | + def test_generate_aws_eks_bearer_token_with_context_manager( |
| 92 | + self, MockSession, MockGetAwsCredentials |
| 93 | + ): |
| 94 | + """Test using context manager approach from test-jobs-aws-utils branch""" |
| 95 | + MockGetAwsCredentials.return_value = self.aws_credentials |
| 96 | + |
| 97 | + mock_session_instance = MagicMock() |
| 98 | + MockSession.return_value = mock_session_instance |
| 99 | + |
| 100 | + mock_client = MagicMock() |
| 101 | + mock_session_instance.client.return_value = mock_client |
| 102 | + mock_client.meta.service_model.service_id = "sts" |
| 103 | + |
| 104 | + mock_signer = MagicMock() |
| 105 | + mock_signer.generate_presigned_url.return_value = self.mock_signed_url |
| 106 | + |
| 107 | + with patch("jobs.aws_utils.RequestSigner", return_value=mock_signer): |
| 108 | + result = generate_aws_eks_bearer_token( |
| 109 | + self.cluster_name, self.challenge |
| 110 | + ) |
| 111 | + |
| 112 | + expected_base64 = base64.urlsafe_b64encode( |
| 113 | + self.mock_signed_url.encode("utf-8") |
| 114 | + ).decode("utf-8") |
| 115 | + expected_token = "k8s-aws-v1." + re.sub(r"=*", "", expected_base64) |
| 116 | + |
| 117 | + self.assertEqual(result, expected_token) |
0 commit comments