Skip to content

Commit 2ef6c21

Browse files
author
David Pinheiro
committed
Finish implementation of insecure mode
This change rewrite how the insecure flag dictates the implementation of the host key validation callback and add unit tests to cover the changes.
1 parent 7782e5c commit 2ef6c21

File tree

4 files changed

+82
-42
lines changed

4 files changed

+82
-42
lines changed

cli/cli.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func (c *App) Parse() error {
5656
f.BoolVar(&c.Version, "version", false, "display the mole version")
5757
f.BoolVar(&c.Detach, "detach", false, "(optional) run process in background")
5858
f.StringVar(&c.Stop, "stop", "", "stop background process")
59-
f.BoolVar(&c.InsecureMode, "insecure", false, "(optional) ignore unknown host keys when connecting to an ssh server")
59+
f.BoolVar(&c.InsecureMode, "insecure", false, "(optional) skip host key validation when connecting to ssh server")
6060

6161
f.Parse(c.args[1:])
6262

cmd/mole/main.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ func start(app cli.App) error {
214214
return err
215215
}
216216

217-
s.SetInsecureMode(app.InsecureMode)
217+
s.Insecure = app.InsecureMode
218218

219219
log.Debugf("server: %s", s)
220220

tunnel/tunnel.go

+23-25
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ const (
2222

2323
// Server holds the SSH Server attributes used for the client to connect to it.
2424
type Server struct {
25-
Name string
26-
Address string
27-
User string
28-
Key string
29-
insecure bool
25+
Name string
26+
Address string
27+
User string
28+
Key string
29+
// Insecure is a flag to indicate if the host keys should be validated.
30+
Insecure bool
3031
}
3132

3233
// NewServer creates a new instance of Server using $HOME/.ssh/config to
@@ -92,11 +93,6 @@ func (s Server) String() string {
9293
return fmt.Sprintf("[name=%s, address=%s, user=%s, key=%s]", s.Name, s.Address, s.User, s.Key)
9394
}
9495

95-
// Set whether or not to check the known_hosts file
96-
func (s *Server) SetInsecureMode(flag bool) {
97-
s.insecure = flag
98-
}
99-
10096
// Tunnel represents the ssh tunnel used to forward a local connection to a
10197
// a remote endpoint through a ssh server.
10298
type Tunnel struct {
@@ -263,7 +259,7 @@ func sshClientConfig(server Server) (*ssh.ClientConfig, error) {
263259
return nil, err
264260
}
265261

266-
callback, err := knownHostsCallback(server)
262+
clb, err := knownHostsCallback(server.Insecure)
267263
if err != nil {
268264
return nil, err
269265
}
@@ -273,7 +269,7 @@ func sshClientConfig(server Server) (*ssh.ClientConfig, error) {
273269
Auth: []ssh.AuthMethod{
274270
ssh.PublicKeys(signer),
275271
},
276-
HostKeyCallback: callback,
272+
HostKeyCallback: clb,
277273
Timeout: 3 * time.Second,
278274
}, nil
279275
}
@@ -285,23 +281,25 @@ func copyConn(writer, reader net.Conn) {
285281
}
286282
}
287283

288-
func knownHostsCallback(s Server) (ssh.HostKeyCallback, error) {
289-
knownHostFile := filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts")
290-
291-
log.Debugf("known_hosts file used: %s", knownHostFile)
292-
293-
secureCallback, err := knownhosts.New(knownHostFile)
294-
if err != nil {
295-
return nil, fmt.Errorf("error while parsing 'known_hosts' file: %s: %v", knownHostFile, err)
296-
}
284+
func knownHostsCallback(insecure bool) (ssh.HostKeyCallback, error) {
285+
var clb func(hostname string, remote net.Addr, key ssh.PublicKey) error
297286

298-
callback := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
299-
if s.insecure {
287+
if insecure {
288+
clb = func(hostname string, remote net.Addr, key ssh.PublicKey) error {
300289
return nil
301290
}
302-
return secureCallback(hostname, remote, key)
291+
} else {
292+
var err error
293+
knownHostFile := filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts")
294+
log.Debugf("known_hosts file used: %s", knownHostFile)
295+
296+
clb, err = knownhosts.New(knownHostFile)
297+
if err != nil {
298+
return nil, fmt.Errorf("error while parsing 'known_hosts' file: %s: %v", knownHostFile, err)
299+
}
303300
}
304-
return callback, nil
301+
302+
return clb, nil
305303
}
306304

307305
func reconcileHostname(givenHostname, resolvedHostname string) string {

tunnel/tunnel_test.go

+57-15
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,10 @@ func TestTunnelOptions(t *testing.T) {
166166

167167
}
168168

169+
//TODO teardown the tunnel
169170
func TestTunnel(t *testing.T) {
170171
expected := "ABC"
171-
tun := prepareTunnel(t)
172+
tun := prepareTunnel(t, false)
172173

173174
select {
174175
case <-tun.Ready:
@@ -191,6 +192,37 @@ func TestTunnel(t *testing.T) {
191192
if expected != response {
192193
t.Errorf("expected: %s, value: %s", expected, response)
193194
}
195+
196+
tun.Stop()
197+
}
198+
199+
func TestInsecureTunnel(t *testing.T) {
200+
expected := "ABC"
201+
tun := prepareTunnel(t, true)
202+
203+
select {
204+
case <-tun.Ready:
205+
t.Log("tunnel is ready to accept connections")
206+
case <-time.After(1 * time.Second):
207+
t.Errorf("no connection after a while")
208+
return
209+
}
210+
211+
resp, err := http.Get(fmt.Sprintf("http://%s/%s", tun.listener.Addr(), expected))
212+
if err != nil {
213+
t.Errorf("error while making local connection: %v", err)
214+
return
215+
}
216+
defer resp.Body.Close()
217+
218+
body, _ := ioutil.ReadAll(resp.Body)
219+
response := string(body)
220+
221+
if expected != response {
222+
t.Errorf("expected: %s, value: %s", expected, response)
223+
}
224+
225+
tun.Stop()
194226
}
195227

196228
func TestRandomLocalPort(t *testing.T) {
@@ -244,13 +276,18 @@ func TestMain(m *testing.M) {
244276

245277
// prepareTunnel creates a Tunnel object making sure all infrastructure
246278
// dependencies (ssh and http servers) are ready.
247-
func prepareTunnel(t *testing.T) *Tunnel {
248-
sshAddr := createSSHServer(keyPath)
249-
generateKnownHosts(sshAddr.String(), publicKeyPath, knownHostsPath)
250-
s, _ := NewServer("mole", sshAddr.String(), "")
279+
func prepareTunnel(t *testing.T, insecure bool) *Tunnel {
280+
ssh := createSSHServer(keyPath)
281+
srv, _ := NewServer("mole", ssh.Addr().String(), "")
251282

252-
httpAddr := createWebServer()
253-
tun := &Tunnel{local: "127.0.0.1:0", server: s, remote: httpAddr.String(), done: make(chan error), Ready: make(chan bool, 1)}
283+
srv.Insecure = insecure
284+
285+
if !insecure {
286+
generateKnownHosts(ssh.Addr().String(), publicKeyPath, knownHostsPath)
287+
}
288+
289+
web := createWebServer()
290+
tun := &Tunnel{local: "127.0.0.1:0", server: srv, remote: web.Addr().String(), done: make(chan error), Ready: make(chan bool, 1)}
254291

255292
go func(t *testing.T) {
256293
err := tun.Start()
@@ -322,19 +359,24 @@ func get(client http.Client, resource string) (string, error) {
322359
//
323360
// Example: If the request URI is /this-is-a-test, the response will be
324361
// this-is-a-test
325-
func createWebServer() net.Addr {
362+
func createWebServer() net.Listener {
326363

327364
handler := func(w http.ResponseWriter, r *http.Request) {
328365
fmt.Fprintf(w, r.URL.Path[1:])
329366
}
330-
http.HandleFunc("/", handler)
367+
368+
mux := http.NewServeMux()
369+
mux.HandleFunc("/", handler)
370+
371+
server := &http.Server{
372+
Handler: mux,
373+
}
374+
331375
l, _ := net.Listen("tcp", "127.0.0.1:0")
332376

333-
go func(l net.Listener) {
334-
http.Serve(l, nil)
335-
}(l)
377+
go server.Serve(l)
336378

337-
return l.Addr()
379+
return l
338380
}
339381

340382
// createSSHServer starts a SSH server that authenticates connections using
@@ -347,7 +389,7 @@ func createWebServer() net.Addr {
347389
// References:
348390
// https://gist.github.com/jpillora/b480fde82bff51a06238
349391
// https://tools.ietf.org/html/rfc4254#section-7.2
350-
func createSSHServer(keyPath string) net.Addr {
392+
func createSSHServer(keyPath string) net.Listener {
351393
conf := &ssh.ServerConfig{
352394
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
353395
return &ssh.Permissions{}, nil
@@ -398,7 +440,7 @@ func createSSHServer(keyPath string) net.Addr {
398440
}
399441
}(l)
400442

401-
return l.Addr()
443+
return l
402444
}
403445

404446
// generateKnownHosts creates a new "known_hosts" file on a given path with a

0 commit comments

Comments
 (0)