diff --git a/interceptor/forward_wait_func.go b/interceptor/forward_wait_func.go index 7dc23cb71..50a850437 100644 --- a/interceptor/forward_wait_func.go +++ b/interceptor/forward_wait_func.go @@ -11,7 +11,7 @@ import ( // forwardWaitFunc is a function that waits for a condition // before proceeding to serve the request. -type forwardWaitFunc func(context.Context, string, string) error +type forwardWaitFunc func(context.Context, string, string) (int, error) func deploymentCanServe(depl appsv1.Deployment) bool { return depl.Status.ReadyReplicas > 0 @@ -21,7 +21,7 @@ func newDeployReplicasForwardWaitFunc( lggr logr.Logger, deployCache k8s.DeploymentCache, ) forwardWaitFunc { - return func(ctx context.Context, deployNS, deployName string) error { + return func(ctx context.Context, deployNS, deployName string) (int, error) { // get a watcher & its result channel before querying the // deployment cache, to ensure we don't miss events watcher := deployCache.Watch(deployNS, deployName) @@ -31,7 +31,7 @@ func newDeployReplicasForwardWaitFunc( deployment, err := deployCache.Get(deployNS, deployName) if err != nil { // if we didn't get the initial deployment state, bail out - return fmt.Errorf( + return 0, fmt.Errorf( "error getting state for deployment %s/%s (%s)", deployNS, deployName, @@ -40,7 +40,7 @@ func newDeployReplicasForwardWaitFunc( } // if there is 1 or more replica, we're done waiting if deploymentCanServe(deployment) { - return nil + return int(deployment.Status.ReadyReplicas), nil } for { @@ -51,14 +51,13 @@ func newDeployReplicasForwardWaitFunc( lggr.Info( "Didn't get a deployment back in event", ) - } - if deploymentCanServe(*deployment) { - return nil + } else if deploymentCanServe(*deployment) { + return 0, nil } case <-ctx.Done(): // otherwise, if the context is marked done before // we're done waiting, fail. - return fmt.Errorf( + return 0, fmt.Errorf( "context marked done while waiting for deployment %s to reach > 0 replicas (%w)", deployName, ctx.Err(), diff --git a/interceptor/forward_wait_func_test.go b/interceptor/forward_wait_func_test.go index 77bc10827..21959ba0e 100644 --- a/interceptor/forward_wait_func_test.go +++ b/interceptor/forward_wait_func_test.go @@ -43,7 +43,8 @@ func TestForwardWaitFuncOneReplica(t *testing.T) { ) group.Go(func() error { - return waitFunc(ctx, ns, deployName) + _, err := waitFunc(ctx, ns, deployName) + return err }) r.NoError(group.Wait(), "wait function failed, but it shouldn't have") } @@ -76,7 +77,7 @@ func TestForwardWaitFuncNoReplicas(t *testing.T) { cache, ) - err := waitFunc(ctx, ns, deployName) + _, err := waitFunc(ctx, ns, deployName) r.Error(err) } @@ -120,6 +121,7 @@ func TestWaitFuncWaitsUntilReplicas(t *testing.T) { watcher.Action(watch.Modified, modifiedDeployment) close(replicasIncreasedCh) }() - r.NoError(waitFunc(ctx, ns, deployName)) + _, err := waitFunc(ctx, ns, deployName) + r.NoError(err) done() } diff --git a/interceptor/main_test.go b/interceptor/main_test.go index 33c036d77..c75c07793 100644 --- a/interceptor/main_test.go +++ b/interceptor/main_test.go @@ -66,9 +66,9 @@ func TestRunProxyServerCountMiddleware(t *testing.T) { ) timeouts := &config.Timeouts{} waiterCh := make(chan struct{}) - waitFunc := func(ctx context.Context, ns, name string) error { + waitFunc := func(ctx context.Context, ns, name string) (int, error) { <-waiterCh - return nil + return 1, nil } g.Go(func() error { return runProxyServer( @@ -106,6 +106,9 @@ func TestRunProxyServerCountMiddleware(t *testing.T) { resp.StatusCode, ) } + if resp.Header.Get("X-KEDA-HTTP-Cold-Start") != "false" { + return fmt.Errorf("expected X-KEDA-HTTP-Cold-Start false, but got %s", resp.Header.Get("X-KEDA-HTTP-Cold-Start")) + } return nil }) time.Sleep(100 * time.Millisecond) diff --git a/interceptor/proxy_handlers.go b/interceptor/proxy_handlers.go index 9f1b27840..1d3860098 100644 --- a/interceptor/proxy_handlers.go +++ b/interceptor/proxy_handlers.go @@ -74,11 +74,12 @@ func newForwardingHandler( waitFuncCtx, done := context.WithTimeout(r.Context(), fwdCfg.waitTimeout) defer done() - if err := waitFunc( + replicas, err := waitFunc( waitFuncCtx, routingTarget.Namespace, routingTarget.Deployment, - ); err != nil { + ) + if err != nil { lggr.Error(err, "wait function failed, not forwarding request") w.WriteHeader(502) w.Write([]byte(fmt.Sprintf("error on backend (%s)", err))) @@ -91,6 +92,11 @@ func newForwardingHandler( w.Write([]byte("error getting backend service URL")) return } + isColdStart := "false" + if replicas == 0 { + isColdStart = "true" + } + w.Header().Add("X-KEDA-HTTP-Cold-Start", isColdStart) forwardRequest(w, r, roundTripper, targetSvcURL) }) } diff --git a/interceptor/proxy_handlers_test.go b/interceptor/proxy_handlers_test.go index 7af8de95b..788f10c57 100644 --- a/interceptor/proxy_handlers_test.go +++ b/interceptor/proxy_handlers_test.go @@ -45,8 +45,8 @@ func TestImmediatelySuccessfulProxy(t *testing.T) { timeouts := defaultTimeouts() dialCtxFunc := retryDialContextFunc(timeouts, timeouts.DefaultBackoff()) - waitFunc := func(context.Context, string, string) error { - return nil + waitFunc := func(context.Context, string, string) (int, error) { + return 1, nil } hdl := newForwardingHandler( logr.Discard(), @@ -68,6 +68,7 @@ func TestImmediatelySuccessfulProxy(t *testing.T) { hdl.ServeHTTP(res, req) + r.Equal("false", res.Header().Get("X-KEDA-HTTP-Cold-Start"), "expected X-KEDA-HTTP-Cold-Start false") r.Equal(200, res.Code, "expected response code 200") r.Equal("test response", res.Body.String()) } @@ -85,8 +86,8 @@ func TestWaitFailedConnection(t *testing.T) { timeouts, backoff, ) - waitFunc := func(context.Context, string, string) error { - return nil + waitFunc := func(context.Context, string, string) (int, error) { + return 1, nil } routingTable := routing.NewTable() routingTable.AddTarget(host, routing.NewTarget( @@ -117,6 +118,7 @@ func TestWaitFailedConnection(t *testing.T) { hdl.ServeHTTP(res, req) + r.Equal("false", res.Header().Get("X-KEDA-HTTP-Cold-Start"), "expected X-KEDA-HTTP-Cold-Start false") r.Equal(502, res.Code, "response code was unexpected") } @@ -166,13 +168,19 @@ func TestTimesOutOnWaitFunc(t *testing.T) { t.Logf("elapsed time was %s", elapsed) // serving should take at least timeouts.DeploymentReplicas, but no more than - // timeouts.DeploymentReplicas*2 - // elapsed time should be more than the deployment replicas wait time - // but not an amount that is much greater than that + // timeouts.DeploymentReplicas*4 r.GreaterOrEqual(elapsed, timeouts.DeploymentReplicas) r.LessOrEqual(elapsed, timeouts.DeploymentReplicas*4) r.Equal(502, res.Code, "response code was unexpected") + // we will always return the X-KEDA-HTTP-Cold-Start header + // when we are able to forward the + // request to the backend but not if we have failed due + // to a timeout from a waitFunc or earlier in the pipeline, + // for example, if we cannot reach the Kubernetes control + // plane. + r.Equal("", res.Header().Get("X-KEDA-HTTP-Cold-Start"), "expected X-KEDA-HTTP-Cold-Start to be empty") + // waitFunc should have been called, even though it timed out waitFuncCalled := false select { @@ -277,8 +285,8 @@ func TestWaitHeaderTimeout(t *testing.T) { timeouts := defaultTimeouts() dialCtxFunc := retryDialContextFunc(timeouts, timeouts.DefaultBackoff()) - waitFunc := func(context.Context, string, string) error { - return nil + waitFunc := func(context.Context, string, string) (int, error) { + return 1, nil } routingTable := routing.NewTable() target := routing.NewTarget( @@ -309,6 +317,7 @@ func TestWaitHeaderTimeout(t *testing.T) { hdl.ServeHTTP(res, req) + r.Equal("false", res.Header().Get("X-KEDA-HTTP-Cold-Start"), "expected X-KEDA-HTTP-Cold-Start false") r.Equal(502, res.Code, "response code was unexpected") close(originHdlCh) } @@ -346,19 +355,19 @@ func waitForSignal(sig <-chan struct{}, waitDur time.Duration) error { // is called, or the context that is passed to it is done (e.g. cancelled, timed out, // etc...). in the former case, the returned func itself returns nil. in the latter, // it returns ctx.Err() -func notifyingFunc() (func(context.Context, string, string) error, <-chan struct{}, func()) { +func notifyingFunc() (forwardWaitFunc, <-chan struct{}, func()) { calledCh := make(chan struct{}) finishCh := make(chan struct{}) finishFunc := func() { close(finishCh) } - return func(ctx context.Context, _, _ string) error { + return func(ctx context.Context, _, _ string) (int, error) { close(calledCh) select { case <-finishCh: - return nil + return 0, nil case <-ctx.Done(): - return fmt.Errorf("TEST FUNCTION CONTEXT ERROR: %w", ctx.Err()) + return 0, fmt.Errorf("TEST FUNCTION CONTEXT ERROR: %w", ctx.Err()) } }, calledCh, finishFunc }