Skip to content

Commit 6da08b0

Browse files
committed
Add iteration support to dynamo
This builds on top of #52199 by updating the dynamo backend to implement backend.BackendWithItems. Both GetRange and DeleteRange were refactored to call Items instead of getAllRecords to unify logic and vet the implementation of Items. The custom pagination logic to retrieve items was also removed in favor of the built in query paginator from the aws sdk. In addition to simplifying logic this also removed some extraneous sorting.
1 parent 1f69fd5 commit 6da08b0

File tree

1 file changed

+133
-156
lines changed

1 file changed

+133
-156
lines changed

lib/backend/dynamo/dynamodbbk.go

Lines changed: 133 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ package dynamo
2121
import (
2222
"context"
2323
"errors"
24+
"iter"
2425
"log/slog"
2526
"net/http"
26-
"sort"
2727
"strconv"
2828
"sync/atomic"
2929
"time"
@@ -548,6 +548,95 @@ func (b *Backend) Update(ctx context.Context, item backend.Item) (*backend.Lease
548548
return backend.NewLease(item), nil
549549
}
550550

551+
func (b *Backend) Items(ctx context.Context, params backend.IterateParams) iter.Seq2[backend.Item, error] {
552+
if params.StartKey.IsZero() {
553+
err := trace.BadParameter("missing parameter startKey")
554+
return func(yield func(backend.Item, error) bool) { yield(backend.Item{}, err) }
555+
}
556+
if params.EndKey.IsZero() {
557+
err := trace.BadParameter("missing parameter endKey")
558+
return func(yield func(backend.Item, error) bool) { yield(backend.Item{}, err) }
559+
}
560+
561+
const (
562+
query = "HashKey = :hashKey AND FullPath BETWEEN :fullPath AND :rangeEnd"
563+
564+
// filter out expired items, otherwise they might show up in the query
565+
// http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/howitworks-ttl.html
566+
filter = "attribute_not_exists(Expires) OR Expires >= :timestamp"
567+
)
568+
569+
attrV := map[string]interface{}{
570+
":fullPath": prependPrefix(params.StartKey),
571+
":hashKey": hashKey,
572+
":timestamp": b.clock.Now().UTC().Unix(),
573+
":rangeEnd": prependPrefix(params.EndKey),
574+
}
575+
576+
av, err := attributevalue.MarshalMap(attrV)
577+
if err != nil {
578+
return func(yield func(backend.Item, error) bool) { yield(backend.Item{}, err) }
579+
}
580+
581+
input := dynamodb.QueryInput{
582+
KeyConditionExpression: aws.String(query),
583+
TableName: &b.TableName,
584+
ExpressionAttributeValues: av,
585+
FilterExpression: aws.String(filter),
586+
ConsistentRead: aws.Bool(true),
587+
ScanIndexForward: aws.Bool(!params.Descending),
588+
}
589+
if params.Limit > 0 {
590+
input.Limit = aws.Int32(int32(params.Limit))
591+
}
592+
593+
return func(yield func(backend.Item, error) bool) {
594+
count := 0
595+
defer func() {
596+
if count == backend.DefaultRangeLimit {
597+
b.logger.WarnContext(ctx, "Range query hit backend limit. (this is a bug!)", "start_key", params.StartKey, "limit", backend.DefaultRangeLimit)
598+
}
599+
}()
600+
601+
paginator := dynamodb.NewQueryPaginator(b.svc, &input)
602+
for paginator.HasMorePages() {
603+
page, err := paginator.NextPage(ctx)
604+
if err != nil {
605+
yield(backend.Item{}, convertError(err))
606+
return
607+
}
608+
609+
for _, itemAttributes := range page.Items {
610+
var r record
611+
if err := attributevalue.UnmarshalMap(itemAttributes, &r); err != nil {
612+
yield(backend.Item{}, convertError(err))
613+
return
614+
}
615+
616+
item := backend.Item{
617+
Key: trimPrefix(r.FullPath),
618+
Value: r.Value,
619+
Revision: r.Revision,
620+
}
621+
if r.Expires != nil {
622+
item.Expires = time.Unix(*r.Expires, 0).UTC()
623+
}
624+
if item.Revision == "" {
625+
item.Revision = backend.BlankRevision
626+
}
627+
628+
if !yield(item, nil) {
629+
return
630+
}
631+
count++
632+
if params.Limit != backend.NoLimit && count >= params.Limit {
633+
return
634+
}
635+
}
636+
}
637+
}
638+
}
639+
551640
// GetRange returns range of elements
552641
func (b *Backend) GetRange(ctx context.Context, startKey, endKey backend.Key, limit int) (*backend.GetResult, error) {
553642
if startKey.IsZero() {
@@ -560,51 +649,11 @@ func (b *Backend) GetRange(ctx context.Context, startKey, endKey backend.Key, li
560649
limit = backend.DefaultRangeLimit
561650
}
562651

563-
result, err := b.getAllRecords(ctx, startKey, endKey, limit)
564-
if err != nil {
565-
return nil, trace.Wrap(err)
566-
}
567-
sort.Sort(records(result.records))
568-
values := make([]backend.Item, len(result.records))
569-
for i, r := range result.records {
570-
values[i] = backend.Item{
571-
Key: trimPrefix(r.FullPath),
572-
Value: r.Value,
573-
Revision: r.Revision,
574-
}
575-
if r.Expires != nil {
576-
values[i].Expires = time.Unix(*r.Expires, 0).UTC()
577-
}
578-
if values[i].Revision == "" {
579-
values[i].Revision = backend.BlankRevision
580-
}
652+
var result backend.GetResult
653+
for i := range b.Items(ctx, backend.IterateParams{StartKey: startKey, EndKey: endKey, Limit: limit}) {
654+
result.Items = append(result.Items, i)
581655
}
582-
return &backend.GetResult{Items: values}, nil
583-
}
584-
585-
func (b *Backend) getAllRecords(ctx context.Context, startKey, endKey backend.Key, limit int) (*getResult, error) {
586-
var result getResult
587-
588-
// this code is being extra careful here not to introduce endless loop
589-
// by some unfortunate series of events
590-
for i := 0; i < backend.DefaultRangeLimit/100; i++ {
591-
re, err := b.getRecords(ctx, prependPrefix(startKey), prependPrefix(endKey), limit, result.lastEvaluatedKey)
592-
if err != nil {
593-
return nil, trace.Wrap(err)
594-
}
595-
result.records = append(result.records, re.records...)
596-
// If the limit was exceeded or there are no more records to fetch return the current result
597-
// otherwise updated lastEvaluatedKey and proceed with obtaining new records.
598-
if (limit != 0 && len(result.records) >= limit) || len(re.lastEvaluatedKey) == 0 {
599-
if len(result.records) == backend.DefaultRangeLimit {
600-
b.logger.WarnContext(ctx, "Range query hit backend limit. (this is a bug!)", "start_key", startKey, "limit", backend.DefaultRangeLimit)
601-
}
602-
result.lastEvaluatedKey = nil
603-
return &result, nil
604-
}
605-
result.lastEvaluatedKey = re.lastEvaluatedKey
606-
}
607-
return nil, trace.BadParameter("backend entered endless loop")
656+
return &result, nil
608657
}
609658

610659
const (
@@ -623,38 +672,54 @@ func (b *Backend) DeleteRange(ctx context.Context, startKey, endKey backend.Key)
623672
if endKey.IsZero() {
624673
return trace.BadParameter("missing parameter endKey")
625674
}
626-
// keep fetching and deleting until no records left,
627-
// keep the very large limit, just in case if someone else
628-
// keeps adding records
629-
for i := 0; i < backend.DefaultRangeLimit/100; i++ {
630-
result, err := b.getRecords(ctx, prependPrefix(startKey), prependPrefix(endKey), batchOperationItemsLimit, nil)
631-
if err != nil {
632-
return trace.Wrap(err)
675+
676+
// Attempt to pull all existing items and delete them in batches
677+
// in accordance with the BatchWriteItem limits. There is a hard
678+
// cap on the total number of items that can be deleted in a single
679+
// DeleteRange call to avoid racing with additional records being added.
680+
const maxDeletions = backend.DefaultRangeLimit / 100
681+
requests := make([]types.WriteRequest, batchOperationItemsLimit)
682+
pageCount, totalCount := 0, 0
683+
for item := range b.Items(ctx, backend.IterateParams{StartKey: startKey, EndKey: endKey}) {
684+
if totalCount >= maxDeletions {
685+
break
633686
}
634-
if len(result.records) == 0 {
635-
return nil
687+
totalCount++
688+
689+
requests[pageCount] = types.WriteRequest{
690+
DeleteRequest: &types.DeleteRequest{
691+
Key: map[string]types.AttributeValue{
692+
hashKeyKey: &types.AttributeValueMemberS{Value: hashKey},
693+
fullPathKey: &types.AttributeValueMemberS{Value: prependPrefix(item.Key)},
694+
},
695+
},
636696
}
637-
requests := make([]types.WriteRequest, 0, len(result.records))
638-
for _, record := range result.records {
639-
requests = append(requests, types.WriteRequest{
640-
DeleteRequest: &types.DeleteRequest{
641-
Key: map[string]types.AttributeValue{
642-
hashKeyKey: &types.AttributeValueMemberS{Value: hashKey},
643-
fullPathKey: &types.AttributeValueMemberS{Value: record.FullPath},
644-
},
697+
pageCount++
698+
699+
if pageCount == batchOperationItemsLimit {
700+
if _, err := b.svc.BatchWriteItem(ctx, &dynamodb.BatchWriteItemInput{
701+
RequestItems: map[string][]types.WriteRequest{
702+
b.TableName: requests,
645703
},
646-
})
704+
}); err != nil {
705+
return trace.Wrap(err)
706+
}
707+
pageCount = 0
647708
}
648-
input := dynamodb.BatchWriteItemInput{
709+
}
710+
711+
if totalCount < maxDeletions && pageCount > 0 {
712+
if _, err := b.svc.BatchWriteItem(ctx, &dynamodb.BatchWriteItemInput{
649713
RequestItems: map[string][]types.WriteRequest{
650-
b.TableName: requests,
714+
b.TableName: requests[:pageCount],
651715
},
652-
}
653-
654-
if _, err = b.svc.BatchWriteItem(ctx, &input); err != nil {
716+
}); err != nil {
655717
return trace.Wrap(err)
656718
}
719+
720+
return nil
657721
}
722+
658723
return trace.ConnectionProblem(nil, "not all items deleted, too many requests")
659724
}
660725

@@ -961,60 +1026,6 @@ func (b *Backend) createTable(ctx context.Context, tableName *string, rangeKey s
9611026
return trace.Wrap(err)
9621027
}
9631028

964-
type getResult struct {
965-
// lastEvaluatedKey is the primary key of the item where the operation stopped, inclusive of the
966-
// previous result set. Use this value to start a new operation, excluding this
967-
// value in the new request.
968-
lastEvaluatedKey map[string]types.AttributeValue
969-
records []record
970-
}
971-
972-
// getRecords retrieves all keys by path
973-
func (b *Backend) getRecords(ctx context.Context, startKey, endKey string, limit int, lastEvaluatedKey map[string]types.AttributeValue) (*getResult, error) {
974-
query := "HashKey = :hashKey AND FullPath BETWEEN :fullPath AND :rangeEnd"
975-
attrV := map[string]interface{}{
976-
":fullPath": startKey,
977-
":hashKey": hashKey,
978-
":timestamp": b.clock.Now().UTC().Unix(),
979-
":rangeEnd": endKey,
980-
}
981-
982-
// filter out expired items, otherwise they might show up in the query
983-
// http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/howitworks-ttl.html
984-
filter := "attribute_not_exists(Expires) OR Expires >= :timestamp"
985-
av, err := attributevalue.MarshalMap(attrV)
986-
if err != nil {
987-
return nil, convertError(err)
988-
}
989-
input := dynamodb.QueryInput{
990-
KeyConditionExpression: aws.String(query),
991-
TableName: &b.TableName,
992-
ExpressionAttributeValues: av,
993-
FilterExpression: aws.String(filter),
994-
ConsistentRead: aws.Bool(true),
995-
ExclusiveStartKey: lastEvaluatedKey,
996-
}
997-
if limit > 0 {
998-
input.Limit = aws.Int32(int32(limit))
999-
}
1000-
out, err := b.svc.Query(ctx, &input)
1001-
if err != nil {
1002-
return nil, trace.Wrap(err)
1003-
}
1004-
var result getResult
1005-
for _, item := range out.Items {
1006-
var r record
1007-
if err := attributevalue.UnmarshalMap(item, &r); err != nil {
1008-
return nil, trace.Wrap(err)
1009-
}
1010-
result.records = append(result.records, r)
1011-
}
1012-
sort.Sort(records(result.records))
1013-
result.records = removeDuplicates(result.records)
1014-
result.lastEvaluatedKey = out.LastEvaluatedKey
1015-
return &result, nil
1016-
}
1017-
10181029
// isExpired returns 'true' if the given object (record) has a TTL and
10191030
// it's due.
10201031
func (r *record) isExpired(now time.Time) bool {
@@ -1025,23 +1036,6 @@ func (r *record) isExpired(now time.Time) bool {
10251036
return now.UTC().After(expiryDateUTC)
10261037
}
10271038

1028-
func removeDuplicates(elements []record) []record {
1029-
// Use map to record duplicates as we find them.
1030-
encountered := map[string]bool{}
1031-
var result []record
1032-
1033-
for v := range elements {
1034-
if !encountered[elements[v].FullPath] {
1035-
// Record this element as an encountered element.
1036-
encountered[elements[v].FullPath] = true
1037-
// Append to result slice.
1038-
result = append(result, elements[v])
1039-
}
1040-
}
1041-
// Return the new slice.
1042-
return result
1043-
}
1044-
10451039
const (
10461040
modeCreate = iota
10471041
modePut
@@ -1235,23 +1229,6 @@ func convertError(err error) error {
12351229
return err
12361230
}
12371231

1238-
type records []record
1239-
1240-
// Len is part of sort.Interface.
1241-
func (r records) Len() int {
1242-
return len(r)
1243-
}
1244-
1245-
// Swap is part of sort.Interface.
1246-
func (r records) Swap(i, j int) {
1247-
r[i], r[j] = r[j], r[i]
1248-
}
1249-
1250-
// Less is part of sort.Interface.
1251-
func (r records) Less(i, j int) bool {
1252-
return r[i].FullPath < r[j].FullPath
1253-
}
1254-
12551232
func fullPathToAttributeValueMap(fullPath string) map[string]types.AttributeValue {
12561233
return map[string]types.AttributeValue{
12571234
hashKeyKey: &types.AttributeValueMemberS{Value: hashKey},

0 commit comments

Comments
 (0)