|
| 1 | +package ai.chronon.integrations.cloud_gcp |
| 2 | + |
| 3 | +import ai.chronon.spark.Format |
| 4 | +import org.apache.spark.sql.SparkSession |
| 5 | +import org.apache.spark.sql.execution.FileSourceScanExec |
| 6 | +import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex |
| 7 | +import org.apache.spark.sql.functions.{col, explode, url_decode} |
| 8 | + |
| 9 | +case class GCS(project: String) extends Format { |
| 10 | + |
| 11 | + override def primaryPartitions(tableName: String, partitionColumn: String, subPartitionsFilter: Map[String, String])( |
| 12 | + implicit sparkSession: SparkSession): Seq[String] = |
| 13 | + super.primaryPartitions(tableName, partitionColumn, subPartitionsFilter) |
| 14 | + |
| 15 | + override def partitions(tableName: String)(implicit sparkSession: SparkSession): Seq[Map[String, String]] = { |
| 16 | + import sparkSession.implicits._ |
| 17 | + |
| 18 | + val tableIdentifier = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) |
| 19 | + val table = tableIdentifier.table |
| 20 | + val database = tableIdentifier.database.getOrElse(throw new IllegalArgumentException("database required!")) |
| 21 | + |
| 22 | + // See: https://github.com/GoogleCloudDataproc/spark-bigquery-connector/issues/434#issuecomment-886156191 |
| 23 | + // and: https://cloud.google.com/bigquery/docs/information-schema-intro#limitations |
| 24 | + sparkSession.conf.set("viewsEnabled", "true") |
| 25 | + sparkSession.conf.set("materializationDataset", database) |
| 26 | + |
| 27 | + // First, grab the URI location from BQ |
| 28 | + val uriSQL = |
| 29 | + s""" |
| 30 | + |select JSON_EXTRACT_STRING_ARRAY(option_value) as option_values from `${project}.${database}.INFORMATION_SCHEMA.TABLE_OPTIONS` |
| 31 | + |WHERE table_name = '${table}' and option_name = 'uris' |
| 32 | + | |
| 33 | + |""".stripMargin |
| 34 | + |
| 35 | + val uris = sparkSession.read |
| 36 | + .format("bigquery") |
| 37 | + .option("project", project) |
| 38 | + .option("query", uriSQL) |
| 39 | + .load() |
| 40 | + .select(explode(col("option_values")).as("option_value")) |
| 41 | + .select(url_decode(col("option_value"))) |
| 42 | + .as[String] |
| 43 | + .collect |
| 44 | + .toList |
| 45 | + |
| 46 | + assert(uris.length == 1, s"External table ${tableName} can be backed by only one URI.") |
| 47 | + |
| 48 | + /** |
| 49 | + * Given: |
| 50 | + * hdfs://<host>:<port>/ path/ to/ partition/ a=1/ b=hello/ c=3.14 |
| 51 | + * hdfs://<host>:<port>/ path/ to/ partition/ a=2/ b=world/ c=6.28 |
| 52 | + * |
| 53 | + * it returns: |
| 54 | + * PartitionSpec( |
| 55 | + * partitionColumns = StructType( |
| 56 | + * StructField(name = "a", dataType = IntegerType, nullable = true), |
| 57 | + * StructField(name = "b", dataType = StringType, nullable = true), |
| 58 | + * StructField(name = "c", dataType = DoubleType, nullable = true)), |
| 59 | + * partitions = Seq( |
| 60 | + * Partition( |
| 61 | + * values = Row(1, "hello", 3.14), |
| 62 | + * path = "hdfs://<host>:<port>/ path/ to/ partition/ a=1/ b=hello/ c=3.14"), |
| 63 | + * Partition( |
| 64 | + * values = Row(2, "world", 6.28), |
| 65 | + * path = "hdfs://<host>:<port>/ path/ to/ partition/ a=2/ b=world/ c=6.28"))) |
| 66 | + * |
| 67 | + */ |
| 68 | + val partitionSpec = sparkSession.read |
| 69 | + .parquet(uris: _*) |
| 70 | + .queryExecution |
| 71 | + .sparkPlan |
| 72 | + .asInstanceOf[FileSourceScanExec] |
| 73 | + .relation |
| 74 | + .location |
| 75 | + .asInstanceOf[PartitioningAwareFileIndex] // Punch through the layers!! |
| 76 | + .partitionSpec |
| 77 | + |
| 78 | + val partitionColumns = partitionSpec.partitionColumns |
| 79 | + val partitions = partitionSpec.partitions.map(_.values) |
| 80 | + |
| 81 | + partitions |
| 82 | + .map((part) => |
| 83 | + partitionColumns.fields.toList.zipWithIndex.map { |
| 84 | + case (field, idx) => { |
| 85 | + val fieldName = field.name |
| 86 | + val fieldValue = part.get(idx, field.dataType) |
| 87 | + fieldName -> fieldValue.toString // Just going to cast this as a string. |
| 88 | + } |
| 89 | + }.toMap) |
| 90 | + .toList |
| 91 | + } |
| 92 | + |
| 93 | + def createTableTypeString: String = throw new UnsupportedOperationException("GCS does not support create table") |
| 94 | + def fileFormatString(format: String): String = "" |
| 95 | + |
| 96 | + override def supportSubPartitionsFilter: Boolean = true |
| 97 | + |
| 98 | +} |
0 commit comments