Skip to content

Commit a46de1e

Browse files
authored
feat(ingest/athena): handle partition fetching errors (datahub-project#11966)
1 parent a92c6b2 commit a46de1e

File tree

2 files changed

+88
-25
lines changed

2 files changed

+88
-25
lines changed

metadata-ingestion/src/datahub/ingestion/source/sql/athena.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
platform_name,
2727
support_status,
2828
)
29+
from datahub.ingestion.api.source import StructuredLogLevel
2930
from datahub.ingestion.api.workunit import MetadataWorkUnit
3031
from datahub.ingestion.source.aws.s3_util import make_s3_urn
3132
from datahub.ingestion.source.common.subtypes import DatasetContainerSubTypes
@@ -35,6 +36,7 @@
3536
register_custom_type,
3637
)
3738
from datahub.ingestion.source.sql.sql_config import SQLCommonConfig, make_sqlalchemy_uri
39+
from datahub.ingestion.source.sql.sql_report import SQLSourceReport
3840
from datahub.ingestion.source.sql.sql_utils import (
3941
add_table_to_schema_container,
4042
gen_database_container,
@@ -48,6 +50,15 @@
4850
get_schema_fields_for_sqlalchemy_column,
4951
)
5052

53+
try:
54+
from typing_extensions import override
55+
except ImportError:
56+
_F = typing.TypeVar("_F", bound=typing.Callable[..., typing.Any])
57+
58+
def override(f: _F, /) -> _F: # noqa: F811
59+
return f
60+
61+
5162
logger = logging.getLogger(__name__)
5263

5364
assert STRUCT, "required type modules are not available"
@@ -322,12 +333,15 @@ class AthenaSource(SQLAlchemySource):
322333
- Profiling when enabled.
323334
"""
324335

325-
table_partition_cache: Dict[str, Dict[str, Partitionitem]] = {}
336+
config: AthenaConfig
337+
report: SQLSourceReport
326338

327339
def __init__(self, config, ctx):
328340
super().__init__(config, ctx, "athena")
329341
self.cursor: Optional[BaseCursor] = None
330342

343+
self.table_partition_cache: Dict[str, Dict[str, Partitionitem]] = {}
344+
331345
@classmethod
332346
def create(cls, config_dict, ctx):
333347
config = AthenaConfig.parse_obj(config_dict)
@@ -452,41 +466,50 @@ def add_table_to_schema_container(
452466
)
453467

454468
# It seems like database/schema filter in the connection string does not work and this to work around that
469+
@override
455470
def get_schema_names(self, inspector: Inspector) -> List[str]:
456471
athena_config = typing.cast(AthenaConfig, self.config)
457472
schemas = inspector.get_schema_names()
458473
if athena_config.database:
459474
return [schema for schema in schemas if schema == athena_config.database]
460475
return schemas
461476

462-
# Overwrite to get partitions
477+
@classmethod
478+
def _casted_partition_key(cls, key: str) -> str:
479+
# We need to cast the partition keys to a VARCHAR, since otherwise
480+
# Athena may throw an error during concatenation / comparison.
481+
return f"CAST({key} as VARCHAR)"
482+
483+
@override
463484
def get_partitions(
464485
self, inspector: Inspector, schema: str, table: str
465-
) -> List[str]:
466-
partitions = []
467-
468-
athena_config = typing.cast(AthenaConfig, self.config)
469-
470-
if not athena_config.extract_partitions:
471-
return []
486+
) -> Optional[List[str]]:
487+
if not self.config.extract_partitions:
488+
return None
472489

473490
if not self.cursor:
474-
return []
491+
return None
475492

476493
metadata: AthenaTableMetadata = self.cursor.get_table_metadata(
477494
table_name=table, schema_name=schema
478495
)
479496

480-
if metadata.partition_keys:
481-
for key in metadata.partition_keys:
482-
if key.name:
483-
partitions.append(key.name)
484-
485-
if not partitions:
486-
return []
497+
partitions = []
498+
for key in metadata.partition_keys:
499+
if key.name:
500+
partitions.append(key.name)
501+
if not partitions:
502+
return []
487503

488-
# We create an artiificaial concatenated partition key to be able to query max partition easier
489-
part_concat = "|| '-' ||".join(partitions)
504+
with self.report.report_exc(
505+
message="Failed to extract partition details",
506+
context=f"{schema}.{table}",
507+
level=StructuredLogLevel.WARN,
508+
):
509+
# We create an artifical concatenated partition key to be able to query max partition easier
510+
part_concat = " || '-' || ".join(
511+
self._casted_partition_key(key) for key in partitions
512+
)
490513
max_partition_query = f'select {",".join(partitions)} from "{schema}"."{table}$partitions" where {part_concat} = (select max({part_concat}) from "{schema}"."{table}$partitions")'
491514
ret = self.cursor.execute(max_partition_query)
492515
max_partition: Dict[str, str] = {}
@@ -500,9 +523,8 @@ def get_partitions(
500523
partitions=partitions,
501524
max_partition=max_partition,
502525
)
503-
return partitions
504526

505-
return []
527+
return partitions
506528

507529
# Overwrite to modify the creation of schema fields
508530
def get_schema_fields_for_column(
@@ -551,7 +573,9 @@ def generate_partition_profiler_query(
551573
if partition and partition.max_partition:
552574
max_partition_filters = []
553575
for key, value in partition.max_partition.items():
554-
max_partition_filters.append(f"CAST({key} as VARCHAR) = '{value}'")
576+
max_partition_filters.append(
577+
f"{self._casted_partition_key(key)} = '{value}'"
578+
)
555579
max_partition = str(partition.max_partition)
556580
return (
557581
max_partition,

metadata-ingestion/tests/unit/test_athena_source.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def test_athena_get_table_properties():
9393
"CreateTime": datetime.now(),
9494
"LastAccessTime": datetime.now(),
9595
"PartitionKeys": [
96-
{"Name": "testKey", "Type": "string", "Comment": "testComment"}
96+
{"Name": "year", "Type": "string", "Comment": "testComment"},
97+
{"Name": "month", "Type": "string", "Comment": "testComment"},
9798
],
9899
"Parameters": {
99100
"comment": "testComment",
@@ -112,8 +113,18 @@ def test_athena_get_table_properties():
112113
response=table_metadata
113114
)
114115

116+
# Mock partition query results
117+
mock_cursor.execute.return_value.description = [
118+
["year"],
119+
["month"],
120+
]
121+
mock_cursor.execute.return_value.__iter__.return_value = [["2023", "12"]]
122+
115123
ctx = PipelineContext(run_id="test")
116124
source = AthenaSource(config=config, ctx=ctx)
125+
source.cursor = mock_cursor
126+
127+
# Test table properties
117128
description, custom_properties, location = source.get_table_properties(
118129
inspector=mock_inspector, table=table, schema=schema
119130
)
@@ -124,13 +135,35 @@ def test_athena_get_table_properties():
124135
"last_access_time": "2020-04-14 07:00:00",
125136
"location": "s3://testLocation",
126137
"outputformat": "testOutputFormat",
127-
"partition_keys": '[{"name": "testKey", "type": "string", "comment": "testComment"}]',
138+
"partition_keys": '[{"name": "year", "type": "string", "comment": "testComment"}, {"name": "month", "type": "string", "comment": "testComment"}]',
128139
"serde.serialization.lib": "testSerde",
129140
"table_type": "testType",
130141
}
131-
132142
assert location == make_s3_urn("s3://testLocation", "PROD")
133143

144+
# Test partition functionality
145+
partitions = source.get_partitions(
146+
inspector=mock_inspector, schema=schema, table=table
147+
)
148+
assert partitions == ["year", "month"]
149+
150+
# Verify the correct SQL query was generated for partitions
151+
expected_query = """\
152+
select year,month from "test_schema"."test_table$partitions" \
153+
where CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR) = \
154+
(select max(CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR)) \
155+
from "test_schema"."test_table$partitions")"""
156+
mock_cursor.execute.assert_called_once()
157+
actual_query = mock_cursor.execute.call_args[0][0]
158+
assert actual_query == expected_query
159+
160+
# Verify partition cache was populated correctly
161+
assert source.table_partition_cache[schema][table].partitions == partitions
162+
assert source.table_partition_cache[schema][table].max_partition == {
163+
"year": "2023",
164+
"month": "12",
165+
}
166+
134167

135168
def test_get_column_type_simple_types():
136169
assert isinstance(
@@ -214,3 +247,9 @@ def test_column_type_complex_combination():
214247
assert isinstance(
215248
result._STRUCT_fields[2][1].item_type._STRUCT_fields[1][1], types.String
216249
)
250+
251+
252+
def test_casted_partition_key():
253+
from datahub.ingestion.source.sql.athena import AthenaSource
254+
255+
assert AthenaSource._casted_partition_key("test_col") == "CAST(test_col as VARCHAR)"

0 commit comments

Comments
 (0)