diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f0d0de90..bd55019b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,8 +6,10 @@ [[GH-241]](https://github.com/digitalocean/csi-digitalocean/pull/241) * Check all snapshots for existence [[GH-240]](https://github.com/digitalocean/csi-digitalocean/pull/240) -* Update sidecars +* Implement graceful shutdown [[GH-238]](https://github.com/digitalocean/csi-digitalocean/pull/238) +* Update sidecars + [[GH-236]](https://github.com/digitalocean/csi-digitalocean/pull/236) * Support checkLimit for multiple pages [[GH-235]](https://github.com/digitalocean/csi-digitalocean/pull/235) * Return error when fetching the snapshot fails diff --git a/cmd/do-csi-plugin/main.go b/cmd/do-csi-plugin/main.go index 4ee3affc5..5a307c301 100644 --- a/cmd/do-csi-plugin/main.go +++ b/cmd/do-csi-plugin/main.go @@ -17,10 +17,13 @@ limitations under the License. package main import ( + "context" "flag" "fmt" "log" "os" + "os/signal" + "syscall" "github.com/digitalocean/csi-digitalocean/driver" ) @@ -47,7 +50,17 @@ func main() { log.Fatalln(err) } - if err := drv.Run(); err != nil { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-c + cancel() + }() + + if err := drv.Run(ctx); err != nil { log.Fatalln(err) } } diff --git a/driver/driver.go b/driver/driver.go index 27892514a..87c7cb9bc 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -164,7 +164,7 @@ func NewDriver(ep, token, url, doTag, driverName, address string) (*Driver, erro } // Run starts the CSI plugin by communication over the given endpoint -func (d *Driver) Run() error { +func (d *Driver) Run(ctx context.Context) error { u, err := url.Parse(d.endpoint) if err != nil { return fmt.Errorf("unable to parse address: %q", err) @@ -247,25 +247,31 @@ func (d *Driver) Run() error { var eg errgroup.Group eg.Go(func() error { + <-ctx.Done() + return d.httpSrv.Shutdown(context.Background()) + }) + eg.Go(func() error { + go func() { + <-ctx.Done() + d.log.Info("server stopped") + d.readyMu.Lock() + d.ready = false + d.readyMu.Unlock() + d.srv.GracefulStop() + }() return d.srv.Serve(grpcListener) }) eg.Go(func() error { - return d.httpSrv.Serve(httpListener) + err := d.httpSrv.Serve(httpListener) + if err == http.ErrServerClosed { + return nil + } + return err }) return eg.Wait() } -// Stop stops the plugin -func (d *Driver) Stop() { - d.readyMu.Lock() - d.ready = false - d.readyMu.Unlock() - - d.log.Info("server stopped") - d.srv.Stop() -} - // When building any packages that import version, pass the build/install cmd // ldflags like so: // go build -ldflags "-X github.com/digitalocean/csi-digitalocean/driver.version=0.0.1" diff --git a/driver/driver_test.go b/driver/driver_test.go index e84e1ac62..afef0ca34 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -29,6 +29,7 @@ import ( "github.com/digitalocean/godo" "github.com/kubernetes-csi/csi-test/pkg/sanity" "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" ) func init() { @@ -81,17 +82,25 @@ func TestDriverSuite(t *testing.T) { account: &fakeAccountDriver{}, tags: &fakeTagsDriver{}, } - defer driver.Stop() - go driver.Run() + ctx, cancel := context.WithCancel(context.Background()) + + var eg errgroup.Group + eg.Go(func() error { + return driver.Run(ctx) + }) cfg := &sanity.Config{ TargetPath: os.TempDir() + "/csi-target", StagingPath: os.TempDir() + "/csi-staging", Address: endpoint, } - sanity.Test(t, cfg) + + cancel() + if err := eg.Wait(); err != nil { + t.Errorf("driver run failed: %s", err) + } } type fakeAccountDriver struct {