Skip to content

Commit e3d6632

Browse files
committed
Refactor on comments
1 parent 6affb63 commit e3d6632

File tree

1 file changed

+67
-41
lines changed

1 file changed

+67
-41
lines changed

spark/src/main/scala/ai/chronon/spark/TableUtils.scala

Lines changed: 67 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,27 @@ import scala.util.{Failure, Success, Try}
4848
* retrieve metadata / configure it appropriately at creation time
4949
*/
5050
trait Format {
51+
// Return the primary partitions (based on the 'partitionColumn') filtered down by sub-partition filters if provided
52+
// If subpartition filters are supplied and the format doesn't support it, we throw an error
53+
def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String] = Map.empty)(implicit sparkSession: SparkSession): Seq[String] = {
54+
if (!supportSubPartitionsFilter && subPartitionsFilter.nonEmpty) {
55+
throw new NotImplementedError(s"subPartitionsFilter is not supported on this format")
56+
}
57+
58+
val partitionSeq = partitions(tableName)(sparkSession)
59+
partitionSeq.flatMap { partitionMap =>
60+
if (
61+
subPartitionsFilter.forall {
62+
case (k, v) => partitionMap.get(k).contains(v)
63+
}
64+
) {
65+
partitionMap.get(partitionColumn)
66+
} else {
67+
None
68+
}
69+
}
70+
}
71+
5172
// Return a sequence for partitions where each partition entry consists of a Map of partition keys to values
5273
def partitions(tableName: String)(implicit sparkSession: SparkSession): Seq[Map[String, String]]
5374

@@ -56,6 +77,9 @@ trait Format {
5677

5778
// Help specify the appropriate file format to use in the Spark create table DDL query
5879
def fileFormatString(format: String): String
80+
81+
// Does this format support sub partitions filters
82+
def supportSubPartitionsFilter: Boolean
5983
}
6084

6185
/**
@@ -143,6 +167,9 @@ case class DefaultFormatProvider(sparkSession: SparkSession) extends FormatProvi
143167
}
144168

145169
case object Hive extends Format {
170+
override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(implicit sparkSession: SparkSession): Seq[String] =
171+
super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)
172+
146173
def parseHivePartition(pstring: String): Map[String, String] = {
147174
pstring
148175
.split("/")
@@ -166,24 +193,61 @@ case object Hive extends Format {
166193

167194
def createTableTypeString: String = ""
168195
def fileFormatString(format: String): String = s"STORED AS $format"
196+
197+
override def supportSubPartitionsFilter: Boolean = true
169198
}
170199

171200
case object Iceberg extends Format {
201+
override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(implicit sparkSession: SparkSession): Seq[String] = {
202+
if (!supportSubPartitionsFilter && subPartitionsFilter.nonEmpty) {
203+
throw new NotImplementedError(s"subPartitionsFilter is not supported on this format")
204+
}
205+
206+
getIcebergPartitions(tableName)
207+
}
208+
172209
override def partitions(tableName: String)(implicit sparkSession: SparkSession): Seq[Map[String, String]] = {
173210
throw new NotImplementedError(
174211
"Multi-partitions retrieval is not supported on Iceberg tables yet." +
175212
"For single partition retrieval, please use 'partition' method.")
176213
}
177214

215+
private def getIcebergPartitions(tableName: String)(implicit sparkSession: SparkSession): Seq[String] = {
216+
val partitionsDf = sparkSession.read.format("iceberg").load(s"$tableName.partitions")
217+
val index = partitionsDf.schema.fieldIndex("partition")
218+
if (partitionsDf.schema(index).dataType.asInstanceOf[StructType].fieldNames.contains("hr")) {
219+
// Hour filter is currently buggy in iceberg. https://github.com/apache/iceberg/issues/4718
220+
// so we collect and then filter.
221+
partitionsDf
222+
.select("partition.ds", "partition.hr")
223+
.collect()
224+
.filter(_.get(1) == null)
225+
.map(_.getString(0))
226+
.toSeq
227+
} else {
228+
partitionsDf
229+
.select("partition.ds")
230+
.collect()
231+
.map(_.getString(0))
232+
.toSeq
233+
}
234+
}
235+
178236
def createTableTypeString: String = "USING iceberg"
179237
def fileFormatString(format: String): String = ""
238+
239+
override def supportSubPartitionsFilter: Boolean = false
180240
}
181241

182242
// The Delta Lake format is compatible with the Delta lake and Spark versions currently supported by the project.
183243
// Attempting to use newer Delta lake library versions (e.g. 3.2 which works with Spark 3.5) results in errors:
184244
// java.lang.NoSuchMethodError: 'org.apache.spark.sql.delta.Snapshot org.apache.spark.sql.delta.DeltaLog.update(boolean)'
185245
// In such cases, you should implement your own FormatProvider built on the newer Delta lake version
186246
case object DeltaLake extends Format {
247+
248+
override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])(implicit sparkSession: SparkSession): Seq[String] =
249+
super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)
250+
187251
override def partitions(tableName: String)(implicit sparkSession: SparkSession): Seq[Map[String, String]] = {
188252
// delta lake doesn't support the `SHOW PARTITIONS <tableName>` syntax - https://github.com/delta-io/delta/issues/996
189253
// there's alternative ways to retrieve partitions using the DeltaLog abstraction which is what we have to lean into
@@ -200,6 +264,8 @@ case object DeltaLake extends Format {
200264

201265
def createTableTypeString: String = "USING DELTA"
202266
def fileFormatString(format: String): String = ""
267+
268+
override def supportSubPartitionsFilter: Boolean = true
203269
}
204270

205271
case class TableUtils(sparkSession: SparkSession) {
@@ -315,47 +381,7 @@ case class TableUtils(sparkSession: SparkSession) {
315381
def partitions(tableName: String, subPartitionsFilter: Map[String, String] = Map.empty): Seq[String] = {
316382
if (!tableExists(tableName)) return Seq.empty[String]
317383
val format = tableReadFormat(tableName)
318-
319-
if (format == Iceberg) {
320-
if (subPartitionsFilter.nonEmpty) {
321-
throw new NotImplementedError("subPartitionsFilter is not supported on Iceberg tables yet.")
322-
}
323-
return getIcebergPartitions(tableName)
324-
}
325-
326-
val partitionSeq = format.partitions(tableName)(sparkSession)
327-
partitionSeq.flatMap { partitionMap =>
328-
if (
329-
subPartitionsFilter.forall {
330-
case (k, v) => partitionMap.get(k).contains(v)
331-
}
332-
) {
333-
partitionMap.get(partitionColumn)
334-
} else {
335-
None
336-
}
337-
}
338-
}
339-
340-
private def getIcebergPartitions(tableName: String): Seq[String] = {
341-
val partitionsDf = sparkSession.read.format("iceberg").load(s"$tableName.partitions")
342-
val index = partitionsDf.schema.fieldIndex("partition")
343-
if (partitionsDf.schema(index).dataType.asInstanceOf[StructType].fieldNames.contains("hr")) {
344-
// Hour filter is currently buggy in iceberg. https://github.com/apache/iceberg/issues/4718
345-
// so we collect and then filter.
346-
partitionsDf
347-
.select("partition.ds", "partition.hr")
348-
.collect()
349-
.filter(_.get(1) == null)
350-
.map(_.getString(0))
351-
.toSeq
352-
} else {
353-
partitionsDf
354-
.select("partition.ds")
355-
.collect()
356-
.map(_.getString(0))
357-
.toSeq
358-
}
384+
format.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)(sparkSession)
359385
}
360386

361387
// Given a table and a query extract the schema of the columns involved as input.

0 commit comments

Comments
 (0)