Skip to content

[Feature] Add timeout for apiserver grpc server #3427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions apiserver/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"path"
"strings"
"sync/atomic"
"time"

assetfs "github.com/elazarl/go-bindata-assetfs"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
Expand All @@ -27,6 +28,7 @@ import (
"github.com/ray-project/kuberay/apiserver/pkg/manager"
"github.com/ray-project/kuberay/apiserver/pkg/server"
"github.com/ray-project/kuberay/apiserver/pkg/swagger"
"github.com/ray-project/kuberay/apiserver/pkg/util"
api "github.com/ray-project/kuberay/proto/go_client"
)

Expand All @@ -36,6 +38,7 @@ var (
collectMetricsFlag = flag.Bool("collectMetricsFlag", true, "Whether to collect Prometheus metrics in API server.")
logFile = flag.String("logFilePath", "", "Synchronize logs to local file")
localSwaggerPath = flag.String("localSwaggerPath", "", "Specify the root directory for `*.swagger.json` the swagger files.")
grpcTimeout = flag.Duration("grpc_timeout", util.GRPCServerDefaultTimeout, "gRPC server timeout duration")
healthy int32
)

Expand All @@ -54,7 +57,8 @@ func main() {
resourceManager := manager.NewResourceManager(&clientManager)

atomic.StoreInt32(&healthy, 1)
go startRPCServer(resourceManager)
klog.Infof("Setting gRPC server timeout to %v", *grpcTimeout)
go startRPCServer(resourceManager, *grpcTimeout)
startHttpProxy()
// See also https://gist.github.com/enricofoltran/10b4a980cd07cb02836f70a4ab3e72d7
quit := make(chan os.Signal, 1)
Expand All @@ -70,7 +74,7 @@ func main() {

type RegisterHttpHandlerFromEndpoint func(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) error

func startRPCServer(resourceManager *manager.ResourceManager) {
func startRPCServer(resourceManager *manager.ResourceManager, grpcTimeout time.Duration) {
klog.Infof("Starting gRPC server at port %s", *rpcPortFlag)

listener, err := net.Listen("tcp", *rpcPortFlag)
Expand All @@ -86,8 +90,13 @@ func startRPCServer(resourceManager *manager.ResourceManager) {

s := grpc.NewServer(
grpc.StreamInterceptor(grpc_prometheus.StreamServerInterceptor),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(grpc_prometheus.UnaryServerInterceptor, interceptor.APIServerInterceptor)),
grpc.MaxRecvMsgSize(math.MaxInt32))
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
interceptor.TimeoutInterceptor(grpcTimeout),
grpc_prometheus.UnaryServerInterceptor,
interceptor.APIServerInterceptor,
)),
grpc.MaxRecvMsgSize(math.MaxInt32),
)
api.RegisterClusterServiceServer(s, clusterServer)
api.RegisterComputeTemplateServiceServer(s, templateServer)
api.RegisterRayJobServiceServer(s, jobServer)
Expand Down
62 changes: 35 additions & 27 deletions apiserver/pkg/client/kubernetes_mock.go
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is automatically updated when running make test

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions apiserver/pkg/interceptor/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package interceptor

import (
"context"
"time"

"google.golang.org/grpc"
klog "k8s.io/klog/v2"
Expand All @@ -19,3 +20,17 @@ func APIServerInterceptor(ctx context.Context, req interface{}, info *grpc.Unary
klog.Infof("%v handler finished", info.FullMethod)
return
}

// TimeoutInterceptor implements UnaryServerInterceptor that sets the timeout for the request
func TimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req interface{},
_ *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return handler(ctx, req)
}
}
91 changes: 86 additions & 5 deletions apiserver/pkg/interceptor/interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ import (
"io"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
klog "k8s.io/klog/v2"
)

Expand All @@ -20,9 +23,28 @@ type mockHandler struct {
called bool
}

func (h *mockHandler) Handle(_ context.Context, _ interface{}) (interface{}, error) {
// Handle simulates the behavior of a gRPC handler with an optional delay.
// If the delay completes before the context expires, it returns "test_response" along with predefined error.
// If the context is canceled or the deadline is exceeded before the delay completes,
// it returns a corresponding gRPC status error instead.
func (h *mockHandler) Handle(ctx context.Context, _ interface{}, delay time.Duration) (interface{}, error) {
h.called = true
return "test_response", h.returnErr

select {
case <-time.After(delay):
return "test_response", h.returnErr
case <-ctx.Done():
var grpcCode codes.Code
switch ctx.Err() {
case context.Canceled:
grpcCode = codes.Canceled
case context.DeadlineExceeded:
grpcCode = codes.DeadlineExceeded
default:
grpcCode = codes.Unknown
}
return nil, status.Error(grpcCode, ctx.Err().Error())
}
Comment on lines +33 to +47
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding this to mimic the grpc IO handler for testing

}

func TestAPIServerInterceptor(t *testing.T) {
Expand Down Expand Up @@ -61,7 +83,7 @@ func TestAPIServerInterceptor(t *testing.T) {
req,
info,
func(ctx context.Context, req interface{}) (interface{}, error) {
return tt.handler.Handle(ctx, req)
return tt.handler.Handle(ctx, req, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 0 /*delay*/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed! Thanks!

},
)

Expand Down Expand Up @@ -96,7 +118,7 @@ func TestAPIServerInterceptorContextPassing(t *testing.T) {
func(receivedCtx context.Context, req interface{}) (interface{}, error) {
// Verify context value is passed through
assert.Equal(t, "test_value", receivedCtx.Value(testContextKey("test_key")))
return handler.Handle(receivedCtx, req)
return handler.Handle(receivedCtx, req, 0)
},
)
}
Expand Down Expand Up @@ -153,7 +175,7 @@ func TestAPIServerInterceptorLogging(t *testing.T) {
"test_request",
info,
func(receivedCtx context.Context, req interface{}) (interface{}, error) {
return handler.Handle(receivedCtx, req)
return handler.Handle(receivedCtx, req, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add a comment besides constants

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed! Thanks!

},
)

Expand Down Expand Up @@ -192,3 +214,62 @@ func TestAPIServerInterceptorLogging(t *testing.T) {
})
}
}

func TestTimeoutInterceptor(t *testing.T) {
tests := []struct {
expectedError error
name string
timeout time.Duration
handlerDelay time.Duration
expectedCalled bool
}{
{
name: "handler completes before timeout",
timeout: 100 * time.Millisecond,
handlerDelay: 50 * time.Millisecond,
expectedError: nil,
expectedCalled: true,
},
{
name: "handler exceeds timeout",
timeout: 50 * time.Millisecond,
handlerDelay: 100 * time.Millisecond,
expectedError: status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error()),
expectedCalled: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create test context and request
ctx := context.Background()
req := "test_request"
handler := &mockHandler{}

// Create the interceptor with the specified timeout
interceptor := TimeoutInterceptor(tt.timeout)

// Call the interceptor
resp, err := interceptor(
ctx,
req,
&grpc.UnaryServerInfo{FullMethod: "TestTimeoutMethod"},
func(ctx context.Context, req interface{}) (interface{}, error) {
return handler.Handle(ctx, req, tt.handlerDelay)
},
)

// Verify response and error
if tt.expectedError == nil {
// Verify handler was called
assert.Equal(t, tt.expectedCalled, handler.called, "handler call status should match expected")

require.NoError(t, err)
assert.Equal(t, "test_response", resp, "response should match expected")
} else {
require.Error(t, err)
require.Equal(t, tt.expectedError, err, "A matching error is expected")
}
})
}
}
Loading
Loading