File tree 2 files changed +9
-6
lines changed
spark/src/main/scala/ai/chronon/spark 2 files changed +9
-6
lines changed Original file line number Diff line number Diff line change @@ -118,7 +118,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
118
118
def loadTable (tableName : String , rangeWheres : Seq [String ] = List .empty[String ]): DataFrame = {
119
119
tableFormatProvider
120
120
.readFormat(tableName)
121
- .map(_.table(tableName, combinePredicates (rangeWheres))(sparkSession))
121
+ .map(_.table(tableName, andPredicates (rangeWheres))(sparkSession))
122
122
.getOrElse(
123
123
throw new RuntimeException (s " Could not load table: ${tableName} with partition filter: ${rangeWheres}" ))
124
124
}
@@ -568,7 +568,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
568
568
}
569
569
}
570
570
571
- private def combinePredicates (predicates : Seq [String ]): String = {
571
+ private def andPredicates (predicates : Seq [String ]): String = {
572
572
val whereStr = predicates.map(p => s " ( $p) " ).mkString(" AND " )
573
573
logger.info(s """ Where str: $whereStr""" )
574
574
whereStr
@@ -597,7 +597,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
597
597
if (selects.nonEmpty) df = df.selectExpr(selects : _* )
598
598
599
599
if (wheres.nonEmpty) {
600
- val whereStr = combinePredicates (wheres)
600
+ val whereStr = andPredicates (wheres)
601
601
logger.info(s """ Where str: $whereStr""" )
602
602
df = df.where(whereStr)
603
603
}
Original file line number Diff line number Diff line change @@ -9,9 +9,12 @@ trait Format {
9
9
@ transient protected lazy val logger : Logger = LoggerFactory .getLogger(getClass)
10
10
11
11
def table (tableName : String , partitionFilters : String )(implicit sparkSession : SparkSession ): DataFrame = {
12
- sparkSession.read
13
- .table(tableName)
14
- .where(partitionFilters)
12
+ val df = sparkSession.read.table(tableName)
13
+ if (partitionFilters.isEmpty) {
14
+ df
15
+ } else {
16
+ df.where(partitionFilters)
17
+ }
15
18
}
16
19
17
20
// Return the primary partitions (based on the 'partitionColumn') filtered down by sub-partition filters if provided
You can’t perform that action at this time.
0 commit comments