Skip to content

Commit 98c616e

Browse files
authored
Add value info struct to join schema response to pull feature -> key mapping (#728)
## Summary Updating the JoinSchemaResponse to include a mapping from feature -> listing key. This PR updates our JoinSchemaResponse to include a value info case class with these details. ## Checklist - [X] Added Unit Tests - [X] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Summary by CodeRabbit - **New Features** - Added detailed metadata for join value fields, including feature names, group names, prefixes, left keys, and schema descriptions, now available in join schema responses. - **Bug Fixes** - Improved consistency and validation between join configuration keys and value field metadata. - **Tests** - Enhanced and added tests to validate the presence and correctness of value field metadata in join schema responses. - Introduced new test suites covering fetcher failure scenarios and metadata store functionality. - Refactored existing fetcher tests to use external utility methods for data generation. - Added utility methods for generating deterministic, random, and event-only test data configurations. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 0545a42 commit 98c616e

File tree

9 files changed

+789
-587
lines changed

9 files changed

+789
-587
lines changed

online/src/main/java/ai/chronon/online/JavaJoinSchemaResponse.java

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,56 @@
22

33
import ai.chronon.online.fetcher.Fetcher;
44

5+
import java.util.Arrays;
6+
57
public class JavaJoinSchemaResponse {
68
public String joinName;
79
public String keySchema;
810
public String valueSchema;
911
public String schemaHash;
12+
public ValueInfo[] valueInfos;
13+
14+
public static class ValueInfo {
15+
public String fullName;
16+
public String groupName;
17+
public String prefix;
18+
public String[] leftKeys;
19+
public String schemaString;
20+
public ValueInfo(String fullName, String groupName, String prefix, String[] leftKeys, String schemaString) {
21+
this.fullName = fullName;
22+
this.groupName = groupName;
23+
this.prefix = prefix;
24+
this.leftKeys = leftKeys;
25+
this.schemaString = schemaString;
26+
}
27+
}
1028

11-
public JavaJoinSchemaResponse(String joinName, String keySchema, String valueSchema, String schemaHash) {
29+
public JavaJoinSchemaResponse(String joinName, String keySchema, String valueSchema, String schemaHash, ValueInfo[] valueInfos) {
1230
this.joinName = joinName;
1331
this.keySchema = keySchema;
1432
this.valueSchema = valueSchema;
1533
this.schemaHash = schemaHash;
34+
this.valueInfos = valueInfos;
1635
}
1736

1837
public JavaJoinSchemaResponse(Fetcher.JoinSchemaResponse scalaResponse){
1938
this.joinName = scalaResponse.joinName();
2039
this.keySchema = scalaResponse.keySchema();
2140
this.valueSchema = scalaResponse.valueSchema();
2241
this.schemaHash = scalaResponse.schemaHash();
42+
this.valueInfos = Arrays.stream(scalaResponse.valueInfos())
43+
.map(v -> new ValueInfo(v.fullName(), v.groupName(), v.prefix(), v.leftKeys(), v.schemaString()))
44+
.toArray(ValueInfo[]::new);
2345
}
2446

2547
public Fetcher.JoinSchemaResponse toScala() {
2648
return new Fetcher.JoinSchemaResponse(
2749
joinName,
2850
keySchema,
2951
valueSchema,
30-
schemaHash);
52+
schemaHash,
53+
Arrays.stream(valueInfos)
54+
.map(v -> new JoinCodec.ValueInfo(v.fullName, v.groupName, v.prefix, v.leftKeys, v.schemaString))
55+
.toArray(JoinCodec.ValueInfo[]::new));
3156
}
3257
}

online/src/main/scala/ai/chronon/online/JoinCodec.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ case class JoinCodec(conf: JoinOps,
3636
baseValueSchema: StructType,
3737
keyCodec: AvroCodec,
3838
baseValueCodec: AvroCodec,
39+
valueInfos: Array[JoinCodec.ValueInfo],
3940
hasPartialFailure: Boolean = false)
4041
extends Serializable {
4142

@@ -98,4 +99,17 @@ object JoinCodec {
9899
)
99100
new Gson().toJson(schemaMap.toJava)
100101
}
102+
103+
/** Tracks details on the feature values that the join is producing.
104+
* @param fullName - Full feature name (e.g. prefix_groupName_featureName)
105+
* @param groupName - Name of the group (GroupBy name / derivation / external part name)
106+
* @param prefix - Prefix for the group
107+
* @param leftKeys - Keys needed to look up this feature
108+
* @param schemaString - User friendly schema string for the feature
109+
*/
110+
case class ValueInfo(fullName: String,
111+
groupName: String,
112+
prefix: String,
113+
leftKeys: Array[String],
114+
schemaString: String)
101115
}

online/src/main/scala/ai/chronon/online/fetcher/Fetcher.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,13 @@ object Fetcher {
7979
* @param keySchema - Avro schema string for the key
8080
* @param valueSchema - Avro schema string for the value
8181
* @param schemaHash - Hash of the join schema payload (used to track updates to key / value schema fields or types)
82+
* @param valueInfos - Per feature column metadata (e.g. group name, corresponding left lookup keys, ..)
8283
*/
83-
case class JoinSchemaResponse(joinName: String, keySchema: String, valueSchema: String, schemaHash: String)
84+
case class JoinSchemaResponse(joinName: String,
85+
keySchema: String,
86+
valueSchema: String,
87+
schemaHash: String,
88+
valueInfos: Array[JoinCodec.ValueInfo])
8489
}
8590

8691
private[online] case class FetcherResponseWithTs(responses: Seq[Fetcher.Response], endTs: Long)
@@ -501,7 +506,8 @@ class Fetcher(val kvStore: KVStore,
501506
val response = JoinSchemaResponse(joinName,
502507
joinCodec.keyCodec.schemaStr,
503508
joinCodec.valueCodec.schemaStr,
504-
joinCodec.loggingSchemaHash)
509+
joinCodec.loggingSchemaHash,
510+
joinCodec.valueInfos.toArray)
505511
if (joinCodec.hasPartialFailure) {
506512
joinCodecCache.refresh(joinName)
507513
}

online/src/main/scala/ai/chronon/online/fetcher/MetadataStore.scala

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,14 +292,23 @@ class MetadataStore(fetchContext: FetchContext) {
292292
def buildJoinCodec(joinConf: Join, refreshOnFail: Boolean): JoinCodec = {
293293
val keyFields = new mutable.LinkedHashSet[StructField]
294294
val valueFields = new mutable.ListBuffer[StructField]
295+
val valueInfos = mutable.ListBuffer.empty[JoinCodec.ValueInfo]
295296
var hasPartialFailure = false
296297
// collect keyFields and valueFields from joinParts/GroupBys
297298
joinConf.joinPartOps.foreach { joinPart =>
298299
getGroupByServingInfo(joinPart.groupBy.metaData.getName)
299300
.map { servingInfo =>
300301
val (keys, values) = buildJoinPartCodec(joinPart, servingInfo)
302+
301303
keys.foreach(k => keyFields.add(k))
302304
values.foreach(v => valueFields.append(v))
305+
306+
val leftKeys = keys.map(_.name).map(joinPart.rightToLeft)
307+
values.foreach { v =>
308+
val schemaString = SparkConversions.fromChrononType(v.fieldType).catalogString
309+
valueInfos.append(JoinCodec
310+
.ValueInfo(v.name, joinPart.groupBy.metaData.getName, joinPart.prefix, leftKeys.toArray, schemaString))
311+
}
303312
}
304313
.recoverWith {
305314
case exception: Throwable => {
@@ -332,9 +341,18 @@ class MetadataStore(fetchContext: FetchContext) {
332341
.fields
333342
.map(f => StructField(prefix + f.name, f.fieldType))
334343

335-
buildFields(source.getKeySchema).foreach(f =>
336-
keyFields.add(f.copy(name = part.rightToLeft.getOrElse(f.name, f.name))))
337-
buildFields(source.getValueSchema, part.fullName + "_").foreach(f => valueFields.append(f))
344+
val keyStructFields =
345+
buildFields(source.getKeySchema).map(f => f.copy(name = part.rightToLeft.getOrElse(f.name, f.name)))
346+
keyStructFields.foreach(keyFields.add)
347+
val leftKeys = keyStructFields.map(_.name)
348+
349+
buildFields(source.getValueSchema, part.fullName + "_").foreach { f =>
350+
val schemaString = SparkConversions.fromChrononType(f.fieldType).catalogString
351+
valueInfos.append(
352+
JoinCodec.ValueInfo(f.name, part.source.metadata.getName, part.prefix, leftKeys.toArray, schemaString))
353+
valueFields.append(f)
354+
}
355+
338356
}
339357
}
340358

@@ -343,7 +361,7 @@ class MetadataStore(fetchContext: FetchContext) {
343361
val keyCodec = AvroCodec.of(AvroConversions.fromChrononSchema(keySchema).toString)
344362
val baseValueSchema = StructType(s"${joinName.sanitize}_value", valueFields.toArray)
345363
val baseValueCodec = serde.AvroCodec.of(AvroConversions.fromChrononSchema(baseValueSchema).toString)
346-
JoinCodec(joinConf, keySchema, baseValueSchema, keyCodec, baseValueCodec, hasPartialFailure)
364+
JoinCodec(joinConf, keySchema, baseValueSchema, keyCodec, baseValueCodec, valueInfos.toArray, hasPartialFailure)
347365
}
348366

349367
def getSchemaFromKVStore(dataset: String, key: String): serde.AvroCodec = {

service/src/test/java/ai/chronon/service/handlers/JoinSchemaHandlerTest.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import ai.chronon.online.JTry;
44
import ai.chronon.online.JavaFetcher;
55
import ai.chronon.online.JavaJoinSchemaResponse;
6+
import ai.chronon.online.JoinCodec;
67
import io.vertx.core.Vertx;
78
import io.vertx.core.http.HttpServerResponse;
89
import io.vertx.core.json.JsonArray;
@@ -59,8 +60,10 @@ public void testSuccessfulRequest(TestContext context) {
5960
Async async = context.async();
6061

6162
String avroSchemaString = "{\"type\":\"record\",\"name\":\"User\",\"namespace\":\"com.example\",\"fields\":[{\"name\":\"id\",\"type\":\"string\"}]}";
62-
63-
JavaJoinSchemaResponse joinSchemaResponse = new JavaJoinSchemaResponse("user_join", avroSchemaString, avroSchemaString, "fakeschemaHash");
63+
String [] keys = {"user_id"};
64+
JavaJoinSchemaResponse.ValueInfo valueInfo = new JavaJoinSchemaResponse.ValueInfo("my_groupby_feature_1", "my_groupby", "", keys, "foo");
65+
JavaJoinSchemaResponse.ValueInfo[] valueInfos = {valueInfo};
66+
JavaJoinSchemaResponse joinSchemaResponse = new JavaJoinSchemaResponse("user_join", avroSchemaString, avroSchemaString, "fakeschemaHash", valueInfos);
6467
JTry<JavaJoinSchemaResponse> joinSchemaResponseTry = JTry.success(joinSchemaResponse);
6568

6669
// Set up mocks
@@ -93,6 +96,15 @@ public void testSuccessfulRequest(TestContext context) {
9396
String valueSchema = actualResponse.getString("valueSchema");
9497
context.assertEquals(valueSchema, avroSchemaString);
9598

99+
// sanity check the value info payload
100+
JsonArray valueInfoArray = actualResponse.getJsonArray("valueInfos");
101+
context.assertEquals(valueInfoArray.size(), 1);
102+
JsonObject valueInfoJson = valueInfoArray.getJsonObject(0);
103+
context.assertEquals(valueInfoJson.getString("fullName"), "my_groupby_feature_1");
104+
JsonArray leftKeysArray = valueInfoJson.getJsonArray("leftKeys");
105+
context.assertEquals(leftKeysArray.size(), 1);
106+
context.assertEquals(leftKeysArray.getString(0), "user_id");
107+
96108
// confirm we can parse the avro schema fine
97109
new Schema.Parser().parse(keySchema);
98110
new Schema.Parser().parse(valueSchema);
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package ai.chronon.spark.test.fetcher
2+
3+
import ai.chronon.api.Constants.MetadataDataset
4+
import ai.chronon.api.Extensions.JoinOps
5+
import ai.chronon.online.fetcher
6+
import ai.chronon.online.fetcher.{FetchContext, MetadataStore}
7+
import ai.chronon.online.fetcher.Fetcher.Request
8+
import ai.chronon.spark.catalog.TableUtils
9+
import ai.chronon.spark.submission
10+
import ai.chronon.spark.test.OnlineUtils
11+
import ai.chronon.spark.utils.MockApi
12+
import org.apache.spark.sql.SparkSession
13+
import org.apache.spark.sql.functions.col
14+
import org.junit.Assert.{assertEquals, assertTrue}
15+
import org.scalatest.flatspec.AnyFlatSpec
16+
import org.slf4j.{Logger, LoggerFactory}
17+
18+
import java.util.TimeZone
19+
import java.util.concurrent.Executors
20+
import scala.concurrent.ExecutionContext
21+
22+
class FetcherFailureTest extends AnyFlatSpec {
23+
24+
val sessionName = "FetcherFailureTest"
25+
val spark: SparkSession = submission.SparkSessionBuilder.build(sessionName, local = true)
26+
private val tableUtils = TableUtils(spark)
27+
28+
private val topic = "test_topic"
29+
TimeZone.setDefault(TimeZone.getTimeZone("UTC"))
30+
private val today = tableUtils.partitionSpec.at(System.currentTimeMillis())
31+
private val yesterday = tableUtils.partitionSpec.before(today)
32+
33+
@transient lazy val logger: Logger = LoggerFactory.getLogger(getClass)
34+
35+
// test soft-fail on missing keys
36+
it should "test empty request" in {
37+
val namespace = "empty_request"
38+
val joinConf = FetcherTestUtil.generateRandomData(namespace, tableUtils, spark, topic, today, yesterday, 5, 5)
39+
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))
40+
val kvStoreFunc = () => OnlineUtils.buildInMemoryKVStore("FetcherFailureTest#empty_request")
41+
val inMemoryKvStore = kvStoreFunc()
42+
val mockApi = new MockApi(kvStoreFunc, namespace)
43+
44+
val metadataStore = new fetcher.MetadataStore(FetchContext(inMemoryKvStore))
45+
inMemoryKvStore.create(MetadataDataset)
46+
metadataStore.putJoinConf(joinConf)
47+
48+
val request = Request(joinConf.metaData.name, Map.empty)
49+
val (responses, _) = FetcherTestUtil.joinResponses(spark, Array(request), mockApi)
50+
val responseMap = responses.head.values.get
51+
52+
logger.info("====== Empty request response map ======")
53+
logger.info(responseMap.toString)
54+
// In this case because of empty keys, both attempts to compute derivation will fail
55+
val derivationExceptionTypes = Seq("derivation_fetch_exception", "derivation_rename_exception")
56+
assertEquals(joinConf.joinParts.size() + derivationExceptionTypes.size, responseMap.size)
57+
assertTrue(responseMap.keys.forall(_.endsWith("_exception")))
58+
}
59+
60+
it should "test KVStore partial failure" in {
61+
val namespace = "test_kv_store_partial_failure"
62+
val joinConf = FetcherTestUtil.generateRandomData(namespace, tableUtils, spark, topic, today, yesterday, 5, 5)
63+
implicit val executionContext: ExecutionContext = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1))
64+
val kvStoreFunc = () =>
65+
OnlineUtils.buildInMemoryKVStore("FetcherFailureTest#test_kv_store_partial_failure",
66+
hardFailureOnInvalidDataset = true)
67+
val inMemoryKvStore = kvStoreFunc()
68+
val mockApi = new MockApi(kvStoreFunc, namespace)
69+
70+
val metadataStore = new MetadataStore(FetchContext(inMemoryKvStore))
71+
inMemoryKvStore.create(MetadataDataset)
72+
metadataStore.putJoinConf(joinConf)
73+
74+
val keys = joinConf.leftKeyCols
75+
val keyData = spark.table(s"$namespace.queries_table").select(keys.map(col): _*).head
76+
val keyMap = keys.indices.map { idx =>
77+
keys(idx) -> keyData.get(idx).asInstanceOf[AnyRef]
78+
}.toMap
79+
80+
val request = Request(joinConf.metaData.name, keyMap)
81+
val (responses, _) = FetcherTestUtil.joinResponses(spark, Array(request), mockApi)
82+
val responseMap = responses.head.values.get
83+
val exceptionKeys = joinConf.joinPartOps.map(jp => jp.fullPrefix + "_exception")
84+
exceptionKeys.foreach(k => assertTrue(responseMap.contains(k)))
85+
}
86+
87+
}

0 commit comments

Comments
 (0)