Skip to content

Commit 9c5f949

Browse files
authored
Merge branch 'main' into vz--copy_planner_2
2 parents 38623a9 + 742958d commit 9c5f949

File tree

10 files changed

+203
-171
lines changed

10 files changed

+203
-171
lines changed

api/python/ai/chronon/repo/gcp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(self, args):
5858
if args["mode"] == "fetch"
5959
else gcp_jar_path
6060
)
61-
61+
6262
self._args = args
6363

6464
super().__init__(args, os.path.expanduser(jar_path))
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from group_bys.gcp.purchases import v1_dev, v1_test
2+
3+
from ai.chronon.api.ttypes import EventSource, Source
4+
from ai.chronon.join import Join, JoinPart
5+
from ai.chronon.query import Query, selects
6+
7+
"""
8+
This is the "left side" of the join that will comprise our training set. It is responsible for providing the primary keys
9+
and timestamps for which features will be computed.
10+
"""
11+
source = Source(
12+
events=EventSource(
13+
table="data.checkouts",
14+
query=Query(
15+
selects=selects(
16+
"user_id"
17+
), # The primary key used to join various GroupBys together
18+
time_column="ts",
19+
), # The event time used to compute feature values as-of
20+
)
21+
)
22+
23+
v1_test = Join(
24+
left=source,
25+
right_parts=[
26+
JoinPart(group_by=v1_test)
27+
],
28+
)
29+
30+
v1_dev = Join(
31+
left=source,
32+
right_parts=[
33+
JoinPart(group_by=v1_dev)
34+
],
35+
)

api/src/main/scala/ai/chronon/api/DataType.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,12 @@ case class StructField(name: String, fieldType: DataType)
167167
case object DateType extends DataType
168168

169169
// maps to java.sql.Timestamp
170+
// maps to java.time.Instant if DATETIME_JAVA8API_ENABLED is true for java8. See spark doc:
171+
// ```
172+
// If the configuration property is set to true, java.time.Instant and java.time.LocalDate classes of Java
173+
// 8 API are used as external types for Catalyst's TimestampType and DateType. If it is set to false,
174+
// java.sql.Timestamp and java.sql.Date are used for the same purpose.
175+
// ```
170176
case object TimestampType extends DataType
171177

172178
// maps to Array[Any]

cloud_gcp/src/test/scala/ai/chronon/integrations/cloud_gcp/BigQueryCatalogTest.scala

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class BigQueryCatalogTest extends AnyFlatSpec with MockitoSugar {
3434
"spark.chronon.partition.column" -> "ds",
3535
"spark.hadoop.fs.gs.impl" -> classOf[GoogleHadoopFileSystem].getName,
3636
"spark.hadoop.fs.AbstractFileSystem.gs.impl" -> classOf[GoogleHadoopFS].getName,
37-
"spark.sql.catalogImplementation" -> "in-memory",
37+
"spark.sql.catalogImplementation" -> "in-memory"
3838

3939
// Uncomment to test
4040
// "spark.sql.defaultCatalog" -> "default_iceberg",
@@ -116,6 +116,19 @@ class BigQueryCatalogTest extends AnyFlatSpec with MockitoSugar {
116116
SparkBigQueryUtil.sparkDateToBigQuery(nonJava8Date)
117117
}
118118

119+
it should "bigquery connector converts spark timestamp regardless of setting" in {
120+
val input = spark.createDataFrame(Seq((1, "2025-04-28 12:30:45"))).toDF("id", "ts")
121+
spark.conf.set(SQLConf.DATETIME_JAVA8API_ENABLED.key, true)
122+
val java8Timestamp = input.select(col("id"), col("ts").cast("timestamp")).collect.take(1).head.get(1)
123+
assert(java8Timestamp.isInstanceOf[java.time.Instant])
124+
SparkBigQueryUtil.sparkTimestampToBigQuery(java8Timestamp)
125+
126+
spark.conf.set(SQLConf.DATETIME_JAVA8API_ENABLED.key, false)
127+
val nonJava8Timestamp = input.select(col("id"), col("ts").cast("timestamp")).collect.take(1).head.get(1)
128+
assert(nonJava8Timestamp.isInstanceOf[java.sql.Timestamp])
129+
SparkBigQueryUtil.sparkTimestampToBigQuery(nonJava8Timestamp)
130+
}
131+
119132
it should "integration testing bigquery native table" ignore {
120133
val nativeTable = "data.checkouts"
121134
val table = tableUtils.loadTable(nativeTable)
@@ -141,9 +154,8 @@ class BigQueryCatalogTest extends AnyFlatSpec with MockitoSugar {
141154

142155
val singleFilter = tableUtils.loadTable(iceberg, List("ds = '2023-11-30'"))
143156
val multiFilter = tableUtils.loadTable(iceberg, List("ds = '2023-11-30'", "ds = '2023-11-30'"))
144-
assertEquals(
145-
singleFilter.select("user_id", "ds").as[(String, String)].collect.toList,
146-
multiFilter.select("user_id", "ds").as[(String, String)].collect.toList)
157+
assertEquals(singleFilter.select("user_id", "ds").as[(String, String)].collect.toList,
158+
multiFilter.select("user_id", "ds").as[(String, String)].collect.toList)
147159
}
148160

149161
it should "integration testing formats" ignore {
@@ -180,37 +192,34 @@ class BigQueryCatalogTest extends AnyFlatSpec with MockitoSugar {
180192
assertTrue(dneFormat.isEmpty)
181193
}
182194

183-
184195
it should "integration testing bigquery partitions" ignore {
185196
// TODO(tchow): This test is ignored because it requires a running instance of the bigquery. Need to figure out stubbing locally.
186197
// to run, set `GOOGLE_APPLICATION_CREDENTIALS=<path_to_application_default_credentials.json>
187198
val externalPartitions = tableUtils.partitions("data.checkouts_parquet_partitioned")
188-
assertEquals(Seq("2023-11-30"), externalPartitions)
199+
assertEquals(Seq("2023-11-30"), externalPartitions)
189200
val nativePartitions = tableUtils.partitions("data.purchases")
190201
assertEquals(
191-
Set(20231118, 20231122, 20231125, 20231102, 20231123, 20231119, 20231130, 20231101, 20231117, 20231110, 20231108, 20231112, 20231115, 20231116, 20231113, 20231104, 20231103, 20231106, 20231121, 20231124, 20231128, 20231109, 20231127, 20231129, 20231126, 20231114, 20231107, 20231111, 20231120, 20231105).map(_.toString), nativePartitions.toSet)
202+
Set(20231118, 20231122, 20231125, 20231102, 20231123, 20231119, 20231130, 20231101, 20231117, 20231110, 20231108,
203+
20231112, 20231115, 20231116, 20231113, 20231104, 20231103, 20231106, 20231121, 20231124, 20231128, 20231109,
204+
20231127, 20231129, 20231126, 20231114, 20231107, 20231111, 20231120, 20231105).map(_.toString),
205+
nativePartitions.toSet
206+
)
192207

193208
val df = tableUtils.loadTable("`canary-443022.data`.purchases")
194209
df.show
195210

196-
tableUtils.insertPartitions(
197-
df,
198-
"data.tchow_test_iceberg",
199-
Map(
200-
"file_format" -> "PARQUET",
201-
"table_type" -> "iceberg"),
202-
List("ds"))
203-
211+
tableUtils.insertPartitions(df,
212+
"data.tchow_test_iceberg",
213+
Map("file_format" -> "PARQUET", "table_type" -> "iceberg"),
214+
List("ds"))
204215

205216
val icebergCols = spark.catalog.listColumns("data.tchow_test_iceberg")
206217
val externalCols = spark.catalog.listColumns("data.checkouts_parquet_partitioned")
207218
val nativeCols = spark.catalog.listColumns("data.purchases")
208219

209220
val icebergPartitions = spark.sql("SELECT * FROM data.tchow_test_iceberg.partitions")
210221

211-
212-
val sqlDf = tableUtils.sql(
213-
s"""
222+
val sqlDf = tableUtils.sql(s"""
214223
|SELECT ds FROM data.checkouts_parquet_partitioned -- external parquet
215224
|UNION ALL
216225
|SELECT ds FROM data.purchases -- bigquery native
@@ -272,8 +281,7 @@ class BigQueryCatalogTest extends AnyFlatSpec with MockitoSugar {
272281
input.close();
273282

274283
assertNotNull("Deserialized object should not be null", deserializedObj);
275-
assertTrue("Deserialized object should be an instance of GCSFileIO",
276-
deserializedObj.isInstanceOf[GCSFileIO]);
284+
assertTrue("Deserialized object should be an instance of GCSFileIO", deserializedObj.isInstanceOf[GCSFileIO]);
277285
assertEquals(original.properties(), deserializedObj.asInstanceOf[GCSFileIO].properties())
278286
}
279287
}

online/src/main/scala/ai/chronon/online/serde/AvroConversions.scala

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,12 @@ object AvroConversions {
3535
def toAvroValue(value: AnyRef, schema: Schema): Object =
3636
schema.getType match {
3737
case Schema.Type.UNION => toAvroValue(value, schema.getTypes.get(1))
38-
case Schema.Type.LONG => value.asInstanceOf[Long].asInstanceOf[Object]
38+
case Schema.Type.LONG
39+
if Option(schema.getLogicalType).map(_.getName).getOrElse("") == LogicalTypes.timestampMillis().getName =>
40+
// because we're setting spark.sql.datetime.java8API.enabled to True https://github.com/zipline-ai/chronon/blob/main/spark/src/main/scala/ai/chronon/spark/submission/SparkSessionBuilder.scala#L132,
41+
// we'll convert to java.time.Instant
42+
value.asInstanceOf[java.time.Instant].asInstanceOf[Object]
43+
case Schema.Type.LONG => value.asInstanceOf[Long].asInstanceOf[Object]
3944
case Schema.Type.INT
4045
if Option(schema.getLogicalType).map(_.getName).getOrElse("") == LogicalTypes.date().getName =>
4146
// Avro represents as java.time.LocalDate: https://github.com/apache/avro/blob/fe0261deecf22234bbd09251764152d4bf9a9c4a/lang/java/avro/src/main/java/org/apache/avro/data/TimeConversions.java#L38
@@ -59,7 +64,10 @@ object AvroConversions {
5964
case Schema.Type.INT
6065
if Option(schema.getLogicalType).map(_.getName).getOrElse("") == LogicalTypes.date().getName =>
6166
DateType
62-
case Schema.Type.INT => IntType
67+
case Schema.Type.INT => IntType
68+
case Schema.Type.LONG
69+
if Option(schema.getLogicalType).map(_.getName).getOrElse("") == LogicalTypes.timestampMillis().getName =>
70+
TimestampType
6371
case Schema.Type.LONG => LongType
6472
case Schema.Type.FLOAT => FloatType
6573
case Schema.Type.DOUBLE => DoubleType
@@ -109,13 +117,14 @@ object AvroConversions {
109117
assert(keyType == StringType, "Avro only supports string keys for a map")
110118
Schema.createMap(fromChrononSchema(valueType, nameSet))
111119
}
112-
case StringType => Schema.create(Schema.Type.STRING)
113-
case IntType => Schema.create(Schema.Type.INT)
114-
case LongType => Schema.create(Schema.Type.LONG)
115-
case FloatType => Schema.create(Schema.Type.FLOAT)
116-
case DoubleType => Schema.create(Schema.Type.DOUBLE)
117-
case BinaryType => Schema.create(Schema.Type.BYTES)
118-
case BooleanType => Schema.create(Schema.Type.BOOLEAN)
120+
case StringType => Schema.create(Schema.Type.STRING)
121+
case IntType => Schema.create(Schema.Type.INT)
122+
case LongType => Schema.create(Schema.Type.LONG)
123+
case FloatType => Schema.create(Schema.Type.FLOAT)
124+
case DoubleType => Schema.create(Schema.Type.DOUBLE)
125+
case BinaryType => Schema.create(Schema.Type.BYTES)
126+
case BooleanType => Schema.create(Schema.Type.BOOLEAN)
127+
case TimestampType => LogicalTypes.timestampMillis().addToSchema(Schema.create(Schema.Type.LONG))
119128
case DateType =>
120129
LogicalTypes.date().addToSchema(Schema.create(Schema.Type.INT))
121130
case _ =>

scripts/distribution/run_gcp_quickstart.sh

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,12 @@ if [[ "$ENVIRONMENT" == "canary" ]]; then
7171
bq rm -f -t canary-443022:data.gcp_purchases_v1_test
7272
bq rm -f -t canary-443022:data.gcp_purchases_v1_view_test
7373
bq rm -f -t canary-443022:data.gcp_purchases_v1_test_upload
74+
bq rm -f -t canary-443022:data.gcp_training_set_v1_test
7475
else
7576
bq rm -f -t canary-443022:data.gcp_purchases_v1_dev
7677
bq rm -f -t canary-443022:data.gcp_purchases_v1_view_dev
7778
bq rm -f -t canary-443022:data.gcp_purchases_v1_dev_upload
79+
bq rm -f -t canary-443022:data.gcp_training_set_v1_dev
7880
fi
7981
#TODO: delete bigtable rows
8082

@@ -127,18 +129,27 @@ zipline compile --chronon-root=$CHRONON_ROOT
127129

128130
echo -e "${GREEN}<<<<<.....................................BACKFILL.....................................>>>>>\033[0m"
129131
if [[ "$ENVIRONMENT" == "canary" ]]; then
130-
zipline run --repo=$CHRONON_ROOT --version $VERSION --mode backfill --conf compiled/group_bys/gcp/purchases.v1_test
132+
zipline run --repo=$CHRONON_ROOT --version $VERSION --mode backfill --conf compiled/group_bys/gcp/purchases.v1_test --start-ds 2023-11-01 --end-ds 2023-12-01
131133
else
132-
zipline run --repo=$CHRONON_ROOT --version $VERSION --mode backfill --conf compiled/group_bys/gcp/purchases.v1_dev
134+
zipline run --repo=$CHRONON_ROOT --version $VERSION --mode backfill --conf compiled/group_bys/gcp/purchases.v1_dev --start-ds 2023-11-01 --end-ds 2023-12-01
133135
fi
134136

135137
fail_if_bash_failed $?
136138

137139
echo -e "${GREEN}<<<<<.....................................BACKFILL-VIEW.....................................>>>>>\033[0m"
138140
if [[ "$ENVIRONMENT" == "canary" ]]; then
139-
zipline run --repo=$CHRONON_ROOT --version $VERSION --mode backfill --conf compiled/group_bys/gcp/purchases.v1_view_test
141+
zipline run --repo=$CHRONON_ROOT --version $VERSION --mode backfill --conf compiled/group_bys/gcp/purchases.v1_view_test --start-ds 2023-11-01 --end-ds 2023-12-01
140142
else
141-
zipline run --repo=$CHRONON_ROOT --version $VERSION --mode backfill --conf compiled/group_bys/gcp/purchases.v1_view_dev
143+
zipline run --repo=$CHRONON_ROOT --version $VERSION --mode backfill --conf compiled/group_bys/gcp/purchases.v1_view_dev --start-ds 2023-11-01 --end-ds 2023-12-01
144+
fi
145+
146+
fail_if_bash_failed $?
147+
148+
echo -e "${GREEN}<<<<<.....................................BACKFILL-JOIN.....................................>>>>>\033[0m"
149+
if [[ "$ENVIRONMENT" == "canary" ]]; then
150+
zipline run --repo=$CHRONON_ROOT --version $VERSION --mode backfill --conf compiled/joins/gcp/training_set.v1_test --start-ds 2023-11-01 --end-ds 2023-12-01
151+
else
152+
zipline run --repo=$CHRONON_ROOT --version $VERSION --mode backfill --conf compiled/joins/gcp/training_set.v1_dev --start-ds 2023-11-01 --end-ds 2023-12-01
142153
fi
143154

144155
fail_if_bash_failed $?

spark/src/main/scala/ai/chronon/spark/GroupByUpload.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,13 @@ class GroupByUpload(endPartition: String, groupBy: GroupBy) extends Serializable
9595
val irSchema = SparkConversions.fromChrononSchema(sawtoothOnlineAggregator.batchIrSchema)
9696
val keyBuilder = FastHashing.generateKeyBuilder(groupBy.keyColumns.toArray, groupBy.inputDf.schema)
9797

98-
logger.info(s"""
99-
|BatchIR Element Size: ${SparkEnv.get.serializer
98+
val batchIrElementSize = SparkEnv.get.serializer
10099
.newInstance()
101100
.serialize(sawtoothOnlineAggregator.init)
102-
.capacity()}
101+
.capacity()
102+
103+
logger.info(s"""
104+
|BatchIR Element Size: $batchIrElementSize
103105
|""".stripMargin)
104106

105107
val outputRdd = tableUtils

0 commit comments

Comments
 (0)