17
17
18
18
import com .hierynomus .sshj .test .HttpServer ;
19
19
import com .hierynomus .sshj .test .SshServerExtension ;
20
- import com .hierynomus .sshj .test .util .FileUtil ;
21
20
import net .schmizz .sshj .SSHClient ;
22
21
import net .schmizz .sshj .connection .ConnectionException ;
23
22
import net .schmizz .sshj .connection .channel .forwarded .RemotePortForwarder ;
27
26
import org .junit .jupiter .api .Test ;
28
27
import org .junit .jupiter .api .extension .RegisterExtension ;
29
28
30
- import java .io .File ;
31
29
import java .io .IOException ;
32
30
import java .net .HttpURLConnection ;
33
31
import java .net .InetSocketAddress ;
32
+ import java .net .URI ;
34
33
import java .net .URL ;
35
- import java .nio .file .Files ;
36
34
37
35
import static org .junit .jupiter .api .Assertions .assertEquals ;
38
36
39
37
public class RemotePortForwarderTest {
40
38
private static final PortRange RANGE = new PortRange (9000 , 9999 );
41
39
private static final String LOCALHOST = "127.0.0.1" ;
42
- private static final String LOCALHOST_URL_FORMAT = "http://127.0.0.1:%d" ;
43
- private static final InetSocketAddress HTTP_SERVER_SOCKET_ADDR = new InetSocketAddress (LOCALHOST , 8080 );
40
+ private static final String URL_FORMAT = "http://%s:%d" ;
44
41
45
42
@ RegisterExtension
46
43
public SshServerExtension fixture = new SshServerExtension ();
@@ -49,21 +46,21 @@ public class RemotePortForwarderTest {
49
46
public HttpServer httpServer = new HttpServer ();
50
47
51
48
@ BeforeEach
52
- public void setUp () throws IOException {
49
+ public void setUp () {
53
50
fixture .getServer ().setForwardingFilter (new AcceptAllForwardingFilter ());
54
- File file = Files .createFile (httpServer .getDocRoot ().toPath ().resolve ("index.html" )).toFile ();
55
- FileUtil .writeToFile (file , "<html><head/><body><h1>Hi!</h1></body></html>" );
56
51
}
57
52
58
53
@ Test
59
54
public void shouldHaveWorkingHttpServer () throws IOException {
60
- assertEquals (200 , httpGet (8080 ));
55
+ final URI serverUrl = httpServer .getServerUrl ();
56
+
57
+ assertEquals (HttpURLConnection .HTTP_NOT_FOUND , httpGet (serverUrl .getHost (), serverUrl .getPort ()));
61
58
}
62
59
63
60
@ Test
64
61
public void shouldDynamicallyForwardPortForLocalhost () throws IOException {
65
62
SSHClient sshClient = getFixtureClient ();
66
- RemotePortForwarder .Forward bind = forwardPort (sshClient , "127.0.0.1" , new SinglePort (0 ));
63
+ RemotePortForwarder .Forward bind = forwardPort (sshClient , LOCALHOST , new SinglePort (0 ));
67
64
assertHttpGetSuccess (bind );
68
65
}
69
66
@@ -84,7 +81,7 @@ public void shouldDynamicallyForwardPortForAllProtocols() throws IOException {
84
81
@ Test
85
82
public void shouldForwardPortForLocalhost () throws IOException {
86
83
SSHClient sshClient = getFixtureClient ();
87
- RemotePortForwarder .Forward bind = forwardPort (sshClient , "127.0.0.1" , RANGE );
84
+ RemotePortForwarder .Forward bind = forwardPort (sshClient , LOCALHOST , RANGE );
88
85
assertHttpGetSuccess (bind );
89
86
}
90
87
@@ -103,17 +100,22 @@ public void shouldForwardPortForAllProtocols() throws IOException {
103
100
}
104
101
105
102
private void assertHttpGetSuccess (final RemotePortForwarder .Forward bind ) throws IOException {
106
- assertEquals (200 , httpGet (bind .getPort ()));
103
+ final String bindAddress = bind .getAddress ();
104
+ final String address = bindAddress .isEmpty () ? LOCALHOST : bindAddress ;
105
+ final int port = bind .getPort ();
106
+ assertEquals (HttpURLConnection .HTTP_NOT_FOUND , httpGet (address , port ));
107
107
}
108
108
109
109
private RemotePortForwarder .Forward forwardPort (SSHClient sshClient , String address , PortRange portRange ) throws IOException {
110
110
while (true ) {
111
+ final URI serverUrl = httpServer .getServerUrl ();
112
+ final InetSocketAddress serverAddress = new InetSocketAddress (serverUrl .getHost (), serverUrl .getPort ());
111
113
try {
112
114
return sshClient .getRemotePortForwarder ().bind (
113
115
// where the server should listen
114
116
new RemotePortForwarder .Forward (address , portRange .nextPort ()),
115
117
// what we do with incoming connections that are forwarded to us
116
- new SocketForwardingConnectListener (HTTP_SERVER_SOCKET_ADDR ));
118
+ new SocketForwardingConnectListener (serverAddress ));
117
119
} catch (ConnectionException ce ) {
118
120
if (!portRange .hasNext ()) {
119
121
throw ce ;
@@ -122,8 +124,8 @@ private RemotePortForwarder.Forward forwardPort(SSHClient sshClient, String addr
122
124
}
123
125
}
124
126
125
- private int httpGet (final int port ) throws IOException {
126
- final URL url = new URL (String .format (LOCALHOST_URL_FORMAT , port ));
127
+ private int httpGet (final String address , final int port ) throws IOException {
128
+ final URL url = new URL (String .format (URL_FORMAT , address , port ));
127
129
final HttpURLConnection urlConnection = (HttpURLConnection ) url .openConnection ();
128
130
urlConnection .setConnectTimeout (3000 );
129
131
urlConnection .setRequestMethod ("GET" );
0 commit comments