Skip to content

fix: remove some nondeterministic timestamp in ChainingFetcherTest #740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 7, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,24 @@
package ai.chronon.spark.test.fetcher

import ai.chronon.api
import ai.chronon.api._
import ai.chronon.api.Constants.MetadataDataset
import ai.chronon.api.Extensions.{JoinOps, MetadataOps}
import ai.chronon.api.ScalaJavaConversions._
import ai.chronon.api._
import ai.chronon.spark.catalog.TableUtils

import ai.chronon.online.fetcher.Fetcher.Request
import ai.chronon.online.fetcher.{FetchContext, MetadataStore}
import ai.chronon.online.fetcher.Fetcher.Request
import ai.chronon.online.serde.SparkConversions
import ai.chronon.spark.{Join => _, _}
import ai.chronon.spark.catalog.TableUtils
import ai.chronon.spark.Extensions._
import ai.chronon.spark.test.{OnlineUtils, TestUtils}
import ai.chronon.spark.utils.MockApi
import ai.chronon.spark.{Join => _, _}
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.junit.Assert.{assertEquals, assertTrue}
import org.scalatest.flatspec.AnyFlatSpec
import org.slf4j.{Logger, LoggerFactory}
import ai.chronon.spark.submission.SparkSessionBuilder

import java.util.TimeZone
import java.util.concurrent.Executors
Expand All @@ -44,14 +43,11 @@ import scala.concurrent.ExecutionContext

class ChainingFetcherTest extends AnyFlatSpec {

import ai.chronon.spark.submission

@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
val sessionName = "ChainingFetcherTest"
val spark: SparkSession = submission.SparkSessionBuilder.build(sessionName, local = true)
val spark: SparkSession = SparkSessionBuilder.build(sessionName, local = true)
private val tableUtils = TableUtils(spark)
TimeZone.setDefault(TimeZone.getTimeZone("UTC"))
private val today = tableUtils.partitionSpec.at(System.currentTimeMillis())
def toTs(arg: String): Long = TsUtils.datetimeToTs(arg)

/** This test group by is trying to get the latest rating of listings a user viewed in the last 7 days.
Expand Down Expand Up @@ -265,7 +261,7 @@ class ChainingFetcherTest extends AnyFlatSpec {
val all: Map[String, AnyRef] =
res.request.keys ++
res.values.get ++
Map(tableUtils.partitionColumn -> today) ++
Map(tableUtils.partitionColumn -> endDs) ++
Map(Constants.TimeColumn -> java.lang.Long.valueOf(res.request.atMillis.get))
val values: Array[Any] = columns.map(all.get(_).orNull)
SparkConversions
Expand All @@ -279,7 +275,6 @@ class ChainingFetcherTest extends AnyFlatSpec {

// compare the result of fetched response with the expected result
def compareTemporalFetch(joinConf: api.Join,
endDs: String,
expectedDf: DataFrame,
responseRows: Seq[Row],
ignoreCol: String): Unit = {
Expand All @@ -288,9 +283,6 @@ class ChainingFetcherTest extends AnyFlatSpec {
val keyishColumns = keys.toList ++ List(tableUtils.partitionColumn, Constants.TimeColumn)
val responseRdd = tableUtils.sparkSession.sparkContext.parallelize(responseRows.toSeq)
var responseDf = tableUtils.sparkSession.createDataFrame(responseRdd, expectedDf.schema)
if (endDs != today) {
responseDf = responseDf.drop("ds").withColumn("ds", lit(endDs))
}
logger.info("expected:")
expectedDf.show()
logger.info("response:")
Expand Down Expand Up @@ -319,7 +311,7 @@ class ChainingFetcherTest extends AnyFlatSpec {
val namespace = "parent_join_fetch"
val joinConf = generateMutationData(namespace, Accuracy.TEMPORAL)
val (expected, fetcherResponse) = executeFetch(joinConf, "2021-04-15", namespace)
compareTemporalFetch(joinConf, "2021-04-15", expected, fetcherResponse, "user")
compareTemporalFetch(joinConf, expected, fetcherResponse, "user")
}

it should "fetch chaining deterministic" in {
Expand All @@ -328,6 +320,6 @@ class ChainingFetcherTest extends AnyFlatSpec {
assertTrue(chainingJoinConf.joinParts.get(0).groupBy.sources.get(0).isSetJoinSource)

val (expected, fetcherResponse) = executeFetch(chainingJoinConf, "2021-04-18", namespace)
compareTemporalFetch(chainingJoinConf, "2021-04-18", expected, fetcherResponse, "listing")
compareTemporalFetch(chainingJoinConf, expected, fetcherResponse, "listing")
}
}