Skip to content

Commit b7abbbd

Browse files
chore: slim down tableutils (#458)
## Summary ## Checklist - [ ] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Refactor** - Updated the data saving process for enhanced consistency by replacing the legacy unpartitioned saving functionality with a unified method that explicitly handles partition columns. - Removed the functionality to save unpartitioned DataFrames, ensuring all saves now require partition column specifications. - **Bug Fixes** - Removed unnecessary partition checks in tests, streamlining the validation process without impacting overall functionality. - **Tests** - Updated method calls in tests to reflect changes in how table formats are accessed, ensuring accurate validation of expected outcomes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> <!-- av pr metadata This information is embedded by the av CLI when creating PRs to track the status of stacks when using Aviator. Please do not delete or edit this section of the PR. ``` {"parent":"main","parentHead":"","trunk":"main"} ``` --> --------- Co-authored-by: Thomas Chow <[email protected]>
1 parent 38fe7b3 commit b7abbbd

File tree

8 files changed

+16
-48
lines changed

8 files changed

+16
-48
lines changed

cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigQueryCatalogTest.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ class BigQueryCatalogTest extends AnyFlatSpec with MockitoSugar {
6666
val nativeTable = "data.sample_native"
6767
val table = tableUtils.loadTable(nativeTable)
6868
table.show
69-
val partitioned = tableUtils.isPartitioned(nativeTable)
70-
println(partitioned)
7169
// val database = tableUtils.createDatabase("test_database")
7270
val allParts = tableUtils.allPartitions(nativeTable)
7371
println(allParts)
@@ -80,8 +78,6 @@ class BigQueryCatalogTest extends AnyFlatSpec with MockitoSugar {
8078
println(bs)
8179
val table = tableUtils.loadTable(externalTable)
8280
table.show
83-
val partitioned = tableUtils.isPartitioned(externalTable)
84-
println(partitioned)
8581
// val database = tableUtils.createDatabase("test_database")
8682
val allParts = tableUtils.allPartitions(externalTable)
8783
println(allParts)

spark/src/main/scala/ai/chronon/spark/Extensions.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,6 @@ object Extensions {
162162
sortByCols = sortByCols)
163163
}
164164

165-
def saveUnPartitioned(tableName: String, tableProperties: Map[String, String] = null): Unit = {
166-
TableUtils(df.sparkSession).insertUnPartitioned(df, tableName, tableProperties)
167-
}
168-
169165
def prefixColumnNames(prefix: String, columns: Seq[String]): DataFrame = {
170166
columns.foldLeft(df) { (renamedDf, key) =>
171167
renamedDf.withColumnRenamed(key, s"${prefix}_$key")

spark/src/main/scala/ai/chronon/spark/GroupByUpload.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ object GroupByUpload {
258258
kvDf
259259
.union(metaDf)
260260
.withColumn("ds", lit(endDs))
261-
.saveUnPartitioned(groupByConf.metaData.uploadTable, groupByConf.metaData.tableProps)
261+
.save(groupByConf.metaData.uploadTable, groupByConf.metaData.tableProps, partitionColumns = List.empty)
262262

263263
val kvDfReloaded = tableUtils
264264
.loadTable(groupByConf.metaData.uploadTable)

spark/src/main/scala/ai/chronon/spark/StagingQuery.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class StagingQuery(stagingQueryConf: api.StagingQuery, endPartition: String, tab
5151
}
5252
// the input table is not partitioned, usually for data testing or for kaggle demos
5353
if (stagingQueryConf.startPartition == null) {
54-
tableUtils.sql(stagingQueryConf.query).saveUnPartitioned(outputTable)
54+
tableUtils.sql(stagingQueryConf.query).save(outputTable, partitionColumns = List.empty)
5555
} else {
5656
val overrideStart = overrideStartPartition.getOrElse(stagingQueryConf.startPartition)
5757
val unfilledRanges =

spark/src/main/scala/ai/chronon/spark/TableUtils.scala

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import ai.chronon.spark.TableUtils.{
3030
TableCreationStatus
3131
}
3232
import ai.chronon.spark.format.CreationUtils.alterTablePropertiesSql
33-
import ai.chronon.spark.format.{DefaultFormatProvider, Format, FormatProvider}
33+
import ai.chronon.spark.format.{DefaultFormatProvider, FormatProvider}
3434
import org.apache.hadoop.hive.metastore.api.AlreadyExistsException
3535
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
3636
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project}
@@ -111,7 +111,6 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
111111
private val aggregationParallelism: Int = sparkSession.conf.get("spark.chronon.group_by.parallelism", "1000").toInt
112112

113113
sparkSession.sparkContext.setLogLevel("ERROR")
114-
// converts String-s like "a=b/c=d" to Map("a" -> "b", "c" -> "d")
115114

116115
def preAggRepartition(df: DataFrame): DataFrame =
117116
if (df.rdd.getNumPartitions < aggregationParallelism) {
@@ -122,7 +121,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
122121

123122
def tableReachable(tableName: String): Boolean = {
124123
try {
125-
tableReadFormat(tableName).isDefined
124+
tableFormatProvider.readFormat(tableName).isDefined
126125
} catch {
127126
case ex: Exception =>
128127
logger.info(s"""Couldn't reach $tableName. Error: ${ex.getMessage.red}
@@ -137,12 +136,6 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
137136
sparkSession.read.load(DataPointer.from(tableName, sparkSession))
138137
}
139138

140-
def isPartitioned(tableName: String): Boolean = {
141-
// TODO: use proper way to detect if a table is partitioned or not
142-
val schema = getSchemaFromTable(tableName)
143-
schema.fieldNames.contains(partitionColumn)
144-
}
145-
146139
// Needs provider
147140
def createDatabase(database: String): Boolean = {
148141
try {
@@ -159,17 +152,16 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
159152
}
160153
}
161154

162-
def tableReadFormat(tableName: String): Option[Format] = tableFormatProvider.readFormat(tableName)
163-
164-
// Needs provider
165155
// return all specified partition columns in a table in format of Map[partitionName, PartitionValue]
166156
def allPartitions(tableName: String, partitionColumnsFilter: Seq[String] = Seq.empty): Seq[Map[String, String]] = {
167157

168158
if (!tableReachable(tableName)) return Seq.empty[Map[String, String]]
169159

170-
val format = tableReadFormat(tableName).getOrElse(
171-
throw new IllegalStateException(
172-
s"Could not determine read format of table ${tableName}. It is no longer reachable."))
160+
val format = tableFormatProvider
161+
.readFormat(tableName)
162+
.getOrElse(
163+
throw new IllegalStateException(
164+
s"Could not determine read format of table ${tableName}. It is no longer reachable."))
173165
val partitionSeq = format.partitions(tableName)(sparkSession)
174166

175167
if (partitionColumnsFilter.isEmpty) {
@@ -189,7 +181,8 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
189181
subPartitionsFilter: Map[String, String] = Map.empty,
190182
partitionColumnName: String = partitionColumn): Seq[String] = {
191183

192-
tableReadFormat(tableName)
184+
tableFormatProvider
185+
.readFormat(tableName)
193186
.map((format) => {
194187
val partitions = format.primaryPartitions(tableName, partitionColumnName, subPartitionsFilter)(sparkSession)
195188

@@ -385,23 +378,6 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
385378
}
386379
}
387380

388-
// Needs provider
389-
def insertUnPartitioned(df: DataFrame,
390-
tableName: String,
391-
tableProperties: Map[String, String] = null,
392-
saveMode: SaveMode = SaveMode.Overwrite,
393-
fileFormat: String = "PARQUET"): Unit = {
394-
395-
val creationStatus = createTable(df, tableName, Seq.empty[String], tableProperties, fileFormat)
396-
397-
creationStatus match {
398-
case TableUtils.TableCreatedWithoutInitialData | TableUtils.TableAlreadyExists =>
399-
repartitionAndWrite(df, tableName, saveMode, None, partitionColumns = Seq.empty)
400-
case TableUtils.TableCreatedWithInitialData =>
401-
}
402-
403-
}
404-
405381
def columnSizeEstimator(dataType: DataType): Long = {
406382
dataType match {
407383
// TODO: improve upon this very basic estimate approach

spark/src/main/scala/ai/chronon/spark/stats/CompareJob.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,13 @@ class CompareJob(
7979
logger.info("Saving comparison output..")
8080
logger.info(
8181
s"Comparison schema ${compareDf.schema.fields.map(sb => (sb.name, sb.dataType)).toMap.mkString("\n - ")}")
82-
compareDf.saveUnPartitioned(comparisonTableName, tableProps)
82+
compareDf.save(comparisonTableName, tableProps, partitionColumns = List.empty)
8383

8484
// Save the metrics table
8585
logger.info("Saving metrics output..")
8686
val metricsDf = metricsTimedKvRdd.toFlatDf
8787
logger.info(s"Metrics schema ${metricsDf.schema.fields.map(sb => (sb.name, sb.dataType)).toMap.mkString("\n - ")}")
88-
metricsDf.saveUnPartitioned(metricsTableName, tableProps)
88+
metricsDf.save(metricsTableName, tableProps, partitionColumns = List.empty)
8989

9090
logger.info("Printing basic comparison results..")
9191
logger.info("(Note: This is just an estimation and not a detailed analysis of results)")

spark/src/main/scala/ai/chronon/spark/utils/PartitionRunner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class PartitionRunner[T](verb: String,
119119
if (outputDf.columns.contains(tu.partitionColumn)) {
120120
outputDf.save(outputTable)
121121
} else {
122-
outputDf.saveUnPartitioned(outputTable)
122+
outputDf.save(outputTable, partitionColumns = List.empty)
123123
}
124124
println(s"""
125125
|Finished computing range ${i + 1}/$n

spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class TableUtilsFormatTest extends AnyFlatSpec {
147147
it should "return empty read format if table doesn't exist" in {
148148
val dbName = s"db_${System.currentTimeMillis()}"
149149
val tableName = s"$dbName.test_table_nonexistent_$format"
150-
assertTrue(tableUtils.tableReadFormat(tableName).isEmpty)
150+
assertTrue(tableUtils.tableFormatProvider.readFormat(tableName).isEmpty)
151151
assertFalse(tableUtils.tableReachable(tableName))
152152
}
153153
}
@@ -188,7 +188,7 @@ object TableUtilsFormatTest {
188188
tableUtils.insertPartitions(df2, tableName, autoExpand = true)
189189

190190
// check that we wrote out a table in the right format
191-
val readTableFormat = tableUtils.tableReadFormat(tableName).get.toString
191+
val readTableFormat = tableUtils.tableFormatProvider.readFormat(tableName).get.toString
192192
assertTrue(s"Mismatch in table format: $readTableFormat; expected: $format", readTableFormat.toLowerCase == format)
193193

194194
// check we have all the partitions written

0 commit comments

Comments
 (0)