Skip to content

Commit a31b900

Browse files
authored
BigTable / Fetcher updates - use closeAsync, setTimeouts, allow bulkRead / readRows choice (#630)
## Summary ## Checklist - [ ] Added Unit Tests - [ ] Covered by existing CI - [ ] Integration tested - [ ] Documentation update <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced bulk read support and refined time series data filtering to enhance data retrieval performance and accuracy. - Improved client retry mechanisms for more robust operations under varying conditions. - **Refactor** - Streamlined the handling of combined responses and updated latency metric reporting for clearer performance insights. - Centralized configuration access for environment variables and settings. - **Tests** - Expanded property-based testing across multiple configurations to ensure consistent and reliable functionality. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 143ef00 commit a31b900

File tree

4 files changed

+672
-401
lines changed

4 files changed

+672
-401
lines changed

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreImpl.scala

Lines changed: 165 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import com.google.cloud.bigtable.admin.v2.BigtableTableAdminClient
2727
import com.google.cloud.bigtable.admin.v2.models.CreateTableRequest
2828
import com.google.cloud.bigtable.admin.v2.models.GCRules
2929
import com.google.cloud.bigtable.data.v2.BigtableDataClient
30-
import com.google.cloud.bigtable.data.v2.models.{Filters, Query, Row, RowMutation, TableId => BTTableId}
30+
import com.google.cloud.bigtable.data.v2.models.{Filters, Query, Row, RowMutation, TargetId, TableId => BTTableId}
3131
import com.google.cloud.bigtable.data.v2.models.Range.ByteStringRange
3232
import com.google.cloud.bigtable.data.v2.models.Range.TimestampRange
3333
import com.google.protobuf.ByteString
@@ -37,6 +37,7 @@ import org.threeten.bp.Duration
3737

3838
import java.nio.charset.Charset
3939
import java.util
40+
import scala.collection.mutable.ArrayBuffer
4041
import scala.compat.java8.FutureConverters
4142
import scala.concurrent.Future
4243
import scala.concurrent.duration._
@@ -69,7 +70,8 @@ import scala.collection.{Seq, mutable}
6970
*/
7071
class BigTableKVStoreImpl(dataClient: BigtableDataClient,
7172
maybeAdminClient: Option[BigtableTableAdminClient] = None,
72-
maybeBigQueryClient: Option[BigQuery] = None)
73+
maybeBigQueryClient: Option[BigQuery] = None,
74+
conf: Map[String, String] = Map.empty)
7375
extends KVStore {
7476

7577
@transient override lazy val logger: Logger = LoggerFactory.getLogger(getClass)
@@ -88,6 +90,8 @@ class BigTableKVStoreImpl(dataClient: BigtableDataClient,
8890

8991
protected val metricsContext: Metrics.Context = Metrics.Context(Metrics.Environment.KVStore).withSuffix("bigtable")
9092

93+
private val useBulkReadRows = GcpApiImpl.getOptional(GcpApiImpl.UseBulkReadRows, conf).forall(_.toBoolean)
94+
9195
override def create(dataset: String): Unit = create(dataset, Map.empty)
9296

9397
override def create(dataset: String, props: Map[String, Any]): Unit = {
@@ -122,13 +126,27 @@ class BigTableKVStoreImpl(dataClient: BigtableDataClient,
122126
}
123127

124128
override def multiGet(requests: Seq[KVStore.GetRequest]): Future[Seq[KVStore.GetResponse]] = {
125-
logger.info(s"Performing multi-get for ${requests.size} requests")
129+
logger.debug(s"Performing multi-get for ${requests.size} requests with useBulkReadRows: $useBulkReadRows")
126130

127131
// Group requests by dataset to minimize the number of BigTable calls
128132
val requestsByDataset = requests.groupBy(_.dataset)
129133

130134
// For each dataset, make a single query with all relevant row keys
131-
val datasetFutures = requestsByDataset.map { case (dataset, datasetRequests) =>
135+
val datasetFutures =
136+
if (useBulkReadRows) {
137+
// Use bulk read for all datasets
138+
bulkReadRowsMultiGet(requestsByDataset)
139+
} else {
140+
// Use read for all datasets
141+
readRowsMultiGet(requestsByDataset)
142+
}
143+
// Combine results from all datasets
144+
Future.sequence(datasetFutures).map(_.flatten)
145+
}
146+
147+
private def bulkReadRowsMultiGet(
148+
requestsByDataset: Map[String, Seq[KVStore.GetRequest]]): Seq[Future[Seq[KVStore.GetResponse]]] = {
149+
requestsByDataset.map { case (dataset, datasetRequests) =>
132150
val targetId = mapDatasetToTable(dataset)
133151
val filter =
134152
Filters.FILTERS
@@ -173,19 +191,9 @@ class BigTableKVStoreImpl(dataClient: BigtableDataClient,
173191
val startTs = System.currentTimeMillis()
174192

175193
// Make a single BigTable call for all rows in this dataset
176-
val batcher = dataClient.newBulkReadRowsBatcher(targetId, filter)
177-
val rowApiFutures: Seq[ApiFuture[Row]] =
178-
requestsWithRowKeys.map(_._2).flatMap(rowKeys => rowKeys.map(batcher.add))
179-
180-
val apiFutureList: ApiFuture[util.List[Row]] = ApiFutures.allAsList(rowApiFutures.asJava)
181-
val completableFuture = ApiFutureUtils.toCompletableFuture(apiFutureList)
182-
val scalaResultFuture = FutureConverters.toScala(completableFuture)
183-
184-
// close batcher to prevent new work from being added + flush any pending calls
185-
batcher.close()
186-
194+
val resultFuture = readAsyncBatches(dataClient, targetId, filter, requestsWithRowKeys)
187195
// Process all results at once
188-
scalaResultFuture
196+
resultFuture
189197
.map { rows =>
190198
metricsContext.distribution("multiGet.latency", System.currentTimeMillis() - startTs, s"dataset:$dataset")
191199
metricsContext.increment("multiGet.successes", s"dataset:$dataset")
@@ -220,9 +228,113 @@ class BigTableKVStoreImpl(dataClient: BigtableDataClient,
220228
}
221229
}
222230
}.toSeq
231+
}
223232

224-
// Combine results from all datasets
225-
Future.sequence(datasetFutures).map(_.flatten)
233+
private def readAsyncBatches(
234+
dataClient: BigtableDataClient,
235+
targetId: TargetId,
236+
filter: Filters.Filter,
237+
requestsWithRowKeys: Seq[(KVStore.GetRequest, ArrayBuffer[ByteString])]): Future[util.List[Row]] = {
238+
// Make a single BigTable call for all rows in this dataset
239+
val batcher = dataClient.newBulkReadRowsBatcher(targetId, filter)
240+
val rowApiFutures: Seq[ApiFuture[Row]] =
241+
requestsWithRowKeys.map(_._2).flatMap(rowKeys => rowKeys.map(batcher.add))
242+
243+
val apiFutureList: ApiFuture[util.List[Row]] = ApiFutures.allAsList(rowApiFutures.asJava)
244+
val scalaResultFuture = googleFutureToScalaFuture(apiFutureList)
245+
246+
// close batcher to prevent new work from being added + flush any pending calls
247+
val closeFuture = batcher.closeAsync()
248+
val closeScalaFuture = googleFutureToScalaFuture(closeFuture)
249+
250+
// order matters - we need to ensure the close op is done as that triggers flushes of pending work which we
251+
// need as part of the final Future[List[Row]]
252+
for {
253+
_ <- closeScalaFuture // close the batcher (which flushes pending work)
254+
rows <- scalaResultFuture // get the results
255+
} yield rows
256+
}
257+
258+
private def readRowsMultiGet(
259+
requestsByDataset: Map[String, Seq[KVStore.GetRequest]]): Seq[Future[Seq[KVStore.GetResponse]]] = {
260+
requestsByDataset.map { case (dataset, datasetRequests) =>
261+
// Create a single query for all requests in this dataset
262+
val query = Query
263+
.create(mapDatasetToTable(dataset))
264+
.filter(Filters.FILTERS.family().exactMatch(ColumnFamilyString))
265+
.filter(Filters.FILTERS.qualifier().exactMatch(ColumnFamilyQualifierString))
266+
267+
// Track which request corresponds to which row key(s)
268+
val requestsWithRowKeys = datasetRequests.map { request =>
269+
val tableType = getTableType(dataset)
270+
val rowKeys = new mutable.ArrayBuffer[ByteString]()
271+
// Apply the appropriate filters based on request type
272+
(request.startTsMillis, tableType) match {
273+
case (Some(startTs), TileSummaries) =>
274+
val endTime = request.endTsMillis.getOrElse(System.currentTimeMillis())
275+
// Use existing method to add row keys
276+
val (_, addedRowKeys) = setQueryTimeSeriesFilters(query, startTs, endTime, request.keyBytes, dataset)
277+
rowKeys ++= addedRowKeys
278+
279+
case (Some(startTs), StreamingTable) =>
280+
val tileKey = TilingUtils.deserializeTileKey(request.keyBytes)
281+
val tileSizeMs = tileKey.tileSizeMillis
282+
val baseKeyBytes = tileKey.keyBytes.asScala.map(_.asInstanceOf[Byte])
283+
val endTime = request.endTsMillis.getOrElse(System.currentTimeMillis())
284+
285+
// Use existing method to add row keys
286+
val (_, addedRowKeys) =
287+
setQueryTimeSeriesFilters(query, startTs, endTime, baseKeyBytes, dataset, Some(tileSizeMs))
288+
rowKeys ++= addedRowKeys
289+
290+
case _ =>
291+
// For non-timeseries data, just add the single row key
292+
val baseRowKey = buildRowKey(request.keyBytes, dataset)
293+
query.rowKey(ByteString.copyFrom(baseRowKey))
294+
query.filter(Filters.FILTERS.limit().cellsPerRow(1))
295+
rowKeys.append(ByteString.copyFrom(baseRowKey))
296+
}
297+
298+
(request, rowKeys)
299+
}
300+
val startTs = System.currentTimeMillis()
301+
302+
// Make a single BigTable call for all rows in this dataset
303+
val apiFuture = dataClient.readRowsCallable().all().futureCall(query)
304+
val scalaResultFuture = googleFutureToScalaFuture(apiFuture)
305+
306+
// Process all results at once
307+
scalaResultFuture
308+
.map { rows =>
309+
metricsContext.distribution("multiGet.latency", System.currentTimeMillis() - startTs, s"dataset:$dataset")
310+
metricsContext.increment("multiGet.successes", s"dataset:$dataset")
311+
312+
// Create a map for quick lookup by row key
313+
val rowKeyToRowMap = rows.asScala.map(row => row.getKey() -> row).toMap
314+
315+
// Map back to original requests
316+
requestsWithRowKeys.map { case (request, rowKeys) =>
317+
// Get all cells from all row keys for this request
318+
val timedValues = rowKeys.flatMap { rowKey =>
319+
rowKeyToRowMap.get(rowKey).toSeq.flatMap { row =>
320+
row.getCells(ColumnFamilyString, ColumnFamilyQualifier).asScala.map { cell =>
321+
KVStore.TimedValue(cell.getValue.toByteArray, cell.getTimestamp / 1000)
322+
}
323+
}
324+
}
325+
326+
KVStore.GetResponse(request, Success(timedValues))
327+
}
328+
}
329+
.recover { case e: Exception =>
330+
logger.error("Error getting values", e)
331+
metricsContext.increment("multiGet.bigtable_errors", s"exception:${e.getClass.getName},dataset:$dataset")
332+
// If the batch fails, return failures for all requests in the batch
333+
datasetRequests.map { request =>
334+
KVStore.GetResponse(request, Failure(e))
335+
}
336+
}
337+
}.toSeq
226338
}
227339

228340
private def buildRowKeysForTimeranges(chainFilter: Filters.ChainFilter,
@@ -254,6 +366,34 @@ class BigTableKVStoreImpl(dataClient: BigtableDataClient,
254366
rowKeyByteStrings
255367
}
256368

369+
private def setQueryTimeSeriesFilters(query: Query,
370+
startTs: Long,
371+
endTs: Long,
372+
keyBytes: Seq[Byte],
373+
dataset: String,
374+
maybeTileSize: Option[Long] = None): (Query, Iterable[ByteString]) = {
375+
// we need to generate a rowkey corresponding to each day from the startTs to now
376+
val millisPerDay = 1.day.toMillis
377+
378+
val startDay = startTs - (startTs % millisPerDay)
379+
val endDay = endTs - (endTs % millisPerDay)
380+
// get the rowKeys
381+
val rowKeyByteStrings =
382+
(startDay to endDay by millisPerDay).map(dayTs => {
383+
val rowKey =
384+
maybeTileSize
385+
.map(tileSize => buildTiledRowKey(keyBytes, dataset, dayTs, tileSize))
386+
.getOrElse(buildRowKey(keyBytes, dataset, Some(dayTs)))
387+
val rowKeyByteString = ByteString.copyFrom(rowKey)
388+
query.rowKey(rowKeyByteString)
389+
rowKeyByteString
390+
})
391+
392+
// Bigtable uses microseconds, and we need to scan from startTs (millis) to endTs (millis)
393+
(query.filter(Filters.FILTERS.timestamp().range().startClosed(startTs * 1000).endClosed(endTs * 1000)),
394+
rowKeyByteStrings)
395+
}
396+
257397
override def list(request: ListRequest): Future[ListResponse] = {
258398
logger.info(s"Performing list for ${request.dataset}")
259399

@@ -287,9 +427,7 @@ class BigTableKVStoreImpl(dataClient: BigtableDataClient,
287427

288428
val startTs = System.currentTimeMillis()
289429
val rowsApiFuture = dataClient.readRowsCallable().all.futureCall(query)
290-
291-
val rowCompletableFuture = ApiFutureUtils.toCompletableFuture(rowsApiFuture)
292-
val rowsScalaFuture = FutureConverters.toScala(rowCompletableFuture)
430+
val rowsScalaFuture = googleFutureToScalaFuture(rowsApiFuture)
293431

294432
rowsScalaFuture
295433
.map { rows =>
@@ -355,8 +493,7 @@ class BigTableKVStoreImpl(dataClient: BigtableDataClient,
355493

356494
val startTs = System.currentTimeMillis()
357495
val mutateApiFuture = dataClient.mutateRowAsync(mutation)
358-
val completableFuture = ApiFutureUtils.toCompletableFuture(mutateApiFuture)
359-
val scalaFuture = FutureConverters.toScala(completableFuture)
496+
val scalaFuture = googleFutureToScalaFuture(mutateApiFuture)
360497
scalaFuture
361498
.map { _ =>
362499
metricsContext.distribution("multiPut.latency",
@@ -519,6 +656,11 @@ object BigTableKVStore {
519656
}
520657
}
521658

659+
def googleFutureToScalaFuture[T](apiFuture: ApiFuture[T]): Future[T] = {
660+
val completableFuture = ApiFutureUtils.toCompletableFuture(apiFuture)
661+
FutureConverters.toScala(completableFuture)
662+
}
663+
522664
val ColumnFamilyString: String = "cf"
523665
val ColumnFamilyQualifierString: String = "value"
524666
val ColumnFamilyQualifier: ByteString = ByteString.copyFromUtf8(ColumnFamilyQualifierString)

0 commit comments

Comments
 (0)