Skip to content

Commit df74553

Browse files
authored
Refactor account_url use in WasbHook (#32980)
* Refactor account_url use in WasbHook This PR moves the account_url setting to one place. Tested this by making connection to azure using the different methods, however, I was not able to connect using the tenant_id in the extra field. This looks like a bug because ClientSecretCredential is not among the credentials to use in BlobServiceClient. The credentials to use include AzureNamedKeyCredential,AzureSasCredential,AsyncTokenCredential. So this will need special debugging. * fixup! Refactor account_url use in WasbHook
1 parent 5f5293f commit df74553

File tree

2 files changed

+18
-16
lines changed
  • airflow/providers/microsoft/azure/hooks
  • tests/providers/microsoft/azure/hooks

2 files changed

+18
-16
lines changed

airflow/providers/microsoft/azure/hooks/wasb.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def get_ui_field_behaviour() -> dict[str, Any]:
132132
"relabeling": {
133133
"login": "Blob Storage Login (optional)",
134134
"password": "Blob Storage Key (optional)",
135-
"host": "Account Name (Active Directory Auth)",
135+
"host": "Account URL (Active Directory Auth)",
136136
},
137137
"placeholders": {
138138
"login": "account name",
@@ -154,7 +154,7 @@ def __init__(
154154
super().__init__()
155155
self.conn_id = wasb_conn_id
156156
self.public_read = public_read
157-
self.blob_service_client = self.get_conn()
157+
self.blob_service_client: BlobServiceClient = self.get_conn()
158158

159159
logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy")
160160
try:
@@ -184,15 +184,19 @@ def get_conn(self) -> BlobServiceClient:
184184
# connection_string auth takes priority
185185
return BlobServiceClient.from_connection_string(connection_string, **extra)
186186

187+
account_url = (
188+
conn.host
189+
if conn.host and conn.host.startswith("https://")
190+
else f"https://{conn.login}.blob.core.windows.net/"
191+
)
192+
187193
tenant = self._get_field(extra, "tenant_id")
188194
if tenant:
189195
# use Active Directory auth
190196
app_id = conn.login
191197
app_secret = conn.password
192198
token_credential = ClientSecretCredential(tenant, app_id, app_secret, **client_secret_auth_config)
193-
return BlobServiceClient(account_url=conn.host, credential=token_credential, **extra)
194-
195-
account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/"
199+
return BlobServiceClient(account_url=account_url, credential=token_credential, **extra)
196200

197201
if self.public_read:
198202
# Here we use anonymous public read
@@ -210,19 +214,13 @@ def get_conn(self) -> BlobServiceClient:
210214
if sas_token.startswith("https"):
211215
return BlobServiceClient(account_url=sas_token, **extra)
212216
else:
213-
if not account_url.startswith("https://"):
214-
# TODO: require url in the host field in the next major version?
215-
account_url = f"https://{conn.login}.blob.core.windows.net"
216217
return BlobServiceClient(account_url=f"{account_url.rstrip('/')}/{sas_token}", **extra)
217218

218219
# Fall back to old auth (password) or use managed identity if not provided.
219220
credential = conn.password
220221
if not credential:
221222
credential = DefaultAzureCredential()
222223
self.log.info("Using DefaultAzureCredential as credential")
223-
if not account_url.startswith("https://"):
224-
# TODO: require url in the host field in the next major version?
225-
account_url = f"https://{conn.login}.blob.core.windows.net/"
226224
return BlobServiceClient(
227225
account_url=account_url,
228226
credential=credential,
@@ -589,6 +587,12 @@ async def get_async_conn(self) -> AsyncBlobServiceClient:
589587
)
590588
return self.blob_service_client
591589

590+
account_url = (
591+
conn.host
592+
if conn.host and conn.host.startswith("https://")
593+
else f"https://{conn.login}.blob.core.windows.net/"
594+
)
595+
592596
tenant = self._get_field(extra, "tenant_id")
593597
if tenant:
594598
# use Active Directory auth
@@ -598,12 +602,10 @@ async def get_async_conn(self) -> AsyncBlobServiceClient:
598602
tenant, app_id, app_secret, **client_secret_auth_config
599603
)
600604
self.blob_service_client = AsyncBlobServiceClient(
601-
account_url=conn.host, credential=token_credential, **extra # type:ignore[arg-type]
605+
account_url=account_url, credential=token_credential, **extra # type:ignore[arg-type]
602606
)
603607
return self.blob_service_client
604608

605-
account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/"
606-
607609
if self.public_read:
608610
# Here we use anonymous public read
609611
# more info
@@ -625,7 +627,7 @@ async def get_async_conn(self) -> AsyncBlobServiceClient:
625627
self.blob_service_client = AsyncBlobServiceClient(account_url=sas_token, **extra)
626628
else:
627629
self.blob_service_client = AsyncBlobServiceClient(
628-
account_url=f"{account_url}/{sas_token}", **extra
630+
account_url=f"{account_url.rstrip('/')}/{sas_token}", **extra
629631
)
630632
return self.blob_service_client
631633

tests/providers/microsoft/azure/hooks/test_wasb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def test_azure_directory_connection(self, mock_get_conn, mock_credential, mock_b
223223
authority=self.client_secret_auth_config["authority"],
224224
)
225225
mock_blob_service_client.assert_called_once_with(
226-
account_url=conn.host,
226+
account_url=f"https://{conn.login}.blob.core.windows.net/",
227227
credential=mock_credential.return_value,
228228
tenant_id=conn.extra_dejson["tenant_id"],
229229
proxies=conn.extra_dejson["proxies"],

0 commit comments

Comments
 (0)