Skip to content

chore: minor Improvements to providerquerymanager #728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Nov 27, 2024
Merged
61 changes: 38 additions & 23 deletions routing/providerquerymanager/providerquerymanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@
}
}

// WithMaxInProcessRequests is the maximum number of requests that can be processed in parallel
// WithMaxInProcessRequests is the maximum number of requests that can be
// processed in parallel. If this is 0, then the number is unlimited. Default
// is defaultMaxInProcessRequests.
func WithMaxInProcessRequests(count int) Option {
return func(mgr *ProviderQueryManager) error {
mgr.maxInProcessRequests = count
Expand All @@ -117,7 +119,7 @@
}

// WithMaxProviders is the maximum number of providers that will be looked up
// per query. We only return providers that we can connect to. Defaults to 0,
// per query. We only return providers that we can connect to. Defaults to 0,
// which means unbounded.
func WithMaxProviders(count int) Option {
return func(mgr *ProviderQueryManager) error {
Expand Down Expand Up @@ -215,8 +217,8 @@
var receivedProviders deque.Deque[peer.AddrInfo]
receivedProviders.Grow(len(receivedInProgressRequest.providersSoFar))
for _, addrInfo := range receivedInProgressRequest.providersSoFar {
receivedProviders.PushBack(addrInfo)
}

Check warning on line 221 in routing/providerquerymanager/providerquerymanager.go

View check run for this annotation

Codecov / codecov/patch

routing/providerquerymanager/providerquerymanager.go#L220-L221

Added lines #L220 - L221 were not covered by tests
incomingProviders := receivedInProgressRequest.incoming

// count how many providers we received from our workers etc.
Expand Down Expand Up @@ -304,25 +306,42 @@
}
}

// findProviderWorker cycles through incoming provider queries one at a time.
func (pqm *ProviderQueryManager) findProviderWorker() {
// findProviderWorker just cycles through incoming provider queries one
// at a time. We have six of these workers running at once
// to let requests go in parallel but keep them rate limited
for {
select {
case fpr, ok := <-pqm.providerRequestsProcessing.Out():
if !ok {
var findSem chan struct{}
// If limiting the number of concurrent requests, create a counting
// semaphore to enforce this limit.
if pqm.maxInProcessRequests > 0 {
findSem = make(chan struct{}, pqm.maxInProcessRequests)
}

// Read find provider requests until channel is closed. The channl is
// closed as soon as pqm.ctx is canceled, so there is no need to select on
// that context here.
for fpr := range pqm.providerRequestsProcessing.Out() {
if findSem != nil {
select {
case findSem <- struct{}{}:
case <-pqm.ctx.Done():

Check warning on line 325 in routing/providerquerymanager/providerquerymanager.go

View check run for this annotation

Codecov / codecov/patch

routing/providerquerymanager/providerquerymanager.go#L325

Added line #L325 was not covered by tests
return
}
k := fpr.k
}

go func(ctx context.Context, k cid.Cid) {
if findSem != nil {
defer func() {
<-findSem
}()
}

log.Debugf("Beginning Find Provider Request for cid: %s", k.String())
findProviderCtx, cancel := context.WithTimeout(fpr.ctx, pqm.findProviderTimeout)
findProviderCtx, cancel := context.WithTimeout(ctx, pqm.findProviderTimeout)
span := trace.SpanFromContext(findProviderCtx)
span.AddEvent("StartFindProvidersAsync")
// We set count == 0. We will cancel the query
// manually once we have enough. This assumes the
// ContentDiscovery implementation does that, which a
// requirement per the libp2p/core/routing interface.
// We set count == 0. We will cancel the query manually once we
// have enough. This assumes the ContentDiscovery
// implementation does that, which a requirement per the
// libp2p/core/routing interface.
providers := pqm.router.FindProvidersAsync(findProviderCtx, k, 0)
wg := &sync.WaitGroup{}
for p := range providers {
Expand All @@ -339,7 +358,7 @@
span.AddEvent("ConnectedToProvider", trace.WithAttributes(attribute.Stringer("peer", p.ID)))
select {
case pqm.providerQueryMessages <- &receivedProviderMessage{
ctx: fpr.ctx,
ctx: ctx,
k: k,
p: p,
}:
Expand All @@ -352,14 +371,12 @@
cancel()
select {
case pqm.providerQueryMessages <- &finishedProviderQueryMessage{
ctx: fpr.ctx,
ctx: ctx,
k: k,
}:
case <-pqm.ctx.Done():
}
case <-pqm.ctx.Done():
return
}
}(fpr.ctx, fpr.k)
}
}

Expand All @@ -378,9 +395,7 @@
pqm.providerRequestsProcessing = chanqueue.New[*findProviderRequest]()
defer pqm.providerRequestsProcessing.Shutdown()

for i := 0; i < pqm.maxInProcessRequests; i++ {
go pqm.findProviderWorker()
}
go pqm.findProviderWorker()

for {
select {
Expand Down
56 changes: 50 additions & 6 deletions routing/providerquerymanager/providerquerymanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ func TestPeersWithConnectionErrorsNotAddedToPeerList(t *testing.T) {
}

func TestRateLimitingRequests(t *testing.T) {
const maxInProcessRequests = 6

peers := random.Peers(10)
fpd := &fakeProviderDialer{}
fpn := &fakeProviderDiscovery{
Expand All @@ -272,31 +274,73 @@ func TestRateLimitingRequests(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
providerQueryManager := mustNotErr(New(ctx, fpd, fpn))
providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxInProcessRequests(maxInProcessRequests)))
providerQueryManager.Startup()

keys := random.Cids(providerQueryManager.maxInProcessRequests + 1)
keys := random.Cids(maxInProcessRequests + 1)
sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
var requestChannels []<-chan peer.AddrInfo
for i := 0; i < providerQueryManager.maxInProcessRequests+1; i++ {
for i := 0; i < maxInProcessRequests+1; i++ {
requestChannels = append(requestChannels, providerQueryManager.FindProvidersAsync(sessionCtx, keys[i], 0))
}
time.Sleep(20 * time.Millisecond)
fpn.queriesMadeMutex.Lock()
if fpn.liveQueries != providerQueryManager.maxInProcessRequests {
if fpn.liveQueries != maxInProcessRequests {
t.Logf("Queries made: %d\n", fpn.liveQueries)
t.Fatal("Did not limit parallel requests to rate limit")
}
fpn.queriesMadeMutex.Unlock()
for i := 0; i < providerQueryManager.maxInProcessRequests+1; i++ {
for i := 0; i < maxInProcessRequests+1; i++ {
for range requestChannels[i] {
}
}

fpn.queriesMadeMutex.Lock()
defer fpn.queriesMadeMutex.Unlock()
if fpn.queriesMade != maxInProcessRequests+1 {
t.Logf("Queries made: %d\n", fpn.queriesMade)
t.Fatal("Did not make all separate requests")
}
}

func TestUnlimitedRequests(t *testing.T) {
const inProcessRequests = 11

peers := random.Peers(10)
fpd := &fakeProviderDialer{}
fpn := &fakeProviderDiscovery{
peersFound: peers,
delay: 5 * time.Millisecond,
}
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxInProcessRequests(0)))
providerQueryManager.Startup()

keys := random.Cids(inProcessRequests)
sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
var requestChannels []<-chan peer.AddrInfo
for i := 0; i < inProcessRequests; i++ {
requestChannels = append(requestChannels, providerQueryManager.FindProvidersAsync(sessionCtx, keys[i], 0))
}
time.Sleep(20 * time.Millisecond)
fpn.queriesMadeMutex.Lock()
if fpn.liveQueries != inProcessRequests {
t.Logf("Queries made: %d\n", fpn.liveQueries)
t.Fatal("Parallel requests appear to be rate limited")
}
fpn.queriesMadeMutex.Unlock()
for i := 0; i < inProcessRequests; i++ {
for range requestChannels[i] {
}
}

fpn.queriesMadeMutex.Lock()
defer fpn.queriesMadeMutex.Unlock()
if fpn.queriesMade != providerQueryManager.maxInProcessRequests+1 {
if fpn.queriesMade != inProcessRequests {
t.Logf("Queries made: %d\n", fpn.queriesMade)
t.Fatal("Did not make all separate requests")
}
Expand Down