Skip to content

Commit 3ca8d09

Browse files
authored
feat(ingest/snowflake): support email_as_user_identifier for queries v2 (datahub-project#12219)
1 parent 172736a commit 3ca8d09

File tree

6 files changed

+132
-40
lines changed

6 files changed

+132
-40
lines changed

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,20 @@ class SnowflakeIdentifierConfig(
138138
description="Whether to convert dataset urns to lowercase.",
139139
)
140140

141-
142-
class SnowflakeUsageConfig(BaseUsageConfig):
143141
email_domain: Optional[str] = pydantic.Field(
144142
default=None,
145143
description="Email domain of your organization so users can be displayed on UI appropriately.",
146144
)
145+
146+
email_as_user_identifier: bool = Field(
147+
default=True,
148+
description="Format user urns as an email, if the snowflake user's email is set. If `email_domain` is "
149+
"provided, generates email addresses for snowflake users with unset emails, based on their "
150+
"username.",
151+
)
152+
153+
154+
class SnowflakeUsageConfig(BaseUsageConfig):
147155
apply_view_usage_to_tables: bool = pydantic.Field(
148156
default=False,
149157
description="Whether to apply view's usage to its base tables. If set to True, usage is applied to base tables only.",
@@ -267,13 +275,6 @@ class SnowflakeV2Config(
267275
" Map of share name -> details of share.",
268276
)
269277

270-
email_as_user_identifier: bool = Field(
271-
default=True,
272-
description="Format user urns as an email, if the snowflake user's email is set. If `email_domain` is "
273-
"provided, generates email addresses for snowflake users with unset emails, based on their "
274-
"username.",
275-
)
276-
277278
include_assertion_results: bool = Field(
278279
default=False,
279280
description="Whether to ingest assertion run results for assertions created using Datahub"

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@
6666

6767
logger = logging.getLogger(__name__)
6868

69+
# Define a type alias
70+
UserName = str
71+
UserEmail = str
72+
UsersMapping = Dict[UserName, UserEmail]
73+
6974

7075
class SnowflakeQueriesExtractorConfig(ConfigModel):
7176
# TODO: Support stateful ingestion for the time windows.
@@ -114,11 +119,13 @@ class SnowflakeQueriesSourceConfig(
114119
class SnowflakeQueriesExtractorReport(Report):
115120
copy_history_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
116121
query_log_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
122+
users_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
117123

118124
audit_log_load_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
119125
sql_aggregator: Optional[SqlAggregatorReport] = None
120126

121127
num_ddl_queries_dropped: int = 0
128+
num_users: int = 0
122129

123130

124131
@dataclass
@@ -225,6 +232,9 @@ def is_allowed_table(self, name: str) -> bool:
225232
def get_workunits_internal(
226233
self,
227234
) -> Iterable[MetadataWorkUnit]:
235+
with self.report.users_fetch_timer:
236+
users = self.fetch_users()
237+
228238
# TODO: Add some logic to check if the cached audit log is stale or not.
229239
audit_log_file = self.local_temp_path / "audit_log.sqlite"
230240
use_cached_audit_log = audit_log_file.exists()
@@ -248,7 +258,7 @@ def get_workunits_internal(
248258
queries.append(entry)
249259

250260
with self.report.query_log_fetch_timer:
251-
for entry in self.fetch_query_log():
261+
for entry in self.fetch_query_log(users):
252262
queries.append(entry)
253263

254264
with self.report.audit_log_load_timer:
@@ -263,6 +273,25 @@ def get_workunits_internal(
263273
shared_connection.close()
264274
audit_log_file.unlink(missing_ok=True)
265275

276+
def fetch_users(self) -> UsersMapping:
277+
users: UsersMapping = dict()
278+
with self.structured_reporter.report_exc("Error fetching users from Snowflake"):
279+
logger.info("Fetching users from Snowflake")
280+
query = SnowflakeQuery.get_all_users()
281+
resp = self.connection.query(query)
282+
283+
for row in resp:
284+
try:
285+
users[row["NAME"]] = row["EMAIL"]
286+
self.report.num_users += 1
287+
except Exception as e:
288+
self.structured_reporter.warning(
289+
"Error parsing user row",
290+
context=f"{row}",
291+
exc=e,
292+
)
293+
return users
294+
266295
def fetch_copy_history(self) -> Iterable[KnownLineageMapping]:
267296
# Derived from _populate_external_lineage_from_copy_history.
268297

@@ -298,7 +327,7 @@ def fetch_copy_history(self) -> Iterable[KnownLineageMapping]:
298327
yield result
299328

300329
def fetch_query_log(
301-
self,
330+
self, users: UsersMapping
302331
) -> Iterable[Union[PreparsedQuery, TableRename, TableSwap]]:
303332
query_log_query = _build_enriched_query_log_query(
304333
start_time=self.config.window.start_time,
@@ -319,7 +348,7 @@ def fetch_query_log(
319348

320349
assert isinstance(row, dict)
321350
try:
322-
entry = self._parse_audit_log_row(row)
351+
entry = self._parse_audit_log_row(row, users)
323352
except Exception as e:
324353
self.structured_reporter.warning(
325354
"Error parsing query log row",
@@ -331,7 +360,7 @@ def fetch_query_log(
331360
yield entry
332361

333362
def _parse_audit_log_row(
334-
self, row: Dict[str, Any]
363+
self, row: Dict[str, Any], users: UsersMapping
335364
) -> Optional[Union[TableRename, TableSwap, PreparsedQuery]]:
336365
json_fields = {
337366
"DIRECT_OBJECTS_ACCESSED",
@@ -430,9 +459,11 @@ def _parse_audit_log_row(
430459
)
431460
)
432461

433-
# TODO: Fetch email addresses from Snowflake to map user -> email
434-
# TODO: Support email_domain fallback for generating user urns.
435-
user = CorpUserUrn(self.identifiers.snowflake_identifier(res["user_name"]))
462+
user = CorpUserUrn(
463+
self.identifiers.get_user_identifier(
464+
res["user_name"], users.get(res["user_name"])
465+
)
466+
)
436467

437468
timestamp: datetime = res["query_start_time"]
438469
timestamp = timestamp.astimezone(timezone.utc)

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -947,4 +947,8 @@ def dmf_assertion_results(start_time_millis: int, end_time_millis: int) -> str:
947947
AND METRIC_NAME ilike '{pattern}' escape '{escape_pattern}'
948948
ORDER BY MEASUREMENT_TIME ASC;
949949
950-
"""
950+
"""
951+
952+
@staticmethod
953+
def get_all_users() -> str:
954+
return """SELECT name as "NAME", email as "EMAIL" FROM SNOWFLAKE.ACCOUNT_USAGE.USERS"""

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -342,10 +342,9 @@ def _map_user_counts(
342342
filtered_user_counts.append(
343343
DatasetUserUsageCounts(
344344
user=make_user_urn(
345-
self.get_user_identifier(
345+
self.identifiers.get_user_identifier(
346346
user_count["user_name"],
347347
user_email,
348-
self.config.email_as_user_identifier,
349348
)
350349
),
351350
count=user_count["total"],
@@ -453,9 +452,7 @@ def _get_operation_aspect_work_unit(
453452
reported_time: int = int(time.time() * 1000)
454453
last_updated_timestamp: int = int(start_time.timestamp() * 1000)
455454
user_urn = make_user_urn(
456-
self.get_user_identifier(
457-
user_name, user_email, self.config.email_as_user_identifier
458-
)
455+
self.identifiers.get_user_identifier(user_name, user_email)
459456
)
460457

461458
# NOTE: In earlier `snowflake-usage` connector this was base_objects_accessed, which is incorrect

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,28 @@ def get_quoted_identifier_for_schema(db_name, schema_name):
300300
def get_quoted_identifier_for_table(db_name, schema_name, table_name):
301301
return f'"{db_name}"."{schema_name}"."{table_name}"'
302302

303+
# Note - decide how to construct user urns.
304+
# Historically urns were created using part before @ from user's email.
305+
# Users without email were skipped from both user entries as well as aggregates.
306+
# However email is not mandatory field in snowflake user, user_name is always present.
307+
def get_user_identifier(
308+
self,
309+
user_name: str,
310+
user_email: Optional[str],
311+
) -> str:
312+
if user_email:
313+
return self.snowflake_identifier(
314+
user_email
315+
if self.identifier_config.email_as_user_identifier is True
316+
else user_email.split("@")[0]
317+
)
318+
return self.snowflake_identifier(
319+
f"{user_name}@{self.identifier_config.email_domain}"
320+
if self.identifier_config.email_as_user_identifier is True
321+
and self.identifier_config.email_domain is not None
322+
else user_name
323+
)
324+
303325

304326
class SnowflakeCommonMixin(SnowflakeStructuredReportMixin):
305327
platform = "snowflake"
@@ -315,24 +337,6 @@ def structured_reporter(self) -> SourceReport:
315337
def identifiers(self) -> SnowflakeIdentifierBuilder:
316338
return SnowflakeIdentifierBuilder(self.config, self.report)
317339

318-
# Note - decide how to construct user urns.
319-
# Historically urns were created using part before @ from user's email.
320-
# Users without email were skipped from both user entries as well as aggregates.
321-
# However email is not mandatory field in snowflake user, user_name is always present.
322-
def get_user_identifier(
323-
self,
324-
user_name: str,
325-
user_email: Optional[str],
326-
email_as_user_identifier: bool,
327-
) -> str:
328-
if user_email:
329-
return self.identifiers.snowflake_identifier(
330-
user_email
331-
if email_as_user_identifier is True
332-
else user_email.split("@")[0]
333-
)
334-
return self.identifiers.snowflake_identifier(user_name)
335-
336340
# TODO: Revisit this after stateful ingestion can commit checkpoint
337341
# for failures that do not affect the checkpoint
338342
# TODO: Add additional parameters to match the signature of the .warning and .failure methods

metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,58 @@ def test_source_close_cleans_tmp(snowflake_connect, tmp_path):
2222
# This closes QueriesExtractor which in turn closes SqlParsingAggregator
2323
source.close()
2424
assert len(os.listdir(tmp_path)) == 0
25+
26+
27+
@patch("snowflake.connector.connect")
28+
def test_user_identifiers_email_as_identifier(snowflake_connect, tmp_path):
29+
source = SnowflakeQueriesSource.create(
30+
{
31+
"connection": {
32+
"account_id": "ABC12345.ap-south-1.aws",
33+
"username": "TST_USR",
34+
"password": "TST_PWD",
35+
},
36+
"email_as_user_identifier": True,
37+
"email_domain": "example.com",
38+
},
39+
PipelineContext("run-id"),
40+
)
41+
assert (
42+
source.identifiers.get_user_identifier("username", "[email protected]")
43+
44+
)
45+
assert (
46+
source.identifiers.get_user_identifier("username", None)
47+
48+
)
49+
50+
# We'd do best effort to use email as identifier, but would keep username as is,
51+
# if email can't be formed.
52+
source.identifiers.identifier_config.email_domain = None
53+
54+
assert (
55+
source.identifiers.get_user_identifier("username", "[email protected]")
56+
57+
)
58+
59+
assert source.identifiers.get_user_identifier("username", None) == "username"
60+
61+
62+
@patch("snowflake.connector.connect")
63+
def test_user_identifiers_username_as_identifier(snowflake_connect, tmp_path):
64+
source = SnowflakeQueriesSource.create(
65+
{
66+
"connection": {
67+
"account_id": "ABC12345.ap-south-1.aws",
68+
"username": "TST_USR",
69+
"password": "TST_PWD",
70+
},
71+
"email_as_user_identifier": False,
72+
},
73+
PipelineContext("run-id"),
74+
)
75+
assert (
76+
source.identifiers.get_user_identifier("username", "[email protected]")
77+
== "username"
78+
)
79+
assert source.identifiers.get_user_identifier("username", None) == "username"

0 commit comments

Comments
 (0)