Skip to content

Commit ff27789

Browse files
rebase
Co-authored-by: Thomas Chow <[email protected]>
1 parent ef69ef0 commit ff27789

File tree

8 files changed

+16
-48
lines changed

8 files changed

+16
-48
lines changed

cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigQueryCatalogTest.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ class BigQueryCatalogTest extends AnyFlatSpec with MockitoSugar {
6666
val nativeTable = "data.sample_native"
6767
val table = tableUtils.loadTable(nativeTable)
6868
table.show
69-
val partitioned = tableUtils.isPartitioned(nativeTable)
70-
println(partitioned)
7169
// val database = tableUtils.createDatabase("test_database")
7270
val allParts = tableUtils.allPartitions(nativeTable)
7371
println(allParts)
@@ -80,8 +78,6 @@ class BigQueryCatalogTest extends AnyFlatSpec with MockitoSugar {
8078
println(bs)
8179
val table = tableUtils.loadTable(externalTable)
8280
table.show
83-
val partitioned = tableUtils.isPartitioned(externalTable)
84-
println(partitioned)
8581
// val database = tableUtils.createDatabase("test_database")
8682
val allParts = tableUtils.allPartitions(externalTable)
8783
println(allParts)

spark/src/main/scala/ai/chronon/spark/Extensions.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,6 @@ object Extensions {
162162
sortByCols = sortByCols)
163163
}
164164

165-
def saveUnPartitioned(tableName: String, tableProperties: Map[String, String] = null): Unit = {
166-
TableUtils(df.sparkSession).insertUnPartitioned(df, tableName, tableProperties)
167-
}
168-
169165
def prefixColumnNames(prefix: String, columns: Seq[String]): DataFrame = {
170166
columns.foldLeft(df) { (renamedDf, key) =>
171167
renamedDf.withColumnRenamed(key, s"${prefix}_$key")

spark/src/main/scala/ai/chronon/spark/GroupByUpload.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ object GroupByUpload {
258258
kvDf
259259
.union(metaDf)
260260
.withColumn("ds", lit(endDs))
261-
.saveUnPartitioned(groupByConf.metaData.uploadTable, groupByConf.metaData.tableProps)
261+
.save(groupByConf.metaData.uploadTable, groupByConf.metaData.tableProps, partitionColumns = List.empty)
262262

263263
val kvDfReloaded = tableUtils
264264
.loadTable(groupByConf.metaData.uploadTable)

spark/src/main/scala/ai/chronon/spark/StagingQuery.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class StagingQuery(stagingQueryConf: api.StagingQuery, endPartition: String, tab
5151
}
5252
// the input table is not partitioned, usually for data testing or for kaggle demos
5353
if (stagingQueryConf.startPartition == null) {
54-
tableUtils.sql(stagingQueryConf.query).saveUnPartitioned(outputTable)
54+
tableUtils.sql(stagingQueryConf.query).save(outputTable, partitionColumns = List.empty)
5555
} else {
5656
val overrideStart = overrideStartPartition.getOrElse(stagingQueryConf.startPartition)
5757
val unfilledRanges =

spark/src/main/scala/ai/chronon/spark/TableUtils.scala

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import ai.chronon.spark.TableUtils.{
3030
TableCreationStatus
3131
}
3232
import ai.chronon.spark.format.CreationUtils.alterTablePropertiesSql
33-
import ai.chronon.spark.format.{DefaultFormatProvider, Format, FormatProvider}
33+
import ai.chronon.spark.format.{DefaultFormatProvider, FormatProvider}
3434
import org.apache.hadoop.hive.metastore.api.AlreadyExistsException
3535
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
3636
import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project}
@@ -111,7 +111,6 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
111111
private val aggregationParallelism: Int = sparkSession.conf.get("spark.chronon.group_by.parallelism", "1000").toInt
112112

113113
sparkSession.sparkContext.setLogLevel("ERROR")
114-
// converts String-s like "a=b/c=d" to Map("a" -> "b", "c" -> "d")
115114

116115
def preAggRepartition(df: DataFrame): DataFrame =
117116
if (df.rdd.getNumPartitions < aggregationParallelism) {
@@ -122,7 +121,7 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
122121

123122
def tableReachable(tableName: String): Boolean = {
124123
try {
125-
tableReadFormat(tableName).isDefined
124+
tableFormatProvider.readFormat(tableName).isDefined
126125
} catch {
127126
case ex: Exception =>
128127
logger.info(s"""Couldn't reach $tableName. Error: ${ex.getMessage.red}
@@ -137,12 +136,6 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
137136
sparkSession.read.load(DataPointer.from(tableName, sparkSession))
138137
}
139138

140-
def isPartitioned(tableName: String): Boolean = {
141-
// TODO: use proper way to detect if a table is partitioned or not
142-
val schema = getSchemaFromTable(tableName)
143-
schema.fieldNames.contains(partitionColumn)
144-
}
145-
146139
// Needs provider
147140
def createDatabase(database: String): Boolean = {
148141
try {
@@ -159,17 +152,16 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
159152
}
160153
}
161154

162-
def tableReadFormat(tableName: String): Option[Format] = tableFormatProvider.readFormat(tableName)
163-
164-
// Needs provider
165155
// return all specified partition columns in a table in format of Map[partitionName, PartitionValue]
166156
def allPartitions(tableName: String, partitionColumnsFilter: Seq[String] = Seq.empty): Seq[Map[String, String]] = {
167157

168158
if (!tableReachable(tableName)) return Seq.empty[Map[String, String]]
169159

170-
val format = tableReadFormat(tableName).getOrElse(
171-
throw new IllegalStateException(
172-
s"Could not determine read format of table ${tableName}. It is no longer reachable."))
160+
val format = tableFormatProvider
161+
.readFormat(tableName)
162+
.getOrElse(
163+
throw new IllegalStateException(
164+
s"Could not determine read format of table ${tableName}. It is no longer reachable."))
173165
val partitionSeq = format.partitions(tableName)(sparkSession)
174166

175167
if (partitionColumnsFilter.isEmpty) {
@@ -189,7 +181,8 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
189181
subPartitionsFilter: Map[String, String] = Map.empty,
190182
partitionColumnName: String = partitionColumn): Seq[String] = {
191183

192-
tableReadFormat(tableName)
184+
tableFormatProvider
185+
.readFormat(tableName)
193186
.map((format) => {
194187
val partitions = format.primaryPartitions(tableName, partitionColumnName, subPartitionsFilter)(sparkSession)
195188

@@ -385,23 +378,6 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
385378
}
386379
}
387380

388-
// Needs provider
389-
def insertUnPartitioned(df: DataFrame,
390-
tableName: String,
391-
tableProperties: Map[String, String] = null,
392-
saveMode: SaveMode = SaveMode.Overwrite,
393-
fileFormat: String = "PARQUET"): Unit = {
394-
395-
val creationStatus = createTable(df, tableName, Seq.empty[String], tableProperties, fileFormat)
396-
397-
creationStatus match {
398-
case TableUtils.TableCreatedWithoutInitialData | TableUtils.TableAlreadyExists =>
399-
repartitionAndWrite(df, tableName, saveMode, None, partitionColumns = Seq.empty)
400-
case TableUtils.TableCreatedWithInitialData =>
401-
}
402-
403-
}
404-
405381
def columnSizeEstimator(dataType: DataType): Long = {
406382
dataType match {
407383
// TODO: improve upon this very basic estimate approach

spark/src/main/scala/ai/chronon/spark/stats/CompareJob.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,13 @@ class CompareJob(
7979
logger.info("Saving comparison output..")
8080
logger.info(
8181
s"Comparison schema ${compareDf.schema.fields.map(sb => (sb.name, sb.dataType)).toMap.mkString("\n - ")}")
82-
compareDf.saveUnPartitioned(comparisonTableName, tableProps)
82+
compareDf.save(comparisonTableName, tableProps, partitionColumns = List.empty)
8383

8484
// Save the metrics table
8585
logger.info("Saving metrics output..")
8686
val metricsDf = metricsTimedKvRdd.toFlatDf
8787
logger.info(s"Metrics schema ${metricsDf.schema.fields.map(sb => (sb.name, sb.dataType)).toMap.mkString("\n - ")}")
88-
metricsDf.saveUnPartitioned(metricsTableName, tableProps)
88+
metricsDf.save(metricsTableName, tableProps, partitionColumns = List.empty)
8989

9090
logger.info("Printing basic comparison results..")
9191
logger.info("(Note: This is just an estimation and not a detailed analysis of results)")

spark/src/main/scala/ai/chronon/spark/utils/PartitionRunner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class PartitionRunner[T](verb: String,
119119
if (outputDf.columns.contains(tu.partitionColumn)) {
120120
outputDf.save(outputTable)
121121
} else {
122-
outputDf.saveUnPartitioned(outputTable)
122+
outputDf.save(outputTable, partitionColumns = List.empty)
123123
}
124124
println(s"""
125125
|Finished computing range ${i + 1}/$n

spark/src/test/scala/ai/chronon/spark/test/TableUtilsFormatTest.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class TableUtilsFormatTest extends AnyFlatSpec {
147147
it should "return empty read format if table doesn't exist" in {
148148
val dbName = s"db_${System.currentTimeMillis()}"
149149
val tableName = s"$dbName.test_table_nonexistent_$format"
150-
assertTrue(tableUtils.tableReadFormat(tableName).isEmpty)
150+
assertTrue(tableUtils.tableFormatProvider.readFormat(tableName).isEmpty)
151151
assertFalse(tableUtils.tableReachable(tableName))
152152
}
153153
}
@@ -188,7 +188,7 @@ object TableUtilsFormatTest {
188188
tableUtils.insertPartitions(df2, tableName, autoExpand = true)
189189

190190
// check that we wrote out a table in the right format
191-
val readTableFormat = tableUtils.tableReadFormat(tableName).get.toString
191+
val readTableFormat = tableUtils.tableFormatProvider.readFormat(tableName).get.toString
192192
assertTrue(s"Mismatch in table format: $readTableFormat; expected: $format", readTableFormat.toLowerCase == format)
193193

194194
// check we have all the partitions written

0 commit comments

Comments
 (0)