Skip to content

[vnet] fix: close proxied channel only after data and requests are complete #56020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 112 additions & 42 deletions lib/vnet/ssh_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@ package vnet
import (
"context"
"errors"
"io"
"log/slog"
"sync"

"github.com/gravitational/trace"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport/lib/utils"
)

// sshConn represents an established SSH client or server connection.
Expand Down Expand Up @@ -171,92 +170,163 @@ func proxyChannel(
return
}

// Copy channel requests in both directions concurrently. If either fails or
// exits it will cancel the context so that utils.ProxyConn below will close
// both channels so the other goroutine can also exit.
// Copy channel data and requests from the incoming channel to the target
// channel, and vice-versa.
target := newSSHChan(targetChan, targetChanRequests, slog.With("direction", "client->target"))
incoming := newSSHChan(incomingChan, incomingChanRequests, slog.With("direction", "target->client"))

var wg sync.WaitGroup
wg.Add(2)
ctx, cancel := context.WithCancel(ctx)
go func() {
proxyChannelRequests(ctx, log, targetChan, incomingChanRequests, cancel)
cancel()
target.writeFrom(ctx, incoming)
wg.Done()
}()
go func() {
proxyChannelRequests(ctx, log, incomingChan, targetChanRequests, cancel)
cancel()
incoming.writeFrom(ctx, target)
wg.Done()
}()
wg.Wait()
}

// ProxyConn copies channel data bidirectionally. If the context is
// canceled it will terminate, it always closes both channels before
// returning.
if err := utils.ProxyConn(ctx, incomingChan, targetChan); err != nil &&
!utils.IsOKNetworkError(err) && !errors.Is(err, context.Canceled) {
log.DebugContext(ctx, "Unexpected error proxying channel data", "error", err)
// sshChan manages all writes to an SSH channel and handles closing the channel
// once no more data or requests will be written to it.
type sshChan struct {
ch ssh.Channel
requests <-chan *ssh.Request
log *slog.Logger
}

func newSSHChan(ch ssh.Channel, requests <-chan *ssh.Request, log *slog.Logger) *sshChan {
return &sshChan{
ch: ch,
requests: requests,
log: log,
}
}

// writeFrom writes channel data and requests from the source to this SSH channel.
//
// In the happy path it waits for:
// - channel data reads from source to return EOF
// - the source request channel to be closed
// and then closes this channel.
//
// Channel data reads from source can return EOF at any time if it has sent
// SSH_MSG_CHANNEL_EOF but it is still valid to send more channel requests
// after this.
//
// If an unrecoverable error is encountered it immediately closes both
// channels.
func (c *sshChan) writeFrom(ctx context.Context, source *sshChan) {
// Close the channel after all data and request writes are complete.
defer c.ch.Close()

// Wait for all goroutines to terminate.
var wg sync.WaitGroup
wg.Add(2)
go func() {
c.writeDataFrom(ctx, source)
wg.Done()
}()
go func() {
c.writeRequestsFrom(ctx, source)
wg.Done()
}()
wg.Wait()
}

func proxyChannelRequests(
ctx context.Context,
log *slog.Logger,
targetChan ssh.Channel,
reqs <-chan *ssh.Request,
closeChannels func(),
) {
log = log.With("request_layer", "channel")
// writeDataFrom writes channel data from source to this SSH channel.
// It handles standard channel data and extended channel data of type stderr.
func (c *sshChan) writeDataFrom(ctx context.Context, source *sshChan) {
// Close the channel for writes only after both the standard and stderr
// streams are finished writing.
defer c.ch.CloseWrite()

errors := make(chan error, 2)
go func() {
_, err := io.Copy(c.ch, source.ch)
errors <- err
}()
go func() {
_, err := io.Copy(c.ch.Stderr(), source.ch.Stderr())
errors <- err
}()

// Read both errors to make sure both goroutines terminate, but only do
// anything on the first non-nil error, the second error is likely either
// the same as the first one or caused by closing the channel.
handledError := false
for range 2 {
err := <-errors
if err != nil && !handledError {
handledError = true
// Failed to write channel data from source to this channel. This was
// not an EOF from source or io.Copy would have returned nil. The
// stream might be missing data so close both channels.
//
// This should also unblock the stderr stream if the regular stream
// returned an error, and vice-versa.
c.log.ErrorContext(ctx, "Fatal error proxying SSH channel data", "error", err)
c.ch.Close()
source.ch.Close()
}
}
}

// writeRequestsFrom forwards channel requests from source to this SSH channel.
func (c *sshChan) writeRequestsFrom(ctx context.Context, source *sshChan) {
log := c.log.With("request_layer", "channel")
sendRequest := func(name string, wantReply bool, payload []byte) (bool, []byte, error) {
ok, err := targetChan.SendRequest(name, wantReply, payload)
ok, err := c.ch.SendRequest(name, wantReply, payload)
// Replies to channel requests never have a payload.
return ok, nil, err
}
proxyRequests(ctx, log, sendRequest, reqs, closeChannels)
// Must forcibly close both channels if there was a fatal error proxying
// channel requests so that we don't continue in a bad state.
onFatalError := func() {
c.ch.Close()
source.ch.Close()
}
proxyRequests(ctx, log, sendRequest, source.requests, onFatalError)
}

func proxyGlobalRequests(
ctx context.Context,
targetConn ssh.Conn,
reqs <-chan *ssh.Request,
closeConnections func(),
onFatalError func(),
) {
log := log.With("request_layer", "global")
sendRequest := targetConn.SendRequest
proxyRequests(ctx, log, sendRequest, reqs, closeConnections)
proxyRequests(ctx, log, sendRequest, reqs, onFatalError)
}

func proxyRequests(
ctx context.Context,
log *slog.Logger,
sendRequest func(name string, wantReply bool, payload []byte) (bool, []byte, error),
reqs <-chan *ssh.Request,
closeRequestSources func(),
onFatalError func(),
) {
for req := range reqs {
log := log.With("request_type", req.Type)
log.DebugContext(ctx, "Proxying SSH request")
ok, reply, err := sendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
// We failed to send the request, the target must be dead.
log.DebugContext(ctx, "Failed to forward SSH request", "request_type", req.Type, "error", err)
// Close both connections or channels to clean up but we must
// continue handling requests on the chan until it is closed by
// crypto/ssh.
closeRequestSources()
_ = req.Reply(false, nil)
continue
log.DebugContext(ctx, "Failed to forward SSH request", "error", err)
onFatalError()
req.Reply(false, nil)
ssh.DiscardRequests(reqs)
return
}
if err := req.Reply(ok, reply); err != nil {
// A reply was expected and returned by the target but we failed to
// forward it back, the connection that initiated the request must
// be dead.
log.DebugContext(ctx, "Failed to reply to SSH request", "request_type", req.Type, "error", err)
// Close both connections or channels to clean up but we must
// continue handling requests on the chan until it is closed by
// crypto/ssh.
closeRequestSources()
log.DebugContext(ctx, "Failed to reply to SSH request", "error", err)
onFatalError()
ssh.DiscardRequests(reqs)
return
}
}
}
104 changes: 88 additions & 16 deletions lib/vnet/ssh_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"io"
"net"
"sync"
"testing"

"github.com/gravitational/trace"
Expand Down Expand Up @@ -108,11 +109,17 @@ func testSSHConnection(t *testing.T, dial dialer) {
}

func testConnectionToSshEchoServer(t *testing.T, sshConn ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) {
go ssh.DiscardRequests(reqs)
requestStreamEnded := make(chan struct{})
go func() {
ssh.DiscardRequests(reqs)
close(requestStreamEnded)
}()
chanStreamEnded := make(chan struct{})
go func() {
for newChan := range chans {
newChan.Reject(ssh.Prohibited, "test")
}
close(chanStreamEnded)
}()

// Try sending some global requests.
Expand All @@ -136,6 +143,26 @@ func testConnectionToSshEchoServer(t *testing.T, sshConn ssh.Conn, chans <-chan
t.Run("echo channel 2", func(t *testing.T) {
testEchoChannel(t, sshConn)
})

t.Run("closing", func(t *testing.T) {
// Send a request that causes the target server to close the connection
// immediately and make sure channel reads are unblocked, and the global
// request and channel request streams end.
ch, reqs, err := sshConn.OpenChannel("echo", nil)
require.NoError(t, err)
go ssh.DiscardRequests(reqs)
readErr := make(chan error)
go func() {
var b [1]byte
_, err := ch.Read(b[:])
readErr <- err
}()
_, _, err = sshConn.SendRequest("close", false, nil)
require.NoError(t, err)
require.ErrorIs(t, <-readErr, io.EOF)
<-requestStreamEnded
<-chanStreamEnded
})
}

func testGlobalRequests(t *testing.T, conn ssh.Conn) {
Expand All @@ -156,7 +183,11 @@ func testGlobalRequests(t *testing.T, conn ssh.Conn) {
func testEchoChannel(t *testing.T, conn ssh.Conn) {
ch, reqs, err := conn.OpenChannel("echo", nil)
require.NoError(t, err)
go ssh.DiscardRequests(reqs)
requestStreamEnded := make(chan struct{})
go func() {
ssh.DiscardRequests(reqs)
close(requestStreamEnded)
}()
defer ch.Close()

// Try sending a message over the SSH channel and asserting that it is
Expand All @@ -170,16 +201,43 @@ func testEchoChannel(t *testing.T, conn ssh.Conn) {
require.Equal(t, len(msg), n)
require.Equal(t, msg, buf[:n])

// Try sending a message over stderr and asserting that it is echoed back.
_, err = ch.Stderr().Write(msg)
require.NoError(t, err)
n, err = ch.Stderr().Read(buf[:])
require.NoError(t, err)
require.Equal(t, len(msg), n)
require.Equal(t, msg, buf[:n])

// Try sending a channel request that expects a reply.
reply, err := ch.SendRequest("echo", true, nil)
require.NoError(t, err)
require.True(t, reply)

// Close the channel for writes of in-band data and send another channel
// request, which should succeed.
require.NoError(t, ch.CloseWrite())
reply, err = ch.SendRequest("echo", true, nil)
require.NoError(t, err)
require.True(t, reply)

// The test server replies false to channel requests with type other than
// "echo".
reply, err = ch.SendRequest("unknown", true, nil)
require.NoError(t, err)
require.False(t, reply)

// Send a channel request that causes the server to close the channel and
// make sure channel reads get unblocked and the incoming request stream ends.
readErr := make(chan error)
go func() {
_, err := ch.Read(buf[:])
readErr <- err
}()
_, err = ch.SendRequest("close", false, nil)
require.NoError(t, err)
require.ErrorIs(t, <-readErr, io.EOF)
<-requestStreamEnded
}

type dialer interface {
Expand Down Expand Up @@ -282,25 +340,14 @@ func runTestSSHServerInstance(tcpConn net.Conn, cfg *ssh.ServerConfig) error {
return trace.Wrap(err)
}
go func() {
handleEchoRequests(reqs)
handleSSHRequests(reqs, sshConn.Close)
sshConn.Close()
}()
handleEchoChannels(chans)
sshConn.Close()
return nil
}

func handleEchoRequests(reqs <-chan *ssh.Request) {
for req := range reqs {
switch req.Type {
case "echo":
req.Reply(true, req.Payload)
default:
req.Reply(false, nil)
}
}
}

func handleEchoChannels(chans <-chan ssh.NewChannel) {
for newChan := range chans {
switch newChan.ChannelType() {
Expand All @@ -317,8 +364,33 @@ func handleEchoChannel(newChan ssh.NewChannel) {
if err != nil {
return
}
go handleEchoRequests(reqs)
io.Copy(ch, ch)
go handleSSHRequests(reqs, ch.Close)
defer ch.CloseWrite()
var wg sync.WaitGroup
wg.Add(2)
go func() {
io.Copy(ch, ch)
wg.Done()
}()
go func() {
io.Copy(ch.Stderr(), ch.Stderr())
wg.Done()
}()
wg.Wait()
}

func handleSSHRequests(reqs <-chan *ssh.Request, closeSource func() error) {
defer closeSource()
for req := range reqs {
switch req.Type {
case "echo":
req.Reply(true, req.Payload)
case "close":
closeSource()
default:
req.Reply(false, nil)
}
}
}

func sshServerConfig(t *testing.T) *ssh.ServerConfig {
Expand Down
Loading