diff --git a/src/pkg/cli/cert.go b/src/pkg/cli/cert.go index 70219265b..8e4946af7 100644 --- a/src/pkg/cli/cert.go +++ b/src/pkg/cli/cert.go @@ -2,6 +2,7 @@ package cli import ( "context" + "crypto/tls" "errors" "fmt" "io" @@ -111,7 +112,7 @@ func GenerateLetsEncryptCert(ctx context.Context, client client.FabricClient, pr } func generateCert(ctx context.Context, domain string, targets []string, client client.FabricClient) { - term.Infof("Triggering TLS cert generation for %v", domain) + term.Infof("Checking DNS setup for %v", domain) if err := waitForCNAME(ctx, domain, targets, client); err != nil { term.Errorf("Error waiting for CNAME: %v", err) return @@ -198,7 +199,6 @@ func waitForCNAME(ctx context.Context, domain string, targets []string, client c ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() - msgShown := false serverSideVerified := false serverVerifyRpcFailure := 0 doSpinner := term.StdoutCanColor() && term.IsTerminal() @@ -211,41 +211,50 @@ func waitForCNAME(ctx context.Context, domain string, targets []string, client c defer cancelSpinner() } + verifyDNS := func() error { + if !serverSideVerified && serverVerifyRpcFailure < 3 { + if err := client.VerifyDNSSetup(ctx, &defangv1.VerifyDNSSetupRequest{Domain: domain, Targets: targets}); err == nil { + term.Debugf("Server side DNS verification for %v successful", domain) + serverSideVerified = true + } else { + if cerr := new(connect.Error); errors.As(err, &cerr) && cerr.Code() == connect.CodeFailedPrecondition { + term.Debugf("Server side DNS verification negative result: %v", cerr.Message()) + } else { + term.Debugf("Server side DNS verification request for %v failed: %v", domain, err) + serverVerifyRpcFailure++ + } + } + if serverVerifyRpcFailure >= 3 { + term.Warnf("Server side DNS verification for %v failed multiple times, skipping server side DNS verification.", domain) + } + } + if serverSideVerified || serverVerifyRpcFailure >= 3 { + locallyVerified := dns.CheckDomainDNSReady(ctx, domain, targets) + if serverSideVerified && !locallyVerified { + term.Warnf("The DNS configuration for %v has been successfully verified. However, your local environment may still be using cached data, so it could take several minutes for the DNS changes to propagate on your system.", domain) + return nil + } + if locallyVerified { + return nil + } + } + return errors.New("not verified") + } + + if err := verifyDNS(); err == nil { + return nil + } + term.Infof("Configure a CNAME or ALIAS record for the domain name: %v", domain) + fmt.Printf(" %v -> %v\n", domain, strings.Join(targets, " or ")) + term.Infof("Awaiting DNS record setup and propagation... This may take a while.") + for { select { case <-ctx.Done(): return ctx.Err() case <-ticker.C: - if !serverSideVerified && serverVerifyRpcFailure < 3 { - if err := client.VerifyDNSSetup(ctx, &defangv1.VerifyDNSSetupRequest{Domain: domain, Targets: targets}); err == nil { - term.Debugf("Server side DNS verification for %v successful", domain) - serverSideVerified = true - } else { - if cerr := new(connect.Error); errors.As(err, &cerr) && cerr.Code() == connect.CodeFailedPrecondition { - term.Debugf("Server side DNS verification negative result: %v", cerr.Message()) - } else { - term.Debugf("Server side DNS verification request for %v failed: %v", domain, err) - serverVerifyRpcFailure++ - } - } - if serverVerifyRpcFailure >= 3 { - term.Warnf("Server side DNS verification for %v failed multiple times, skipping server side DNS verification.", domain) - } - } else { - locallyVerified := dns.CheckDomainDNSReady(ctx, domain, targets) - if serverSideVerified && !locallyVerified { - term.Warnf("The DNS configuration for %v has been successfully verified. However, your local environment may still be using cached data, so it could take several minutes for the DNS changes to propagate on your system.", domain) - return nil - } - if locallyVerified { - return nil - } - } - if !msgShown { - term.Infof("Please set up a CNAME record for %v", domain) - fmt.Printf(" %v CNAME or as an alias to [ %v ]\n", domain, strings.Join(targets, " or ")) - term.Infof("Waiting for CNAME record setup and DNS propagation...") - msgShown = true + if err := verifyDNS(); err == nil { + return nil } } } @@ -261,14 +270,20 @@ func getWithRetries(ctx context.Context, url string, tries int) error { resp, err := httpClient.Do(req) if err == nil { defer resp.Body.Close() - var msg []byte - msg, err = io.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) // Read the body to ensure the request is not swallowed by alb if resp.StatusCode == http.StatusOK { return nil } + if resp != nil && resp.Request != nil && resp.Request.URL.Scheme == "https" { + term.Debugf("cert gen request success, received redirect to %v", resp.Request.URL) + return nil // redirect to https indicate a successful cert generation + } if err == nil { - err = fmt.Errorf("HTTP %v: %v", resp.StatusCode, string(msg)) + err = fmt.Errorf("HTTP: %v", resp.StatusCode) } + } else if cve := new(tls.CertificateVerificationError); errors.As(err, &cve) { + term.Debugf("cert gen request success, received tls error: %v", cve) + return nil // tls error indicate a successful cert gen trigger, as it has to be redirected to https } term.Debugf("Error fetching %v: %v, tries left %v", url, err, tries-i-1) diff --git a/src/pkg/cli/cert_test.go b/src/pkg/cli/cert_test.go index 1dfd460ae..6f0ee27d4 100644 --- a/src/pkg/cli/cert_test.go +++ b/src/pkg/cli/cert_test.go @@ -2,6 +2,7 @@ package cli import ( "context" + "crypto/tls" "errors" "fmt" "io" @@ -32,6 +33,9 @@ func (c *testClient) Do(req *http.Request) (*http.Response, error) { } tr := c.tries[0] c.tries = c.tries[1:] + if tr.result != nil && tr.result.Request == nil { + tr.result.Request = req + } return tr.result, tr.err } @@ -88,13 +92,44 @@ func TestGetWithRetries(t *testing.T) { err := getWithRetries(context.Background(), "http://example.com", 3) if err == nil { t.Errorf("Expected error, got %v", err) - } else if !strings.Contains(err.Error(), "HTTP 503: Random Error") { - t.Errorf("Expected HTTP 503: Random Error, got %v", err) + } else if !strings.Contains(err.Error(), "HTTP: 503") { + t.Errorf("Expected HTTP 503:, got %v", err) } if tc.calls != 3 { t.Errorf("Expected 3 calls, got %v", tc.calls) } }) + t.Run("redirect to https considers success", func(t *testing.T) { + redirectURL, _ := url.Parse("https://example.com") + tc := &testClient{tries: []tryResult{ + {result: &http.Response{StatusCode: 503, Request: &http.Request{URL: redirectURL}, Body: mockBody("Random Error")}, err: nil}, + }} + originalClient := httpClient + t.Cleanup(func() { httpClient = originalClient }) + httpClient = tc + err := getWithRetries(context.Background(), "http://example.com", 3) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if tc.calls != 1 { + t.Errorf("Expected 1 call, got %v", tc.calls) + } + }) + t.Run("TLS error considers success", func(t *testing.T) { + tc := &testClient{tries: []tryResult{ + {result: nil, err: &tls.CertificateVerificationError{Err: errors.New("error")}}, + }} + originalClient := httpClient + t.Cleanup(func() { httpClient = originalClient }) + httpClient = tc + err := getWithRetries(context.Background(), "http://example.com", 3) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if tc.calls != 1 { + t.Errorf("Expected 1 call, got %v", tc.calls) + } + }) t.Run("handles all http errors", func(t *testing.T) { tc := &testClient{tries: []tryResult{ {result: &http.Response{StatusCode: 404, Body: mockBody("Random Error")}, err: nil}, @@ -107,7 +142,7 @@ func TestGetWithRetries(t *testing.T) { err := getWithRetries(context.Background(), "http://example.com", 3) if err == nil { t.Errorf("Expected error, got %v", err) - } else if !strings.Contains(err.Error(), "HTTP 404: Random Error") || !strings.Contains(err.Error(), "HTTP 502: Random Error") || !strings.Contains(err.Error(), "HTTP 503: Random Error") { + } else if !strings.Contains(err.Error(), "HTTP: 404") || !strings.Contains(err.Error(), "HTTP: 502") || !strings.Contains(err.Error(), "HTTP: 503") { t.Errorf("Expected HTTP 404,502,503 erros, got %v", err) } if tc.calls != 3 {