Skip to content

Commit 57c909d

Browse files
authored
Add fork after authentication for tsh ssh (#54696)
This change adds fork after authentication support to tsh ssh.
1 parent de6aa5d commit 57c909d

File tree

10 files changed

+808
-12
lines changed

10 files changed

+808
-12
lines changed

lib/client/api.go

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1901,6 +1901,14 @@ type SSHOptions struct {
19011901
// machine. If provided, it will be used instead of establishing a connection
19021902
// to the target host and executing the command remotely.
19031903
LocalCommandExecutor func(string, []string) error
1904+
// OnChildAuthenticate is a function to run in the child process during
1905+
// --fork-after authentications. It runs after authentication completes
1906+
// but before the session begins.
1907+
OnChildAuthenticate func() error
1908+
}
1909+
1910+
func (opts SSHOptions) forkAfterAuthentication() bool {
1911+
return opts.OnChildAuthenticate != nil
19041912
}
19051913

19061914
// WithHostAddress returns a SSHOptions which overrides the
@@ -1919,6 +1927,15 @@ func WithLocalCommandExecutor(executor func(string, []string) error) func(*SSHOp
19191927
}
19201928
}
19211929

1930+
// WithForkAfterAuthentication indicates that tsh is currently reexec-ing
1931+
// for --fork-after-authentication. The given function is called after
1932+
// authentication is complete but before the session starts.
1933+
func WithForkAfterAuthentication(onAuthenticate func() error) func(*SSHOptions) {
1934+
return func(opt *SSHOptions) {
1935+
opt.OnChildAuthenticate = onAuthenticate
1936+
}
1937+
}
1938+
19221939
// SSH connects to a node and, if 'command' is specified, executes the command on it,
19231940
// otherwise runs interactive shell
19241941
//
@@ -1961,9 +1978,14 @@ func (tc *TeleportClient) SSH(ctx context.Context, command []string, opts ...fun
19611978
}
19621979

19631980
if len(nodeAddrs) > 1 {
1981+
if options.forkAfterAuthentication() {
1982+
return &NonRetryableError{
1983+
Err: trace.BadParameter("fork after authentication not supported for commands on multiple nodes"),
1984+
}
1985+
}
19641986
return tc.runShellOrCommandOnMultipleNodes(ctx, clt, nodeAddrs, command)
19651987
}
1966-
return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0].Addr, command, options.LocalCommandExecutor)
1988+
return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0].Addr, command, options)
19671989
}
19681990

19691991
// ConnectToNode attempts to establish a connection to the node resolved to by the provided
@@ -2166,7 +2188,7 @@ func (tc *TeleportClient) connectToNodeWithMFA(ctx context.Context, clt *Cluster
21662188
return nodeClient, trace.Wrap(err)
21672189
}
21682190

2169-
func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt *ClusterClient, nodeAddr string, command []string, commandExecutor func(string, []string) error) error {
2191+
func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt *ClusterClient, nodeAddr string, command []string, options SSHOptions) error {
21702192
cluster := clt.ClusterName()
21712193
ctx, span := tc.Tracer.Start(
21722194
ctx,
@@ -2190,6 +2212,12 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt
21902212
return trace.Wrap(err)
21912213
}
21922214
defer nodeClient.Close()
2215+
2216+
if options.OnChildAuthenticate != nil {
2217+
if err := options.OnChildAuthenticate(); err != nil {
2218+
return trace.Wrap(err)
2219+
}
2220+
}
21932221
// If forwarding ports were specified, start port forwarding.
21942222
if err := tc.startPortForwarding(ctx, nodeClient); err != nil {
21952223
return trace.Wrap(err)
@@ -2221,11 +2249,11 @@ func (tc *TeleportClient) runShellOrCommandOnSingleNode(ctx context.Context, clt
22212249

22222250
// After port forwarding, run a local command that uses the connection, and
22232251
// then disconnect.
2224-
if commandExecutor != nil {
2252+
if options.LocalCommandExecutor != nil {
22252253
if len(tc.Config.LocalForwardPorts) == 0 {
22262254
fmt.Println("Executing command locally without connecting to any servers. This makes no sense.")
22272255
}
2228-
return commandExecutor(tc.Config.HostLogin, command)
2256+
return options.LocalCommandExecutor(tc.Config.HostLogin, command)
22292257
}
22302258

22312259
if len(command) > 0 {
@@ -2260,7 +2288,7 @@ func (tc *TeleportClient) runShellOrCommandOnMultipleNodes(ctx context.Context,
22602288

22612289
// Issue "shell" request to the first matching node.
22622290
fmt.Printf("\x1b[1mWARNING\x1b[0m: Multiple nodes match the label selector, picking first: %q\n", nodeAddrs[0])
2263-
return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0], nil, nil)
2291+
return tc.runShellOrCommandOnSingleNode(ctx, clt, nodeAddrs[0], nil, SSHOptions{})
22642292
}
22652293

22662294
func (tc *TeleportClient) startPortForwarding(ctx context.Context, nodeClient *NodeClient) error {

lib/client/reexec/reexec.go

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
// Teleport
2+
// Copyright (C) 2025 Gravitational, Inc.
3+
//
4+
// This program is free software: you can redistribute it and/or modify
5+
// it under the terms of the GNU Affero General Public License as published by
6+
// the Free Software Foundation, either version 3 of the License, or
7+
// (at your option) any later version.
8+
//
9+
// This program is distributed in the hope that it will be useful,
10+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
// GNU Affero General Public License for more details.
13+
//
14+
// You should have received a copy of the GNU Affero General Public License
15+
// along with this program. If not, see <http://www.gnu.org/licenses/>.
16+
17+
package reexec
18+
19+
import (
20+
"context"
21+
"errors"
22+
"io"
23+
"os"
24+
"os/exec"
25+
"strings"
26+
27+
"github.com/gravitational/trace"
28+
)
29+
30+
// NotifyFileSignal signals on the returned channel when the provided file
31+
// receives a signal (a one-byte read).
32+
func NotifyFileSignal(f *os.File) <-chan error {
33+
errorCh := make(chan error, 1)
34+
go func() {
35+
n, err := f.Read(make([]byte, 1))
36+
if n > 0 {
37+
errorCh <- nil
38+
} else if err == nil {
39+
// this should be impossible according to the io.Reader contract
40+
errorCh <- io.ErrUnexpectedEOF
41+
} else {
42+
errorCh <- err
43+
}
44+
}()
45+
return errorCh
46+
}
47+
48+
// SignalAndClose writes a byte to the provided file (to signal a caller of
49+
// NotifyFileSignal) and closes it.
50+
func SignalAndClose(f *os.File) error {
51+
_, err := f.Write([]byte{0x00})
52+
return trace.NewAggregate(err, f.Close())
53+
}
54+
55+
// ForkAuthenticateParams are the parameters to RunForkAuthenticate.
56+
type ForkAuthenticateParams struct {
57+
// GetArgs gets the arguments to re-exec with, excluding the executable
58+
// (equivalent to os.Args[1:]).
59+
GetArgs func(signalFd, killFd uint64) []string
60+
// executable is the executable to run while re-execing. Overridden in tests.
61+
executable string
62+
// Stdin is the child process' stdin.
63+
Stdin io.Reader
64+
// Stdout is the child process' stdout.
65+
Stdout io.Writer
66+
// Stderr is the child process' stderr.
67+
Stderr io.Writer
68+
}
69+
70+
// RunForkAuthenticate re-execs the current executable and waits for any of
71+
// the following:
72+
// - The child process exits (usually in error).
73+
// - The child process signals the parent that it is ready to be disowned.
74+
// - The context is canceled.
75+
func RunForkAuthenticate(ctx context.Context, params ForkAuthenticateParams) error {
76+
if params.executable == "" {
77+
executable, err := getExecutable()
78+
if err != nil {
79+
return trace.Wrap(err)
80+
}
81+
params.executable = executable
82+
}
83+
cmd := exec.Command(params.executable)
84+
// Set up signal pipes.
85+
disownR, disownW, err := os.Pipe()
86+
if err != nil {
87+
return trace.Wrap(err)
88+
}
89+
killR, killW, err := os.Pipe()
90+
if err != nil {
91+
return trace.Wrap(err)
92+
}
93+
defer func() {
94+
// If the child is still listening, kill it. If the child successfully
95+
// disowned, this will do nothing.
96+
SignalAndClose(killW)
97+
killR.Close()
98+
disownW.Close()
99+
disownR.Close()
100+
}()
101+
102+
signalFd, killFd := configureReexecForOS(cmd, disownW, killR)
103+
cmd.Args = append(cmd.Args, params.GetArgs(signalFd, killFd)...)
104+
cmd.Args[0] = os.Args[0]
105+
cmd.Stdin = params.Stdin
106+
cmd.Stdout = params.Stdout
107+
cmd.Stderr = params.Stderr
108+
109+
if err := cmd.Start(); err != nil {
110+
return trace.Wrap(err)
111+
}
112+
113+
// Clean up parent end of pipes.
114+
if err := disownW.Close(); err != nil {
115+
return trace.NewAggregate(err, killAndWaitProcess(cmd))
116+
}
117+
if err := killR.Close(); err != nil {
118+
return trace.NewAggregate(err, killAndWaitProcess(cmd))
119+
}
120+
121+
select {
122+
case err := <-NotifyFileSignal(disownR):
123+
if err == nil {
124+
return trace.Wrap(cmd.Process.Release())
125+
} else if errors.Is(err, io.EOF) {
126+
// EOF means the child process exited, no need to report it on top of kill/wait.
127+
return trace.Wrap(killAndWaitProcess(cmd))
128+
}
129+
return trace.NewAggregate(err, killAndWaitProcess(cmd))
130+
case <-ctx.Done():
131+
return trace.NewAggregate(ctx.Err(), killAndWaitProcess(cmd))
132+
}
133+
}
134+
135+
func killAndWaitProcess(cmd *exec.Cmd) error {
136+
if err := cmd.Process.Kill(); err != nil {
137+
return trace.Wrap(err)
138+
}
139+
err := cmd.Wait()
140+
var execErr *exec.ExitError
141+
if errors.As(err, &execErr) && execErr.ExitCode() != 0 {
142+
return trace.Wrap(err)
143+
} else if err != nil && strings.Contains(err.Error(), "signal: killed") {
144+
// If the process was successfully killed, there is no issue.
145+
return nil
146+
}
147+
return trace.Wrap(err)
148+
}

lib/client/reexec/reexec_darwin.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Teleport
2+
// Copyright (C) 2025 Gravitational, Inc.
3+
//
4+
// This program is free software: you can redistribute it and/or modify
5+
// it under the terms of the GNU Affero General Public License as published by
6+
// the Free Software Foundation, either version 3 of the License, or
7+
// (at your option) any later version.
8+
//
9+
// This program is distributed in the hope that it will be useful,
10+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
// GNU Affero General Public License for more details.
13+
//
14+
// You should have received a copy of the GNU Affero General Public License
15+
// along with this program. If not, see <http://www.gnu.org/licenses/>.
16+
17+
//go:build darwin
18+
19+
package reexec
20+
21+
import (
22+
"os"
23+
"os/exec"
24+
"syscall"
25+
26+
"github.com/gravitational/trace"
27+
)
28+
29+
// getExecutable gets the path to the executable that should be used for re-exec.
30+
func getExecutable() (string, error) {
31+
executable, err := os.Executable()
32+
return executable, trace.Wrap(err)
33+
}
34+
35+
// configureReexecForOS configures the command with files to inherit and
36+
// os-specific tweaks.
37+
func configureReexecForOS(cmd *exec.Cmd, signal, kill *os.File) (signalFd, killFd uint64) {
38+
cmd.SysProcAttr = &syscall.SysProcAttr{
39+
Setsid: true,
40+
}
41+
cmd.ExtraFiles = []*os.File{signal, kill}
42+
return 3, 4
43+
}

lib/client/reexec/reexec_linux.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Teleport
2+
// Copyright (C) 2025 Gravitational, Inc.
3+
//
4+
// This program is free software: you can redistribute it and/or modify
5+
// it under the terms of the GNU Affero General Public License as published by
6+
// the Free Software Foundation, either version 3 of the License, or
7+
// (at your option) any later version.
8+
//
9+
// This program is distributed in the hope that it will be useful,
10+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
// GNU Affero General Public License for more details.
13+
//
14+
// You should have received a copy of the GNU Affero General Public License
15+
// along with this program. If not, see <http://www.gnu.org/licenses/>.
16+
17+
//go:build linux
18+
19+
package reexec
20+
21+
import (
22+
"os"
23+
"os/exec"
24+
"syscall"
25+
)
26+
27+
// getExecutable gets the path to the executable that should be used for re-exec.
28+
func getExecutable() (string, error) {
29+
return "/proc/self/exe", nil
30+
}
31+
32+
// configureReexecForOS configures the command with files to inherit and
33+
// os-specific tweaks.
34+
func configureReexecForOS(cmd *exec.Cmd, signal, kill *os.File) (signalFd, killFd uint64) {
35+
cmd.SysProcAttr = &syscall.SysProcAttr{
36+
Setsid: true,
37+
}
38+
cmd.ExtraFiles = []*os.File{signal, kill}
39+
return 3, 4
40+
}

0 commit comments

Comments
 (0)