Skip to content

Commit 935b614

Browse files
fix the UTs
Co-authored-by: Thomas Chow <[email protected]>
1 parent 0b51360 commit 935b614

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

spark/src/main/scala/ai/chronon/spark/TableUtils.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
118118
def loadTable(tableName: String, rangeWheres: Seq[String] = List.empty[String]): DataFrame = {
119119
tableFormatProvider
120120
.readFormat(tableName)
121-
.map(_.table(tableName, combinePredicates(rangeWheres))(sparkSession))
121+
.map(_.table(tableName, andPredicates(rangeWheres))(sparkSession))
122122
.getOrElse(
123123
throw new RuntimeException(s"Could not load table: ${tableName} with partition filter: ${rangeWheres}"))
124124
}
@@ -568,7 +568,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
568568
}
569569
}
570570

571-
private def combinePredicates(predicates: Seq[String]): String = {
571+
private def andPredicates(predicates: Seq[String]): String = {
572572
val whereStr = predicates.map(p => s"($p)").mkString(" AND ")
573573
logger.info(s"""Where str: $whereStr""")
574574
whereStr
@@ -597,7 +597,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
597597
if (selects.nonEmpty) df = df.selectExpr(selects: _*)
598598

599599
if (wheres.nonEmpty) {
600-
val whereStr = combinePredicates(wheres)
600+
val whereStr = andPredicates(wheres)
601601
logger.info(s"""Where str: $whereStr""")
602602
df = df.where(whereStr)
603603
}

spark/src/main/scala/ai/chronon/spark/format/Format.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ trait Format {
99
@transient protected lazy val logger: Logger = LoggerFactory.getLogger(getClass)
1010

1111
def table(tableName: String, partitionFilters: String)(implicit sparkSession: SparkSession): DataFrame = {
12-
sparkSession.read
13-
.table(tableName)
14-
.where(partitionFilters)
12+
val df = sparkSession.read.table(tableName)
13+
if (partitionFilters.isEmpty) {
14+
df
15+
} else {
16+
df.where(partitionFilters)
17+
}
1518
}
1619

1720
// Return the primary partitions (based on the 'partitionColumn') filtered down by sub-partition filters if provided

0 commit comments

Comments
 (0)