diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index aeb4a45..dcf7331 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -29,7 +29,7 @@ jobs: run: go generate ./... - name: Tests - run: go test -coverprofile=coverage.out -covermode=atomic ./... + run: go test -race -coverprofile=coverage.out -covermode=atomic ./... - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 \ No newline at end of file diff --git a/cmd/manager/main.go b/cmd/manager/main.go index 1e4811f..5704810 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -28,7 +28,7 @@ func run() error { return err } - pool := compute.NewPool(logger, aws.NewAWSWorkerFactory(logger, ec2.NewFromConfig(cfg), aws.DefaultAWSInstanceParams, 443)) + pool := compute.NewPool(logger, aws.NewWorkerFactory(logger, ec2.NewFromConfig(cfg), aws.DefaultInstanceParams, 443)) defer func(pool compute.Pool) { _ = pool.Close() }(pool) diff --git a/go.mod b/go.mod index 0964e3e..4926471 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/s3 v1.26.6 github.com/golang/mock v1.6.0 github.com/jaypipes/ghw v0.9.0 + github.com/jaypipes/pcidb v1.0.0 github.com/launchdarkly/go-test-helpers/v2 v2.3.1 github.com/pkg/errors v0.9.1 github.com/xfrr/goffmpeg v0.0.0-20210624103149-5ca2d3062daf @@ -38,7 +39,6 @@ require ( github.com/ghodss/yaml v1.0.0 // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/golang/protobuf v1.5.2 // indirect - github.com/jaypipes/pcidb v1.0.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect go.uber.org/atomic v1.7.0 // indirect diff --git a/internal/compute/options.go b/internal/compute/options.go new file mode 100644 index 0000000..1e613be --- /dev/null +++ b/internal/compute/options.go @@ -0,0 +1,22 @@ +package compute + +import "time" + +type ReadyOptions struct { + TickerInterval time.Duration + ConnTimeout time.Duration +} + +type ReadyOptionsFunc func(options *ReadyOptions) + +func WithTickerInterval(interval time.Duration) ReadyOptionsFunc { + return func(opts *ReadyOptions) { + opts.TickerInterval = interval + } +} + +func WithConnTimeout(timeout time.Duration) ReadyOptionsFunc { + return func(opts *ReadyOptions) { + opts.ConnTimeout = timeout + } +} diff --git a/internal/compute/options_test.go b/internal/compute/options_test.go new file mode 100644 index 0000000..fd60d0c --- /dev/null +++ b/internal/compute/options_test.go @@ -0,0 +1,54 @@ +package compute + +import ( + "testing" + "time" + + m "github.com/launchdarkly/go-test-helpers/v2/matchers" +) + +func TestWithTickerInterval(t *testing.T) { + tests := []struct { + name string + interval time.Duration + }{ + {"0", time.Duration(0)}, + {"seconds", 1 * time.Second}, + {"ms", 100 * time.Millisecond}, + {"hours", 2 * time.Hour}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f := WithTickerInterval(test.interval) + + opt := new(ReadyOptions) + + f(opt) + + m.For(t, "val").Assert(opt.TickerInterval, m.Equal(test.interval)) + }) + } +} + +func TestWithConnTimeout(t *testing.T) { + tests := []struct { + name string + interval time.Duration + }{ + {"0", time.Duration(0)}, + {"seconds", 1 * time.Second}, + {"ms", 100 * time.Millisecond}, + {"hours", 2 * time.Hour}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f := WithConnTimeout(test.interval) + + opt := new(ReadyOptions) + + f(opt) + + m.For(t, "val").Assert(opt.ConnTimeout, m.Equal(test.interval)) + }) + } +} diff --git a/internal/compute/work_queue.go b/internal/compute/work_queue.go index ae4a786..4042f63 100644 --- a/internal/compute/work_queue.go +++ b/internal/compute/work_queue.go @@ -46,7 +46,10 @@ func NewQueue(logger *zap.Logger, pool Pool, maxSize int) WorkQueue { func (q *DefaultWorkQueue) run() { for { + q.mtx.Lock() if len(q.workers) < q.maxSize { + q.mtx.Unlock() + // Wait for work work := <-q.workQueue @@ -71,9 +74,11 @@ func (q *DefaultWorkQueue) run() { go func() { defer func(worker Worker) { - q.wg.Done() + q.mtx.Lock() + defer q.mtx.Lock() q.workers = removeItem(q.workers, worker) q.pool.ReturnWorker(worker) + q.wg.Done() }(worker) q.logger.Debug("Waiting for worker ready", zap.Any("req", work.getReq())) @@ -98,6 +103,8 @@ func (q *DefaultWorkQueue) run() { } q.logger.Info("Work finished", zap.Any("req", work.getReq())) }() + } else { + q.mtx.Unlock() } } } @@ -117,5 +124,7 @@ func (q *DefaultWorkQueue) GetMaxSize() int { } func (q *DefaultWorkQueue) SetMaxSize(size int) { + q.mtx.Lock() + defer q.mtx.Unlock() q.maxSize = size } diff --git a/internal/compute/work_queue_test.go b/internal/compute/work_queue_test.go index 2bb6acc..ccae801 100644 --- a/internal/compute/work_queue_test.go +++ b/internal/compute/work_queue_test.go @@ -134,8 +134,8 @@ func TestDefaultWorkQueue_Add(t *testing.T) { AnyTimes() mWorker.EXPECT(). - IsReadyChan(gomock.Any()). - DoAndReturn(func(ctx context.Context) <-chan error { + IsReadyChan(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, opts ...func(options *ReadyOptions)) <-chan error { ch := make(chan error) go func() { ch <- nil diff --git a/internal/compute/worker.go b/internal/compute/worker.go index 7028fc3..1650ad9 100644 --- a/internal/compute/worker.go +++ b/internal/compute/worker.go @@ -23,8 +23,8 @@ type Worker interface { Worker() proto.WorkerServiceClient Job() proto.JobServiceClient - IsReady(ctx context.Context) (bool, error) - IsReadyChan(ctx context.Context) <-chan error + IsReady(ctx context.Context, opts ...ReadyOptionsFunc) (bool, error) + IsReadyChan(ctx context.Context, opts ...ReadyOptionsFunc) <-chan error } type WorkerFactory interface { diff --git a/internal/worker/aws/aws.go b/internal/worker/aws/aws.go index 03a544d..f33dcfc 100644 --- a/internal/worker/aws/aws.go +++ b/internal/worker/aws/aws.go @@ -22,7 +22,7 @@ import ( //go:embed userdata.sh var userData []byte -var DefaultAWSInstanceParams = &ec2.RunInstancesInput{ +var DefaultInstanceParams = &ec2.RunInstancesInput{ MinCount: aws.Int32(1), MaxCount: aws.Int32(1), IamInstanceProfile: nil, @@ -37,16 +37,16 @@ var DefaultAWSInstanceParams = &ec2.RunInstancesInput{ UserData: aws.String(base64.StdEncoding.EncodeToString(userData)), } -type AWSWorkerEC2Client interface { +type WorkerEC2Client interface { RunInstances(ctx context.Context, params *ec2.RunInstancesInput, optFns ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error) DescribeInstances(ctx context.Context, params *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) DescribeInstanceStatus(ctx context.Context, params *ec2.DescribeInstanceStatusInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstanceStatusOutput, error) TerminateInstances(ctx context.Context, params *ec2.TerminateInstancesInput, optFns ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error) } -type AWSWorker struct { +type Worker struct { logger *zap.Logger - client AWSWorkerEC2Client + client WorkerEC2Client id string port uint16 @@ -60,7 +60,7 @@ type AWSWorker struct { closed bool } -func (w *AWSWorker) Close() error { +func (w *Worker) Close() error { if w.closed { return compute.ErrClosed } @@ -81,16 +81,16 @@ func (w *AWSWorker) Close() error { return err } -func (w *AWSWorker) Equals(other compute.Worker) bool { +func (w *Worker) Equals(other compute.Worker) bool { switch v := other.(type) { - case *AWSWorker: + case *Worker: return w.id == v.id default: return false } } -func (w *AWSWorker) getIP(ctx context.Context) (netip.Addr, error) { +func (w *Worker) getIP(ctx context.Context) (netip.Addr, error) { instances, err := w.client.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ InstanceIds: []string{w.id}, }) @@ -110,7 +110,7 @@ func (w *AWSWorker) getIP(ctx context.Context) (netip.Addr, error) { return netip.ParseAddr(aws.ToString(instance.PublicIpAddress)) } -func (w *AWSWorker) Connect(ctx context.Context) (err error) { +func (w *Worker) Connect(ctx context.Context) (err error) { if w.closed { return compute.ErrClosed } @@ -139,15 +139,15 @@ func (w *AWSWorker) Connect(ctx context.Context) (err error) { return nil } -func (w *AWSWorker) Worker() proto.WorkerServiceClient { +func (w *Worker) Worker() proto.WorkerServiceClient { return w.worker } -func (w *AWSWorker) Job() proto.JobServiceClient { +func (w *Worker) Job() proto.JobServiceClient { return w.job } -func (w *AWSWorker) getInstanceStatus(ctx context.Context) (types.InstanceStateName, error) { +func (w *Worker) getInstanceStatus(ctx context.Context) (types.InstanceStateName, error) { statuses, err := w.client.DescribeInstanceStatus(ctx, &ec2.DescribeInstanceStatusInput{ InstanceIds: []string{w.id}, }) @@ -167,7 +167,7 @@ func (w *AWSWorker) getInstanceStatus(ctx context.Context) (types.InstanceStateN return types.InstanceStateNamePending, nil } -func (w *AWSWorker) IsReady(ctx context.Context) (bool, error) { +func (w *Worker) IsReady(ctx context.Context, opts ...compute.ReadyOptionsFunc) (bool, error) { if w.closed { return false, compute.ErrClosed } @@ -181,7 +181,15 @@ func (w *AWSWorker) IsReady(ctx context.Context) (bool, error) { return false, nil } - connCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + options := &compute.ReadyOptions{ + ConnTimeout: 10 * time.Second, + } + + for _, opt := range opts { + opt(options) + } + + connCtx, cancel := context.WithTimeout(ctx, options.ConnTimeout) defer cancel() err = w.Connect(connCtx) @@ -196,12 +204,20 @@ func (w *AWSWorker) IsReady(ctx context.Context) (bool, error) { return true, nil } -func (w *AWSWorker) IsReadyChan(ctx context.Context) <-chan error { +func (w *Worker) IsReadyChan(ctx context.Context, opts ...compute.ReadyOptionsFunc) <-chan error { ch := make(chan error) - ticker := time.NewTicker(15 * time.Second) + options := &compute.ReadyOptions{ + TickerInterval: 15 * time.Second, + ConnTimeout: 10 * time.Second, + } + + for _, opt := range opts { + opt(options) + } go func() { + ticker := time.NewTicker(options.TickerInterval) for { select { case <-ctx.Done(): @@ -209,9 +225,9 @@ func (w *AWSWorker) IsReadyChan(ctx context.Context) <-chan error { ticker.Stop() return case <-ticker.C: - isReady, err := w.IsReady(ctx) + isReady, err := w.IsReady(ctx, opts...) if err != nil { - if err.Error() == "no instance statuses returned" { + if err.Error() == "instance not found" { continue } ch <- err @@ -221,6 +237,8 @@ func (w *AWSWorker) IsReadyChan(ctx context.Context) <-chan error { if isReady { ch <- nil + ticker.Stop() + return } } } @@ -229,15 +247,15 @@ func (w *AWSWorker) IsReadyChan(ctx context.Context) <-chan error { return ch } -type AWSWorkerFactory struct { +type WorkerFactory struct { logger *zap.Logger - client AWSWorkerEC2Client + client WorkerEC2Client params *ec2.RunInstancesInput port uint16 } -func NewAWSWorkerFactory(logger *zap.Logger, client AWSWorkerEC2Client, input *ec2.RunInstancesInput, port uint16) compute.WorkerFactory { - return &AWSWorkerFactory{ +func NewWorkerFactory(logger *zap.Logger, client WorkerEC2Client, input *ec2.RunInstancesInput, port uint16) compute.WorkerFactory { + return &WorkerFactory{ logger: logger, client: client, params: input, @@ -245,7 +263,7 @@ func NewAWSWorkerFactory(logger *zap.Logger, client AWSWorkerEC2Client, input *e } } -func (f *AWSWorkerFactory) Create(ctx context.Context) (compute.Worker, error) { +func (f *WorkerFactory) Create(ctx context.Context) (compute.Worker, error) { instances, err := f.client.RunInstances(ctx, f.params) if err != nil { return nil, err @@ -257,7 +275,7 @@ func (f *AWSWorkerFactory) Create(ctx context.Context) (compute.Worker, error) { instance := instances.Instances[0] - return &AWSWorker{ + return &Worker{ logger: f.logger, client: f.client, id: aws.ToString(instance.InstanceId), diff --git a/internal/worker/aws/aws_test.go b/internal/worker/aws/aws_test.go index 0cb0013..2320ed0 100644 --- a/internal/worker/aws/aws_test.go +++ b/internal/worker/aws/aws_test.go @@ -37,7 +37,7 @@ func grpcServer(t *testing.T) (*net.TCPAddr, *grpc.Server) { return l.Addr().(*net.TCPAddr), gsrv } -func TestNewAWSWorkerFactory(t *testing.T) { +func TestNewWorkerFactory(t *testing.T) { logger := zaptest.NewLogger(t) ctrl := gomock.NewController(t) @@ -45,18 +45,18 @@ func TestNewAWSWorkerFactory(t *testing.T) { ctrl.Finish() }) - mAWSClient := NewMockAWSWorkerEC2Client(ctrl) + mClient := NewMockWorkerEC2Client(ctrl) - factoryI := NewAWSWorkerFactory(logger, mAWSClient, DefaultAWSInstanceParams, 443) + factoryI := NewWorkerFactory(logger, mClient, DefaultInstanceParams, 443) - factory := factoryI.(*AWSWorkerFactory) + factory := factoryI.(*WorkerFactory) m.For(t, "factory").For("input"). - Assert(factory.params, m.Equal(DefaultAWSInstanceParams)) + Assert(factory.params, m.Equal(DefaultInstanceParams)) m.For(t, "factory").For("port").Assert(factory.port, m.Equal(uint16(443))) } -func TestAWSWorkerFactory_Create(t *testing.T) { +func TestWorkerFactory_Create(t *testing.T) { logger := zaptest.NewLogger(t) ctrl := gomock.NewController(t) @@ -65,22 +65,22 @@ func TestAWSWorkerFactory_Create(t *testing.T) { }) t.Run("normal behavior", func(t *testing.T) { - mAWSClient := NewMockAWSWorkerEC2Client(ctrl) - mAWSClient.EXPECT(). + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). RunInstances(gomock.Any(), gomock.Any(), gomock.Any()). Return(&ec2.RunInstancesOutput{ Instances: []types.Instance{{InstanceId: aws.String("i-123456")}}, }, nil) - factory := NewAWSWorkerFactory(logger, mAWSClient, DefaultAWSInstanceParams, 443) + factory := NewWorkerFactory(logger, mClient, DefaultInstanceParams, 443) workerI, err := factory.Create(context.Background()) m.For(t, "create err").Assert(err, m.BeNil()) - worker := workerI.(*AWSWorker) + worker := workerI.(*Worker) - m.For(t, "worker").For("client").Assert(worker.client, m.Equal(mAWSClient)) + m.For(t, "worker").For("client").Assert(worker.client, m.Equal(mClient)) m.For(t, "worker").For("id").Assert(worker.id, m.Equal("i-123456")) m.For(t, "worker").For("port").Assert(worker.port, m.Equal(uint16(443))) }) @@ -88,12 +88,12 @@ func TestAWSWorkerFactory_Create(t *testing.T) { t.Run("error", func(t *testing.T) { expectedErr := errors.New("something bad happened") - mAWSClient := NewMockAWSWorkerEC2Client(ctrl) - mAWSClient.EXPECT(). + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). RunInstances(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, expectedErr) - factory := NewAWSWorkerFactory(logger, mAWSClient, DefaultAWSInstanceParams, 443) + factory := NewWorkerFactory(logger, mClient, DefaultInstanceParams, 443) workerI, err := factory.Create(context.Background()) @@ -102,14 +102,14 @@ func TestAWSWorkerFactory_Create(t *testing.T) { }) t.Run("no instance returned", func(t *testing.T) { - mAWSClient := NewMockAWSWorkerEC2Client(ctrl) - mAWSClient.EXPECT(). + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). RunInstances(gomock.Any(), gomock.Any(), gomock.Any()). Return(&ec2.RunInstancesOutput{ Instances: []types.Instance{}, }, nil) - factory := NewAWSWorkerFactory(logger, mAWSClient, DefaultAWSInstanceParams, 443) + factory := NewWorkerFactory(logger, mClient, DefaultInstanceParams, 443) workerI, err := factory.Create(context.Background()) @@ -118,7 +118,7 @@ func TestAWSWorkerFactory_Create(t *testing.T) { }) } -func TestAWSWorker_Close(t *testing.T) { +func TestWorker_Close(t *testing.T) { logger := zaptest.NewLogger(t) ctrl := gomock.NewController(t) @@ -127,14 +127,14 @@ func TestAWSWorker_Close(t *testing.T) { }) t.Run("close new worker", func(t *testing.T) { - mAWSClient := NewMockAWSWorkerEC2Client(ctrl) - mAWSClient.EXPECT(). + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). TerminateInstances(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, nil) - worker := &AWSWorker{ + worker := &Worker{ logger: logger, - client: mAWSClient, + client: mClient, id: "id", } @@ -143,14 +143,14 @@ func TestAWSWorker_Close(t *testing.T) { }) t.Run("double close worker", func(t *testing.T) { - mAWSClient := NewMockAWSWorkerEC2Client(ctrl) - mAWSClient.EXPECT(). + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). TerminateInstances(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, nil) - worker := &AWSWorker{ + worker := &Worker{ logger: logger, - client: mAWSClient, + client: mClient, id: "id", } @@ -164,14 +164,14 @@ func TestAWSWorker_Close(t *testing.T) { t.Run("close connected worker", func(t *testing.T) { addr, _ := grpcServer(t) - mAWSClient := NewMockAWSWorkerEC2Client(ctrl) - mAWSClient.EXPECT(). + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). TerminateInstances(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, nil) - worker := &AWSWorker{ + worker := &Worker{ logger: logger, - client: mAWSClient, + client: mClient, id: "id", } @@ -190,14 +190,14 @@ func TestAWSWorker_Close(t *testing.T) { t.Run("close connected worker grpc error", func(t *testing.T) { addr, _ := grpcServer(t) - mAWSClient := NewMockAWSWorkerEC2Client(ctrl) - mAWSClient.EXPECT(). + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). TerminateInstances(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, nil) - worker := &AWSWorker{ + worker := &Worker{ logger: logger, - client: mAWSClient, + client: mClient, id: "id", } @@ -215,40 +215,40 @@ func TestAWSWorker_Close(t *testing.T) { }) } -func TestAWSWorker_Equals(t *testing.T) { +func TestWorker_Equals(t *testing.T) { tests := []struct { name string - w1 *AWSWorker + w1 *Worker w2 compute.Worker expect bool }{ { name: "same", - w1: &AWSWorker{id: "id"}, - w2: &AWSWorker{id: "id"}, + w1: &Worker{id: "id"}, + w2: &Worker{id: "id"}, expect: true, }, { name: "different", - w1: &AWSWorker{id: "id"}, - w2: &AWSWorker{id: "id2"}, + w1: &Worker{id: "id"}, + w2: &Worker{id: "id2"}, expect: false, }, { name: "different2", - w1: &AWSWorker{id: "id2"}, - w2: &AWSWorker{id: "id"}, + w1: &Worker{id: "id2"}, + w2: &Worker{id: "id"}, expect: false, }, { name: "nil", - w1: &AWSWorker{id: "id"}, + w1: &Worker{id: "id"}, w2: nil, expect: false, }, { name: "different worker", - w1: &AWSWorker{id: "id"}, + w1: &Worker{id: "id"}, w2: &compute.MockWorker{}, expect: false, }, @@ -260,7 +260,7 @@ func TestAWSWorker_Equals(t *testing.T) { } } -func TestAWSWorker_Connect(t *testing.T) { +func TestWorker_Connect(t *testing.T) { addr, _ := grpcServer(t) ip := addr.AddrPort().Addr() @@ -274,8 +274,8 @@ func TestAWSWorker_Connect(t *testing.T) { }) t.Run("normal behavior", func(t *testing.T) { - mAWSClient := NewMockAWSWorkerEC2Client(ctrl) - mAWSClient.EXPECT(). + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). DescribeInstances(gomock.Any(), gomock.Any(), gomock.Any()). Return(&ec2.DescribeInstancesOutput{ Reservations: []types.Reservation{ @@ -285,9 +285,9 @@ func TestAWSWorker_Connect(t *testing.T) { }, }, nil) - worker := &AWSWorker{ + worker := &Worker{ logger: logger, - client: mAWSClient, + client: mClient, id: "id", port: port, } @@ -297,8 +297,8 @@ func TestAWSWorker_Connect(t *testing.T) { }) t.Run("double connect", func(t *testing.T) { - mAWSClient := NewMockAWSWorkerEC2Client(ctrl) - mAWSClient.EXPECT(). + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). DescribeInstances(gomock.Any(), gomock.Any(), gomock.Any()). Return(&ec2.DescribeInstancesOutput{ Reservations: []types.Reservation{ @@ -309,9 +309,9 @@ func TestAWSWorker_Connect(t *testing.T) { }, nil). Times(2) - worker := &AWSWorker{ + worker := &Worker{ logger: logger, - client: mAWSClient, + client: mClient, id: "id", port: port, } @@ -324,14 +324,14 @@ func TestAWSWorker_Connect(t *testing.T) { }) t.Run("connect to closed", func(t *testing.T) { - mAWSClient := NewMockAWSWorkerEC2Client(ctrl) - mAWSClient.EXPECT(). + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). TerminateInstances(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, nil) - worker := &AWSWorker{ + worker := &Worker{ logger: logger, - client: mAWSClient, + client: mClient, id: "id", port: port, } @@ -344,8 +344,8 @@ func TestAWSWorker_Connect(t *testing.T) { }) t.Run("missing ip", func(t *testing.T) { - mAWSClient := NewMockAWSWorkerEC2Client(ctrl) - mAWSClient.EXPECT(). + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). DescribeInstances(gomock.Any(), gomock.Any(), gomock.Any()). Return(&ec2.DescribeInstancesOutput{ Reservations: []types.Reservation{ @@ -355,9 +355,9 @@ func TestAWSWorker_Connect(t *testing.T) { }, }, nil) - worker := &AWSWorker{ + worker := &Worker{ logger: logger, - client: mAWSClient, + client: mClient, id: "id", port: port, } @@ -370,14 +370,14 @@ func TestAWSWorker_Connect(t *testing.T) { t.Run("aws error", func(t *testing.T) { expectedErr := errors.New("something bad happened") - mAWSClient := NewMockAWSWorkerEC2Client(ctrl) - mAWSClient.EXPECT(). + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). DescribeInstances(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, expectedErr) - worker := &AWSWorker{ + worker := &Worker{ logger: logger, - client: mAWSClient, + client: mClient, id: "id", port: port, } @@ -387,8 +387,8 @@ func TestAWSWorker_Connect(t *testing.T) { }) t.Run("connection failure", func(t *testing.T) { - mAWSClient := NewMockAWSWorkerEC2Client(ctrl) - mAWSClient.EXPECT(). + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). DescribeInstances(gomock.Any(), gomock.Any(), gomock.Any()). Return(&ec2.DescribeInstancesOutput{ Reservations: []types.Reservation{ @@ -398,9 +398,9 @@ func TestAWSWorker_Connect(t *testing.T) { }, }, nil) - worker := &AWSWorker{ + worker := &Worker{ logger: logger, - client: mAWSClient, + client: mClient, id: "id", port: port + 1, // <- CHANGE HERE!!! } @@ -412,3 +412,315 @@ func TestAWSWorker_Connect(t *testing.T) { m.For(t, "err").Assert(err, m.Equal(context.DeadlineExceeded)) }) } + +func TestWorker_IsReady(t *testing.T) { + addr, _ := grpcServer(t) + + ip := addr.AddrPort().Addr() + port := addr.AddrPort().Port() + + logger := zaptest.NewLogger(t) + + ctrl := gomock.NewController(t) + t.Cleanup(func() { + ctrl.Finish() + }) + + t.Run("normal behavior", func(t *testing.T) { + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). + DescribeInstances(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ec2.DescribeInstancesOutput{ + Reservations: []types.Reservation{ + {Instances: []types.Instance{{ + PublicIpAddress: aws.String(ip.String()), + }}}, + }, + }, nil) + mClient.EXPECT(). + DescribeInstanceStatus(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ec2.DescribeInstanceStatusOutput{ + InstanceStatuses: []types.InstanceStatus{{ + InstanceState: &types.InstanceState{ + Name: types.InstanceStateNameRunning, + }, + }}, + }, nil) + + worker := &Worker{ + logger: logger, + client: mClient, + id: "id", + port: port, + } + + ready, err := worker.IsReady(context.Background()) + m.For(t, "err").Assert(err, m.BeNil()) + m.For(t, "ready").Assert(ready, m.Equal(true)) + }) + + t.Run("pending status", func(t *testing.T) { + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). + DescribeInstanceStatus(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ec2.DescribeInstanceStatusOutput{ + InstanceStatuses: []types.InstanceStatus{{ + InstanceState: &types.InstanceState{ + Name: types.InstanceStateNamePending, + }, + }}, + }, nil) + + worker := &Worker{ + logger: logger, + client: mClient, + id: "id", + port: port, + } + + ready, err := worker.IsReady(context.Background()) + m.For(t, "err").Assert(err, m.BeNil()) + m.For(t, "ready").Assert(ready, m.Equal(false)) + }) + + t.Run("status error", func(t *testing.T) { + expectedErr := errors.New("something bad happened") + + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). + DescribeInstanceStatus(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, expectedErr) + + worker := &Worker{ + logger: logger, + client: mClient, + id: "id", + port: port, + } + + ready, err := worker.IsReady(context.Background()) + m.For(t, "err").Assert(err, m.Equal(expectedErr)) + m.For(t, "ready").Assert(ready, m.Equal(false)) + }) + + t.Run("connection timeout", func(t *testing.T) { + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). + DescribeInstances(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ec2.DescribeInstancesOutput{ + Reservations: []types.Reservation{ + {Instances: []types.Instance{{ + PublicIpAddress: aws.String(ip.String()), + }}}, + }, + }, nil) + mClient.EXPECT(). + DescribeInstanceStatus(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ec2.DescribeInstanceStatusOutput{ + InstanceStatuses: []types.InstanceStatus{{ + InstanceState: &types.InstanceState{ + Name: types.InstanceStateNameRunning, + }, + }}, + }, nil) + + worker := &Worker{ + logger: logger, + client: mClient, + id: "id", + port: port + 1, // <- CHANGE HERE + } + + ready, err := worker.IsReady(context.Background(), compute.WithConnTimeout(1*time.Millisecond)) + m.For(t, "err").Assert(err, m.BeNil()) + m.For(t, "ready").Assert(ready, m.Equal(false)) + }) + + t.Run("already closed", func(t *testing.T) { + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). + TerminateInstances(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, nil) + + worker := &Worker{ + logger: logger, + client: mClient, + id: "id", + port: port, + } + + err := worker.Close() + m.For(t, "close").For("err").Require(err, m.BeNil()) + + ready, err := worker.IsReady(context.Background()) + m.For(t, "err").Assert(err, m.Equal(compute.ErrClosed)) + m.For(t, "ready").Assert(ready, m.Equal(false)) + }) +} + +func TestWorker_IsReadyChan(t *testing.T) { + addr, _ := grpcServer(t) + + ip := addr.AddrPort().Addr() + port := addr.AddrPort().Port() + + logger := zaptest.NewLogger(t) + + ctrl := gomock.NewController(t) + t.Cleanup(func() { + ctrl.Finish() + }) + + t.Run("normal behavior", func(t *testing.T) { + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). + DescribeInstances(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ec2.DescribeInstancesOutput{ + Reservations: []types.Reservation{ + {Instances: []types.Instance{{ + PublicIpAddress: aws.String(ip.String()), + }}}, + }, + }, nil). + AnyTimes() + mClient.EXPECT(). + DescribeInstanceStatus(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ec2.DescribeInstanceStatusOutput{ + InstanceStatuses: []types.InstanceStatus{{ + InstanceState: &types.InstanceState{ + Name: types.InstanceStateNameRunning, + }, + }}, + }, nil). + AnyTimes() + + worker := &Worker{ + logger: logger, + client: mClient, + id: "id", + port: port, + } + + ch := worker.IsReadyChan(context.Background(), + compute.WithTickerInterval(50*time.Millisecond), + compute.WithConnTimeout(100*time.Millisecond)) + m.For(t, "ch err").Assert(<-ch, m.BeNil()) + }) + + t.Run("double check", func(t *testing.T) { + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). + DescribeInstances(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ec2.DescribeInstancesOutput{ + Reservations: []types.Reservation{ + {Instances: []types.Instance{{ + PublicIpAddress: aws.String(ip.String()), + }}}, + }, + }, nil). + AnyTimes() + mClient.EXPECT(). + DescribeInstanceStatus(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ec2.DescribeInstanceStatusOutput{ + InstanceStatuses: []types.InstanceStatus{{ + InstanceState: &types.InstanceState{ + Name: types.InstanceStateNameRunning, + }, + }}, + }, nil). + After(mClient.EXPECT(). + DescribeInstanceStatus(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ec2.DescribeInstanceStatusOutput{ + InstanceStatuses: []types.InstanceStatus{{ + InstanceState: &types.InstanceState{ + Name: types.InstanceStateNamePending, + }, + }}, + }, nil). + Times(1)). + AnyTimes() + + worker := &Worker{ + logger: logger, + client: mClient, + id: "id", + port: port, + } + + ch := worker.IsReadyChan(context.Background(), + compute.WithTickerInterval(50*time.Millisecond), + compute.WithConnTimeout(100*time.Millisecond)) + m.For(t, "ch err").Assert(<-ch, m.BeNil()) + }) + + t.Run("context expired", func(t *testing.T) { + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). + DescribeInstances(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ec2.DescribeInstancesOutput{ + Reservations: []types.Reservation{ + {Instances: []types.Instance{{ + PublicIpAddress: aws.String(ip.String()), + }}}, + }, + }, nil). + AnyTimes() + mClient.EXPECT(). + DescribeInstanceStatus(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ec2.DescribeInstanceStatusOutput{ + InstanceStatuses: []types.InstanceStatus{{ + InstanceState: &types.InstanceState{ + Name: types.InstanceStateNamePending, + }, + }}, + }, nil). + AnyTimes() + + worker := &Worker{ + logger: logger, + client: mClient, + id: "id", + port: port, + } + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + ch := worker.IsReadyChan(ctx, + compute.WithTickerInterval(50*time.Millisecond), + compute.WithConnTimeout(100*time.Millisecond)) + m.For(t, "ch err").Assert(<-ch, m.Equal(context.DeadlineExceeded)) + }) + + t.Run("error", func(t *testing.T) { + expectedErr := errors.New("something bad happened") + + mClient := NewMockWorkerEC2Client(ctrl) + mClient.EXPECT(). + DescribeInstances(gomock.Any(), gomock.Any(), gomock.Any()). + Return(&ec2.DescribeInstancesOutput{ + Reservations: []types.Reservation{ + {Instances: []types.Instance{{ + PublicIpAddress: aws.String(ip.String()), + }}}, + }, + }, nil). + AnyTimes() + mClient.EXPECT(). + DescribeInstanceStatus(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, expectedErr). + AnyTimes() + + worker := &Worker{ + logger: logger, + client: mClient, + id: "id", + port: port, + } + + ch := worker.IsReadyChan(context.Background(), + compute.WithTickerInterval(50*time.Millisecond), + compute.WithConnTimeout(100*time.Millisecond)) + m.For(t, "ch err").Assert(<-ch, m.Equal(expectedErr)) + }) +}