Skip to content

Commit ebdc788

Browse files
authored
fix(term): confirmation prompts on windows (#413)
This fixes confirmation prompts on windows by scanning for a line of input and comparing that to the desired prompt. term.Confirm(...) is refactored slightly to make it easier to test. Fixes #347
1 parent d36d1e5 commit ebdc788

File tree

2 files changed

+65
-7
lines changed

2 files changed

+65
-7
lines changed

pkg/term/alert.go

+27-7
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,42 @@ package term
33
import (
44
"bufio"
55
"fmt"
6+
"io"
67
"os"
78

89
"github.com/pkg/errors"
910
)
1011

12+
var ErrConfirmationFailed = errors.New("aborted by user")
13+
1114
// Confirm asks the user for confirmation
1215
func Confirm(msg, approval string) error {
13-
reader := bufio.NewReader(os.Stdin)
14-
fmt.Println(msg)
15-
fmt.Printf("Please type '%s' to confirm: ", approval)
16-
read, err := reader.ReadString('\n')
16+
return confirmFrom(os.Stdin, os.Stdout, msg, approval)
17+
}
18+
19+
func confirmFrom(r io.Reader, w io.Writer, msg, approval string) error {
20+
reader := bufio.NewScanner(r)
21+
_, err := fmt.Fprintln(w, msg)
1722
if err != nil {
18-
return errors.Wrap(err, "reading from stdin")
23+
return errors.Wrap(err, "writing to stdout")
1924
}
20-
if read != approval+"\n" {
21-
return errors.New("aborted by user")
25+
26+
_, err = fmt.Fprintf(w, "Please type '%s' to confirm: ", approval)
27+
if err != nil {
28+
return errors.Wrap(err, "writing to stdout")
2229
}
30+
31+
if !reader.Scan() {
32+
if err := reader.Err(); err != nil {
33+
return errors.Wrap(err, "reading from stdin")
34+
}
35+
36+
return ErrConfirmationFailed
37+
}
38+
39+
if reader.Text() != approval {
40+
return ErrConfirmationFailed
41+
}
42+
2343
return nil
2444
}

pkg/term/alert_test.go

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package term
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestConfirm(t *testing.T) {
11+
tests := []struct {
12+
name string
13+
input string
14+
expected error
15+
}{
16+
{name: "linux yes", input: "yes\n", expected: nil},
17+
{name: "windows yes", input: "yes\r\n", expected: nil},
18+
{name: "linux no", input: "no\n", expected: ErrConfirmationFailed},
19+
{name: "windows no", input: "no\r\n", expected: ErrConfirmationFailed},
20+
}
21+
22+
for _, tt := range tests {
23+
t.Run(tt.name, func(t *testing.T) {
24+
in := strings.NewReader(tt.input)
25+
out := &strings.Builder{}
26+
27+
err := confirmFrom(in, out, "foo", "yes")
28+
29+
assert.Equal(t, "foo\nPlease type 'yes' to confirm: ", out.String())
30+
31+
if tt.expected != nil {
32+
assert.EqualError(t, err, tt.expected.Error())
33+
} else {
34+
assert.NoError(t, err)
35+
}
36+
})
37+
}
38+
}

0 commit comments

Comments
 (0)