Skip to content

feat(streaming): add streaming proxying #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
# yamllint disable-line rule:line-length rule:comments
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
- name: Run tests
run: make test
Expand All @@ -21,6 +22,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
# yamllint disable-line rule:line-length rule:comments
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
- name: Run linter
run: make lint
Expand All @@ -29,6 +31,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
# yamllint disable-line rule:line-length rule:comments
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
- name: Check go mod tidy
run: make tidy-ci
Expand All @@ -37,6 +40,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
# yamllint disable-line rule:line-length rule:comments
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
- name: Build binary
run: make docker-build
4 changes: 4 additions & 0 deletions .github/workflows/container.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ jobs:

steps:
- name: Checkout repository
# yamllint disable-line rule:line-length rule:comments
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4

- name: Log in to Container Registry
# yamllint disable-line rule:line-length rule:comments
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3
with:
registry: ghcr.io
Expand All @@ -28,6 +30,7 @@ jobs:

- name: Extract metadata
id: meta
# yamllint disable-line rule:line-length rule:comments
uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 # v5
with:
images: ghcr.io/jkoelker/schwab-proxy
Expand All @@ -37,6 +40,7 @@ jobs:
type=semver,pattern={{version}}
- name: Build and push Docker image
# yamllint disable-line rule:line-length rule:comments
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # v6
with:
context: .
Expand Down
1 change: 1 addition & 0 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ linters:
- go.opentelemetry.io/otel/metric.Meter
- go.opentelemetry.io/otel/trace.Span
- go.opentelemetry.io/otel/trace.Tracer
- net.Conn

spancheck:
extra-start-span-signatures:
Expand Down
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ COPY health/ ./health/
COPY kdf/ ./kdf/
COPY log/ ./log/
COPY metrics/ ./metrics/
COPY middleware/ ./middleware/
COPY observability/ ./observability/
COPY proxy/ ./proxy/
COPY storage/ ./storage/
COPY streaming/ ./streaming/
COPY tls/ ./tls/
COPY tracing/ ./tracing/

Expand Down
2 changes: 1 addition & 1 deletion cmd/schwab-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func handleShutdown(ctx context.Context, server *http.Server, apiProxy *proxy.AP
defer cancel()

// Shutdown background services first
apiProxy.Shutdown()
apiProxy.Shutdown(shutdownCtx)

// Shutdown server
if err := server.Shutdown(shutdownCtx); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ require (
github.com/dgraph-io/badger/v4 v4.8.0
github.com/fsnotify/fsnotify v1.9.0
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.0
github.com/ory/fosite v0.49.0
github.com/prometheus/client_golang v1.22.0
github.com/stretchr/testify v1.10.0
Expand Down Expand Up @@ -45,7 +46,6 @@ require (
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/flatbuffers v25.2.10+incompatible // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/gorilla/websocket v1.5.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/go-retryablehttp v0.7.7 // indirect
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ github.com/cristalhq/jwt/v4 v4.0.2/go.mod h1:HnYraSNKDRag1DZP92rYHyrjyQHnVEHPNqe
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgraph-io/badger/v4 v4.7.0 h1:Q+J8HApYAY7UMpL8d9owqiB+odzEc0zn/aqOD9jhc6Y=
github.com/dgraph-io/badger/v4 v4.7.0/go.mod h1:He7TzG3YBy3j4f5baj5B7Zl2XyfNe5bl4Udl0aPemVA=
github.com/dgraph-io/badger/v4 v4.8.0 h1:JYph1ChBijCw8SLeybvPINizbDKWZ5n/GYbz2yhN/bs=
github.com/dgraph-io/badger/v4 v4.8.0/go.mod h1:U6on6e8k/RTbUWxqKR0MvugJuVmkxSNc79ap4917h4w=
github.com/dgraph-io/ristretto v1.0.0 h1:SYG07bONKMlFDUYu5pEu3DGAh8c2OFNzKm6G9J4Si84=
Expand Down
17 changes: 16 additions & 1 deletion log/middleware.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
package log

import (
"bufio"
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"net/http"
"strings"

"github.com/jkoelker/schwab-proxy/middleware"
)

const (
Expand Down Expand Up @@ -90,7 +95,7 @@ func LoggingMiddleware(next http.Handler, opts ...func(*LoggingOptions)) http.Ha
Log(ctx, level, "HTTP request started",
"method", request.Method,
"path", request.URL.Path,
"remote_addr", request.RemoteAddr,
"remote_addr", middleware.GetRealIP(request),
"user_agent", request.Header.Get("User-Agent"),
)

Expand Down Expand Up @@ -121,6 +126,16 @@ func (rw *responseWriter) WriteHeader(code int) {
rw.ResponseWriter.WriteHeader(code)
}

// Hijack implements the http.Hijacker interface to support WebSocket upgrades.
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
conn, buf, err := middleware.HijackConnection(rw.ResponseWriter, &rw.statusCode)
if err != nil {
return nil, nil, fmt.Errorf("error hijacking connection: %w", err)
}

return conn, buf, nil
}

// generateCorrelationID creates a new random correlation ID.
func generateCorrelationID() string {
bytes := make([]byte, correlationIDByteLength)
Expand Down
71 changes: 71 additions & 0 deletions middleware/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package middleware

import (
"bufio"
"errors"
"fmt"
"net"
"net/http"
"strings"
)

// ErrHijackNotSupported is returned when the ResponseWriter does not support hijacking.
var ErrHijackNotSupported = errors.New("ResponseWriter does not support hijacking")

// HijackConnection hijacks the underlying connection from a ResponseWriter for WebSocket upgrades.
// It sets statusCode to 101 (Switching Protocols) and returns the hijacked connection.
func HijackConnection(w http.ResponseWriter, statusCode *int) (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := w.(http.Hijacker)
if !ok {
return nil, nil, ErrHijackNotSupported
}

// Set status code to 101 for WebSocket upgrades
if statusCode != nil {
*statusCode = http.StatusSwitchingProtocols
}

conn, buf, err := hijacker.Hijack()
if err != nil {
return nil, nil, fmt.Errorf("failed to hijack connection: %w", err)
}

return conn, buf, nil
}

// GetRealIP extracts the real client IP from the request, checking various headers.
func GetRealIP(req *http.Request) string {
// Check X-Real-IP header first (single IP)
if ip := req.Header.Get("X-Real-IP"); ip != "" {
return ip
}

// Check X-Forwarded-For header (comma-separated list, first is original client)
if xff := req.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the list
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}

return strings.TrimSpace(xff)
}

// Check CF-Connecting-IP for Cloudflare
if ip := req.Header.Get("Cf-Connecting-Ip"); ip != "" {
return ip
}

// Check True-Client-IP for Cloudflare Enterprise
if ip := req.Header.Get("True-Client-Ip"); ip != "" {
return ip
}

// Fall back to RemoteAddr
host, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
// If splitting fails, return the whole RemoteAddr
return req.RemoteAddr
}

return host
}
92 changes: 92 additions & 0 deletions middleware/utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package middleware_test

import (
"net/http"
"testing"

"github.com/jkoelker/schwab-proxy/middleware"
)

func TestGetRealIP(t *testing.T) {
t.Parallel()

tests := []struct {
name string
headers map[string]string
remoteAddr string
expected string
}{
{
name: "X-Real-IP present",
headers: map[string]string{"X-Real-IP": "192.168.1.100"},
remoteAddr: "10.0.0.1:12345",
expected: "192.168.1.100",
},
{
name: "X-Forwarded-For single IP",
headers: map[string]string{"X-Forwarded-For": "192.168.1.100"},
remoteAddr: "10.0.0.1:12345",
expected: "192.168.1.100",
},
{
name: "X-Forwarded-For multiple IPs",
headers: map[string]string{"X-Forwarded-For": "192.168.1.100, 10.0.0.2, 10.0.0.3"},
remoteAddr: "10.0.0.1:12345",
expected: "192.168.1.100",
},
{
name: "CF-Connecting-IP present",
headers: map[string]string{"Cf-Connecting-Ip": "192.168.1.100"},
remoteAddr: "10.0.0.1:12345",
expected: "192.168.1.100",
},
{
name: "True-Client-IP present",
headers: map[string]string{"True-Client-Ip": "192.168.1.100"},
remoteAddr: "10.0.0.1:12345",
expected: "192.168.1.100",
},
{
name: "No headers - RemoteAddr with port",
headers: map[string]string{},
remoteAddr: "192.168.1.100:12345",
expected: "192.168.1.100",
},
{
name: "No headers - RemoteAddr without port",
headers: map[string]string{},
remoteAddr: "192.168.1.100",
expected: "192.168.1.100",
},
{
name: "Priority: X-Real-IP over others",
headers: map[string]string{
"X-Real-IP": "192.168.1.100",
"X-Forwarded-For": "10.0.0.100",
"Cf-Connecting-Ip": "172.16.0.100",
},
remoteAddr: "10.0.0.1:12345",
expected: "192.168.1.100",
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()

req := &http.Request{
Header: http.Header{},
RemoteAddr: test.remoteAddr,
}

for key, value := range test.headers {
req.Header.Set(key, value)
}

got := middleware.GetRealIP(req)
if got != test.expected {
t.Errorf("GetRealIP() = %v, want %v", got, test.expected)
}
})
}
}
15 changes: 14 additions & 1 deletion observability/middleware.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package observability

import (
"bufio"
"fmt"
"net"
"net/http"
"strconv"
"time"

"github.com/jkoelker/schwab-proxy/metrics"
"github.com/jkoelker/schwab-proxy/middleware"
"github.com/jkoelker/schwab-proxy/tracing"
)

Expand Down Expand Up @@ -79,7 +82,7 @@ func TracingMiddleware(next http.Handler) http.Handler {
"http.scheme", request.URL.Scheme,
"http.host", request.Host,
"http.user_agent", request.Header.Get("User-Agent"),
"http.remote_addr", request.RemoteAddr,
"http.remote_addr", middleware.GetRealIP(request),
)

// Create a response writer wrapper to capture status code
Expand Down Expand Up @@ -140,6 +143,16 @@ func (rw *responseWriter) Write(data []byte) (int, error) {
return bytesWritten, nil
}

// Hijack implements the http.Hijacker interface to support WebSocket upgrades.
func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
conn, buf, err := middleware.HijackConnection(rw.ResponseWriter, &rw.statusCode)
if err != nil {
return nil, nil, fmt.Errorf("error hijacking connection: %w", err)
}

return conn, buf, nil
}

// httpError represents an HTTP error for tracing.
type httpError struct {
statusCode int
Expand Down
2 changes: 1 addition & 1 deletion proxy/approvals.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (p *APIProxy) handleApproveRequest(writer http.ResponseWriter, request *htt
// Delete the approval request
if err := p.storage.DeleteApprovalRequest(approvalID); err != nil {
// Log but don't fail the request
p.logger.Warn("failed to delete approval request", "id", approvalID, "error", err)
log.Warn(ctx, "failed to delete approval request", "id", approvalID, "error", err)
}

// Build the complete redirect URL
Expand Down
Loading
Loading