|
| 1 | +package ai.chronon.integrations.cloud_gcp |
| 2 | + |
| 3 | +import ai.chronon.spark.SparkSessionBuilder |
| 4 | +import org.apache.spark.sql.Row |
| 5 | +import org.apache.spark.sql.SaveMode |
| 6 | +import org.apache.spark.sql.SparkSession |
| 7 | +import org.apache.spark.sql.functions._ |
| 8 | +import org.apache.spark.sql.types.StringType |
| 9 | +import org.apache.spark.sql.types.StructField |
| 10 | +import org.apache.spark.sql.types.StructType |
| 11 | +import org.junit.Assert.assertEquals |
| 12 | +import org.scalatest.funsuite.AnyFunSuite |
| 13 | + |
| 14 | +import java.nio.file.Files |
| 15 | + |
| 16 | +class GCSFormatTest extends AnyFunSuite { |
| 17 | + |
| 18 | + lazy val spark: SparkSession = SparkSessionBuilder.build( |
| 19 | + "BigQuerySparkTest", |
| 20 | + local = true |
| 21 | + ) |
| 22 | + |
| 23 | + test("partitions method should return correctly parsed partitions as maps") { |
| 24 | + |
| 25 | + val testData = List( |
| 26 | + ("20241223", "b", "c"), |
| 27 | + ("20241224", "e", "f"), |
| 28 | + ("20241225", "h", "i") |
| 29 | + ) |
| 30 | + |
| 31 | + val dir = Files.createTempDirectory("spark-test-output").toFile |
| 32 | + dir.deleteOnExit() |
| 33 | + |
| 34 | + val df = spark.createDataFrame(testData).toDF("ds", "first", "second") |
| 35 | + df.write.partitionBy("ds").format("parquet").mode(SaveMode.Overwrite).save(dir.getAbsolutePath) |
| 36 | + val gcsFormat = GCS(project = "test-project", sourceUri = dir.getAbsolutePath, fileFormat = "parquet") |
| 37 | + val partitions = gcsFormat.partitions("unused_table")(spark) |
| 38 | + |
| 39 | + assertEquals(Set(Map("ds" -> "20241223"), Map("ds" -> "20241224"), Map("ds" -> "20241225")), partitions.toSet) |
| 40 | + |
| 41 | + } |
| 42 | + |
| 43 | + test("partitions method should handle empty partitions gracefully") { |
| 44 | + |
| 45 | + val testData = List( |
| 46 | + ("20241223", "b", "c"), |
| 47 | + ("20241224", "e", "f"), |
| 48 | + ("20241225", "h", "i") |
| 49 | + ) |
| 50 | + |
| 51 | + val dir = Files.createTempDirectory("spark-test-output").toFile |
| 52 | + dir.deleteOnExit() |
| 53 | + |
| 54 | + val df = spark.createDataFrame(testData).toDF("ds", "first", "second") |
| 55 | + df.write.format("parquet").mode(SaveMode.Overwrite).save(dir.getAbsolutePath) |
| 56 | + val gcsFormat = GCS(project = "test-project", sourceUri = dir.getAbsolutePath, fileFormat = "parquet") |
| 57 | + val partitions = gcsFormat.partitions("unused_table")(spark) |
| 58 | + |
| 59 | + assertEquals(Set.empty, partitions.toSet) |
| 60 | + |
| 61 | + } |
| 62 | + |
| 63 | + test("partitions method should handle date types") { |
| 64 | + val testData = List( |
| 65 | + Row("2024-12-23", "b", "c"), |
| 66 | + Row("2024-12-24", "e", "f"), |
| 67 | + Row("2024-12-25", "h", "i") |
| 68 | + ) |
| 69 | + |
| 70 | + val dir = Files.createTempDirectory("spark-test-output").toFile |
| 71 | + dir.deleteOnExit() |
| 72 | + |
| 73 | + val schema = StructType( |
| 74 | + Seq( |
| 75 | + StructField("ds", StringType, nullable = true), |
| 76 | + StructField("first", StringType, nullable = true), |
| 77 | + StructField("second", StringType, nullable = true) |
| 78 | + )) |
| 79 | + |
| 80 | + val df = |
| 81 | + spark |
| 82 | + .createDataFrame(spark.sparkContext.parallelize(testData), schema) |
| 83 | + .toDF("ds", "first", "second") |
| 84 | + .select(to_date(col("ds"), "yyyy-MM-dd").as("ds"), col("first"), col("second")) |
| 85 | + df.write.format("parquet").partitionBy("ds").mode(SaveMode.Overwrite).save(dir.getAbsolutePath) |
| 86 | + val gcsFormat = GCS(project = "test-project", sourceUri = dir.getAbsolutePath, fileFormat = "parquet") |
| 87 | + val partitions = gcsFormat.partitions("unused_table")(spark) |
| 88 | + |
| 89 | + assertEquals(Set(Map("ds" -> "2024-12-23"), Map("ds" -> "2024-12-24"), Map("ds" -> "2024-12-25")), partitions.toSet) |
| 90 | + |
| 91 | + } |
| 92 | +} |
0 commit comments