Skip to content

Commit 02cf98c

Browse files
committed
feat: use CosmosDBManagementClient to authenticate cosmos client through DefaultAzureCredential
1 parent 38f0329 commit 02cf98c

File tree

5 files changed

+34
-7
lines changed

5 files changed

+34
-7
lines changed

airflow/providers/microsoft/azure/hooks/cosmos.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,14 @@
2828
import json
2929
import uuid
3030
from typing import Any
31+
from urllib.parse import urlparse
3132

3233
from azure.cosmos.cosmos_client import CosmosClient
3334
from azure.cosmos.exceptions import CosmosHttpResponseError
3435
from azure.identity import DefaultAzureCredential
36+
from azure.mgmt.cosmosdb import CosmosDBManagementClient
3537

36-
from airflow.exceptions import AirflowBadRequest
38+
from airflow.exceptions import AirflowBadRequest, AirflowException
3739
from airflow.hooks.base import BaseHook
3840
from airflow.providers.microsoft.azure.utils import get_field
3941

@@ -69,6 +71,14 @@ def get_connection_form_widgets() -> dict[str, Any]:
6971
"collection_name": StringField(
7072
lazy_gettext("Cosmos Collection Name (optional)"), widget=BS3TextFieldWidget()
7173
),
74+
"subscription_id": StringField(
75+
lazy_gettext("Subscription ID (optional)"),
76+
widget=BS3TextFieldWidget(),
77+
),
78+
"resource_group_name": StringField(
79+
lazy_gettext("Resource Group Name (optional)"),
80+
widget=BS3TextFieldWidget(),
81+
),
7282
}
7383

7484
@staticmethod
@@ -82,9 +92,11 @@ def get_ui_field_behaviour() -> dict[str, Any]:
8292
},
8393
"placeholders": {
8494
"login": "endpoint uri",
85-
"password": "master key",
95+
"password": "master key (not needed for Azure AD authentication)",
8696
"database_name": "database name",
8797
"collection_name": "collection name",
98+
"subscription_id": "Subscription ID (required for Azure AD authentication)",
99+
"resource_group_name": "Resource Group Name (required for Azure AD authentication)",
88100
},
89101
}
90102

@@ -110,18 +122,29 @@ def get_conn(self) -> CosmosClient:
110122
conn = self.get_connection(self.conn_id)
111123
extras = conn.extra_dejson
112124
endpoint_uri = conn.login
113-
credential: dict[str, Any] | DefaultAzureCredential
125+
resource_group_name = self._get_field(extras, "resource_group_name")
126+
114127
if conn.password:
115128
master_key = conn.password
116-
credential = {"masterKey": master_key}
129+
elif resource_group_name:
130+
management_client = CosmosDBManagementClient(
131+
credential=DefaultAzureCredential(),
132+
subscription_id=self._get_field(extras, "subscription_id"),
133+
)
134+
135+
database_account = urlparse(conn.login).netloc.split(".")[0]
136+
database_account_keys = management_client.database_accounts.list_keys(
137+
resource_group_name, database_account
138+
)
139+
master_key = database_account_keys.primary_master_key
117140
else:
118-
credential = DefaultAzureCredential()
141+
raise AirflowException("Either password or resource_group_name is required")
119142

120143
self.default_database_name = self._get_field(extras, "database_name")
121144
self.default_collection_name = self._get_field(extras, "collection_name")
122145

123146
# Initialize the Python Azure Cosmos DB client
124-
self._conn = CosmosClient(endpoint_uri, credential=credential)
147+
self._conn = CosmosClient(endpoint_uri, {"masterKey": master_key})
125148
return self._conn
126149

127150
def __get_database_name(self, database_name: str | None = None) -> str:

airflow/providers/microsoft/azure/provider.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ dependencies:
6666
- apache-airflow>=2.4.0
6767
- azure-batch>=8.0.0
6868
- azure-cosmos>=4.0.0
69+
- azure-mgmt-cosmosdb
6970
- azure-datalake-store>=0.0.45
7071
- azure-identity>=1.3.1
7172
- azure-keyvault-secrets>=4.1.0

docs/apache-airflow-providers-microsoft-azure/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ PIP package Version required
107107
``apache-airflow`` ``>=2.4.0``
108108
``azure-batch`` ``>=8.0.0``
109109
``azure-cosmos`` ``>=4.0.0``
110+
``azure-mgmt-cosmosdb``
110111
``azure-datalake-store`` ``>=0.0.45``
111112
``azure-identity`` ``>=1.3.1``
112113
``azure-keyvault-secrets`` ``>=4.1.0``

generated/provider_dependencies.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@
560560
"azure-keyvault-secrets>=4.1.0",
561561
"azure-kusto-data>=4.1.0",
562562
"azure-mgmt-containerinstance>=1.5.0,<2.0",
563+
"azure-mgmt-cosmosdb",
563564
"azure-mgmt-datafactory>=1.0.0,<2.0",
564565
"azure-mgmt-datalake-store>=0.5.0",
565566
"azure-mgmt-resource>=2.2.0",

tests/providers/microsoft/azure/hooks/test_azure_cosmos.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434

3535

3636
class TestAzureCosmosDbHook:
37-
3837
# Set up an environment to test with
3938
def setup_method(self):
4039
# set up some test variables
@@ -266,6 +265,8 @@ def test_get_ui_field_behaviour_placeholders(self):
266265
"password",
267266
"database_name",
268267
"collection_name",
268+
"subscription_id",
269+
"resource_group_name",
269270
]
270271
if get_provider_min_airflow_version("apache-airflow-providers-microsoft-azure") >= (2, 5):
271272
raise Exception(

0 commit comments

Comments
 (0)