Skip to content

Add list & join schema fetcher APIs #431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions api/src/main/scala/ai/chronon/api/Constants.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,14 @@ object Constants {
val GroupByKeyword = "group_bys"
val StagingQueryKeyword = "staging_queries"
val ModelKeyword = "models"

// KV store related constants
// continuation key to help with list pagination
val ContinuationKey: String = "continuation-key"

// Limit of max number of entries to return in a list call
val ListLimit: String = "limit"

// List entity type
val ListEntityType: String = "entity_type"
}
1 change: 1 addition & 0 deletions cloud_aws/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ scala_library(

test_deps = [
":cloud_aws_lib",
"//api:lib",
"//online:lib",
maven_artifact("software.amazon.awssdk:dynamodb"),
maven_artifact("software.amazon.awssdk:regions"),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai.chronon.integrations.aws

import ai.chronon.api.Constants
import ai.chronon.api.Constants.{ContinuationKey, ListLimit}
import ai.chronon.api.ScalaJavaConversions._
import ai.chronon.online.KVStore
import ai.chronon.online.KVStore.GetResponse
Expand Down Expand Up @@ -36,7 +37,6 @@ import java.util.concurrent.ConcurrentHashMap
import scala.concurrent.Future
import scala.util.Success
import scala.util.Try

import scala.collection.Seq

object DynamoDBKVStoreConstants {
Expand All @@ -49,12 +49,6 @@ object DynamoDBKVStoreConstants {
// Optional field that indicates if this table is meant to be time sorted in Dynamo or not
val isTimedSorted = "is-time-sorted"

// Limit of max number of entries to return in a list call
val listLimit = "limit"

// continuation key to help with list pagination
val continuationKey = "continuation-key"

// Name of the partition key column to use
val partitionKeyColumn = "keyBytes"

Expand Down Expand Up @@ -172,13 +166,13 @@ class DynamoDBKVStoreImpl(dynamoDbClient: DynamoDbClient) extends KVStore {
}

override def list(request: ListRequest): Future[ListResponse] = {
val listLimit = request.props.get(DynamoDBKVStoreConstants.listLimit) match {
val listLimit = request.props.get(ListLimit) match {
case Some(value: Int) => value
case Some(value: String) => value.toInt
case _ => 100
}

val maybeExclusiveStartKey = request.props.get(continuationKey)
val maybeExclusiveStartKey = request.props.get(ContinuationKey)
val maybeExclusiveStartKeyAttribute = maybeExclusiveStartKey.map { k =>
AttributeValue.builder.b(SdkBytes.fromByteArray(k.asInstanceOf[Array[Byte]])).build
}
Expand All @@ -199,7 +193,7 @@ class DynamoDBKVStoreImpl(dynamoDbClient: DynamoDbClient) extends KVStore {
case Success(scanResponse) if scanResponse.hasLastEvaluatedKey =>
val lastEvalKey = scanResponse.lastEvaluatedKey().toScala.get(partitionKeyColumn)
lastEvalKey match {
case Some(av) => ListResponse(request, resultElements, Map(continuationKey -> av.b().asByteArray()))
case Some(av) => ListResponse(request, resultElements, Map(ContinuationKey -> av.b().asByteArray()))
case _ => noPagesLeftResponse
}
case _ => noPagesLeftResponse
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package ai.chronon.integrations.aws

import ai.chronon.api.Constants.{ContinuationKey, ListLimit}
import ai.chronon.online.KVStore._
import com.amazonaws.services.dynamodbv2.local.main.ServerRunner
import com.amazonaws.services.dynamodbv2.local.server.DynamoDBProxyServer
Expand Down Expand Up @@ -124,16 +125,16 @@ class DynamoDBKVStoreTest extends AnyFlatSpec with BeforeAndAfterAll {
putResults.foreach(r => r shouldBe true)

// call list - first call is only for 10 elements
val listReq1 = ListRequest(dataset, Map(listLimit -> 10))
val listReq1 = ListRequest(dataset, Map(ListLimit -> 10))
val listResults1 = Await.result(kvStore.list(listReq1), 1.minute)
listResults1.resultProps.contains(continuationKey) shouldBe true
listResults1.resultProps.contains(ContinuationKey) shouldBe true
validateExpectedListResponse(listResults1.values, 10)

// call list - with continuation key
val listReq2 =
ListRequest(dataset, Map(listLimit -> 100, continuationKey -> listResults1.resultProps(continuationKey)))
ListRequest(dataset, Map(ListLimit -> 100, ContinuationKey -> listResults1.resultProps(ContinuationKey)))
val listResults2 = Await.result(kvStore.list(listReq2), 1.minute)
listResults2.resultProps.contains(continuationKey) shouldBe false
listResults2.resultProps.contains(ContinuationKey) shouldBe false
validateExpectedListResponse(listResults2.values, 100)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package ai.chronon.integrations.cloud_gcp

import ai.chronon.api.Constants.{ContinuationKey, ListEntityType, ListLimit}
import ai.chronon.api.Extensions.GroupByOps
import ai.chronon.api.Extensions.StringOps
import ai.chronon.api.Extensions.WindowOps
Expand Down Expand Up @@ -211,13 +212,14 @@ class BigTableKVStoreImpl(dataClient: BigtableDataClient,
override def list(request: ListRequest): Future[ListResponse] = {
logger.info(s"Performing list for ${request.dataset}")

val listLimit = request.props.get(BigTableKVStore.listLimit) match {
val listLimit = request.props.get(ListLimit) match {
case Some(value: Int) => value
case Some(value: String) => value.toInt
case _ => defaultListLimit
}

val maybeStartKey = request.props.get(continuationKey)
val maybeListEntityType = request.props.get(ListEntityType)
val maybeStartKey = request.props.get(ContinuationKey)

val query = Query
.create(mapDatasetToTable(request.dataset))
Expand All @@ -227,9 +229,15 @@ class BigTableKVStoreImpl(dataClient: BigtableDataClient,
.filter(Filters.FILTERS.limit().cellsPerRow(1))
.limit(listLimit)

// if we got a start row key, lets wire it up
maybeStartKey.foreach { startKey =>
query.range(ByteStringRange.unbounded().startOpen(ByteString.copyFrom(startKey.asInstanceOf[Array[Byte]])))
(maybeStartKey, maybeListEntityType) match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat!

case (Some(startKey), _) =>
// we have a start key, we use that to pick up from where we left off
query.range(ByteStringRange.unbounded().startOpen(ByteString.copyFrom(startKey.asInstanceOf[Array[Byte]])))
case (None, Some(listEntityType)) =>
val startRowKey = buildRowKey(s"$listEntityType/".getBytes(Charset.forName("UTF-8")), request.dataset)
query.range(ByteStringRange.unbounded().startOpen(ByteString.copyFrom(startRowKey)))
case _ =>
logger.info("No start key or list entity type provided. Starting from the beginning")
}

val startTs = System.currentTimeMillis()
Expand All @@ -253,7 +261,7 @@ class BigTableKVStoreImpl(dataClient: BigtableDataClient,
if (listValues.size < listLimit) {
Map.empty // last page, we're done
} else
Map(continuationKey -> listValues.last.keyBytes)
Map(ContinuationKey -> listValues.last.keyBytes)

ListResponse(request, Success(listValues), propsMap)

Expand Down Expand Up @@ -410,12 +418,6 @@ class BigTableKVStoreImpl(dataClient: BigtableDataClient,

object BigTableKVStore {

// continuation key to help with list pagination
val continuationKey: String = "continuationKey"

// Limit of max number of entries to return in a list call
val listLimit: String = "limit"

// Default list limit
val defaultListLimit: Int = 100

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package ai.chronon.integrations.cloud_gcp

import ai.chronon.api.Constants.{ContinuationKey, GroupByKeyword, JoinKeyword, ListEntityType, ListLimit}
import ai.chronon.api.TilingUtils
import ai.chronon.online.KVStore.GetRequest
import ai.chronon.online.KVStore.GetResponse
Expand Down Expand Up @@ -176,21 +177,21 @@ class BigTableKVStoreTest extends AnyFlatSpec with BeforeAndAfter {

// let's try and read these
val limit = 10
val listReq1 = ListRequest(dataset, Map(listLimit -> limit))
val listReq1 = ListRequest(dataset, Map(ListLimit -> limit))

val listResult1 = Await.result(kvStore.list(listReq1), 1.second)
listResult1.values.isSuccess shouldBe true
listResult1.resultProps.contains(BigTableKVStore.continuationKey) shouldBe true
listResult1.resultProps.contains(ContinuationKey) shouldBe true
val listValues1 = listResult1.values.get
listValues1.size shouldBe limit

// another call, bigger limit
val limit2 = 1000
val continuationKey = listResult1.resultProps(BigTableKVStore.continuationKey)
val listReq2 = ListRequest(dataset, Map(listLimit -> limit2, BigTableKVStore.continuationKey -> continuationKey))
val continuationKey = listResult1.resultProps(ContinuationKey)
val listReq2 = ListRequest(dataset, Map(ListLimit -> limit2, ContinuationKey -> continuationKey))
val listResult2 = Await.result(kvStore.list(listReq2), 1.second)
listResult2.values.isSuccess shouldBe true
listResult2.resultProps.contains(BigTableKVStore.continuationKey) shouldBe false
listResult2.resultProps.contains(ContinuationKey) shouldBe false
val listValues2 = listResult2.values.get
listValues2.size shouldBe (putReqs.size - limit)

Expand All @@ -201,6 +202,53 @@ class BigTableKVStoreTest extends AnyFlatSpec with BeforeAndAfter {
.toSet
}

it should "list entity types with pagination" in {
val dataset = "metadata"
val kvStore = new BigTableKVStoreImpl(dataClient, adminClient)
kvStore.create(dataset)

val putGrpByReqs = (0 until 50).map { i =>
val key = s"$GroupByKeyword/gbkey-$i"
val value = s"""{"name": "name-$i", "age": $i}"""
PutRequest(key.getBytes, value.getBytes, dataset, None)
}

val putJoinReqs = (0 until 50).map { i =>
val key = s"$JoinKeyword/joinkey-$i"
val value = s"""{"name": "name-$i", "age": $i}"""
PutRequest(key.getBytes, value.getBytes, dataset, None)
}

val putResults = Await.result(kvStore.multiPut(putGrpByReqs ++ putJoinReqs), 1.second)
putResults.foreach(r => r shouldBe true)

// let's try and read just the joins
val limit = 10
val listReq1 = ListRequest(dataset, Map(ListLimit -> limit, ListEntityType -> JoinKeyword))

val listResult1 = Await.result(kvStore.list(listReq1), 1.second)
listResult1.values.isSuccess shouldBe true
listResult1.resultProps.contains(ContinuationKey) shouldBe true
val listValues1 = listResult1.values.get
listValues1.size shouldBe limit

// another call, bigger limit
val limit2 = 1000
val continuationKey = listResult1.resultProps(ContinuationKey)
val listReq2 = ListRequest(dataset, Map(ListLimit -> limit2, ContinuationKey -> continuationKey))
val listResult2 = Await.result(kvStore.list(listReq2), 1.second)
listResult2.values.isSuccess shouldBe true
listResult2.resultProps.contains(ContinuationKey) shouldBe false
val listValues2 = listResult2.values.get
listValues2.size shouldBe (putJoinReqs.size - limit)

// lets collect all the keys and confirm we got everything
val allKeys = (listValues1 ++ listValues2).map(v => new String(v.keyBytes, StandardCharsets.UTF_8))
allKeys.toSet shouldBe putJoinReqs
.map(r => new String(buildRowKey(r.keyBytes, r.dataset), StandardCharsets.UTF_8))
.toSet
}

it should "multiput failures" in {
val mockDataClient = mock[BigtableDataClient](withSettings().mockMaker("mock-maker-inline"))
val mockAdminClient = mock[BigtableTableAdminClient]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ cd $CHRONON_ROOT_DIR
echo "Building jars"

bazel build //cloud_gcp:cloud_gcp_lib_deploy.jar
bazel build //cloud_aws:cloud_aws_lib_deploy.jar
bazel build //service:service_assembly_deploy.jar

CLOUD_GCP_JAR="$CHRONON_ROOT_DIR/bazel-bin/cloud_gcp/cloud_gcp_lib_deploy.jar"
CLOUD_AWS_JAR="$CHRONON_ROOT_DIR/bazel-bin/cloud_aws/cloud_aws_lib_deploy.jar"
SERVICE_JAR="$CHRONON_ROOT_DIR/bazel-bin/service/service_assembly_deploy.jar"

if [ ! -f "$CLOUD_GCP_JAR" ]; then
Expand All @@ -45,6 +47,11 @@ if [ ! -f "$SERVICE_JAR" ]; then
exit 1
fi

if [ ! -f "$CLOUD_AWS_JAR" ]; then
echo "$CLOUD_AWS_JAR not found"
exit 1
fi

# We copy to build output as the docker build can't access the bazel-bin (as its a symlink)
echo "Copying jars to build_output"
mkdir -p build_output
Expand Down
2 changes: 2 additions & 0 deletions online/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ test_deps = [
scala_library(
name = "test_lib",
srcs = glob(["src/test/**/*.scala"]),
resources = glob(["src/test/resources/**/*"]),
format = select({
"//tools/config:scala_2_13": False, # Disable for 2.13
"//conditions:default": True, # Enable for other versions
Expand All @@ -81,6 +82,7 @@ scala_library(
scala_test_suite(
name = "tests",
srcs = glob(["src/test/**/*.scala"]),
resources = glob(["src/test/resources/**/*"]),
jvm_flags = _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES,
visibility = ["//visibility:public"],
deps = test_deps + [":test_lib"],
Expand Down
14 changes: 14 additions & 0 deletions online/src/main/java/ai/chronon/online/JavaFetcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package ai.chronon.online;

import ai.chronon.api.ScalaJavaConversions;
import ai.chronon.online.fetcher.Fetcher;
import ai.chronon.online.fetcher.FetcherResponseWithTs;
import scala.collection.Iterator;
Expand All @@ -25,6 +26,7 @@
import scala.compat.java8.FutureConverters;
import scala.concurrent.Future;
import scala.concurrent.ExecutionContext;
import scala.util.Try;

import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -170,6 +172,18 @@ public CompletableFuture<List<JavaResponse>> fetchJoin(List<JavaRequest> request
return convertResponsesWithTs(scalaResponses, false, startTs);
}

public CompletableFuture<List<String>> listJoins(boolean isOnline) {
// Get responses from the fetcher
Future<Seq<String>> scalaResponses = this.fetcher.metadataStore().listJoins(isOnline);
// convert to Java friendly types
return FutureConverters.toJava(scalaResponses).toCompletableFuture().thenApply(ScalaJavaConversions::toJava);
}

public JTry<JavaJoinSchemaResponse> fetchJoinSchema(String joinName) {
Try<Fetcher.JoinSchemaResponse> scalaResponse = this.fetcher.fetchJoinSchema(joinName);
return JTry.fromScala(scalaResponse).map(JavaJoinSchemaResponse::new);
}

private void instrument(List<String> requestNames, boolean isGroupBy, String metricName, Long startTs) {
long endTs = System.currentTimeMillis();
for (String s : requestNames) {
Expand Down
32 changes: 32 additions & 0 deletions online/src/main/java/ai/chronon/online/JavaJoinSchemaResponse.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package ai.chronon.online;

import ai.chronon.online.fetcher.Fetcher;

public class JavaJoinSchemaResponse {
public String joinName;
public String keySchema;
public String valueSchema;
public String schemaHash;

public JavaJoinSchemaResponse(String joinName, String keySchema, String valueSchema, String schemaHash) {
this.joinName = joinName;
this.keySchema = keySchema;
this.valueSchema = valueSchema;
this.schemaHash = schemaHash;
}

public JavaJoinSchemaResponse(Fetcher.JoinSchemaResponse scalaResponse){
this.joinName = scalaResponse.joinName();
this.keySchema = scalaResponse.keySchema();
this.valueSchema = scalaResponse.valueSchema();
this.schemaHash = scalaResponse.schemaHash();
}

public Fetcher.JoinSchemaResponse toScala() {
return new Fetcher.JoinSchemaResponse(
joinName,
keySchema,
valueSchema,
schemaHash);
}
}
6 changes: 3 additions & 3 deletions online/src/main/scala/ai/chronon/online/JoinCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ import com.google.gson.Gson
case class JoinCodec(conf: JoinOps,
keySchema: StructType,
baseValueSchema: StructType,
keyCodec: serde.AvroCodec,
baseValueCodec: serde.AvroCodec)
keyCodec: AvroCodec,
baseValueCodec: AvroCodec)
extends Serializable {

@transient lazy val valueSchema: StructType = {
Expand Down Expand Up @@ -89,7 +89,7 @@ case class JoinCodec(conf: JoinOps,

object JoinCodec {

def buildLoggingSchema(joinName: String, keyCodec: serde.AvroCodec, valueCodec: serde.AvroCodec): String = {
def buildLoggingSchema(joinName: String, keyCodec: AvroCodec, valueCodec: AvroCodec): String = {
val schemaMap = Map(
"join_name" -> joinName,
"key_schema" -> keyCodec.schemaStr,
Expand Down
1 change: 1 addition & 0 deletions online/src/main/scala/ai/chronon/online/Metrics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ object Metrics {
type Environment = String
val MetaDataFetching = "metadata.fetch"
val JoinFetching = "join.fetch"
val JoinSchemaFetching = "join.schema.fetch"
val GroupByFetching = "group_by.fetch"
val GroupByUpload = "group_by.upload"
val GroupByStreaming = "group_by.streaming"
Expand Down
Loading