Skip to content

fix: remove references to custom json, float essential apis to top #492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions api/src/main/scala/ai/chronon/api/Builders.scala
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,14 @@ object Builders {
online: Boolean = false,
production: Boolean = false,
customJson: String = null,
dependencies: Seq[String] = null,
namespace: String = null,
team: String = null,
samplePercent: Double = 100,
consistencySamplePercent: Double = 5,
tableProperties: Map[String, String] = Map.empty,
historicalBackfill: Boolean = true,
driftSpec: DriftSpec = null
driftSpec: DriftSpec = null,
additionalOutputPartitionColumns: Seq[String] = Seq.empty
): MetaData = {
val result = new MetaData()
result.setName(name)
Expand All @@ -298,6 +298,11 @@ object Builders {
result.setTableProperties(tableProperties.toJava)
if (driftSpec != null)
result.setDriftSpec(driftSpec)

if (additionalOutputPartitionColumns.nonEmpty) {
result.setAdditionalOutputPartitionColumns(additionalOutputPartitionColumns.toJava)
}

result
}
}
Expand Down
20 changes: 3 additions & 17 deletions api/src/main/scala/ai/chronon/api/Extensions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,6 @@ object Extensions {
@deprecated("Use `name` instead.")
def nameToFilePath: String = metaData.name.replaceFirst("\\.", "/")

// helper function to extract values from customJson
def customJsonLookUp(key: String): Any = {
if (metaData.customJson == null) return null
val mapper = new ObjectMapper()
val typeRef = new TypeReference[java.util.HashMap[String, Object]]() {}
val jMap: java.util.Map[String, Object] = mapper.readValue(metaData.customJson, typeRef)
jMap.toScala.get(key).orNull
}

def owningTeam: String = {
val teamOverride = Try(customJsonLookUp(Constants.TeamOverride).asInstanceOf[String]).toOption
teamOverride.getOrElse(metaData.team)
}

// if drift spec is set but tile size is not set, default to 30 minutes
def driftTileSize: Option[Window] = {
Option(metaData.getDriftSpec) match {
Expand Down Expand Up @@ -462,9 +448,9 @@ object Extensions {

// Check if tiling is enabled for a given GroupBy. Defaults to false if the 'enable_tiling' flag isn't set.
def isTilingEnabled: Boolean =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can delete I think as this was moved to flagStore

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

groupBy.getMetaData.customJsonLookUp("enable_tiling") match {
case s: Boolean => s
case _ => false
Option(groupBy.getMetaData.streamWriteStrategy) match {
case Some(StreamWriteStrategy.SIMPLE_TILES) | None => true
case _ => false
}

def semanticHash: String = {
Expand Down
18 changes: 0 additions & 18 deletions api/src/test/scala/ai/chronon/api/test/ExtensionsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,6 @@ class ExtensionsTest extends AnyFlatSpec {
)
}

it should "owning team" in {
val metadata =
Builders.MetaData(
customJson = "{\"check_consistency\": true, \"lag\": 0, \"team_override\": \"ml_infra\"}",
team = "chronon"
)

assertEquals(
"ml_infra",
metadata.owningTeam
)

assertEquals(
"chronon",
metadata.team
)
}

it should "row identifier" in {
val labelPart = Builders.LabelPart();
val res = labelPart.rowIdentifier(Arrays.asList("yoyo", "yujia"), "ds")
Expand Down
20 changes: 16 additions & 4 deletions api/thrift/api.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,19 @@ struct MetaData {

4: optional string outputNamespace

5: optional map<string, string> tableProperties
/**
* By default we will just partition the output by the date column - set via "spark.chronon.partition.column"
* With this we will partition the output with the specified additional columns
**/
5: optional list<string> additionalOutputPartitionColumns

6: optional map<string, string> tableProperties
Comment on lines +256 to +258
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep tableProperties the same field number as before (5)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thought outputPartitionCols are more important than table props. Safe to change these for now actually.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we should call out that folks need to recompile their existing configs right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call - we need to release a wheel and cut them over. was working on it separately. (basically I am doing the compile for them)


// tag_key -> tag_value - tags allow for repository wide querying, deprecations etc
// this is object level tag - applies to all columns produced by the object - GroupBy, Join, Model etc
6: optional map<string, string> tags
20: optional map<string, string> tags
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Field number here too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the spacing in field nums allows for adding things new fields in the right order. so if we find some other thing later, we can add it in the right place instead of in the end.

// column -> tag_key -> tag_value
7: optional map<string, map<string, string>> columnTags
21: optional map<string, map<string, string>> columnTags

// marking this as true means that the conf can be served online
// once marked online, a conf cannot be changed - compiling the conf won't be allowed
Expand Down Expand Up @@ -284,9 +290,15 @@ struct MetaData {

# information that needs to be present on every physical node
204: optional common.ExecutionInfo executionInfo
}

300: optional StreamWriteStrategy streamWriteStrategy
}

enum StreamWriteStrategy {
RAW,
SIMPLE_TILES,
CUMULATIVE_AND_TILES,
}

// Equivalent to a FeatureSet in chronon terms
struct GroupBy {
Expand Down
18 changes: 4 additions & 14 deletions online/src/main/scala/ai/chronon/online/MetadataEndPoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,12 @@ object MetadataEndPoint {
val ConfByKeyEndPointName = "CHRONON_METADATA"
val NameByTeamEndPointName = "CHRONON_ENTITY_BY_TEAM"

private def getTeamFromMetadata(metaData: MetaData): String = {
val team = metaData.team
if (metaData.customJson != null && metaData.customJson.nonEmpty) {
implicit val formats = DefaultFormats
val customJson = parse(metaData.customJson)
val teamFromJson: String = (customJson \ "team_override").extractOpt[String].getOrElse("")
if (teamFromJson.nonEmpty) teamFromJson else team
} else team
}

private def parseTeam[Conf <: TBase[_, _]: Manifest: ClassTag](conf: Conf): String = {
conf match {
case join: Join => "joins/" + getTeamFromMetadata(join.metaData)
case groupBy: GroupBy => "group_bys/" + getTeamFromMetadata(groupBy.metaData)
case stagingQuery: StagingQuery => "staging_queries/" + getTeamFromMetadata(stagingQuery.metaData)
case model: Model => "models/" + getTeamFromMetadata(model.metaData)
case join: Join => "joins/" + join.metaData.team
case groupBy: GroupBy => "group_bys/" + groupBy.metaData.team
case stagingQuery: StagingQuery => "staging_queries/" + stagingQuery.metaData.team
case model: Model => "models/" + model.metaData.team
case _ =>
logger.error(s"Failed to parse team from $conf")
throw new Exception(s"Failed to parse team from $conf")
Expand Down
6 changes: 3 additions & 3 deletions online/src/main/scala/ai/chronon/online/Metrics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ object Metrics {
environment = environment,
join = join.metaData.cleanName,
production = join.metaData.isProduction,
team = join.metaData.owningTeam
team = join.metaData.team
)
}

Expand All @@ -108,7 +108,7 @@ object Metrics {
groupBy = groupBy.metaData.cleanName,
production = groupBy.metaData.isProduction,
accuracy = groupBy.inferredAccuracy,
team = groupBy.metaData.owningTeam,
team = groupBy.metaData.team,
join = groupBy.sources.toScala
.find(_.isSetJoinSource)
.map(_.getJoinSource.join.metaData.cleanName)
Expand All @@ -127,7 +127,7 @@ object Metrics {
environment = environment,
groupBy = stagingQuery.metaData.cleanName,
production = stagingQuery.metaData.isProduction,
team = stagingQuery.metaData.owningTeam
team = stagingQuery.metaData.team
)
}

Expand Down
6 changes: 1 addition & 5 deletions spark/src/main/scala/ai/chronon/spark/StagingQuery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,7 @@ class StagingQuery(stagingQueryConf: api.StagingQuery, endPartition: String, tab
.orNull

private val partitionCols: Seq[String] =
Seq(Option(stagingQueryConf.getPartitionColumn).getOrElse(tableUtils.partitionColumn)) ++
Option(stagingQueryConf.metaData.customJsonLookUp(key = "additional_partition_cols"))
.getOrElse(new java.util.ArrayList[String]())
.asInstanceOf[java.util.ArrayList[String]]
.toScala
Seq(tableUtils.partitionColumn) ++ stagingQueryConf.metaData.additionalOutputPartitionColumns.toScala
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the main change that ben needs


def computeStagingQuery(stepDays: Option[Int] = None,
enableAutoExpand: Option[Boolean] = Some(true),
Expand Down
87 changes: 85 additions & 2 deletions spark/src/test/scala/ai/chronon/spark/test/StagingQueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ import ai.chronon.spark.SparkSessionBuilder
import ai.chronon.spark.StagingQuery
import ai.chronon.spark.TableUtils
import org.apache.spark.sql.SparkSession
import org.junit.Assert.assertEquals
import org.junit.Assert.{assertEquals, assertTrue}
import org.scalatest.flatspec.AnyFlatSpec
import org.slf4j.Logger
import org.slf4j.LoggerFactory

class StagingQueryTest extends AnyFlatSpec {
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
lazy val spark: SparkSession = SparkSessionBuilder.build("StagingQueryTest", local = true)
implicit lazy val spark: SparkSession = SparkSessionBuilder.build("StagingQueryTest", local = true)
implicit private val tableUtils: TableUtils = TableUtils(spark)

private val today = tableUtils.partitionSpec.at(System.currentTimeMillis())
Expand Down Expand Up @@ -285,4 +285,87 @@ class StagingQueryTest extends AnyFlatSpec {
}
assertEquals(0, diff.count())
}

private def getPartitionColumnNames(tableName: String)(implicit spark: SparkSession): Seq[String] = {
// Get the catalog table information
val tableIdentifier = spark.sessionState.sqlParser.parseTableIdentifier(tableName)
val catalogTable = spark.sessionState.catalog.getTableMetadata(tableIdentifier)

// Extract partition column names from the table schema
catalogTable.partitionColumnNames
}

it should "handle additional output partition columns" in {
val schema = List(
Column("user", StringType, 10),
Column("region", StringType, 5, nullRate = 0.0), // partition columns cannot have null
Column("device", StringType, 3, nullRate = 0.0), // partition columns cannot have null
Column("session_length", IntType, 1000)
)

// Generate test data with columns that can be used for additional partitioning
val df = DataFrameGen
.events(spark, schema, count = 10000, partitions = 20)
.dropDuplicates("ts")
logger.info("Generated test data for additional partition columns:")
df.show()

val tableName = s"$namespace.test_additional_partition_cols"
df.save(tableName)

// Define a staging query with multiple additional partition columns
val stagingQueryConf = Builders.StagingQuery(
query = s"select * from $tableName WHERE ds BETWEEN {{ start_date }} AND {{ end_date }}",
startPartition = ninetyDaysAgo,
metaData = Builders.MetaData(
name = "test.additional_partitions",
namespace = namespace,
additionalOutputPartitionColumns = Seq("region", "device"), // Explicitly specify additional partition columns
tableProperties = Map("key" -> "val")
)
)

val stagingQuery = new StagingQuery(stagingQueryConf, today, tableUtils)
stagingQuery.computeStagingQuery(stepDays = Option(30))

// Verify the data was written correctly
val expected = tableUtils.sql(
s"select * from $tableName where ds between '$ninetyDaysAgo' and '$today'"
)

val computed = tableUtils.sql(s"select * from ${stagingQueryConf.metaData.outputTable}")
val diff = Comparison.sideBySide(expected, computed, List("user", "ts", "ds"))

val diffCount = diff.count()
if (diffCount > 0) {
logger.info("Different rows between expected and computed")

logger.info("Expected rows")
expected.show()

logger.info("Computed rows")
computed.show()

logger.info("Diff rows (SxS)")
diff.show()
}

assertEquals(0, diff.count())

// Verify the table was created with the additional partition columns
val tableDesc = spark.sql(s"DESCRIBE ${stagingQueryConf.metaData.outputTable}")
val partitionInfo = spark.sql(s"SHOW PARTITIONS ${stagingQueryConf.metaData.outputTable}")

logger.info("Table description:")
tableDesc.show()
logger.info("Partition information:")
partitionInfo.show()

// Get the partition column names from the table metadata
val partitionColumnNames = getPartitionColumnNames(stagingQueryConf.metaData.outputTable)(spark)

// Verify all expected partition columns are present
val expectedPartitionCols = Seq(tableUtils.partitionColumn, "region", "device")
assertEquals(expectedPartitionCols.toSet, partitionColumnNames.toSet)
}
}
Loading