Skip to content

chore: add bazel build file for cloud_aws #343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ build --java_language_version=11
build --java_runtime_version=11
build --remote_cache=https://storage.googleapis.com/zipline-bazel-cache
test --test_output=errors
test --test_timeout=900
test --test_timeout=900
11 changes: 11 additions & 0 deletions .github/workflows/test_scala_non_spark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ on:
- 'hub/**'
- 'orchestration/**'
- 'service/**'
- 'cloud_aws/**'
- 'cloud_gcp/**'
- '.github/workflows/test_scala_non_spark.yaml'
pull_request:
Expand All @@ -25,6 +26,7 @@ on:
- 'hub/**'
- 'orchestration/**'
- 'service/**'
- 'cloud_aws/**'
- 'cloud_gcp/**'
- '.github/workflows/test_scala_non_spark.yaml'

Expand Down Expand Up @@ -107,3 +109,12 @@ jobs:
--remote_cache=https://storage.googleapis.com/zipline-bazel-cache \
--google_credentials=bazel-cache-key.json \
//cloud_gcp:tests

- name: Run cloud aws tests
run: |
bazel test \
--remote_cache=https://storage.googleapis.com/zipline-bazel-cache \
--google_credentials=bazel-cache-key.json \
--java_language_version=17 \
--java_runtime_version=17 \
//cloud_aws:tests
53 changes: 53 additions & 0 deletions cloud_aws/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
scala_library(
name = "cloud_aws_lib",
srcs = glob(["src/main/**/*.scala"]),
visibility = ["//visibility:public"],
format = True,
deps = [
maven_artifact("software.amazon.awssdk:dynamodb"),
maven_artifact("software.amazon.awssdk:regions"),
maven_artifact("software.amazon.awssdk:aws-core"),
maven_artifact("software.amazon.awssdk:sdk-core"),
maven_artifact("software.amazon.awssdk:utils"),
maven_artifact("com.google.guava:guava"),
maven_artifact("org.slf4j:slf4j-api"),
maven_scala_artifact("org.scala-lang.modules:scala-collection-compat"),
"//spark:lib",
"//online:lib",
"//api:lib",
"//api:thrift_java",
],
)

test_deps = [
":cloud_aws_lib",
"//online:lib",
maven_artifact("software.amazon.awssdk:dynamodb"),
maven_artifact("software.amazon.awssdk:regions"),
maven_artifact("software.amazon.awssdk:aws-core"),
maven_artifact("software.amazon.awssdk:sdk-core"),
maven_artifact("software.amazon.awssdk:utils"),
maven_artifact("software.amazon.awssdk:auth"),
maven_artifact("software.amazon.awssdk:identity-spi"),
maven_scala_artifact("org.typelevel:cats-core"),
maven_artifact("com.amazonaws:DynamoDBLocal"),
maven_scala_artifact("com.chuusai:shapeless"),
] + _CIRCE_DEPS + _SCALA_TEST_DEPS

scala_library(
name = "test_lib",
srcs = glob(["src/test/**/*.scala"]),
format = True,
visibility = ["//visibility:public"],
deps = test_deps,
)


scala_test_suite(
name = "tests",
srcs = glob(["src/test/**/*.scala"]),
# defined in prelude_bazel file
jvm_flags = _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES is defined.


🏁 Script executed:

#!/bin/bash
# Search for JVM flags definition
rg "_JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES" --type bazel

Length of output: 459


Undefined _JVM_FLAGS_FOR_ACCESSING_BASE_JAVA_CLASSES identified:
The flag is referenced in multiple BUILD files (cloud_aws, cloud_gcp, flink, online, spark) but no definition is found in the repository. Please define it or update its references.

visibility = ["//visibility:public"],
deps = test_deps + [":test_lib"],
)
Original file line number Diff line number Diff line change
@@ -1,27 +1,22 @@
package ai.chronon.integrations.aws

import ai.chronon.online.Api
import ai.chronon.online.ExternalSourceRegistry
import ai.chronon.online.GroupByServingInfoParsed
import ai.chronon.online.KVStore
import ai.chronon.online.LoggableResponse
import ai.chronon.online.Serde
import ai.chronon.online._
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.services.dynamodb.DynamoDbClient

import java.net.URI

/**
* Implementation of Chronon's API interface for AWS. This is a work in progress and currently just covers the
/** Implementation of Chronon's API interface for AWS. This is a work in progress and currently just covers the
* DynamoDB based KV store implementation.
*/
class AwsApiImpl(conf: Map[String, String]) extends Api(conf) {
@transient lazy val ddbClient: DynamoDbClient = {
var builder = DynamoDbClient
.builder()

sys.env.get("AWS_DEFAULT_REGION").foreach { region =>
try {
builder = builder.region(Region.of(region))
builder.region(Region.of(region))
} catch {
case e: IllegalArgumentException =>
throw new IllegalArgumentException(s"Invalid AWS region format: $region", e)
Expand All @@ -43,21 +38,18 @@ class AwsApiImpl(conf: Map[String, String]) extends Api(conf) {
new DynamoDBKVStoreImpl(ddbClient)
}

/**
* The stream decoder method in the AwsApi is currently unimplemented. This needs to be implemented before
/** The stream decoder method in the AwsApi is currently unimplemented. This needs to be implemented before
* we can spin up the Aws streaming Chronon stack
*/
override def streamDecoder(groupByServingInfoParsed: GroupByServingInfoParsed): Serde = ???

/**
* The external registry extension is currently unimplemented. We'll need to implement this prior to spinning up
/** The external registry extension is currently unimplemented. We'll need to implement this prior to spinning up
* a fully functional Chronon serving stack in Aws
* @return
*/
override def externalRegistry: ExternalSourceRegistry = ???

/**
* The logResponse method is currently unimplemented. We'll need to implement this prior to bringing up the
/** The logResponse method is currently unimplemented. We'll need to implement this prior to bringing up the
* fully functional serving stack in Aws which includes logging feature responses to a stream for OOC
*/
override def logResponse(resp: LoggableResponse): Unit = ???
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,31 +141,29 @@ class DynamoDBKVStoreImpl(dynamoDbClient: DynamoDbClient) extends KVStore {
// timestamp to use for all get responses when the underlying tables don't have a ts field
val defaultTimestamp = Instant.now().toEpochMilli

val getItemResults = getItemRequestPairs.map {
case (req, getItemReq) =>
Future {
readRateLimiters.computeIfAbsent(req.dataset, _ => RateLimiter.create(defaultReadCapacityUnits)).acquire()
val item: Try[util.Map[String, AttributeValue]] =
handleDynamoDbOperation(metricsContext.withSuffix("multiget"), req.dataset) {
dynamoDbClient.getItem(getItemReq).item()
}

val response = item.map(i => List(i).asJava)
val resultValue: Try[Seq[TimedValue]] = extractTimedValues(response, defaultTimestamp)
GetResponse(req, resultValue)
}
val getItemResults = getItemRequestPairs.map { case (req, getItemReq) =>
Future {
readRateLimiters.computeIfAbsent(req.dataset, _ => RateLimiter.create(defaultReadCapacityUnits)).acquire()
val item: Try[util.Map[String, AttributeValue]] =
handleDynamoDbOperation(metricsContext.withSuffix("multiget"), req.dataset) {
dynamoDbClient.getItem(getItemReq).item()
}

val response = item.map(i => List(i).asJava)
val resultValue: Try[Seq[TimedValue]] = extractTimedValues(response, defaultTimestamp)
GetResponse(req, resultValue)
}
}

val queryResults = queryRequestPairs.map {
case (req, queryRequest) =>
Future {
readRateLimiters.computeIfAbsent(req.dataset, _ => RateLimiter.create(defaultReadCapacityUnits)).acquire()
val responses = handleDynamoDbOperation(metricsContext.withSuffix("query"), req.dataset) {
dynamoDbClient.query(queryRequest).items()
}
val resultValue: Try[Seq[TimedValue]] = extractTimedValues(responses, defaultTimestamp)
GetResponse(req, resultValue)
val queryResults = queryRequestPairs.map { case (req, queryRequest) =>
Future {
readRateLimiters.computeIfAbsent(req.dataset, _ => RateLimiter.create(defaultReadCapacityUnits)).acquire()
val responses = handleDynamoDbOperation(metricsContext.withSuffix("query"), req.dataset) {
dynamoDbClient.query(queryRequest).items()
}
val resultValue: Try[Seq[TimedValue]] = extractTimedValues(responses, defaultTimestamp)
GetResponse(req, resultValue)
}
}

Future.sequence(getItemResults ++ queryResults)
Expand Down Expand Up @@ -224,20 +222,18 @@ class DynamoDBKVStoreImpl(dynamoDbClient: DynamoDbClient) extends KVStore {
(req.dataset, putItemReq)
}

val futureResponses = datasetToWriteRequests.map {
case (dataset, putItemRequest) =>
Future {
writeRateLimiters.computeIfAbsent(dataset, _ => RateLimiter.create(defaultWriteCapacityUnits)).acquire()
handleDynamoDbOperation(metricsContext.withSuffix("multiput"), dataset) {
dynamoDbClient.putItem(putItemRequest)
}.isSuccess
}
val futureResponses = datasetToWriteRequests.map { case (dataset, putItemRequest) =>
Future {
writeRateLimiters.computeIfAbsent(dataset, _ => RateLimiter.create(defaultWriteCapacityUnits)).acquire()
handleDynamoDbOperation(metricsContext.withSuffix("multiput"), dataset) {
dynamoDbClient.putItem(putItemRequest)
}.isSuccess
}
}
Future.sequence(futureResponses)
}

/**
* Implementation of bulkPut is currently a TODO for the DynamoDB store. This involves transforming the underlying
/** Implementation of bulkPut is currently a TODO for the DynamoDB store. This involves transforming the underlying
* Parquet data to Amazon's Ion format + swapping out old table for new (as bulkLoad only writes to new tables)
*/
override def bulkPut(sourceOfflineTable: String, destinationOnlineDataSet: String, partition: String): Unit = ???
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package ai.chronon.integrations.aws

import ai.chronon.spark.{JobSubmitter, JobType}

class LivySubmitter extends JobSubmitter {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add AWS Livy client initialization.

Initialize Livy client in constructor or companion object.

-class LivySubmitter extends JobSubmitter {
+class LivySubmitter(
+  livyEndpoint: String,
+  awsRegion: String
+) extends JobSubmitter {
+  private val livyClient = LivyClient.builder()
+    .endpoint(livyEndpoint)
+    .region(awsRegion)
+    .build()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
class LivySubmitter extends JobSubmitter {
class LivySubmitter(
livyEndpoint: String,
awsRegion: String
) extends JobSubmitter {
private val livyClient = LivyClient.builder()
.endpoint(livyEndpoint)
.region(awsRegion)
.build()
// Other methods and implementations...
}


override def submit(jobType: JobType,
jobProperties: Map[String, String],
files: List[String],
args: String*): String = ???

Comment on lines +7 to +11
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Implement required methods.

All three overridden methods are unimplemented. Please provide implementations for:

  • submit: Job submission via Livy REST API
  • status: Job status checking
  • kill: Job termination

Would you like me to help implement these methods with proper AWS Livy integration?

Also applies to: 12-13, 14-15

override def status(jobId: String): Unit = ???

override def kill(jobId: String): Unit = ???
}
Original file line number Diff line number Diff line change
@@ -1,40 +1,43 @@
package ai.chronon.integrations.aws

import ai.chronon.online.KVStore.GetRequest
import ai.chronon.online.KVStore.GetResponse
import ai.chronon.online.KVStore.ListRequest
import ai.chronon.online.KVStore.ListValue
import ai.chronon.online.KVStore.PutRequest
import ai.chronon.online.KVStore._
import com.amazonaws.services.dynamodbv2.local.main.ServerRunner
import com.amazonaws.services.dynamodbv2.local.server.DynamoDBProxyServer
import io.circe.generic.auto._
import io.circe.generic.semiauto._
import io.circe.parser._
import io.circe.syntax._
import org.scalatest.BeforeAndAfter
import io.circe.{Decoder, Encoder}
import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers.be
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider
import software.amazon.awssdk.auth.credentials.{AwsBasicCredentials, StaticCredentialsProvider}
import software.amazon.awssdk.regions.Region
import software.amazon.awssdk.services.dynamodb.DynamoDbClient

import java.net.URI
import java.nio.charset.StandardCharsets
import scala.concurrent.Await
import scala.concurrent.duration.DurationInt
import scala.util.Failure
import scala.util.Success
import scala.util.Try
import scala.util.{Failure, Success, Try}

// different types of tables to store
case class Model(modelId: String, modelName: String, online: Boolean)
case class TimeSeries(joinName: String, featureName: String, tileTs: Long, metric: String, summary: Array[Double])
object DDBTestUtils {

class DynamoDBKVStoreTest extends AnyFlatSpec with BeforeAndAfter {
// different types of tables to store
case class Model(modelId: String, modelName: String, online: Boolean)
case class TimeSeries(joinName: String, featureName: String, tileTs: Long, metric: String, summary: Array[Double])

}
class DynamoDBKVStoreTest extends AnyFlatSpec with BeforeAndAfterAll {

import DDBTestUtils._
import DynamoDBKVStoreConstants._

implicit val modelEncoder: Encoder[Model] = deriveEncoder[Model]
implicit val modelDecoder: Decoder[Model] = deriveDecoder[Model]
implicit val tsEncoder: Encoder[TimeSeries] = deriveEncoder[TimeSeries]
implicit val tsDecoder: Decoder[TimeSeries] = deriveDecoder[TimeSeries]

var server: DynamoDBProxyServer = _
var client: DynamoDbClient = _
var kvStoreImpl: DynamoDBKVStoreImpl = _
Expand All @@ -55,7 +58,7 @@ class DynamoDBKVStoreTest extends AnyFlatSpec with BeforeAndAfter {
series.asJson.noSpaces.getBytes(StandardCharsets.UTF_8)
}

before {
override def beforeAll(): Unit = {
// Start the local DynamoDB instance
server = ServerRunner.createServerFromCommandLineArgs(Array("-inMemory", "-port", "8000"))
server.start()
Expand All @@ -72,9 +75,9 @@ class DynamoDBKVStoreTest extends AnyFlatSpec with BeforeAndAfter {
.build()
}

after {
client.close()
server.stop()
override def afterAll(): Unit = {
// client.close()
// server.stop()
}
Comment on lines +78 to 81
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Cleanup code is commented out.

The cleanup code in afterAll should be uncommented to properly release resources.

-  override def afterAll(): Unit = {
-//    client.close()
-//    server.stop()
-  }
+  override def afterAll(): Unit = {
+    client.close()
+    server.stop()
+  }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
override def afterAll(): Unit = {
// client.close()
// server.stop()
}
override def afterAll(): Unit = {
client.close()
server.stop()
}


// Test creation of a table with primary keys only (e.g. model)
Expand Down Expand Up @@ -115,20 +118,20 @@ class DynamoDBKVStoreTest extends AnyFlatSpec with BeforeAndAfter {
buildModelPutRequest(model, dataset)
}

val putResults = Await.result(kvStore.multiPut(putReqs), 1.second)
val putResults = Await.result(kvStore.multiPut(putReqs), 1.minute)
putResults.length shouldBe putReqs.length
putResults.foreach(r => r shouldBe true)

// call list - first call is only for 10 elements
val listReq1 = ListRequest(dataset, Map(listLimit -> 10))
val listResults1 = Await.result(kvStore.list(listReq1), 1.second)
val listResults1 = Await.result(kvStore.list(listReq1), 1.minute)
listResults1.resultProps.contains(continuationKey) shouldBe true
validateExpectedListResponse(listResults1.values, 10)

// call list - with continuation key
val listReq2 =
ListRequest(dataset, Map(listLimit -> 100, continuationKey -> listResults1.resultProps(continuationKey)))
val listResults2 = Await.result(kvStore.list(listReq2), 1.second)
val listResults2 = Await.result(kvStore.list(listReq2), 1.minute)
listResults2.resultProps.contains(continuationKey) shouldBe false
validateExpectedListResponse(listResults2.values, 100)
}
Expand All @@ -148,17 +151,17 @@ class DynamoDBKVStoreTest extends AnyFlatSpec with BeforeAndAfter {
val putReq2 = buildModelPutRequest(model2, dataset)
val putReq3 = buildModelPutRequest(model3, dataset)

val putResults = Await.result(kvStore.multiPut(Seq(putReq1, putReq2, putReq3)), 1.second)
val putResults = Await.result(kvStore.multiPut(Seq(putReq1, putReq2, putReq3)), 1.minute)
putResults shouldBe Seq(true, true, true)

// let's try and read these
val getReq1 = buildModelGetRequest(model1, dataset)
val getReq2 = buildModelGetRequest(model2, dataset)
val getReq3 = buildModelGetRequest(model3, dataset)

val getResult1 = Await.result(kvStore.multiGet(Seq(getReq1)), 1.second)
val getResult2 = Await.result(kvStore.multiGet(Seq(getReq2)), 1.second)
val getResult3 = Await.result(kvStore.multiGet(Seq(getReq3)), 1.second)
val getResult1 = Await.result(kvStore.multiGet(Seq(getReq1)), 1.minute)
val getResult2 = Await.result(kvStore.multiGet(Seq(getReq2)), 1.minute)
val getResult3 = Await.result(kvStore.multiGet(Seq(getReq3)), 1.minute)

validateExpectedModelResponse(model1, getResult1)
validateExpectedModelResponse(model2, getResult2)
Expand All @@ -178,13 +181,13 @@ class DynamoDBKVStoreTest extends AnyFlatSpec with BeforeAndAfter {

// write to the kv store and confirm the writes were successful
val putRequests = points.map(p => buildTSPutRequest(p, dataset))
val putResult = Await.result(kvStore.multiPut(putRequests), 1.second)
val putResult = Await.result(kvStore.multiPut(putRequests), 1.minute)
putResult.length shouldBe tsRange.length
putResult.foreach(r => r shouldBe true)

// query in time range: 10/05/24 00:00 to 10/10
val getRequest1 = buildTSGetRequest(points.head, dataset, 1728086400000L, 1728518400000L)
val getResult1 = Await.result(kvStore.multiGet(Seq(getRequest1)), 1.second)
val getResult1 = Await.result(kvStore.multiGet(Seq(getRequest1)), 1.minute)
validateExpectedTimeSeriesResponse(points.head, 1728086400000L, 1728518400000L, getResult1)
}

Expand Down Expand Up @@ -231,7 +234,7 @@ class DynamoDBKVStoreTest extends AnyFlatSpec with BeforeAndAfter {
private def validateExpectedListResponse(response: Try[Seq[ListValue]], maxElements: Int): Unit = {
response match {
case Success(mSeq) =>
mSeq.length should be <= maxElements
mSeq.length <= maxElements shouldBe true
mSeq.foreach { modelKV =>
val jsonStr = new String(modelKV.valueBytes, StandardCharsets.UTF_8)
val returnedModel = decode[Model](jsonStr)
Expand Down
Loading