Skip to content

Commit f72f114

Browse files
committed
feat(streaming): add streaming proxying
1 parent abc9a49 commit f72f114

File tree

23 files changed

+2412
-46
lines changed

23 files changed

+2412
-46
lines changed

.github/workflows/ci.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ jobs:
1313
runs-on: ubuntu-latest
1414
steps:
1515
- name: Checkout repository
16+
# yamllint disable-line rule:line-length rule:comments
1617
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
1718
- name: Run tests
1819
run: make test
@@ -21,6 +22,7 @@ jobs:
2122
runs-on: ubuntu-latest
2223
steps:
2324
- name: Checkout repository
25+
# yamllint disable-line rule:line-length rule:comments
2426
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
2527
- name: Run linter
2628
run: make lint
@@ -29,6 +31,7 @@ jobs:
2931
runs-on: ubuntu-latest
3032
steps:
3133
- name: Checkout repository
34+
# yamllint disable-line rule:line-length rule:comments
3235
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
3336
- name: Check go mod tidy
3437
run: make tidy-ci
@@ -37,6 +40,7 @@ jobs:
3740
runs-on: ubuntu-latest
3841
steps:
3942
- name: Checkout repository
43+
# yamllint disable-line rule:line-length rule:comments
4044
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
4145
- name: Build binary
4246
run: make docker-build

.github/workflows/container.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ jobs:
1717

1818
steps:
1919
- name: Checkout repository
20+
# yamllint disable-line rule:line-length rule:comments
2021
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
2122

2223
- name: Log in to Container Registry
24+
# yamllint disable-line rule:line-length rule:comments
2325
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3
2426
with:
2527
registry: ghcr.io
@@ -28,6 +30,7 @@ jobs:
2830

2931
- name: Extract metadata
3032
id: meta
33+
# yamllint disable-line rule:line-length rule:comments
3134
uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 # v5
3235
with:
3336
images: ghcr.io/jkoelker/schwab-proxy
@@ -37,6 +40,7 @@ jobs:
3740
type=semver,pattern={{version}}
3841
3942
- name: Build and push Docker image
43+
# yamllint disable-line rule:line-length rule:comments
4044
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # v6
4145
with:
4246
context: .

.golangci.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ linters:
3333
- go.opentelemetry.io/otel/metric.Meter
3434
- go.opentelemetry.io/otel/trace.Span
3535
- go.opentelemetry.io/otel/trace.Tracer
36+
- net.Conn
3637

3738
spancheck:
3839
extra-start-span-signatures:

Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ COPY health/ ./health/
2424
COPY kdf/ ./kdf/
2525
COPY log/ ./log/
2626
COPY metrics/ ./metrics/
27+
COPY middleware/ ./middleware/
2728
COPY observability/ ./observability/
2829
COPY proxy/ ./proxy/
2930
COPY storage/ ./storage/
31+
COPY streaming/ ./streaming/
3032
COPY tls/ ./tls/
3133
COPY tracing/ ./tracing/
3234

cmd/schwab-proxy/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ func handleShutdown(ctx context.Context, server *http.Server, apiProxy *proxy.AP
261261
defer cancel()
262262

263263
// Shutdown background services first
264-
apiProxy.Shutdown()
264+
apiProxy.Shutdown(shutdownCtx)
265265

266266
// Shutdown server
267267
if err := server.Shutdown(shutdownCtx); err != nil {

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ require (
99
github.com/dgraph-io/badger/v4 v4.8.0
1010
github.com/fsnotify/fsnotify v1.9.0
1111
github.com/google/uuid v1.6.0
12+
github.com/gorilla/websocket v1.5.0
1213
github.com/ory/fosite v0.49.0
1314
github.com/prometheus/client_golang v1.22.0
1415
github.com/stretchr/testify v1.10.0
@@ -45,7 +46,6 @@ require (
4546
github.com/golang/protobuf v1.5.3 // indirect
4647
github.com/google/flatbuffers v25.2.10+incompatible // indirect
4748
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
48-
github.com/gorilla/websocket v1.5.0 // indirect
4949
github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 // indirect
5050
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
5151
github.com/hashicorp/go-retryablehttp v0.7.7 // indirect

go.sum

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ github.com/cristalhq/jwt/v4 v4.0.2/go.mod h1:HnYraSNKDRag1DZP92rYHyrjyQHnVEHPNqe
7171
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
7272
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
7373
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
74-
github.com/dgraph-io/badger/v4 v4.7.0 h1:Q+J8HApYAY7UMpL8d9owqiB+odzEc0zn/aqOD9jhc6Y=
75-
github.com/dgraph-io/badger/v4 v4.7.0/go.mod h1:He7TzG3YBy3j4f5baj5B7Zl2XyfNe5bl4Udl0aPemVA=
7674
github.com/dgraph-io/badger/v4 v4.8.0 h1:JYph1ChBijCw8SLeybvPINizbDKWZ5n/GYbz2yhN/bs=
7775
github.com/dgraph-io/badger/v4 v4.8.0/go.mod h1:U6on6e8k/RTbUWxqKR0MvugJuVmkxSNc79ap4917h4w=
7876
github.com/dgraph-io/ristretto v1.0.0 h1:SYG07bONKMlFDUYu5pEu3DGAh8c2OFNzKm6G9J4Si84=

log/middleware.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
package log
22

33
import (
4+
"bufio"
45
"context"
56
"crypto/rand"
67
"encoding/hex"
8+
"fmt"
9+
"net"
710
"net/http"
811
"strings"
12+
13+
"github.com/jkoelker/schwab-proxy/middleware"
914
)
1015

1116
const (
@@ -90,7 +95,7 @@ func LoggingMiddleware(next http.Handler, opts ...func(*LoggingOptions)) http.Ha
9095
Log(ctx, level, "HTTP request started",
9196
"method", request.Method,
9297
"path", request.URL.Path,
93-
"remote_addr", request.RemoteAddr,
98+
"remote_addr", middleware.GetRealIP(request),
9499
"user_agent", request.Header.Get("User-Agent"),
95100
)
96101

@@ -121,6 +126,16 @@ func (rw *responseWriter) WriteHeader(code int) {
121126
rw.ResponseWriter.WriteHeader(code)
122127
}
123128

129+
// Hijack implements the http.Hijacker interface to support WebSocket upgrades.
130+
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
131+
conn, buf, err := middleware.HijackConnection(rw.ResponseWriter, &rw.statusCode)
132+
if err != nil {
133+
return nil, nil, fmt.Errorf("error hijacking connection: %w", err)
134+
}
135+
136+
return conn, buf, nil
137+
}
138+
124139
// generateCorrelationID creates a new random correlation ID.
125140
func generateCorrelationID() string {
126141
bytes := make([]byte, correlationIDByteLength)

middleware/utils.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package middleware
2+
3+
import (
4+
"bufio"
5+
"errors"
6+
"fmt"
7+
"net"
8+
"net/http"
9+
"strings"
10+
)
11+
12+
// ErrHijackNotSupported is returned when the ResponseWriter does not support hijacking.
13+
var ErrHijackNotSupported = errors.New("ResponseWriter does not support hijacking")
14+
15+
// HijackConnection hijacks the underlying connection from a ResponseWriter for WebSocket upgrades.
16+
// It sets statusCode to 101 (Switching Protocols) and returns the hijacked connection.
17+
func HijackConnection(w http.ResponseWriter, statusCode *int) (net.Conn, *bufio.ReadWriter, error) {
18+
hijacker, ok := w.(http.Hijacker)
19+
if !ok {
20+
return nil, nil, ErrHijackNotSupported
21+
}
22+
23+
// Set status code to 101 for WebSocket upgrades
24+
if statusCode != nil {
25+
*statusCode = http.StatusSwitchingProtocols
26+
}
27+
28+
conn, buf, err := hijacker.Hijack()
29+
if err != nil {
30+
return nil, nil, fmt.Errorf("failed to hijack connection: %w", err)
31+
}
32+
33+
return conn, buf, nil
34+
}
35+
36+
// GetRealIP extracts the real client IP from the request, checking various headers.
37+
func GetRealIP(req *http.Request) string {
38+
// Check X-Real-IP header first (single IP)
39+
if ip := req.Header.Get("X-Real-IP"); ip != "" {
40+
return ip
41+
}
42+
43+
// Check X-Forwarded-For header (comma-separated list, first is original client)
44+
if xff := req.Header.Get("X-Forwarded-For"); xff != "" {
45+
// Take the first IP in the list
46+
if idx := strings.Index(xff, ","); idx != -1 {
47+
return strings.TrimSpace(xff[:idx])
48+
}
49+
50+
return strings.TrimSpace(xff)
51+
}
52+
53+
// Check CF-Connecting-IP for Cloudflare
54+
if ip := req.Header.Get("Cf-Connecting-Ip"); ip != "" {
55+
return ip
56+
}
57+
58+
// Check True-Client-IP for Cloudflare Enterprise
59+
if ip := req.Header.Get("True-Client-Ip"); ip != "" {
60+
return ip
61+
}
62+
63+
// Fall back to RemoteAddr
64+
host, _, err := net.SplitHostPort(req.RemoteAddr)
65+
if err != nil {
66+
// If splitting fails, return the whole RemoteAddr
67+
return req.RemoteAddr
68+
}
69+
70+
return host
71+
}

middleware/utils_test.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package middleware_test
2+
3+
import (
4+
"net/http"
5+
"testing"
6+
7+
"github.com/jkoelker/schwab-proxy/middleware"
8+
)
9+
10+
func TestGetRealIP(t *testing.T) {
11+
t.Parallel()
12+
13+
tests := []struct {
14+
name string
15+
headers map[string]string
16+
remoteAddr string
17+
expected string
18+
}{
19+
{
20+
name: "X-Real-IP present",
21+
headers: map[string]string{"X-Real-IP": "192.168.1.100"},
22+
remoteAddr: "10.0.0.1:12345",
23+
expected: "192.168.1.100",
24+
},
25+
{
26+
name: "X-Forwarded-For single IP",
27+
headers: map[string]string{"X-Forwarded-For": "192.168.1.100"},
28+
remoteAddr: "10.0.0.1:12345",
29+
expected: "192.168.1.100",
30+
},
31+
{
32+
name: "X-Forwarded-For multiple IPs",
33+
headers: map[string]string{"X-Forwarded-For": "192.168.1.100, 10.0.0.2, 10.0.0.3"},
34+
remoteAddr: "10.0.0.1:12345",
35+
expected: "192.168.1.100",
36+
},
37+
{
38+
name: "CF-Connecting-IP present",
39+
headers: map[string]string{"Cf-Connecting-Ip": "192.168.1.100"},
40+
remoteAddr: "10.0.0.1:12345",
41+
expected: "192.168.1.100",
42+
},
43+
{
44+
name: "True-Client-IP present",
45+
headers: map[string]string{"True-Client-Ip": "192.168.1.100"},
46+
remoteAddr: "10.0.0.1:12345",
47+
expected: "192.168.1.100",
48+
},
49+
{
50+
name: "No headers - RemoteAddr with port",
51+
headers: map[string]string{},
52+
remoteAddr: "192.168.1.100:12345",
53+
expected: "192.168.1.100",
54+
},
55+
{
56+
name: "No headers - RemoteAddr without port",
57+
headers: map[string]string{},
58+
remoteAddr: "192.168.1.100",
59+
expected: "192.168.1.100",
60+
},
61+
{
62+
name: "Priority: X-Real-IP over others",
63+
headers: map[string]string{
64+
"X-Real-IP": "192.168.1.100",
65+
"X-Forwarded-For": "10.0.0.100",
66+
"Cf-Connecting-Ip": "172.16.0.100",
67+
},
68+
remoteAddr: "10.0.0.1:12345",
69+
expected: "192.168.1.100",
70+
},
71+
}
72+
73+
for _, test := range tests {
74+
t.Run(test.name, func(t *testing.T) {
75+
t.Parallel()
76+
77+
req := &http.Request{
78+
Header: http.Header{},
79+
RemoteAddr: test.remoteAddr,
80+
}
81+
82+
for key, value := range test.headers {
83+
req.Header.Set(key, value)
84+
}
85+
86+
got := middleware.GetRealIP(req)
87+
if got != test.expected {
88+
t.Errorf("GetRealIP() = %v, want %v", got, test.expected)
89+
}
90+
})
91+
}
92+
}

0 commit comments

Comments
 (0)