diff --git a/lib/vnet/ssh_proxy.go b/lib/vnet/ssh_proxy.go index 9dd2c2a0235e6..df82175c25fd7 100644 --- a/lib/vnet/ssh_proxy.go +++ b/lib/vnet/ssh_proxy.go @@ -19,13 +19,12 @@ package vnet import ( "context" "errors" + "io" "log/slog" "sync" "github.com/gravitational/trace" "golang.org/x/crypto/ssh" - - "github.com/gravitational/teleport/lib/utils" ) // sshConn represents an established SSH client or server connection. @@ -171,60 +170,134 @@ func proxyChannel( return } - // Copy channel requests in both directions concurrently. If either fails or - // exits it will cancel the context so that utils.ProxyConn below will close - // both channels so the other goroutine can also exit. + // Copy channel data and requests from the incoming channel to the target + // channel, and vice-versa. + target := newSSHChan(targetChan, targetChanRequests, slog.With("direction", "client->target")) + incoming := newSSHChan(incomingChan, incomingChanRequests, slog.With("direction", "target->client")) + var wg sync.WaitGroup wg.Add(2) - ctx, cancel := context.WithCancel(ctx) go func() { - proxyChannelRequests(ctx, log, targetChan, incomingChanRequests, cancel) - cancel() + target.writeFrom(ctx, incoming) wg.Done() }() go func() { - proxyChannelRequests(ctx, log, incomingChan, targetChanRequests, cancel) - cancel() + incoming.writeFrom(ctx, target) wg.Done() }() + wg.Wait() +} - // ProxyConn copies channel data bidirectionally. If the context is - // canceled it will terminate, it always closes both channels before - // returning. - if err := utils.ProxyConn(ctx, incomingChan, targetChan); err != nil && - !utils.IsOKNetworkError(err) && !errors.Is(err, context.Canceled) { - log.DebugContext(ctx, "Unexpected error proxying channel data", "error", err) +// sshChan manages all writes to an SSH channel and handles closing the channel +// once no more data or requests will be written to it. +type sshChan struct { + ch ssh.Channel + requests <-chan *ssh.Request + log *slog.Logger +} + +func newSSHChan(ch ssh.Channel, requests <-chan *ssh.Request, log *slog.Logger) *sshChan { + return &sshChan{ + ch: ch, + requests: requests, + log: log, } +} + +// writeFrom writes channel data and requests from the source to this SSH channel. +// +// In the happy path it waits for: +// - channel data reads from source to return EOF +// - the source request channel to be closed +// and then closes this channel. +// +// Channel data reads from source can return EOF at any time if it has sent +// SSH_MSG_CHANNEL_EOF but it is still valid to send more channel requests +// after this. +// +// If an unrecoverable error is encountered it immediately closes both +// channels. +func (c *sshChan) writeFrom(ctx context.Context, source *sshChan) { + // Close the channel after all data and request writes are complete. + defer c.ch.Close() - // Wait for all goroutines to terminate. + var wg sync.WaitGroup + wg.Add(2) + go func() { + c.writeDataFrom(ctx, source) + wg.Done() + }() + go func() { + c.writeRequestsFrom(ctx, source) + wg.Done() + }() wg.Wait() } -func proxyChannelRequests( - ctx context.Context, - log *slog.Logger, - targetChan ssh.Channel, - reqs <-chan *ssh.Request, - closeChannels func(), -) { - log = log.With("request_layer", "channel") +// writeDataFrom writes channel data from source to this SSH channel. +// It handles standard channel data and extended channel data of type stderr. +func (c *sshChan) writeDataFrom(ctx context.Context, source *sshChan) { + // Close the channel for writes only after both the standard and stderr + // streams are finished writing. + defer c.ch.CloseWrite() + + errors := make(chan error, 2) + go func() { + _, err := io.Copy(c.ch, source.ch) + errors <- err + }() + go func() { + _, err := io.Copy(c.ch.Stderr(), source.ch.Stderr()) + errors <- err + }() + + // Read both errors to make sure both goroutines terminate, but only do + // anything on the first non-nil error, the second error is likely either + // the same as the first one or caused by closing the channel. + handledError := false + for range 2 { + err := <-errors + if err != nil && !handledError { + handledError = true + // Failed to write channel data from source to this channel. This was + // not an EOF from source or io.Copy would have returned nil. The + // stream might be missing data so close both channels. + // + // This should also unblock the stderr stream if the regular stream + // returned an error, and vice-versa. + c.log.ErrorContext(ctx, "Fatal error proxying SSH channel data", "error", err) + c.ch.Close() + source.ch.Close() + } + } +} + +// writeRequestsFrom forwards channel requests from source to this SSH channel. +func (c *sshChan) writeRequestsFrom(ctx context.Context, source *sshChan) { + log := c.log.With("request_layer", "channel") sendRequest := func(name string, wantReply bool, payload []byte) (bool, []byte, error) { - ok, err := targetChan.SendRequest(name, wantReply, payload) + ok, err := c.ch.SendRequest(name, wantReply, payload) // Replies to channel requests never have a payload. return ok, nil, err } - proxyRequests(ctx, log, sendRequest, reqs, closeChannels) + // Must forcibly close both channels if there was a fatal error proxying + // channel requests so that we don't continue in a bad state. + onFatalError := func() { + c.ch.Close() + source.ch.Close() + } + proxyRequests(ctx, log, sendRequest, source.requests, onFatalError) } func proxyGlobalRequests( ctx context.Context, targetConn ssh.Conn, reqs <-chan *ssh.Request, - closeConnections func(), + onFatalError func(), ) { log := log.With("request_layer", "global") sendRequest := targetConn.SendRequest - proxyRequests(ctx, log, sendRequest, reqs, closeConnections) + proxyRequests(ctx, log, sendRequest, reqs, onFatalError) } func proxyRequests( @@ -232,7 +305,7 @@ func proxyRequests( log *slog.Logger, sendRequest func(name string, wantReply bool, payload []byte) (bool, []byte, error), reqs <-chan *ssh.Request, - closeRequestSources func(), + onFatalError func(), ) { for req := range reqs { log := log.With("request_type", req.Type) @@ -240,23 +313,20 @@ func proxyRequests( ok, reply, err := sendRequest(req.Type, req.WantReply, req.Payload) if err != nil { // We failed to send the request, the target must be dead. - log.DebugContext(ctx, "Failed to forward SSH request", "request_type", req.Type, "error", err) - // Close both connections or channels to clean up but we must - // continue handling requests on the chan until it is closed by - // crypto/ssh. - closeRequestSources() - _ = req.Reply(false, nil) - continue + log.DebugContext(ctx, "Failed to forward SSH request", "error", err) + onFatalError() + req.Reply(false, nil) + ssh.DiscardRequests(reqs) + return } if err := req.Reply(ok, reply); err != nil { // A reply was expected and returned by the target but we failed to // forward it back, the connection that initiated the request must // be dead. - log.DebugContext(ctx, "Failed to reply to SSH request", "request_type", req.Type, "error", err) - // Close both connections or channels to clean up but we must - // continue handling requests on the chan until it is closed by - // crypto/ssh. - closeRequestSources() + log.DebugContext(ctx, "Failed to reply to SSH request", "error", err) + onFatalError() + ssh.DiscardRequests(reqs) + return } } } diff --git a/lib/vnet/ssh_proxy_test.go b/lib/vnet/ssh_proxy_test.go index 4c9b0e026ab80..d7eb8313c8a11 100644 --- a/lib/vnet/ssh_proxy_test.go +++ b/lib/vnet/ssh_proxy_test.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "net" + "sync" "testing" "github.com/gravitational/trace" @@ -108,11 +109,17 @@ func testSSHConnection(t *testing.T, dial dialer) { } func testConnectionToSshEchoServer(t *testing.T, sshConn ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { - go ssh.DiscardRequests(reqs) + requestStreamEnded := make(chan struct{}) + go func() { + ssh.DiscardRequests(reqs) + close(requestStreamEnded) + }() + chanStreamEnded := make(chan struct{}) go func() { for newChan := range chans { newChan.Reject(ssh.Prohibited, "test") } + close(chanStreamEnded) }() // Try sending some global requests. @@ -136,6 +143,26 @@ func testConnectionToSshEchoServer(t *testing.T, sshConn ssh.Conn, chans <-chan t.Run("echo channel 2", func(t *testing.T) { testEchoChannel(t, sshConn) }) + + t.Run("closing", func(t *testing.T) { + // Send a request that causes the target server to close the connection + // immediately and make sure channel reads are unblocked, and the global + // request and channel request streams end. + ch, reqs, err := sshConn.OpenChannel("echo", nil) + require.NoError(t, err) + go ssh.DiscardRequests(reqs) + readErr := make(chan error) + go func() { + var b [1]byte + _, err := ch.Read(b[:]) + readErr <- err + }() + _, _, err = sshConn.SendRequest("close", false, nil) + require.NoError(t, err) + require.ErrorIs(t, <-readErr, io.EOF) + <-requestStreamEnded + <-chanStreamEnded + }) } func testGlobalRequests(t *testing.T, conn ssh.Conn) { @@ -156,7 +183,11 @@ func testGlobalRequests(t *testing.T, conn ssh.Conn) { func testEchoChannel(t *testing.T, conn ssh.Conn) { ch, reqs, err := conn.OpenChannel("echo", nil) require.NoError(t, err) - go ssh.DiscardRequests(reqs) + requestStreamEnded := make(chan struct{}) + go func() { + ssh.DiscardRequests(reqs) + close(requestStreamEnded) + }() defer ch.Close() // Try sending a message over the SSH channel and asserting that it is @@ -170,16 +201,43 @@ func testEchoChannel(t *testing.T, conn ssh.Conn) { require.Equal(t, len(msg), n) require.Equal(t, msg, buf[:n]) + // Try sending a message over stderr and asserting that it is echoed back. + _, err = ch.Stderr().Write(msg) + require.NoError(t, err) + n, err = ch.Stderr().Read(buf[:]) + require.NoError(t, err) + require.Equal(t, len(msg), n) + require.Equal(t, msg, buf[:n]) + // Try sending a channel request that expects a reply. reply, err := ch.SendRequest("echo", true, nil) require.NoError(t, err) require.True(t, reply) + // Close the channel for writes of in-band data and send another channel + // request, which should succeed. + require.NoError(t, ch.CloseWrite()) + reply, err = ch.SendRequest("echo", true, nil) + require.NoError(t, err) + require.True(t, reply) + // The test server replies false to channel requests with type other than // "echo". reply, err = ch.SendRequest("unknown", true, nil) require.NoError(t, err) require.False(t, reply) + + // Send a channel request that causes the server to close the channel and + // make sure channel reads get unblocked and the incoming request stream ends. + readErr := make(chan error) + go func() { + _, err := ch.Read(buf[:]) + readErr <- err + }() + _, err = ch.SendRequest("close", false, nil) + require.NoError(t, err) + require.ErrorIs(t, <-readErr, io.EOF) + <-requestStreamEnded } type dialer interface { @@ -282,7 +340,7 @@ func runTestSSHServerInstance(tcpConn net.Conn, cfg *ssh.ServerConfig) error { return trace.Wrap(err) } go func() { - handleEchoRequests(reqs) + handleSSHRequests(reqs, sshConn.Close) sshConn.Close() }() handleEchoChannels(chans) @@ -290,17 +348,6 @@ func runTestSSHServerInstance(tcpConn net.Conn, cfg *ssh.ServerConfig) error { return nil } -func handleEchoRequests(reqs <-chan *ssh.Request) { - for req := range reqs { - switch req.Type { - case "echo": - req.Reply(true, req.Payload) - default: - req.Reply(false, nil) - } - } -} - func handleEchoChannels(chans <-chan ssh.NewChannel) { for newChan := range chans { switch newChan.ChannelType() { @@ -317,8 +364,33 @@ func handleEchoChannel(newChan ssh.NewChannel) { if err != nil { return } - go handleEchoRequests(reqs) - io.Copy(ch, ch) + go handleSSHRequests(reqs, ch.Close) + defer ch.CloseWrite() + var wg sync.WaitGroup + wg.Add(2) + go func() { + io.Copy(ch, ch) + wg.Done() + }() + go func() { + io.Copy(ch.Stderr(), ch.Stderr()) + wg.Done() + }() + wg.Wait() +} + +func handleSSHRequests(reqs <-chan *ssh.Request, closeSource func() error) { + defer closeSource() + for req := range reqs { + switch req.Type { + case "echo": + req.Reply(true, req.Payload) + case "close": + closeSource() + default: + req.Reply(false, nil) + } + } } func sshServerConfig(t *testing.T) *ssh.ServerConfig {