Skip to content

[Feature] Add timeout for apiserver grpc server #3427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions apiserver/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"path"
"strings"
"sync/atomic"
"time"

assetfs "github.com/elazarl/go-bindata-assetfs"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
Expand Down Expand Up @@ -50,11 +51,21 @@ func main() {
_ = flagSet.Set("log_file", *logFile)
}

grpcTimeout := 60 * time.Second // Default timeout
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we following mechanisms to define constants or we are adding to each files where we are using it ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quickly glancing over the code, we have constants.go for other components (i.e. operator)
https://github.com/ray-project/kuberay/blob/ebb5ba441b0a7f888c17aa5c2d33943084a9a2d9/ray-operator/controllers/ray/utils/constant.go

I usually do this in two ways:

  • either place it to constant file, just as what we did for kuberay operator
    • the benefit of which is we group all constants in one place, rather than scattered around the codebase
  • or define a util function getGrpcServerTimeoutOrDefault and have default timeout besides
    • the benefit of which is it's easy to locate all timeout related functions and features

Our codebase seems to prefer (1).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I found that in apiserver, they put constants in config.go, I'll add it here

const (
// Label keys
RayClusterNameLabelKey = "ray.io/cluster-name"
RayClusterUserLabelKey = "ray.io/user"
RayClusterVersionLabelKey = "ray.io/version"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

if timeoutStr := os.Getenv("GRPC_SERVER_TIMEOUT"); timeoutStr != "" {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw why do we use env var instead of flags? I think flags are strictly better in a few ways:

  • program checks env variables; for example, bazel uses env to decide whether we could reuse cache
  • impose security issue, because env is shared among all processes which could be accessed everywhere

I almost only use env variables when:

  • across language boundary
  • across process boundary, if no other easier way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the guidance!

The reason why I put it in environment variable instead of flag is because I search through the code base and find they put this (which I think is a bit similar to timeout?) in the environment variable, so I just simply follow what it does

requeueAfterSeconds, err := strconv.Atoi(os.Getenv(utils.RAYCLUSTER_DEFAULT_REQUEUE_SECONDS_ENV))

I agree to your points, if there's no other places that need this value, I think I'll just move it to flag instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to flag

if timeout, err := time.ParseDuration(timeoutStr); err == nil {
grpcTimeout = timeout
klog.Infof("gRPC servier timeout set to %v", grpcTimeout)
} else {
klog.Warningf("Invalid GRPC_SERVER_TIMEOUT value: %v, using default timeout (60 seconds)", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use %d to print out default value, in case we change in the future

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just added

}
}

clientManager := manager.NewClientManager()
resourceManager := manager.NewResourceManager(&clientManager)

atomic.StoreInt32(&healthy, 1)
go startRPCServer(resourceManager)
go startRPCServer(resourceManager, grpcTimeout)
startHttpProxy()
// See also https://gist.github.com/enricofoltran/10b4a980cd07cb02836f70a4ab3e72d7
quit := make(chan os.Signal, 1)
Expand All @@ -70,7 +81,7 @@ func main() {

type RegisterHttpHandlerFromEndpoint func(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) error

func startRPCServer(resourceManager *manager.ResourceManager) {
func startRPCServer(resourceManager *manager.ResourceManager, grpcTimeout time.Duration) {
klog.Infof("Starting gRPC server at port %s", *rpcPortFlag)

listener, err := net.Listen("tcp", *rpcPortFlag)
Expand All @@ -86,8 +97,13 @@ func startRPCServer(resourceManager *manager.ResourceManager) {

s := grpc.NewServer(
grpc.StreamInterceptor(grpc_prometheus.StreamServerInterceptor),
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(grpc_prometheus.UnaryServerInterceptor, interceptor.APIServerInterceptor)),
grpc.MaxRecvMsgSize(math.MaxInt32))
grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(
interceptor.TimeoutInterceptor(grpcTimeout),
grpc_prometheus.UnaryServerInterceptor,
interceptor.APIServerInterceptor,
)),
grpc.MaxRecvMsgSize(math.MaxInt32),
)
api.RegisterClusterServiceServer(s, clusterServer)
api.RegisterComputeTemplateServiceServer(s, templateServer)
api.RegisterRayJobServiceServer(s, jobServer)
Expand Down
40 changes: 40 additions & 0 deletions apiserver/pkg/interceptor/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package interceptor

import (
"context"
"fmt"
"time"

"google.golang.org/grpc"
klog "k8s.io/klog/v2"
Expand All @@ -19,3 +21,41 @@ func APIServerInterceptor(ctx context.Context, req interface{}, info *grpc.Unary
klog.Infof("%v handler finished", info.FullMethod)
return
}

// TimeoutInterceptor implements UnaryServerInterceptor that sets the timeout for the request
func TimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req interface{},
_ *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
// Create a context with timeout
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

// Channel to capture execution result
done := make(chan struct{})
var (
resp interface{}
err error
)

go func() {
resp, err = handler(ctx, req)
close(done)
}()

select {
case <-ctx.Done():
// Raise error if time out
if ctx.Err() == context.DeadlineExceeded {
return nil, fmt.Errorf("grpc server timed out")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we name the grpc server with KubeRay API server ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! Just changed

}
return nil, ctx.Err()
case <-done:
// Handler finished
return resp, err
}
}
}
62 changes: 62 additions & 0 deletions apiserver/pkg/interceptor/interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -192,3 +194,63 @@ func TestAPIServerInterceptorLogging(t *testing.T) {
})
}
}

func TestTimeoutInterceptor(t *testing.T) {
tests := []struct {
expectedError error
name string
timeout time.Duration
handlerDelay time.Duration
expectedCalled bool
}{
{
name: "handler completes before timeout",
timeout: 100 * time.Millisecond,
handlerDelay: 50 * time.Millisecond,
expectedError: nil,
expectedCalled: true,
},
{
name: "handler exceeds timeout",
timeout: 50 * time.Millisecond,
handlerDelay: 100 * time.Millisecond,
expectedError: fmt.Errorf("grpc server timed out"),
expectedCalled: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create test context and request
ctx := context.Background()
req := "test_request"
handler := &mockHandler{}

// Create the interceptor with the specified timeout
interceptor := TimeoutInterceptor(tt.timeout)

// Call the interceptor
resp, err := interceptor(
ctx,
req,
&grpc.UnaryServerInfo{FullMethod: "TestTimeoutMethod"},
func(ctx context.Context, req interface{}) (interface{}, error) {
time.Sleep(tt.handlerDelay)
return handler.Handle(ctx, req)
},
)

// Verify response and error
if tt.expectedError == nil {
// Verify handler was called
assert.Equal(t, tt.expectedCalled, handler.called, "handler call status should match expected")

require.NoError(t, err)
assert.Equal(t, "test_response", resp, "response should match expected")
} else {
require.Error(t, err)
require.EqualError(t, err, tt.expectedError.Error(), "A matching error is expected")
}
})
}
}
Loading