1
1
package ai .chronon .spark .test
2
2
3
3
import ai .chronon .api .{DoubleType , IntType , LongType , StringType , StructField , StructType }
4
+ import ai .chronon .spark .SparkSessionBuilder .FormatTestEnvVar
4
5
import ai .chronon .spark .test .TestUtils .makeDf
5
- import ai .chronon .spark .{DeltaLake , Format , Hive , IncompatibleSchemaException , SparkSessionBuilder , TableUtils }
6
- import org .apache .spark .SparkContext
6
+ import ai .chronon .spark .{IncompatibleSchemaException , SparkSessionBuilder , TableUtils }
7
7
import org .apache .spark .sql .functions .col
8
8
import org .apache .spark .sql .{AnalysisException , DataFrame , Row , SparkSession }
9
9
import org .junit .Assert .{assertEquals , assertTrue }
10
- import org .scalatest .BeforeAndAfterEach
11
10
import org .scalatest .funsuite .AnyFunSuite
12
- import org .scalatest .prop .TableDrivenPropertyChecks ._
13
11
14
12
import scala .util .Try
15
13
16
- class TestTableUtils (sparkSession : SparkSession , format : Format ) extends TableUtils (sparkSession) {
17
- override def getWriteFormat : Format = format
18
- }
19
-
20
- class TableUtilsFormatTest extends AnyFunSuite with BeforeAndAfterEach {
14
+ class TableUtilsFormatTest extends AnyFunSuite {
21
15
22
16
import TableUtilsFormatTest ._
23
17
24
- val deltaConfigMap = Map (
25
- " spark.sql.extensions" -> " io.delta.sql.DeltaSparkSessionExtension" ,
26
- " spark.sql.catalog.spark_catalog" -> " org.apache.spark.sql.delta.catalog.DeltaCatalog" ,
27
- )
28
- val hiveConfigMap = Map .empty[String , String ]
29
-
30
- // TODO: include Hive + Iceberg support in these tests
31
- val formats =
32
- Table (
33
- (" format" , " configs" ),
34
- (DeltaLake , deltaConfigMap),
35
- (Hive , hiveConfigMap)
36
- )
37
-
38
- private def withSparkSession [T ](configs : Map [String , String ])(test : SparkSession => T ): T = {
39
- val spark = SparkSessionBuilder .build(" TableUtilsFormatTest" , local = true , additionalConfig = Some (configs))
40
- val sc = SparkContext .getOrCreate()
18
+ // Read the format we want this instantiation of the test to run via environment vars
19
+ val format : String = sys.env.getOrElse(FormatTestEnvVar , " hive" )
20
+
21
+ private def withSparkSession [T ](test : SparkSession => T ): T = {
22
+ val spark = SparkSessionBuilder .build(" TableUtilsFormatTest" , local = true )
41
23
try {
42
24
test(spark)
43
25
} finally {
44
- configs.keys.foreach(cfg => sc.getConf.remove(cfg))
45
26
spark.stop()
46
27
}
47
28
}
48
29
49
30
ignore(" test insertion of partitioned data and adding of columns" ) {
50
- forAll(formats) { (format, configs) =>
51
- withSparkSession(configs) { spark =>
52
- val tableUtils = new TestTableUtils (spark, format)
53
-
54
- val tableName = s " db.test_table_1_ $format"
55
- spark.sql(" CREATE DATABASE IF NOT EXISTS db" )
56
- val columns1 = Array (
57
- StructField (" long_field" , LongType ),
58
- StructField (" int_field" , IntType ),
59
- StructField (" string_field" , StringType )
31
+ withSparkSession { spark =>
32
+ val tableUtils = TableUtils (spark)
33
+
34
+ val tableName = s " db.test_table_1_ $format"
35
+ spark.sql(" CREATE DATABASE IF NOT EXISTS db" )
36
+ val columns1 = Array (
37
+ StructField (" long_field" , LongType ),
38
+ StructField (" int_field" , IntType ),
39
+ StructField (" string_field" , StringType )
40
+ )
41
+ val df1 = makeDf(
42
+ spark,
43
+ StructType (
44
+ tableName,
45
+ columns1 :+ StructField (" ds" , StringType )
46
+ ),
47
+ List (
48
+ Row (1L , 2 , " 3" , " 2022-10-01" )
60
49
)
61
- val df1 = makeDf(
62
- spark,
63
- StructType (
64
- tableName,
65
- columns1 :+ StructField (" ds" , StringType )
66
- ),
67
- List (
68
- Row (1L , 2 , " 3" , " 2022-10-01" )
69
- )
50
+ )
51
+
52
+ val df2 = makeDf(
53
+ spark,
54
+ StructType (
55
+ tableName,
56
+ columns1
57
+ :+ StructField (" double_field" , DoubleType )
58
+ :+ StructField (" ds" , StringType )
59
+ ),
60
+ List (
61
+ Row (4L , 5 , " 6" , 7.0 , " 2022-10-02" )
70
62
)
71
-
72
- val df2 = makeDf(
73
- spark,
74
- StructType (
75
- tableName,
76
- columns1
77
- :+ StructField (" double_field" , DoubleType )
78
- :+ StructField (" ds" , StringType )
79
- ),
80
- List (
81
- Row (4L , 5 , " 6" , 7.0 , " 2022-10-02" )
82
- )
83
- )
84
- testInsertPartitions(spark, tableUtils, tableName, format, df1, df2, ds1 = " 2022-10-01" , ds2 = " 2022-10-02" )
85
- }
63
+ )
64
+ testInsertPartitions(spark, tableUtils, tableName, format, df1, df2, ds1 = " 2022-10-01" , ds2 = " 2022-10-02" )
86
65
}
87
66
}
88
67
89
68
ignore(" test insertion of partitioned data and removal of columns" ) {
90
- forAll(formats) { (format, configs) =>
91
- withSparkSession(configs) { spark =>
92
- val tableUtils = TableUtils (spark)
93
- val tableName = s " db.test_table_2_ $format"
94
- spark.sql(" CREATE DATABASE IF NOT EXISTS db" )
95
- val columns1 = Array (
96
- StructField (" long_field" , LongType ),
97
- StructField (" int_field" , IntType ),
98
- StructField (" string_field" , StringType )
69
+ withSparkSession { spark =>
70
+ val tableUtils = TableUtils (spark)
71
+ val tableName = s " db.test_table_2_ $format"
72
+ spark.sql(" CREATE DATABASE IF NOT EXISTS db" )
73
+ val columns1 = Array (
74
+ StructField (" long_field" , LongType ),
75
+ StructField (" int_field" , IntType ),
76
+ StructField (" string_field" , StringType )
77
+ )
78
+ val df1 = makeDf(
79
+ spark,
80
+ StructType (
81
+ tableName,
82
+ columns1
83
+ :+ StructField (" double_field" , DoubleType )
84
+ :+ StructField (" ds" , StringType )
85
+ ),
86
+ List (
87
+ Row (1L , 2 , " 3" , 4.0 , " 2022-10-01" )
99
88
)
100
- val df1 = makeDf(
101
- spark,
102
- StructType (
103
- tableName,
104
- columns1
105
- :+ StructField (" double_field" , DoubleType )
106
- :+ StructField (" ds" , StringType )
107
- ),
108
- List (
109
- Row (1L , 2 , " 3" , 4.0 , " 2022-10-01" )
110
- )
89
+ )
90
+
91
+ val df2 = makeDf(
92
+ spark,
93
+ StructType (
94
+ tableName,
95
+ columns1 :+ StructField (" ds" , StringType )
96
+ ),
97
+ List (
98
+ Row (5L , 6 , " 7" , " 2022-10-02" )
111
99
)
112
-
113
- val df2 = makeDf(
114
- spark,
115
- StructType (
116
- tableName,
117
- columns1 :+ StructField (" ds" , StringType )
118
- ),
119
- List (
120
- Row (5L , 6 , " 7" , " 2022-10-02" )
121
- )
122
- )
123
- testInsertPartitions(spark, tableUtils, tableName, format, df1, df2, ds1 = " 2022-10-01" , ds2 = " 2022-10-02" )
124
- }
100
+ )
101
+ testInsertPartitions(spark, tableUtils, tableName, format, df1, df2, ds1 = " 2022-10-01" , ds2 = " 2022-10-02" )
125
102
}
126
103
}
127
104
128
105
ignore(" test insertion of partitioned data and modification of columns" ) {
129
- forAll(formats) { (format, configs) =>
130
- withSparkSession(configs) { spark =>
131
- val tableUtils = TableUtils (spark)
132
-
133
- val tableName = s " db.test_table_3_ $format"
134
- spark.sql(" CREATE DATABASE IF NOT EXISTS db" )
135
- val columns1 = Array (
136
- StructField (" long_field" , LongType ),
137
- StructField (" int_field" , IntType )
106
+ withSparkSession { spark =>
107
+ val tableUtils = TableUtils (spark)
108
+
109
+ val tableName = s " db.test_table_3_ $format"
110
+ spark.sql(" CREATE DATABASE IF NOT EXISTS db" )
111
+ val columns1 = Array (
112
+ StructField (" long_field" , LongType ),
113
+ StructField (" int_field" , IntType )
114
+ )
115
+ val df1 = makeDf(
116
+ spark,
117
+ StructType (
118
+ tableName,
119
+ columns1
120
+ :+ StructField (" string_field" , StringType )
121
+ :+ StructField (" ds" , StringType )
122
+ ),
123
+ List (
124
+ Row (1L , 2 , " 3" , " 2022-10-01" )
138
125
)
139
- val df1 = makeDf(
140
- spark,
141
- StructType (
142
- tableName,
143
- columns1
144
- :+ StructField (" string_field" , StringType )
145
- :+ StructField (" ds" , StringType )
146
- ),
147
- List (
148
- Row (1L , 2 , " 3" , " 2022-10-01" )
149
- )
126
+ )
127
+
128
+ val df2 = makeDf(
129
+ spark,
130
+ StructType (
131
+ tableName,
132
+ columns1
133
+ :+ StructField (" string_field" , DoubleType ) // modified column data type
134
+ :+ StructField (" ds" , StringType )
135
+ ),
136
+ List (
137
+ Row (1L , 2 , 3.0 , " 2022-10-02" )
150
138
)
139
+ )
151
140
152
- val df2 = makeDf(
153
- spark,
154
- StructType (
155
- tableName,
156
- columns1
157
- :+ StructField (" string_field" , DoubleType ) // modified column data type
158
- :+ StructField (" ds" , StringType )
159
- ),
160
- List (
161
- Row (1L , 2 , 3.0 , " 2022-10-02" )
162
- )
163
- )
164
-
165
- testInsertPartitions(spark, tableUtils, tableName, format, df1, df2, ds1 = " 2022-10-01" , ds2 = " 2022-10-02" )
166
- }
141
+ testInsertPartitions(spark, tableUtils, tableName, format, df1, df2, ds1 = " 2022-10-01" , ds2 = " 2022-10-02" )
167
142
}
168
143
}
169
144
}
@@ -172,7 +147,7 @@ object TableUtilsFormatTest {
172
147
private def testInsertPartitions (spark : SparkSession ,
173
148
tableUtils : TableUtils ,
174
149
tableName : String ,
175
- format : Format ,
150
+ format : String ,
176
151
df1 : DataFrame ,
177
152
df2 : DataFrame ,
178
153
ds1 : String ,
@@ -204,7 +179,8 @@ object TableUtilsFormatTest {
204
179
tableUtils.insertPartitions(df2, tableName, autoExpand = true )
205
180
206
181
// check that we wrote out a table in the right format
207
- assertTrue(tableUtils.tableFormat(tableName) == format)
182
+ val readTableFormat = tableUtils.tableFormat(tableName).toString
183
+ assertTrue(s " Mismatch in table format: $readTableFormat; expected: $format" , readTableFormat.toLowerCase == format)
208
184
209
185
// check we have all the partitions written
210
186
val returnedPartitions = tableUtils.partitions(tableName)
0 commit comments