Skip to content

Return custom header when request was returned from a cold start #366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions interceptor/forward_wait_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

// forwardWaitFunc is a function that waits for a condition
// before proceeding to serve the request.
type forwardWaitFunc func(context.Context, string, string) error
type forwardWaitFunc func(context.Context, string, string) (int, error)

func deploymentCanServe(depl appsv1.Deployment) bool {
return depl.Status.ReadyReplicas > 0
Expand All @@ -21,7 +21,7 @@ func newDeployReplicasForwardWaitFunc(
lggr logr.Logger,
deployCache k8s.DeploymentCache,
) forwardWaitFunc {
return func(ctx context.Context, deployNS, deployName string) error {
return func(ctx context.Context, deployNS, deployName string) (int, error) {
// get a watcher & its result channel before querying the
// deployment cache, to ensure we don't miss events
watcher := deployCache.Watch(deployNS, deployName)
Expand All @@ -31,7 +31,7 @@ func newDeployReplicasForwardWaitFunc(
deployment, err := deployCache.Get(deployNS, deployName)
if err != nil {
// if we didn't get the initial deployment state, bail out
return fmt.Errorf(
return 0, fmt.Errorf(
"error getting state for deployment %s/%s (%s)",
deployNS,
deployName,
Expand All @@ -40,7 +40,7 @@ func newDeployReplicasForwardWaitFunc(
}
// if there is 1 or more replica, we're done waiting
if deploymentCanServe(deployment) {
return nil
return int(deployment.Status.ReadyReplicas), nil
}

for {
Expand All @@ -51,14 +51,13 @@ func newDeployReplicasForwardWaitFunc(
lggr.Info(
"Didn't get a deployment back in event",
)
}
if deploymentCanServe(*deployment) {
return nil
} else if deploymentCanServe(*deployment) {
return 0, nil
}
case <-ctx.Done():
// otherwise, if the context is marked done before
// we're done waiting, fail.
return fmt.Errorf(
return 0, fmt.Errorf(
"context marked done while waiting for deployment %s to reach > 0 replicas (%w)",
deployName,
ctx.Err(),
Expand Down
8 changes: 5 additions & 3 deletions interceptor/forward_wait_func_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ func TestForwardWaitFuncOneReplica(t *testing.T) {
)

group.Go(func() error {
return waitFunc(ctx, ns, deployName)
_, err := waitFunc(ctx, ns, deployName)
return err
})
r.NoError(group.Wait(), "wait function failed, but it shouldn't have")
}
Expand Down Expand Up @@ -76,7 +77,7 @@ func TestForwardWaitFuncNoReplicas(t *testing.T) {
cache,
)

err := waitFunc(ctx, ns, deployName)
_, err := waitFunc(ctx, ns, deployName)
r.Error(err)
}

Expand Down Expand Up @@ -120,6 +121,7 @@ func TestWaitFuncWaitsUntilReplicas(t *testing.T) {
watcher.Action(watch.Modified, modifiedDeployment)
close(replicasIncreasedCh)
}()
r.NoError(waitFunc(ctx, ns, deployName))
_, err := waitFunc(ctx, ns, deployName)
r.NoError(err)
done()
}
7 changes: 5 additions & 2 deletions interceptor/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ func TestRunProxyServerCountMiddleware(t *testing.T) {
)
timeouts := &config.Timeouts{}
waiterCh := make(chan struct{})
waitFunc := func(ctx context.Context, ns, name string) error {
waitFunc := func(ctx context.Context, ns, name string) (int, error) {
<-waiterCh
return nil
return 1, nil
}
g.Go(func() error {
return runProxyServer(
Expand Down Expand Up @@ -106,6 +106,9 @@ func TestRunProxyServerCountMiddleware(t *testing.T) {
resp.StatusCode,
)
}
if resp.Header.Get("X-KEDA-HTTP-Cold-Start") != "false" {
return fmt.Errorf("expected X-KEDA-HTTP-Cold-Start false, but got %s", resp.Header.Get("X-KEDA-HTTP-Cold-Start"))
}
return nil
})
time.Sleep(100 * time.Millisecond)
Expand Down
10 changes: 8 additions & 2 deletions interceptor/proxy_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,12 @@ func newForwardingHandler(

waitFuncCtx, done := context.WithTimeout(r.Context(), fwdCfg.waitTimeout)
defer done()
if err := waitFunc(
replicas, err := waitFunc(
waitFuncCtx,
routingTarget.Namespace,
routingTarget.Deployment,
); err != nil {
)
if err != nil {
lggr.Error(err, "wait function failed, not forwarding request")
w.WriteHeader(502)
w.Write([]byte(fmt.Sprintf("error on backend (%s)", err)))
Expand All @@ -91,6 +92,11 @@ func newForwardingHandler(
w.Write([]byte("error getting backend service URL"))
return
}
isColdStart := "false"
if replicas == 0 {
isColdStart = "true"
}
w.Header().Add("X-KEDA-HTTP-Cold-Start", isColdStart)
forwardRequest(w, r, roundTripper, targetSvcURL)
})
}
35 changes: 22 additions & 13 deletions interceptor/proxy_handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ func TestImmediatelySuccessfulProxy(t *testing.T) {

timeouts := defaultTimeouts()
dialCtxFunc := retryDialContextFunc(timeouts, timeouts.DefaultBackoff())
waitFunc := func(context.Context, string, string) error {
return nil
waitFunc := func(context.Context, string, string) (int, error) {
return 1, nil
}
hdl := newForwardingHandler(
logr.Discard(),
Expand All @@ -68,6 +68,7 @@ func TestImmediatelySuccessfulProxy(t *testing.T) {

hdl.ServeHTTP(res, req)

r.Equal("false", res.Header().Get("X-KEDA-HTTP-Cold-Start"), "expected X-KEDA-HTTP-Cold-Start false")
r.Equal(200, res.Code, "expected response code 200")
r.Equal("test response", res.Body.String())
}
Expand All @@ -85,8 +86,8 @@ func TestWaitFailedConnection(t *testing.T) {
timeouts,
backoff,
)
waitFunc := func(context.Context, string, string) error {
return nil
waitFunc := func(context.Context, string, string) (int, error) {
return 1, nil
}
routingTable := routing.NewTable()
routingTable.AddTarget(host, routing.NewTarget(
Expand Down Expand Up @@ -117,6 +118,7 @@ func TestWaitFailedConnection(t *testing.T) {

hdl.ServeHTTP(res, req)

r.Equal("false", res.Header().Get("X-KEDA-HTTP-Cold-Start"), "expected X-KEDA-HTTP-Cold-Start false")
r.Equal(502, res.Code, "response code was unexpected")
}

Expand Down Expand Up @@ -166,13 +168,19 @@ func TestTimesOutOnWaitFunc(t *testing.T) {

t.Logf("elapsed time was %s", elapsed)
// serving should take at least timeouts.DeploymentReplicas, but no more than
// timeouts.DeploymentReplicas*2
// elapsed time should be more than the deployment replicas wait time
// but not an amount that is much greater than that
// timeouts.DeploymentReplicas*4
r.GreaterOrEqual(elapsed, timeouts.DeploymentReplicas)
r.LessOrEqual(elapsed, timeouts.DeploymentReplicas*4)
r.Equal(502, res.Code, "response code was unexpected")

// we will always return the X-KEDA-HTTP-Cold-Start header
// when we are able to forward the
// request to the backend but not if we have failed due
// to a timeout from a waitFunc or earlier in the pipeline,
// for example, if we cannot reach the Kubernetes control
// plane.
r.Equal("", res.Header().Get("X-KEDA-HTTP-Cold-Start"), "expected X-KEDA-HTTP-Cold-Start to be empty")

// waitFunc should have been called, even though it timed out
waitFuncCalled := false
select {
Expand Down Expand Up @@ -277,8 +285,8 @@ func TestWaitHeaderTimeout(t *testing.T) {

timeouts := defaultTimeouts()
dialCtxFunc := retryDialContextFunc(timeouts, timeouts.DefaultBackoff())
waitFunc := func(context.Context, string, string) error {
return nil
waitFunc := func(context.Context, string, string) (int, error) {
return 1, nil
}
routingTable := routing.NewTable()
target := routing.NewTarget(
Expand Down Expand Up @@ -309,6 +317,7 @@ func TestWaitHeaderTimeout(t *testing.T) {

hdl.ServeHTTP(res, req)

r.Equal("false", res.Header().Get("X-KEDA-HTTP-Cold-Start"), "expected X-KEDA-HTTP-Cold-Start false")
r.Equal(502, res.Code, "response code was unexpected")
close(originHdlCh)
}
Expand Down Expand Up @@ -346,19 +355,19 @@ func waitForSignal(sig <-chan struct{}, waitDur time.Duration) error {
// is called, or the context that is passed to it is done (e.g. cancelled, timed out,
// etc...). in the former case, the returned func itself returns nil. in the latter,
// it returns ctx.Err()
func notifyingFunc() (func(context.Context, string, string) error, <-chan struct{}, func()) {
func notifyingFunc() (forwardWaitFunc, <-chan struct{}, func()) {
calledCh := make(chan struct{})
finishCh := make(chan struct{})
finishFunc := func() {
close(finishCh)
}
return func(ctx context.Context, _, _ string) error {
return func(ctx context.Context, _, _ string) (int, error) {
close(calledCh)
select {
case <-finishCh:
return nil
return 0, nil
case <-ctx.Done():
return fmt.Errorf("TEST FUNCTION CONTEXT ERROR: %w", ctx.Err())
return 0, fmt.Errorf("TEST FUNCTION CONTEXT ERROR: %w", ctx.Err())
}
}, calledCh, finishFunc
}
Expand Down