Skip to content

Commit 00d6aa4

Browse files
committed
revert back loadStore.Stop() to accept context
1 parent 7da08b1 commit 00d6aa4

File tree

8 files changed

+74
-51
lines changed

8 files changed

+74
-51
lines changed

xds/internal/balancer/clusterimpl/clusterimpl.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
package clusterimpl
2525

2626
import (
27+
"context"
2728
"encoding/json"
2829
"fmt"
2930
"sync"
@@ -99,7 +100,7 @@ type clusterImplBalancer struct {
99100
// The following fields are only accessed from balancer API methods, which
100101
// are guaranteed to be called serially by gRPC.
101102
xdsClient xdsclient.XDSClient // Sent down in ResolverState attributes.
102-
cancelLoadReport func(time.Duration) // To stop reporting load through the above xDS client.
103+
cancelLoadReport func(context.Context) // To stop reporting load through the above xDS client.
103104
edsServiceName string // EDS service name to report load for.
104105
lrsServer *bootstrap.ServerConfig // Load reporting server configuration.
105106
dropCategories []DropConfig // The categories for drops.
@@ -221,7 +222,9 @@ func (b *clusterImplBalancer) updateLoadStore(newConfig *LBConfig) error {
221222

222223
if stopOldLoadReport {
223224
if b.cancelLoadReport != nil {
224-
b.cancelLoadReport(loadStoreStopTimeout)
225+
stopCtx, stopCancel := context.WithTimeout(context.Background(), loadStoreStopTimeout)
226+
defer stopCancel()
227+
b.cancelLoadReport(stopCtx)
225228
b.cancelLoadReport = nil
226229
if !startNewLoadReport {
227230
// If a new LRS stream will be started later, no need to update
@@ -347,7 +350,9 @@ func (b *clusterImplBalancer) Close() {
347350
b.childState = balancer.State{}
348351

349352
if b.cancelLoadReport != nil {
350-
b.cancelLoadReport(loadStoreStopTimeout)
353+
stopCtx, stopCancel := context.WithTimeout(context.Background(), loadStoreStopTimeout)
354+
defer stopCancel()
355+
b.cancelLoadReport(stopCtx)
351356
b.cancelLoadReport = nil
352357
}
353358
b.logger.Infof("Shutdown")

xds/internal/clients/lrsclient/load_store.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package lrsclient
2020

2121
import (
22+
"context"
2223
"sync"
2324
"sync/atomic"
2425
"time"
@@ -35,7 +36,7 @@ import (
3536
// It is safe for concurrent use.
3637
type LoadStore struct {
3738
// stop is the function to call to Stop the LoadStore reporting.
38-
stop func(timeout time.Duration)
39+
stop func(ctx context.Context)
3940

4041
// mu only protects the map (2 layers). The read/write to
4142
// *PerClusterReporter doesn't need to hold the mu.
@@ -61,13 +62,13 @@ func newLoadStore() *LoadStore {
6162
// Stop signals the LoadStore to stop reporting.
6263
//
6364
// Before closing the underlying LRS stream, this method may block until a
64-
// final load report send attempt completes or the provided timeout duration
65+
// final load report send attempt completes or the provided context `ctx`
6566
// expires.
6667
//
67-
// The `timeout` duration should be set to prevent Stop from blocking
68-
// indefinitely if the final send attempt fails to complete.
69-
func (ls *LoadStore) Stop(timeout time.Duration) {
70-
ls.stop(timeout)
68+
// The provided context must have a deadline or timeout set to prevent Stop
69+
// from blocking indefinitely if the final send attempt fails to complete.
70+
func (ls *LoadStore) Stop(ctx context.Context) {
71+
ls.stop(ctx)
7172
}
7273

7374
// ReporterForCluster returns the PerClusterReporter for the given cluster and

xds/internal/clients/lrsclient/loadreport_test.go

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ func (s) TestReportLoad_ConnectionCreation(t *testing.T) {
146146
if err != nil {
147147
t.Fatalf("client.ReportLoad() failed: %v", err)
148148
}
149-
defer loadStore1.Stop(defaultTestShortTimeout)
149+
ssCtx, ssCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
150+
defer ssCancel()
151+
defer loadStore1.Stop(ssCtx)
150152

151153
// Call the load reporting API to report load to the first management
152154
// server, and ensure that a connection to the server is created.
@@ -231,7 +233,9 @@ func (s) TestReportLoad_ConnectionCreation(t *testing.T) {
231233
}
232234

233235
// Stop this load reporting stream, server should see error canceled.
234-
loadStore2.Stop(defaultTestShortTimeout)
236+
ssCtx, ssCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout)
237+
defer ssCancel()
238+
loadStore2.Stop(ssCtx)
235239

236240
// Server should receive a stream canceled error. There may be additional
237241
// load reports from the client in the channel.
@@ -418,15 +422,19 @@ func (s) TestReportLoad_StreamCreation(t *testing.T) {
418422

419423
// Cancel the first load reporting call, and ensure that the stream does not
420424
// close (because we have another call open).
421-
loadStore1.Stop(defaultTestShortTimeout)
425+
ssCtx, ssCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
426+
defer ssCancel()
427+
loadStore1.Stop(ssCtx)
422428
sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout)
423429
defer sCancel()
424430
if _, err := lrsServer.LRSStreamCloseChan.Receive(sCtx); err != context.DeadlineExceeded {
425431
t.Fatal("LRS stream closed when expected to stay open")
426432
}
427433

428434
// Stop the second load reporting call, and ensure the stream is closed.
429-
loadStore2.Stop(defaultTestShortTimeout)
435+
ssCtx, ssCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout)
436+
defer ssCancel()
437+
loadStore2.Stop(ssCtx)
430438
if _, err := lrsServer.LRSStreamCloseChan.Receive(ctx); err != nil {
431439
t.Fatal("Timeout waiting for LRS stream to close")
432440
}
@@ -441,16 +449,18 @@ func (s) TestReportLoad_StreamCreation(t *testing.T) {
441449
if _, err := lrsServer.LRSStreamOpenChan.Receive(ctx); err != nil {
442450
t.Fatalf("Timeout when waiting for LRS stream to be created: %v", err)
443451
}
444-
loadStore3.Stop(defaultTestShortTimeout)
452+
ssCtx, ssCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout)
453+
defer ssCancel()
454+
loadStore3.Stop(ssCtx)
445455
}
446456

447-
// TestReportLoad_StopWithTimeout tests the behavior of LoadStore.Stop() when
448-
// called with a timeout duration. It verifies that:
449-
// - Stop() blocks until the timeout expires or final load send attempt is
457+
// TestReportLoad_StopWithContext tests the behavior of LoadStore.Stop() when
458+
// called with a context. It verifies that:
459+
// - Stop() blocks until the context expires or final load send attempt is
450460
// made.
451461
// - Final load report is seen on the server after stop is called.
452462
// - The stream is closed after Stop() returns.
453-
func (s) TestReportLoad_StopWithTimeout(t *testing.T) {
463+
func (s) TestReportLoad_StopWithContext(t *testing.T) {
454464
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
455465
defer cancel()
456466

@@ -535,11 +545,11 @@ func (s) TestReportLoad_StopWithTimeout(t *testing.T) {
535545
t.Fatalf("Unexpected diff in LRS request (-got, +want):\n%s", diff)
536546
}
537547

538-
// Create a timeout duration for Stop() that remains until the end of test
539-
// to ensure that only possibility of Stop() to finish is if final load
540-
// send attempt is made. If final load attempt is not made, test itself
541-
// will timeout.
542-
largeStopTimeout := 10 * defaultTestTimeout
548+
// Create a context for Stop() that remains until the end of test to ensure
549+
// that only possibility of Stop()s to finish is if final load send attempt
550+
// is made. If final load attempt is not made, test will timeout.
551+
stopCtx, stopCancel := context.WithCancel(ctx)
552+
defer stopCancel()
543553

544554
// Push more loads.
545555
loadStore.ReporterForCluster("cluster2", "eds2").CallDropped("test")
@@ -548,7 +558,7 @@ func (s) TestReportLoad_StopWithTimeout(t *testing.T) {
548558
// Call Stop in a separate goroutine. It will block until
549559
// final load send attempt is made.
550560
go func() {
551-
loadStore.Stop(largeStopTimeout)
561+
loadStore.Stop(stopCtx)
552562
close(stopCalled)
553563
}()
554564
<-stopCalled

xds/internal/clients/lrsclient/lrsclient.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
package lrsclient
2323

2424
import (
25+
"context"
2526
"errors"
2627
"fmt"
2728
"sync"
@@ -140,8 +141,8 @@ func (c *LRSClient) getOrCreateLRSStream(serverIdentifier clients.ServerIdentifi
140141
// the LRS stream when the last reference is removed and closes the
141142
// transport and removes the lrs stream and its references from the
142143
// respective maps. Before closing the stream, it waits for the provided
143-
// timeout duration for the final load report attempt to complete.
144-
stop := func(timeout time.Duration) {
144+
// context to be done (timeout or cancellation).
145+
stop := func(ctx context.Context) {
145146
c.mu.Lock()
146147
defer c.mu.Unlock()
147148

@@ -156,16 +157,13 @@ func (c *LRSClient) getOrCreateLRSStream(serverIdentifier clients.ServerIdentifi
156157

157158
lrs.finalSendRequest <- struct{}{}
158159

159-
timer := time.NewTimer(timeout)
160-
defer timer.Stop()
161-
162160
select {
163161
case err := <-lrs.finalSendDone:
164162
if err != nil {
165163
c.logger.Warningf("Final send attempt failed: %v", err)
166164
}
167-
case <-timer.C:
168-
c.logger.Warningf("Timed out before finishing the final send attempt: %v", err)
165+
case <-ctx.Done():
166+
c.logger.Warningf("Context canceled before finishing the final send attempt: %v", err)
169167
}
170168

171169
lrs.cancelStream()

xds/internal/testutils/fakeclient/client.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ package fakeclient
2121

2222
import (
2323
"context"
24-
"time"
2524

2625
"google.golang.org/grpc/internal/testutils"
2726
"google.golang.org/grpc/internal/xds/bootstrap"
@@ -81,14 +80,14 @@ func (*stream) Recv() ([]byte, error) {
8180
}
8281

8382
// ReportLoad starts reporting load about clusterName to server.
84-
func (xdsC *Client) ReportLoad(server *bootstrap.ServerConfig) (loadStore *lrsclient.LoadStore, cancel func(time.Duration)) {
83+
func (xdsC *Client) ReportLoad(server *bootstrap.ServerConfig) (loadStore *lrsclient.LoadStore, cancel func(context.Context)) {
8584
lrsClient, _ := lrsclient.New(lrsclient.Config{Node: clients.Node{ID: "fake-node-id"}, TransportBuilder: &transportBuilder{}})
8685
xdsC.loadStore, _ = lrsClient.ReportLoad(clients.ServerIdentifier{ServerURI: server.ServerURI()})
8786

8887
xdsC.loadReportCh.Send(ReportLoadArgs{Server: server})
8988

90-
return xdsC.loadStore, func(timeout time.Duration) {
91-
xdsC.loadStore.Stop(timeout)
89+
return xdsC.loadStore, func(ctx context.Context) {
90+
xdsC.loadStore.Stop(ctx)
9291
xdsC.lrsCancelCh.Send(nil)
9392
}
9493
}

xds/internal/xdsclient/client.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
package xdsclient
2222

2323
import (
24-
"time"
24+
"context"
2525

2626
v3statuspb "github.com/envoyproxy/go-control-plane/envoy/service/status/v3"
2727
"google.golang.org/grpc/internal/xds/bootstrap"
@@ -49,7 +49,7 @@ type XDSClient interface {
4949
// the watcher is canceled. Callers need to handle this case.
5050
WatchResource(rType xdsresource.Type, resourceName string, watcher xdsresource.ResourceWatcher) (cancel func())
5151

52-
ReportLoad(*bootstrap.ServerConfig) (*lrsclient.LoadStore, func(time.Duration))
52+
ReportLoad(*bootstrap.ServerConfig) (*lrsclient.LoadStore, func(context.Context))
5353

5454
BootstrapConfig() *bootstrap.Config
5555
}

xds/internal/xdsclient/clientimpl_loadreport.go

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
package xdsclient
1919

2020
import (
21+
"context"
2122
"sync"
22-
"time"
2323

2424
"google.golang.org/grpc/internal/xds/bootstrap"
2525
"google.golang.org/grpc/xds/internal/clients"
@@ -31,26 +31,24 @@ import (
3131
// reports to the same server share the LRS stream.
3232
//
3333
// It returns a lrsclient.LoadStore for the user to report loads.
34-
func (c *clientImpl) ReportLoad(server *bootstrap.ServerConfig) (*lrsclient.LoadStore, func(time.Duration)) {
34+
func (c *clientImpl) ReportLoad(server *bootstrap.ServerConfig) (*lrsclient.LoadStore, func(context.Context)) {
3535
if c.lrsClient == nil {
3636
lrsConfig := lrsclient.Config{Node: c.gConfig.Node, TransportBuilder: c.gConfig.TransportBuilder}
3737
lrsC, err := lrsclient.New(lrsConfig)
3838
if err != nil {
3939
c.logger.Warningf("Failed to create an lrs client to the management server to report load: %v", server, err)
40-
return nil, func(time.Duration) {}
40+
return nil, func(context.Context) {}
4141
}
4242
c.lrsClient = lrsC
4343
}
4444

4545
load, err := c.lrsClient.ReportLoad(clients.ServerIdentifier{ServerURI: server.ServerURI(), Extensions: grpctransport.ServerIdentifierExtension{ConfigName: server.SelectedCreds().Type}})
4646
if err != nil {
4747
c.logger.Warningf("Failed to create a load store to the management server to report load: %v", server, err)
48-
return nil, func(time.Duration) {}
48+
return nil, func(context.Context) {}
4949
}
5050
var loadStop sync.Once
51-
return load, func(timeout time.Duration) {
52-
loadStop.Do(func() {
53-
load.Stop(timeout)
54-
})
51+
return load, func(ctx context.Context) {
52+
loadStop.Do(func() { load.Stop(ctx) })
5553
}
5654
}

xds/internal/xdsclient/tests/loadreport_test.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ func (s) TestReportLoad_ConnectionCreation(t *testing.T) {
144144
// Call the load reporting API to report load to the first management
145145
// server, and ensure that a connection to the server is created.
146146
store1, lrsCancel1 := client.ReportLoad(serverCfg1)
147-
defer lrsCancel1(defaultTestShortTimeout)
147+
sCtx, sCancel := context.WithTimeout(ctx, defaultTestShortTimeout)
148+
defer sCancel()
149+
defer lrsCancel1(sCtx)
148150
if _, err := newConnChan1.Receive(ctx); err != nil {
149151
t.Fatal("Timeout when waiting for a connection to the first management server, after starting load reporting")
150152
}
@@ -159,7 +161,9 @@ func (s) TestReportLoad_ConnectionCreation(t *testing.T) {
159161
// Call the load reporting API to report load to the second management
160162
// server, and ensure that a connection to the server is created.
161163
store2, lrsCancel2 := client.ReportLoad(serverCfg2)
162-
defer lrsCancel2(defaultTestShortTimeout)
164+
sCtx2, sCancel2 := context.WithTimeout(ctx, defaultTestShortTimeout)
165+
defer sCancel2()
166+
defer lrsCancel2(sCtx2)
163167
if _, err := newConnChan2.Receive(ctx); err != nil {
164168
t.Fatal("Timeout when waiting for a connection to the second management server, after starting load reporting")
165169
}
@@ -227,7 +231,9 @@ func (s) TestReportLoad_ConnectionCreation(t *testing.T) {
227231
}
228232

229233
// Cancel this load reporting stream, server should see error canceled.
230-
lrsCancel2(defaultTestShortTimeout)
234+
sCtx2, sCancel2 = context.WithTimeout(ctx, defaultTestShortTimeout)
235+
defer sCancel2()
236+
lrsCancel2(sCtx2)
231237

232238
// Server should receive a stream canceled error. There may be additional
233239
// load reports from the client in the channel.
@@ -403,15 +409,19 @@ func (s) TestReportLoad_StreamCreation(t *testing.T) {
403409

404410
// Cancel the first load reporting call, and ensure that the stream does not
405411
// close (because we have another call open).
406-
cancel1(defaultTestShortTimeout)
412+
sCtx1, sCancel1 := context.WithTimeout(ctx, defaultTestShortTimeout)
413+
defer sCancel1()
414+
cancel1(sCtx1)
407415
sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout)
408416
defer sCancel()
409417
if _, err := lrsServer.LRSStreamCloseChan.Receive(sCtx); err != context.DeadlineExceeded {
410418
t.Fatal("LRS stream closed when expected to stay open")
411419
}
412420

413421
// Cancel the second load reporting call, and ensure the stream is closed.
414-
cancel2(defaultTestShortTimeout)
422+
sCtx2, sCancel2 := context.WithTimeout(ctx, defaultTestShortTimeout)
423+
defer sCancel2()
424+
cancel2(sCtx2)
415425
if _, err := lrsServer.LRSStreamCloseChan.Receive(ctx); err != nil {
416426
t.Fatal("Timeout waiting for LRS stream to close")
417427
}
@@ -423,5 +433,7 @@ func (s) TestReportLoad_StreamCreation(t *testing.T) {
423433
if _, err := lrsServer.LRSStreamOpenChan.Receive(ctx); err != nil {
424434
t.Fatalf("Timeout when waiting for LRS stream to be created: %v", err)
425435
}
426-
cancel3(defaultTestShortTimeout)
436+
sCtx3, sCancel3 := context.WithTimeout(ctx, defaultTestShortTimeout)
437+
defer sCancel3()
438+
cancel3(sCtx3)
427439
}

0 commit comments

Comments
 (0)