28
28
import json
29
29
import uuid
30
30
from typing import Any
31
+ from urllib .parse import urlparse
31
32
32
33
from azure .cosmos .cosmos_client import CosmosClient
33
34
from azure .cosmos .exceptions import CosmosHttpResponseError
34
35
from azure .identity import DefaultAzureCredential
36
+ from azure .mgmt .cosmosdb import CosmosDBManagementClient
35
37
36
- from airflow .exceptions import AirflowBadRequest
38
+ from airflow .exceptions import AirflowBadRequest , AirflowException
37
39
from airflow .hooks .base import BaseHook
38
40
from airflow .providers .microsoft .azure .utils import get_field
39
41
@@ -69,6 +71,14 @@ def get_connection_form_widgets() -> dict[str, Any]:
69
71
"collection_name" : StringField (
70
72
lazy_gettext ("Cosmos Collection Name (optional)" ), widget = BS3TextFieldWidget ()
71
73
),
74
+ "subscription_id" : StringField (
75
+ lazy_gettext ("Subscription ID (optional)" ),
76
+ widget = BS3TextFieldWidget (),
77
+ ),
78
+ "resource_group_name" : StringField (
79
+ lazy_gettext ("Resource Group Name (optional)" ),
80
+ widget = BS3TextFieldWidget (),
81
+ ),
72
82
}
73
83
74
84
@staticmethod
@@ -82,9 +92,11 @@ def get_ui_field_behaviour() -> dict[str, Any]:
82
92
},
83
93
"placeholders" : {
84
94
"login" : "endpoint uri" ,
85
- "password" : "master key" ,
95
+ "password" : "master key (not needed for Azure AD authentication) " ,
86
96
"database_name" : "database name" ,
87
97
"collection_name" : "collection name" ,
98
+ "subscription_id" : "Subscription ID (required for Azure AD authentication)" ,
99
+ "resource_group_name" : "Resource Group Name (required for Azure AD authentication)" ,
88
100
},
89
101
}
90
102
@@ -110,18 +122,29 @@ def get_conn(self) -> CosmosClient:
110
122
conn = self .get_connection (self .conn_id )
111
123
extras = conn .extra_dejson
112
124
endpoint_uri = conn .login
113
- credential : dict [str , Any ] | DefaultAzureCredential
125
+ resource_group_name = self ._get_field (extras , "resource_group_name" )
126
+
114
127
if conn .password :
115
128
master_key = conn .password
116
- credential = {"masterKey" : master_key }
129
+ elif resource_group_name :
130
+ management_client = CosmosDBManagementClient (
131
+ credential = DefaultAzureCredential (),
132
+ subscription_id = self ._get_field (extras , "subscription_id" ),
133
+ )
134
+
135
+ database_account = urlparse (conn .login ).netloc .split ("." )[0 ]
136
+ database_account_keys = management_client .database_accounts .list_keys (
137
+ resource_group_name , database_account
138
+ )
139
+ master_key = database_account_keys .primary_master_key
117
140
else :
118
- credential = DefaultAzureCredential ( )
141
+ raise AirflowException ( "Either password or resource_group_name is required" )
119
142
120
143
self .default_database_name = self ._get_field (extras , "database_name" )
121
144
self .default_collection_name = self ._get_field (extras , "collection_name" )
122
145
123
146
# Initialize the Python Azure Cosmos DB client
124
- self ._conn = CosmosClient (endpoint_uri , credential = credential )
147
+ self ._conn = CosmosClient (endpoint_uri , { "masterKey" : master_key } )
125
148
return self ._conn
126
149
127
150
def __get_database_name (self , database_name : str | None = None ) -> str :
0 commit comments