@@ -48,6 +48,27 @@ import scala.util.{Failure, Success, Try}
48
48
* retrieve metadata / configure it appropriately at creation time
49
49
*/
50
50
trait Format {
51
+ // Return the primary partitions (based on the 'partitionColumn') filtered down by sub-partition filters if provided
52
+ // If subpartition filters are supplied and the format doesn't support it, we throw an error
53
+ def primaryPartitions (tableName : String , partitionColumn : String , subPartitionsFilter : Map [String , String ] = Map .empty)(implicit sparkSession : SparkSession ): Seq [String ] = {
54
+ if (! supportSubPartitionsFilter && subPartitionsFilter.nonEmpty) {
55
+ throw new NotImplementedError (s " subPartitionsFilter is not supported on this format " )
56
+ }
57
+
58
+ val partitionSeq = partitions(tableName)(sparkSession)
59
+ partitionSeq.flatMap { partitionMap =>
60
+ if (
61
+ subPartitionsFilter.forall {
62
+ case (k, v) => partitionMap.get(k).contains(v)
63
+ }
64
+ ) {
65
+ partitionMap.get(partitionColumn)
66
+ } else {
67
+ None
68
+ }
69
+ }
70
+ }
71
+
51
72
// Return a sequence for partitions where each partition entry consists of a Map of partition keys to values
52
73
def partitions (tableName : String )(implicit sparkSession : SparkSession ): Seq [Map [String , String ]]
53
74
@@ -56,6 +77,9 @@ trait Format {
56
77
57
78
// Help specify the appropriate file format to use in the Spark create table DDL query
58
79
def fileFormatString (format : String ): String
80
+
81
+ // Does this format support sub partitions filters
82
+ def supportSubPartitionsFilter : Boolean
59
83
}
60
84
61
85
/**
@@ -143,6 +167,9 @@ case class DefaultFormatProvider(sparkSession: SparkSession) extends FormatProvi
143
167
}
144
168
145
169
case object Hive extends Format {
170
+ override def primaryPartitions (tableName : String , partitionColumn : String , subPartitionsFilter : Map [String , String ])(implicit sparkSession : SparkSession ): Seq [String ] =
171
+ super .primaryPartitions(tableName, partitionColumn, subPartitionsFilter)
172
+
146
173
def parseHivePartition (pstring : String ): Map [String , String ] = {
147
174
pstring
148
175
.split(" /" )
@@ -166,24 +193,61 @@ case object Hive extends Format {
166
193
167
194
def createTableTypeString : String = " "
168
195
def fileFormatString (format : String ): String = s " STORED AS $format"
196
+
197
+ override def supportSubPartitionsFilter : Boolean = true
169
198
}
170
199
171
200
case object Iceberg extends Format {
201
+ override def primaryPartitions (tableName : String , partitionColumn : String , subPartitionsFilter : Map [String , String ])(implicit sparkSession : SparkSession ): Seq [String ] = {
202
+ if (! supportSubPartitionsFilter && subPartitionsFilter.nonEmpty) {
203
+ throw new NotImplementedError (s " subPartitionsFilter is not supported on this format " )
204
+ }
205
+
206
+ getIcebergPartitions(tableName)
207
+ }
208
+
172
209
override def partitions (tableName : String )(implicit sparkSession : SparkSession ): Seq [Map [String , String ]] = {
173
210
throw new NotImplementedError (
174
211
" Multi-partitions retrieval is not supported on Iceberg tables yet." +
175
212
" For single partition retrieval, please use 'partition' method." )
176
213
}
177
214
215
+ private def getIcebergPartitions (tableName : String )(implicit sparkSession : SparkSession ): Seq [String ] = {
216
+ val partitionsDf = sparkSession.read.format(" iceberg" ).load(s " $tableName.partitions " )
217
+ val index = partitionsDf.schema.fieldIndex(" partition" )
218
+ if (partitionsDf.schema(index).dataType.asInstanceOf [StructType ].fieldNames.contains(" hr" )) {
219
+ // Hour filter is currently buggy in iceberg. https://github.com/apache/iceberg/issues/4718
220
+ // so we collect and then filter.
221
+ partitionsDf
222
+ .select(" partition.ds" , " partition.hr" )
223
+ .collect()
224
+ .filter(_.get(1 ) == null )
225
+ .map(_.getString(0 ))
226
+ .toSeq
227
+ } else {
228
+ partitionsDf
229
+ .select(" partition.ds" )
230
+ .collect()
231
+ .map(_.getString(0 ))
232
+ .toSeq
233
+ }
234
+ }
235
+
178
236
def createTableTypeString : String = " USING iceberg"
179
237
def fileFormatString (format : String ): String = " "
238
+
239
+ override def supportSubPartitionsFilter : Boolean = false
180
240
}
181
241
182
242
// The Delta Lake format is compatible with the Delta lake and Spark versions currently supported by the project.
183
243
// Attempting to use newer Delta lake library versions (e.g. 3.2 which works with Spark 3.5) results in errors:
184
244
// java.lang.NoSuchMethodError: 'org.apache.spark.sql.delta.Snapshot org.apache.spark.sql.delta.DeltaLog.update(boolean)'
185
245
// In such cases, you should implement your own FormatProvider built on the newer Delta lake version
186
246
case object DeltaLake extends Format {
247
+
248
+ override def primaryPartitions (tableName : String , partitionColumn : String , subPartitionsFilter : Map [String , String ])(implicit sparkSession : SparkSession ): Seq [String ] =
249
+ super .primaryPartitions(tableName, partitionColumn, subPartitionsFilter)
250
+
187
251
override def partitions (tableName : String )(implicit sparkSession : SparkSession ): Seq [Map [String , String ]] = {
188
252
// delta lake doesn't support the `SHOW PARTITIONS <tableName>` syntax - https://github.com/delta-io/delta/issues/996
189
253
// there's alternative ways to retrieve partitions using the DeltaLog abstraction which is what we have to lean into
@@ -200,6 +264,8 @@ case object DeltaLake extends Format {
200
264
201
265
def createTableTypeString : String = " USING DELTA"
202
266
def fileFormatString (format : String ): String = " "
267
+
268
+ override def supportSubPartitionsFilter : Boolean = true
203
269
}
204
270
205
271
case class TableUtils (sparkSession : SparkSession ) {
@@ -315,47 +381,7 @@ case class TableUtils(sparkSession: SparkSession) {
315
381
def partitions (tableName : String , subPartitionsFilter : Map [String , String ] = Map .empty): Seq [String ] = {
316
382
if (! tableExists(tableName)) return Seq .empty[String ]
317
383
val format = tableReadFormat(tableName)
318
-
319
- if (format == Iceberg ) {
320
- if (subPartitionsFilter.nonEmpty) {
321
- throw new NotImplementedError (" subPartitionsFilter is not supported on Iceberg tables yet." )
322
- }
323
- return getIcebergPartitions(tableName)
324
- }
325
-
326
- val partitionSeq = format.partitions(tableName)(sparkSession)
327
- partitionSeq.flatMap { partitionMap =>
328
- if (
329
- subPartitionsFilter.forall {
330
- case (k, v) => partitionMap.get(k).contains(v)
331
- }
332
- ) {
333
- partitionMap.get(partitionColumn)
334
- } else {
335
- None
336
- }
337
- }
338
- }
339
-
340
- private def getIcebergPartitions (tableName : String ): Seq [String ] = {
341
- val partitionsDf = sparkSession.read.format(" iceberg" ).load(s " $tableName.partitions " )
342
- val index = partitionsDf.schema.fieldIndex(" partition" )
343
- if (partitionsDf.schema(index).dataType.asInstanceOf [StructType ].fieldNames.contains(" hr" )) {
344
- // Hour filter is currently buggy in iceberg. https://github.com/apache/iceberg/issues/4718
345
- // so we collect and then filter.
346
- partitionsDf
347
- .select(" partition.ds" , " partition.hr" )
348
- .collect()
349
- .filter(_.get(1 ) == null )
350
- .map(_.getString(0 ))
351
- .toSeq
352
- } else {
353
- partitionsDf
354
- .select(" partition.ds" )
355
- .collect()
356
- .map(_.getString(0 ))
357
- .toSeq
358
- }
384
+ format.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)(sparkSession)
359
385
}
360
386
361
387
// Given a table and a query extract the schema of the columns involved as input.
0 commit comments