Skip to content

Commit 318b601

Browse files
committed
Add fork after authentication for tsh ssh (#54696)
This change adds fork after authentication support to tsh ssh.
1 parent 9470f29 commit 318b601

File tree

10 files changed

+809
-12
lines changed

10 files changed

+809
-12
lines changed

lib/client/api.go

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,6 +1889,14 @@ type SSHOptions struct {
18891889
// machine. If provided, it will be used instead of establishing a connection
18901890
// to the target host and executing the command remotely.
18911891
LocalCommandExecutor func(string, []string) error
1892+
// OnChildAuthenticate is a function to run in the child process during
1893+
// --fork-after authentications. It runs after authentication completes
1894+
// but before the session begins.
1895+
OnChildAuthenticate func() error
1896+
}
1897+
1898+
func (opts SSHOptions) forkAfterAuthentication() bool {
1899+
return opts.OnChildAuthenticate != nil
18921900
}
18931901

18941902
// WithHostAddress returns a SSHOptions which overrides the
@@ -1907,6 +1915,15 @@ func WithLocalCommandExecutor(executor func(string, []string) error) func(*SSHOp
19071915
}
19081916
}
19091917

1918+
// WithForkAfterAuthentication indicates that tsh is currently reexec-ing
1919+
// for --fork-after-authentication. The given function is called after
1920+
// authentication is complete but before the session starts.
1921+
func WithForkAfterAuthentication(onAuthenticate func() error) func(*SSHOptions) {
1922+
return func(opt *SSHOptions) {
1923+
opt.OnChildAuthenticate = onAuthenticate
1924+
}
1925+
}
1926+
19101927
// SSH connects to a node and, if 'command' is specified, executes the command on it,
19111928
// otherwise runs interactive shell
19121929
//
@@ -1949,9 +1966,14 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, opts ...fun
19491966
}
19501967

19511968
if len(nodeAddrs) > 1 {
1969+
if options.forkAfterAuthentication() {
1970+
return &NonRetryableError{
1971+
Err: trace.BadParameter("fork after authentication not supported for commands on multiple nodes"),
1972+
}
1973+
}
19521974
return tc.runShellOrCommandOnMultipleNodes(ctx, clt, nodeAddrs, command)
19531975
}
1954-
return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0].Addr, command, options.LocalCommandExecutor)
1976+
return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0].Addr, command, options)
19551977
}
19561978

19571979
// ConnectToNode attempts to establish a connection to the node resolved to by the provided
@@ -2154,7 +2176,7 @@ func (tc *TeleportClient) connectToNodeWithMFA(ctx context.Context, clt *Cluster
21542176
return nodeClient, trace.Wrap(err)
21552177
}
21562178

2157-
func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt *ClusterClient, nodeAddr string, command []string, commandExecutor func(string, []string) error) error {
2179+
func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt *ClusterClient, nodeAddr string, command []string, options SSHOptions) error {
21582180
cluster := clt.ClusterName()
21592181
ctx, span := tc.Tracer.Start(
21602182
ctx,
@@ -2178,6 +2200,12 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt
21782200
return trace.Wrap(err)
21792201
}
21802202
defer nodeClient.Close()
2203+
2204+
if options.OnChildAuthenticate != nil {
2205+
if err := options.OnChildAuthenticate(); err != nil {
2206+
return trace.Wrap(err)
2207+
}
2208+
}
21812209
// If forwarding ports were specified, start port forwarding.
21822210
if err := tc.startPortForwarding(ctx, nodeClient); err != nil {
21832211
return trace.Wrap(err)
@@ -2209,11 +2237,11 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt
22092237

22102238
// After port forwarding, run a local command that uses the connection, and
22112239
// then disconnect.
2212-
if commandExecutor != nil {
2240+
if options.LocalCommandExecutor != nil {
22132241
if len(tc.Config.LocalForwardPorts) == 0 {
22142242
fmt.Println("Executing command locally without connecting to any servers. This makes no sense.")
22152243
}
2216-
return commandExecutor(tc.Config.HostLogin, command)
2244+
return options.LocalCommandExecutor(tc.Config.HostLogin, command)
22172245
}
22182246

22192247
if len(command) > 0 {
@@ -2248,7 +2276,7 @@ func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context,
22482276

22492277
// Issue "shell" request to the first matching node.
22502278
fmt.Printf("\x1b[1mWARNING\x1b[0m: Multiple nodes match the label selector, picking first: %q\n", nodeAddrs[0])
2251-
return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0], nil, nil)
2279+
return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0], nil, SSHOptions{})
22522280
}
22532281

22542282
func (tc *TeleportClient) startPortForwarding(ctx context.Context, nodeClient *NodeClient) error {

lib/client/reexec/reexec.go

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
// Teleport
2+
// Copyright (C) 2025 Gravitational, Inc.
3+
//
4+
// This program is free software: you can redistribute it and/or modify
5+
// it under the terms of the GNU Affero General Public License as published by
6+
// the Free Software Foundation, either version 3 of the License, or
7+
// (at your option) any later version.
8+
//
9+
// This program is distributed in the hope that it will be useful,
10+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
// GNU Affero General Public License for more details.
13+
//
14+
// You should have received a copy of the GNU Affero General Public License
15+
// along with this program. If not, see <http://www.gnu.org/licenses/>.
16+
17+
package reexec
18+
19+
import (
20+
"context"
21+
"errors"
22+
"io"
23+
"os"
24+
"os/exec"
25+
"strings"
26+
27+
"github.com/gravitational/trace"
28+
)
29+
30+
// NotifyFileSignal signals on the returned channel when the provided file
31+
// receives a signal (a one-byte read).
32+
func NotifyFileSignal(f *os.File) <-chan error {
33+
errorCh := make(chan error, 1)
34+
go func() {
35+
n, err := f.Read(make([]byte, 1))
36+
if n > 0 {
37+
errorCh <- nil
38+
} else if err == nil {
39+
// this should be impossible according to the io.Reader contract
40+
errorCh <- io.ErrUnexpectedEOF
41+
} else {
42+
errorCh <- err
43+
}
44+
}()
45+
return errorCh
46+
}
47+
48+
// SignalAndClose writes a byte to the provided file (to signal a caller of
49+
// NotifyFileSignal) and closes it.
50+
func SignalAndClose(f *os.File) error {
51+
_, err := f.Write([]byte{0x00})
52+
return trace.NewAggregate(err, f.Close())
53+
}
54+
55+
// ForkAuthenticateParams are the parameters to RunForkAuthenticate.
56+
type ForkAuthenticateParams struct {
57+
// GetArgs gets the arguments to re-exec with, excluding the executable
58+
// (equivalent to os.Args[1:]).
59+
GetArgs func(signalFd, killFd uint64) []string
60+
// executable is the executable to run while re-execing. Overridden in tests.
61+
executable string
62+
// Stdin is the child process' stdin.
63+
Stdin io.Reader
64+
// Stdout is the child process' stdout.
65+
Stdout io.Writer
66+
// Stderr is the child process' stderr.
67+
Stderr io.Writer
68+
}
69+
70+
// RunForkAuthenticate re-execs the current executable and waits for any of
71+
// the following:
72+
// - The child process exits (usually in error).
73+
// - The child process signals the parent that it is ready to be disowned.
74+
// - The context is canceled.
75+
func RunForkAuthenticate(ctx context.Context, params ForkAuthenticateParams) error {
76+
if params.executable == "" {
77+
executable, err := getExecutable()
78+
if err != nil {
79+
return trace.Wrap(err)
80+
}
81+
params.executable = executable
82+
}
83+
cmd := exec.Command(params.executable)
84+
// Set up signal pipes.
85+
disownR, disownW, err := os.Pipe()
86+
if err != nil {
87+
return trace.Wrap(err)
88+
}
89+
killR, killW, err := os.Pipe()
90+
if err != nil {
91+
return trace.Wrap(err)
92+
}
93+
defer func() {
94+
// If the child is still listening, kill it. If the child successfully
95+
// disowned, this will do nothing.
96+
SignalAndClose(killW)
97+
killR.Close()
98+
disownW.Close()
99+
disownR.Close()
100+
}()
101+
102+
signalFd, killFd := configureReexecForOS(cmd, disownW, killR)
103+
cmd.Args = append(cmd.Args, params.GetArgs(signalFd, killFd)...)
104+
cmd.Args[0] = os.Args[0]
105+
cmd.Stdin = params.Stdin
106+
cmd.Stdout = params.Stdout
107+
cmd.Stderr = params.Stderr
108+
109+
if err := cmd.Start(); err != nil {
110+
return trace.Wrap(err)
111+
}
112+
113+
// Clean up parent end of pipes.
114+
if err := disownW.Close(); err != nil {
115+
return trace.NewAggregate(err, killAndWaitProcess(cmd))
116+
}
117+
if err := killR.Close(); err != nil {
118+
return trace.NewAggregate(err, killAndWaitProcess(cmd))
119+
}
120+
121+
select {
122+
case err := <-NotifyFileSignal(disownR):
123+
if err == nil {
124+
return trace.Wrap(cmd.Process.Release())
125+
} else if errors.Is(err, io.EOF) {
126+
// EOF means the child process exited, no need to report it on top of kill/wait.
127+
return trace.Wrap(killAndWaitProcess(cmd))
128+
}
129+
return trace.NewAggregate(err, killAndWaitProcess(cmd))
130+
case <-ctx.Done():
131+
return trace.NewAggregate(ctx.Err(), killAndWaitProcess(cmd))
132+
}
133+
}
134+
135+
func killAndWaitProcess(cmd *exec.Cmd) error {
136+
if err := cmd.Process.Kill(); err != nil {
137+
return trace.Wrap(err)
138+
}
139+
err := cmd.Wait()
140+
var execErr *exec.ExitError
141+
if errors.As(err, &execErr) && execErr.ExitCode() != 0 {
142+
return trace.Wrap(err)
143+
} else if err != nil && strings.Contains(err.Error(), "signal: killed") {
144+
// If the process was successfully killed, there is no issue.
145+
return nil
146+
}
147+
return trace.Wrap(err)
148+
}

lib/client/reexec/reexec_darwin.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Teleport
2+
// Copyright (C) 2025 Gravitational, Inc.
3+
//
4+
// This program is free software: you can redistribute it and/or modify
5+
// it under the terms of the GNU Affero General Public License as published by
6+
// the Free Software Foundation, either version 3 of the License, or
7+
// (at your option) any later version.
8+
//
9+
// This program is distributed in the hope that it will be useful,
10+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
// GNU Affero General Public License for more details.
13+
//
14+
// You should have received a copy of the GNU Affero General Public License
15+
// along with this program. If not, see <http://www.gnu.org/licenses/>.
16+
17+
//go:build darwin
18+
19+
package reexec
20+
21+
import (
22+
"os"
23+
"os/exec"
24+
"syscall"
25+
26+
"github.com/gravitational/trace"
27+
)
28+
29+
// getExecutable gets the path to the executable that should be used for re-exec.
30+
func getExecutable() (string, error) {
31+
executable, err := os.Executable()
32+
return executable, trace.Wrap(err)
33+
}
34+
35+
// configureReexecForOS configures the command with files to inherit and
36+
// os-specific tweaks.
37+
func configureReexecForOS(cmd *exec.Cmd, signal, kill *os.File) (signalFd, killFd uint64) {
38+
cmd.SysProcAttr = &syscall.SysProcAttr{
39+
Setsid: true,
40+
}
41+
cmd.ExtraFiles = []*os.File{signal, kill}
42+
return 3, 4
43+
}

lib/client/reexec/reexec_linux.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Teleport
2+
// Copyright (C) 2025 Gravitational, Inc.
3+
//
4+
// This program is free software: you can redistribute it and/or modify
5+
// it under the terms of the GNU Affero General Public License as published by
6+
// the Free Software Foundation, either version 3 of the License, or
7+
// (at your option) any later version.
8+
//
9+
// This program is distributed in the hope that it will be useful,
10+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
// GNU Affero General Public License for more details.
13+
//
14+
// You should have received a copy of the GNU Affero General Public License
15+
// along with this program. If not, see <http://www.gnu.org/licenses/>.
16+
17+
//go:build linux
18+
19+
package reexec
20+
21+
import (
22+
"os"
23+
"os/exec"
24+
"syscall"
25+
)
26+
27+
// getExecutable gets the path to the executable that should be used for re-exec.
28+
func getExecutable() (string, error) {
29+
return "/proc/self/exe", nil
30+
}
31+
32+
// configureReexecForOS configures the command with files to inherit and
33+
// os-specific tweaks.
34+
func configureReexecForOS(cmd *exec.Cmd, signal, kill *os.File) (signalFd, killFd uint64) {
35+
cmd.SysProcAttr = &syscall.SysProcAttr{
36+
Setsid: true,
37+
}
38+
cmd.ExtraFiles = []*os.File{signal, kill}
39+
return 3, 4
40+
}

0 commit comments

Comments
 (0)