Skip to content

Consider redirection a success response for cert gen trigger #832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 50 additions & 35 deletions src/pkg/cli/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cli

import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -111,7 +112,7 @@ func GenerateLetsEncryptCert(ctx context.Context, client client.FabricClient, pr
}

func generateCert(ctx context.Context, domain string, targets []string, client client.FabricClient) {
term.Infof("Triggering TLS cert generation for %v", domain)
term.Infof("Checking DNS setup for %v", domain)
if err := waitForCNAME(ctx, domain, targets, client); err != nil {
term.Errorf("Error waiting for CNAME: %v", err)
return
Expand Down Expand Up @@ -198,7 +199,6 @@ func waitForCNAME(ctx context.Context, domain string, targets []string, client c
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()

msgShown := false
serverSideVerified := false
serverVerifyRpcFailure := 0
doSpinner := term.StdoutCanColor() && term.IsTerminal()
Expand All @@ -211,41 +211,50 @@ func waitForCNAME(ctx context.Context, domain string, targets []string, client c
defer cancelSpinner()
}

verifyDNS := func() error {
if !serverSideVerified && serverVerifyRpcFailure < 3 {
if err := client.VerifyDNSSetup(ctx, &defangv1.VerifyDNSSetupRequest{Domain: domain, Targets: targets}); err == nil {
term.Debugf("Server side DNS verification for %v successful", domain)
serverSideVerified = true
} else {
if cerr := new(connect.Error); errors.As(err, &cerr) && cerr.Code() == connect.CodeFailedPrecondition {
term.Debugf("Server side DNS verification negative result: %v", cerr.Message())
} else {
term.Debugf("Server side DNS verification request for %v failed: %v", domain, err)
serverVerifyRpcFailure++
}
}
if serverVerifyRpcFailure >= 3 {
term.Warnf("Server side DNS verification for %v failed multiple times, skipping server side DNS verification.", domain)
}
}
if serverSideVerified || serverVerifyRpcFailure >= 3 {
locallyVerified := dns.CheckDomainDNSReady(ctx, domain, targets)
if serverSideVerified && !locallyVerified {
term.Warnf("The DNS configuration for %v has been successfully verified. However, your local environment may still be using cached data, so it could take several minutes for the DNS changes to propagate on your system.", domain)
return nil
}
if locallyVerified {
return nil
}
}
return errors.New("not verified")
}

if err := verifyDNS(); err == nil {
return nil
}
term.Infof("Please set up a CNAME record for %v", domain)
fmt.Printf(" %v CNAME or as an alias to [ %v ]\n", domain, strings.Join(targets, " or "))
term.Infof("Waiting for CNAME record setup and DNS propagation...")

for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
if !serverSideVerified && serverVerifyRpcFailure < 3 {
if err := client.VerifyDNSSetup(ctx, &defangv1.VerifyDNSSetupRequest{Domain: domain, Targets: targets}); err == nil {
term.Debugf("Server side DNS verification for %v successful", domain)
serverSideVerified = true
} else {
if cerr := new(connect.Error); errors.As(err, &cerr) && cerr.Code() == connect.CodeFailedPrecondition {
term.Debugf("Server side DNS verification negative result: %v", cerr.Message())
} else {
term.Debugf("Server side DNS verification request for %v failed: %v", domain, err)
serverVerifyRpcFailure++
}
}
if serverVerifyRpcFailure >= 3 {
term.Warnf("Server side DNS verification for %v failed multiple times, skipping server side DNS verification.", domain)
}
} else {
locallyVerified := dns.CheckDomainDNSReady(ctx, domain, targets)
if serverSideVerified && !locallyVerified {
term.Warnf("The DNS configuration for %v has been successfully verified. However, your local environment may still be using cached data, so it could take several minutes for the DNS changes to propagate on your system.", domain)
return nil
}
if locallyVerified {
return nil
}
}
if !msgShown {
term.Infof("Please set up a CNAME record for %v", domain)
fmt.Printf(" %v CNAME or as an alias to [ %v ]\n", domain, strings.Join(targets, " or "))
term.Infof("Waiting for CNAME record setup and DNS propagation...")
msgShown = true
if err := verifyDNS(); err == nil {
return nil
}
}
}
Expand All @@ -261,14 +270,20 @@ func getWithRetries(ctx context.Context, url string, tries int) error {
resp, err := httpClient.Do(req)
if err == nil {
defer resp.Body.Close()
var msg []byte
msg, err = io.ReadAll(resp.Body)
_, err = io.ReadAll(resp.Body) // Read the body to ensure the request is not swallowed by alb
if resp.StatusCode == http.StatusOK {
return nil
}
if resp != nil && resp.Request != nil && resp.Request.URL.Scheme == "https" {
term.Debugf("cert gen request success, received redirect to %v", resp.Request.URL)
return nil // redirect to https indicate a successful cert generation
}
if err == nil {
err = fmt.Errorf("HTTP %v: %v", resp.StatusCode, string(msg))
err = fmt.Errorf("HTTP: %v", resp.StatusCode)
}
} else if cve := new(tls.CertificateVerificationError); errors.As(err, &cve) {
term.Debugf("cert gen request success, received tls error: %v", cve)
return nil // tls error indicate a successful cert generation, as it has to be redirected to https
}

term.Debugf("Error fetching %v: %v, tries left %v", url, err, tries-i-1)
Expand Down
41 changes: 38 additions & 3 deletions src/pkg/cli/cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cli

import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -32,6 +33,9 @@ func (c *testClient) Do(req *http.Request) (*http.Response, error) {
}
tr := c.tries[0]
c.tries = c.tries[1:]
if tr.result != nil && tr.result.Request == nil {
tr.result.Request = req
}
return tr.result, tr.err
}

Expand Down Expand Up @@ -88,13 +92,44 @@ func TestGetWithRetries(t *testing.T) {
err := getWithRetries(context.Background(), "http://example.com", 3)
if err == nil {
t.Errorf("Expected error, got %v", err)
} else if !strings.Contains(err.Error(), "HTTP 503: Random Error") {
t.Errorf("Expected HTTP 503: Random Error, got %v", err)
} else if !strings.Contains(err.Error(), "HTTP: 503") {
t.Errorf("Expected HTTP 503:, got %v", err)
}
if tc.calls != 3 {
t.Errorf("Expected 3 calls, got %v", tc.calls)
}
})
t.Run("redirect to https considers success", func(t *testing.T) {
redirectURL, _ := url.Parse("https://example.com")
tc := &testClient{tries: []tryResult{
{result: &http.Response{StatusCode: 503, Request: &http.Request{URL: redirectURL}, Body: mockBody("Random Error")}, err: nil},
}}
originalClient := httpClient
t.Cleanup(func() { httpClient = originalClient })
httpClient = tc
err := getWithRetries(context.Background(), "http://example.com", 3)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if tc.calls != 1 {
t.Errorf("Expected 1 call, got %v", tc.calls)
}
})
t.Run("TLS error considers success", func(t *testing.T) {
tc := &testClient{tries: []tryResult{
{result: nil, err: &tls.CertificateVerificationError{Err: errors.New("error")}},
}}
originalClient := httpClient
t.Cleanup(func() { httpClient = originalClient })
httpClient = tc
err := getWithRetries(context.Background(), "http://example.com", 3)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if tc.calls != 1 {
t.Errorf("Expected 1 call, got %v", tc.calls)
}
})
t.Run("handles all http errors", func(t *testing.T) {
tc := &testClient{tries: []tryResult{
{result: &http.Response{StatusCode: 404, Body: mockBody("Random Error")}, err: nil},
Expand All @@ -107,7 +142,7 @@ func TestGetWithRetries(t *testing.T) {
err := getWithRetries(context.Background(), "http://example.com", 3)
if err == nil {
t.Errorf("Expected error, got %v", err)
} else if !strings.Contains(err.Error(), "HTTP 404: Random Error") || !strings.Contains(err.Error(), "HTTP 502: Random Error") || !strings.Contains(err.Error(), "HTTP 503: Random Error") {
} else if !strings.Contains(err.Error(), "HTTP: 404") || !strings.Contains(err.Error(), "HTTP: 502") || !strings.Contains(err.Error(), "HTTP: 503") {
t.Errorf("Expected HTTP 404,502,503 erros, got %v", err)
}
if tc.calls != 3 {
Expand Down
Loading