Skip to content

Commit 2b6579c

Browse files
committed
some additional changes
1 parent fbfd09c commit 2b6579c

File tree

3 files changed

+27
-48
lines changed

3 files changed

+27
-48
lines changed

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigQueryFormat.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ import ai.chronon.spark.FormatProvider
55
import ai.chronon.spark.Hive
66
import com.google.cloud.bigquery.BigQueryOptions
77
import com.google.cloud.bigquery.ExternalTableDefinition
8+
import com.google.cloud.bigquery.FormatOptions
89
import com.google.cloud.bigquery.StandardTableDefinition
910
import com.google.cloud.bigquery.connector.common.BigQueryUtil
10-
import com.google.cloud.bigquery.{TableId => BTableId}
1111
import com.google.cloud.spark.bigquery.repackaged.com.google.cloud.bigquery.TableId
1212
import org.apache.spark.sql.SparkSession
1313

@@ -19,8 +19,8 @@ case class GCPFormatProvider(sparkSession: SparkSession) extends FormatProvider
1919

2020
override def resolveTableName(tableName: String): String = {
2121
format(tableName: String) match {
22-
case BQuery(_) => tableName
2322
case GCS(_, uri, _) => uri
23+
case _ => tableName
2424
}
2525
}
2626
override def readFormat(tableName: String): Format = format(tableName)
@@ -39,15 +39,16 @@ case class GCPFormatProvider(sparkSession: SparkSession) extends FormatProvider
3939
// Active project in the gcloud CLI configuration.
4040
// No default project: An error will occur if no project ID is available.
4141

42-
val unshadedTI: BTableId =
43-
BTableId.of(bqOptions.getProjectId, btTableIdentifier.getDataset, btTableIdentifier.getTable)
44-
45-
val tableOpt = Option(bigQueryClient.getTable(unshadedTI))
42+
val tableOpt = Option(bigQueryClient.getTable(btTableIdentifier.getDataset, btTableIdentifier.getTable))
4643
tableOpt match {
4744
case Some(table) => {
4845
if (table.getDefinition.isInstanceOf[ExternalTableDefinition]) {
49-
import com.google.cloud.bigquery.FormatOptions
50-
val uris = table.getDefinition.asInstanceOf[ExternalTableDefinition].getSourceUris.asScala.toList
46+
val uris = table.getDefinition
47+
.asInstanceOf[ExternalTableDefinition]
48+
.getSourceUris
49+
.asScala
50+
.toList
51+
.map((uri) => uri.stripSuffix("/*") + "/")
5152

5253
assert(uris.length == 1, s"External table ${tableName} can be backed by only one URI.")
5354

@@ -56,8 +57,9 @@ case class GCPFormatProvider(sparkSession: SparkSession) extends FormatProvider
5657
.getFormatOptions
5758
.asInstanceOf[FormatOptions]
5859
.getType
59-
GCS(unshadedTI.getProject, uris.head, formatStr)
60-
} else if (table.getDefinition.isInstanceOf[StandardTableDefinition]) BQuery(unshadedTI.getProject)
60+
61+
GCS(table.getTableId.getProject, uris.head, formatStr)
62+
} else if (table.getDefinition.isInstanceOf[StandardTableDefinition]) BQuery(table.getTableId.getProject)
6163
else throw new IllegalStateException(s"Cannot support table of type: ${table.getDefinition}")
6264
}
6365
case None => Hive

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/GCSFormat.scala

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ import ai.chronon.spark.Format
44
import org.apache.spark.sql.SparkSession
55
import org.apache.spark.sql.execution.FileSourceScanExec
66
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
7-
import org.apache.spark.sql.functions.col
8-
import org.apache.spark.sql.functions.explode
9-
import org.apache.spark.sql.functions.url_decode
7+
import org.apache.spark.sql.Encoders
8+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
9+
import org.apache.spark.sql.Row
1010

1111
case class GCS(project: String, sourceUri: String, format: String) extends Format {
1212

@@ -17,37 +17,6 @@ case class GCS(project: String, sourceUri: String, format: String) extends Forma
1717
super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter)
1818

1919
override def partitions(tableName: String)(implicit sparkSession: SparkSession): Seq[Map[String, String]] = {
20-
import sparkSession.implicits._
21-
22-
val tableIdentifier = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName)
23-
val table = tableIdentifier.table
24-
val database = tableIdentifier.database.getOrElse(throw new IllegalArgumentException("database required!"))
25-
26-
// See: https://github.com/GoogleCloudDataproc/spark-bigquery-connector/issues/434#issuecomment-886156191
27-
// and: https://cloud.google.com/bigquery/docs/information-schema-intro#limitations
28-
sparkSession.conf.set("viewsEnabled", "true")
29-
sparkSession.conf.set("materializationDataset", database)
30-
31-
// First, grab the URI location from BQ
32-
val uriSQL =
33-
s"""
34-
|select JSON_EXTRACT_STRING_ARRAY(option_value) as option_values from `${project}.${database}.INFORMATION_SCHEMA.TABLE_OPTIONS`
35-
|WHERE table_name = '${table}' and option_name = 'uris'
36-
|
37-
|""".stripMargin
38-
39-
val uris = sparkSession.read
40-
.format("bigquery")
41-
.option("project", project)
42-
.option("query", uriSQL)
43-
.load()
44-
.select(explode(col("option_values")).as("option_value"))
45-
.select(url_decode(col("option_value")))
46-
.as[String]
47-
.collect
48-
.toList
49-
50-
assert(uris.length == 1, s"External table ${tableName} can be backed by only one URI.")
5120

5221
/**
5322
* Given:
@@ -70,7 +39,7 @@ case class GCS(project: String, sourceUri: String, format: String) extends Forma
7039
*
7140
*/
7241
val partitionSpec = sparkSession.read
73-
.parquet(uris: _*)
42+
.parquet(sourceUri)
7443
.queryExecution
7544
.sparkPlan
7645
.asInstanceOf[FileSourceScanExec]
@@ -82,16 +51,23 @@ case class GCS(project: String, sourceUri: String, format: String) extends Forma
8251
val partitionColumns = partitionSpec.partitionColumns
8352
val partitions = partitionSpec.partitions.map(_.values)
8453

85-
partitions
54+
val deserializer =
55+
Encoders.row(partitionColumns).asInstanceOf[ExpressionEncoder[Row]].resolveAndBind().createDeserializer()
56+
57+
val roundTripped = sparkSession
58+
.createDataFrame(sparkSession.sparkContext.parallelize(partitions.map(deserializer)), partitionColumns)
59+
.collect
60+
.toList
61+
62+
roundTripped
8663
.map((part) =>
8764
partitionColumns.fields.toList.zipWithIndex.map {
8865
case (field, idx) => {
8966
val fieldName = field.name
90-
val fieldValue = part.get(idx, field.dataType)
67+
val fieldValue = part.get(idx)
9168
fieldName -> fieldValue.toString // Just going to cast this as a string.
9269
}
9370
}.toMap)
94-
.toList
9571
}
9672

9773
def createTableTypeString: String = throw new UnsupportedOperationException("GCS does not support create table")

spark/src/main/scala/ai/chronon/spark/Format.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ trait FormatProvider extends Serializable {
6161
def sparkSession: SparkSession
6262
def readFormat(tableName: String): Format
6363
def writeFormat(tableName: String): Format
64+
6465
def resolveTableName(tableName: String) = tableName
6566
}
6667

0 commit comments

Comments
 (0)