26
26
platform_name ,
27
27
support_status ,
28
28
)
29
+ from datahub .ingestion .api .source import StructuredLogLevel
29
30
from datahub .ingestion .api .workunit import MetadataWorkUnit
30
31
from datahub .ingestion .source .aws .s3_util import make_s3_urn
31
32
from datahub .ingestion .source .common .subtypes import DatasetContainerSubTypes
35
36
register_custom_type ,
36
37
)
37
38
from datahub .ingestion .source .sql .sql_config import SQLCommonConfig , make_sqlalchemy_uri
39
+ from datahub .ingestion .source .sql .sql_report import SQLSourceReport
38
40
from datahub .ingestion .source .sql .sql_utils import (
39
41
add_table_to_schema_container ,
40
42
gen_database_container ,
48
50
get_schema_fields_for_sqlalchemy_column ,
49
51
)
50
52
53
+ try :
54
+ from typing_extensions import override
55
+ except ImportError :
56
+ _F = typing .TypeVar ("_F" , bound = typing .Callable [..., typing .Any ])
57
+
58
+ def override (f : _F , / ) -> _F : # noqa: F811
59
+ return f
60
+
61
+
51
62
logger = logging .getLogger (__name__ )
52
63
53
64
assert STRUCT , "required type modules are not available"
@@ -322,12 +333,15 @@ class AthenaSource(SQLAlchemySource):
322
333
- Profiling when enabled.
323
334
"""
324
335
325
- table_partition_cache : Dict [str , Dict [str , Partitionitem ]] = {}
336
+ config : AthenaConfig
337
+ report : SQLSourceReport
326
338
327
339
def __init__ (self , config , ctx ):
328
340
super ().__init__ (config , ctx , "athena" )
329
341
self .cursor : Optional [BaseCursor ] = None
330
342
343
+ self .table_partition_cache : Dict [str , Dict [str , Partitionitem ]] = {}
344
+
331
345
@classmethod
332
346
def create (cls , config_dict , ctx ):
333
347
config = AthenaConfig .parse_obj (config_dict )
@@ -452,41 +466,50 @@ def add_table_to_schema_container(
452
466
)
453
467
454
468
# It seems like database/schema filter in the connection string does not work and this to work around that
469
+ @override
455
470
def get_schema_names (self , inspector : Inspector ) -> List [str ]:
456
471
athena_config = typing .cast (AthenaConfig , self .config )
457
472
schemas = inspector .get_schema_names ()
458
473
if athena_config .database :
459
474
return [schema for schema in schemas if schema == athena_config .database ]
460
475
return schemas
461
476
462
- # Overwrite to get partitions
477
+ @classmethod
478
+ def _casted_partition_key (cls , key : str ) -> str :
479
+ # We need to cast the partition keys to a VARCHAR, since otherwise
480
+ # Athena may throw an error during concatenation / comparison.
481
+ return f"CAST({ key } as VARCHAR)"
482
+
483
+ @override
463
484
def get_partitions (
464
485
self , inspector : Inspector , schema : str , table : str
465
- ) -> List [str ]:
466
- partitions = []
467
-
468
- athena_config = typing .cast (AthenaConfig , self .config )
469
-
470
- if not athena_config .extract_partitions :
471
- return []
486
+ ) -> Optional [List [str ]]:
487
+ if not self .config .extract_partitions :
488
+ return None
472
489
473
490
if not self .cursor :
474
- return []
491
+ return None
475
492
476
493
metadata : AthenaTableMetadata = self .cursor .get_table_metadata (
477
494
table_name = table , schema_name = schema
478
495
)
479
496
480
- if metadata .partition_keys :
481
- for key in metadata .partition_keys :
482
- if key .name :
483
- partitions .append (key .name )
484
-
485
- if not partitions :
486
- return []
497
+ partitions = []
498
+ for key in metadata .partition_keys :
499
+ if key .name :
500
+ partitions .append (key .name )
501
+ if not partitions :
502
+ return []
487
503
488
- # We create an artiificaial concatenated partition key to be able to query max partition easier
489
- part_concat = "|| '-' ||" .join (partitions )
504
+ with self .report .report_exc (
505
+ message = "Failed to extract partition details" ,
506
+ context = f"{ schema } .{ table } " ,
507
+ level = StructuredLogLevel .WARN ,
508
+ ):
509
+ # We create an artifical concatenated partition key to be able to query max partition easier
510
+ part_concat = " || '-' || " .join (
511
+ self ._casted_partition_key (key ) for key in partitions
512
+ )
490
513
max_partition_query = f'select { "," .join (partitions )} from "{ schema } "."{ table } $partitions" where { part_concat } = (select max({ part_concat } ) from "{ schema } "."{ table } $partitions")'
491
514
ret = self .cursor .execute (max_partition_query )
492
515
max_partition : Dict [str , str ] = {}
@@ -500,9 +523,8 @@ def get_partitions(
500
523
partitions = partitions ,
501
524
max_partition = max_partition ,
502
525
)
503
- return partitions
504
526
505
- return []
527
+ return partitions
506
528
507
529
# Overwrite to modify the creation of schema fields
508
530
def get_schema_fields_for_column (
@@ -551,7 +573,9 @@ def generate_partition_profiler_query(
551
573
if partition and partition .max_partition :
552
574
max_partition_filters = []
553
575
for key , value in partition .max_partition .items ():
554
- max_partition_filters .append (f"CAST({ key } as VARCHAR) = '{ value } '" )
576
+ max_partition_filters .append (
577
+ f"{ self ._casted_partition_key (key )} = '{ value } '"
578
+ )
555
579
max_partition = str (partition .max_partition )
556
580
return (
557
581
max_partition ,
0 commit comments