diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 47ec013..690c1b9 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -13,6 +13,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository + # yamllint disable-line rule:line-length rule:comments uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - name: Run tests run: make test @@ -21,6 +22,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository + # yamllint disable-line rule:line-length rule:comments uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - name: Run linter run: make lint @@ -29,6 +31,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository + # yamllint disable-line rule:line-length rule:comments uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - name: Check go mod tidy run: make tidy-ci @@ -37,6 +40,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository + # yamllint disable-line rule:line-length rule:comments uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - name: Build binary run: make docker-build diff --git a/.github/workflows/container.yaml b/.github/workflows/container.yaml index e0a112b..75b646c 100644 --- a/.github/workflows/container.yaml +++ b/.github/workflows/container.yaml @@ -17,9 +17,11 @@ jobs: steps: - name: Checkout repository + # yamllint disable-line rule:line-length rule:comments uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 - name: Log in to Container Registry + # yamllint disable-line rule:line-length rule:comments uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3 with: registry: ghcr.io @@ -28,6 +30,7 @@ jobs: - name: Extract metadata id: meta + # yamllint disable-line rule:line-length rule:comments uses: docker/metadata-action@902fa8ec7d6ecbf8d84d538b9b233a880e428804 # v5 with: images: ghcr.io/jkoelker/schwab-proxy @@ -37,6 +40,7 @@ jobs: type=semver,pattern={{version}} - name: Build and push Docker image + # yamllint disable-line rule:line-length rule:comments uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # v6 with: context: . diff --git a/.golangci.yaml b/.golangci.yaml index b5d3a0f..129984a 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -33,6 +33,7 @@ linters: - go.opentelemetry.io/otel/metric.Meter - go.opentelemetry.io/otel/trace.Span - go.opentelemetry.io/otel/trace.Tracer + - net.Conn spancheck: extra-start-span-signatures: diff --git a/Dockerfile b/Dockerfile index dcac0b2..ac6fabe 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,9 +24,11 @@ COPY health/ ./health/ COPY kdf/ ./kdf/ COPY log/ ./log/ COPY metrics/ ./metrics/ +COPY middleware/ ./middleware/ COPY observability/ ./observability/ COPY proxy/ ./proxy/ COPY storage/ ./storage/ +COPY streaming/ ./streaming/ COPY tls/ ./tls/ COPY tracing/ ./tracing/ diff --git a/cmd/schwab-proxy/main.go b/cmd/schwab-proxy/main.go index d3debdf..9b79dcf 100644 --- a/cmd/schwab-proxy/main.go +++ b/cmd/schwab-proxy/main.go @@ -261,7 +261,7 @@ func handleShutdown(ctx context.Context, server *http.Server, apiProxy *proxy.AP defer cancel() // Shutdown background services first - apiProxy.Shutdown() + apiProxy.Shutdown(shutdownCtx) // Shutdown server if err := server.Shutdown(shutdownCtx); err != nil { diff --git a/go.mod b/go.mod index a531a5f..4e9103d 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/dgraph-io/badger/v4 v4.8.0 github.com/fsnotify/fsnotify v1.9.0 github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.0 github.com/ory/fosite v0.49.0 github.com/prometheus/client_golang v1.22.0 github.com/stretchr/testify v1.10.0 @@ -45,7 +46,6 @@ require ( github.com/golang/protobuf v1.5.3 // indirect github.com/google/flatbuffers v25.2.10+incompatible // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect - github.com/gorilla/websocket v1.5.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-retryablehttp v0.7.7 // indirect diff --git a/go.sum b/go.sum index c892095..80653b5 100644 --- a/go.sum +++ b/go.sum @@ -71,8 +71,6 @@ github.com/cristalhq/jwt/v4 v4.0.2/go.mod h1:HnYraSNKDRag1DZP92rYHyrjyQHnVEHPNqe github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgraph-io/badger/v4 v4.7.0 h1:Q+J8HApYAY7UMpL8d9owqiB+odzEc0zn/aqOD9jhc6Y= -github.com/dgraph-io/badger/v4 v4.7.0/go.mod h1:He7TzG3YBy3j4f5baj5B7Zl2XyfNe5bl4Udl0aPemVA= github.com/dgraph-io/badger/v4 v4.8.0 h1:JYph1ChBijCw8SLeybvPINizbDKWZ5n/GYbz2yhN/bs= github.com/dgraph-io/badger/v4 v4.8.0/go.mod h1:U6on6e8k/RTbUWxqKR0MvugJuVmkxSNc79ap4917h4w= github.com/dgraph-io/ristretto v1.0.0 h1:SYG07bONKMlFDUYu5pEu3DGAh8c2OFNzKm6G9J4Si84= diff --git a/log/middleware.go b/log/middleware.go index 1d65937..c6dfb88 100644 --- a/log/middleware.go +++ b/log/middleware.go @@ -1,11 +1,16 @@ package log import ( + "bufio" "context" "crypto/rand" "encoding/hex" + "fmt" + "net" "net/http" "strings" + + "github.com/jkoelker/schwab-proxy/middleware" ) const ( @@ -90,7 +95,7 @@ func LoggingMiddleware(next http.Handler, opts ...func(*LoggingOptions)) http.Ha Log(ctx, level, "HTTP request started", "method", request.Method, "path", request.URL.Path, - "remote_addr", request.RemoteAddr, + "remote_addr", middleware.GetRealIP(request), "user_agent", request.Header.Get("User-Agent"), ) @@ -121,6 +126,16 @@ func (rw *responseWriter) WriteHeader(code int) { rw.ResponseWriter.WriteHeader(code) } +// Hijack implements the http.Hijacker interface to support WebSocket upgrades. +func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + conn, buf, err := middleware.HijackConnection(rw.ResponseWriter, &rw.statusCode) + if err != nil { + return nil, nil, fmt.Errorf("error hijacking connection: %w", err) + } + + return conn, buf, nil +} + // generateCorrelationID creates a new random correlation ID. func generateCorrelationID() string { bytes := make([]byte, correlationIDByteLength) diff --git a/middleware/utils.go b/middleware/utils.go new file mode 100644 index 0000000..0a20480 --- /dev/null +++ b/middleware/utils.go @@ -0,0 +1,71 @@ +package middleware + +import ( + "bufio" + "errors" + "fmt" + "net" + "net/http" + "strings" +) + +// ErrHijackNotSupported is returned when the ResponseWriter does not support hijacking. +var ErrHijackNotSupported = errors.New("ResponseWriter does not support hijacking") + +// HijackConnection hijacks the underlying connection from a ResponseWriter for WebSocket upgrades. +// It sets statusCode to 101 (Switching Protocols) and returns the hijacked connection. +func HijackConnection(w http.ResponseWriter, statusCode *int) (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := w.(http.Hijacker) + if !ok { + return nil, nil, ErrHijackNotSupported + } + + // Set status code to 101 for WebSocket upgrades + if statusCode != nil { + *statusCode = http.StatusSwitchingProtocols + } + + conn, buf, err := hijacker.Hijack() + if err != nil { + return nil, nil, fmt.Errorf("failed to hijack connection: %w", err) + } + + return conn, buf, nil +} + +// GetRealIP extracts the real client IP from the request, checking various headers. +func GetRealIP(req *http.Request) string { + // Check X-Real-IP header first (single IP) + if ip := req.Header.Get("X-Real-IP"); ip != "" { + return ip + } + + // Check X-Forwarded-For header (comma-separated list, first is original client) + if xff := req.Header.Get("X-Forwarded-For"); xff != "" { + // Take the first IP in the list + if idx := strings.Index(xff, ","); idx != -1 { + return strings.TrimSpace(xff[:idx]) + } + + return strings.TrimSpace(xff) + } + + // Check CF-Connecting-IP for Cloudflare + if ip := req.Header.Get("Cf-Connecting-Ip"); ip != "" { + return ip + } + + // Check True-Client-IP for Cloudflare Enterprise + if ip := req.Header.Get("True-Client-Ip"); ip != "" { + return ip + } + + // Fall back to RemoteAddr + host, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + // If splitting fails, return the whole RemoteAddr + return req.RemoteAddr + } + + return host +} diff --git a/middleware/utils_test.go b/middleware/utils_test.go new file mode 100644 index 0000000..3523f91 --- /dev/null +++ b/middleware/utils_test.go @@ -0,0 +1,92 @@ +package middleware_test + +import ( + "net/http" + "testing" + + "github.com/jkoelker/schwab-proxy/middleware" +) + +func TestGetRealIP(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + headers map[string]string + remoteAddr string + expected string + }{ + { + name: "X-Real-IP present", + headers: map[string]string{"X-Real-IP": "192.168.1.100"}, + remoteAddr: "10.0.0.1:12345", + expected: "192.168.1.100", + }, + { + name: "X-Forwarded-For single IP", + headers: map[string]string{"X-Forwarded-For": "192.168.1.100"}, + remoteAddr: "10.0.0.1:12345", + expected: "192.168.1.100", + }, + { + name: "X-Forwarded-For multiple IPs", + headers: map[string]string{"X-Forwarded-For": "192.168.1.100, 10.0.0.2, 10.0.0.3"}, + remoteAddr: "10.0.0.1:12345", + expected: "192.168.1.100", + }, + { + name: "CF-Connecting-IP present", + headers: map[string]string{"Cf-Connecting-Ip": "192.168.1.100"}, + remoteAddr: "10.0.0.1:12345", + expected: "192.168.1.100", + }, + { + name: "True-Client-IP present", + headers: map[string]string{"True-Client-Ip": "192.168.1.100"}, + remoteAddr: "10.0.0.1:12345", + expected: "192.168.1.100", + }, + { + name: "No headers - RemoteAddr with port", + headers: map[string]string{}, + remoteAddr: "192.168.1.100:12345", + expected: "192.168.1.100", + }, + { + name: "No headers - RemoteAddr without port", + headers: map[string]string{}, + remoteAddr: "192.168.1.100", + expected: "192.168.1.100", + }, + { + name: "Priority: X-Real-IP over others", + headers: map[string]string{ + "X-Real-IP": "192.168.1.100", + "X-Forwarded-For": "10.0.0.100", + "Cf-Connecting-Ip": "172.16.0.100", + }, + remoteAddr: "10.0.0.1:12345", + expected: "192.168.1.100", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + req := &http.Request{ + Header: http.Header{}, + RemoteAddr: test.remoteAddr, + } + + for key, value := range test.headers { + req.Header.Set(key, value) + } + + got := middleware.GetRealIP(req) + if got != test.expected { + t.Errorf("GetRealIP() = %v, want %v", got, test.expected) + } + }) + } +} diff --git a/observability/middleware.go b/observability/middleware.go index e2992e9..f8a7a4e 100644 --- a/observability/middleware.go +++ b/observability/middleware.go @@ -1,12 +1,15 @@ package observability import ( + "bufio" "fmt" + "net" "net/http" "strconv" "time" "github.com/jkoelker/schwab-proxy/metrics" + "github.com/jkoelker/schwab-proxy/middleware" "github.com/jkoelker/schwab-proxy/tracing" ) @@ -79,7 +82,7 @@ func TracingMiddleware(next http.Handler) http.Handler { "http.scheme", request.URL.Scheme, "http.host", request.Host, "http.user_agent", request.Header.Get("User-Agent"), - "http.remote_addr", request.RemoteAddr, + "http.remote_addr", middleware.GetRealIP(request), ) // Create a response writer wrapper to capture status code @@ -140,6 +143,16 @@ func (rw *responseWriter) Write(data []byte) (int, error) { return bytesWritten, nil } +// Hijack implements the http.Hijacker interface to support WebSocket upgrades. +func (rw *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + conn, buf, err := middleware.HijackConnection(rw.ResponseWriter, &rw.statusCode) + if err != nil { + return nil, nil, fmt.Errorf("error hijacking connection: %w", err) + } + + return conn, buf, nil +} + // httpError represents an HTTP error for tracing. type httpError struct { statusCode int diff --git a/proxy/approvals.go b/proxy/approvals.go index f4697c6..2513482 100644 --- a/proxy/approvals.go +++ b/proxy/approvals.go @@ -147,7 +147,7 @@ func (p *APIProxy) handleApproveRequest(writer http.ResponseWriter, request *htt // Delete the approval request if err := p.storage.DeleteApprovalRequest(approvalID); err != nil { // Log but don't fail the request - p.logger.Warn("failed to delete approval request", "id", approvalID, "error", err) + log.Warn(ctx, "failed to delete approval request", "id", approvalID, "error", err) } // Build the complete redirect URL diff --git a/proxy/interceptor.go b/proxy/interceptor.go new file mode 100644 index 0000000..e2d33e1 --- /dev/null +++ b/proxy/interceptor.go @@ -0,0 +1,227 @@ +package proxy + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/jkoelker/schwab-proxy/log" + "github.com/jkoelker/schwab-proxy/streaming" +) + +// ResponseData holds the necessary data from an HTTP response. +type ResponseData struct { + StatusCode int + Headers http.Header + Body []byte +} + +// forwardResponse forwards the response to the client. +func forwardResponse(ctx context.Context, writer http.ResponseWriter, resp *ResponseData) { + // Copy headers + for key, values := range resp.Headers { + for _, value := range values { + writer.Header().Add(key, value) + } + } + + writer.WriteHeader(resp.StatusCode) + + if _, err := writer.Write(resp.Body); err != nil { + log.Error(ctx, err, "Failed to write response") + } +} + +// fetchUserPreferences retrieves user preferences from Schwab API and returns the response data. +func (p *APIProxy) fetchUserPreferences(ctx context.Context, request *http.Request) (*ResponseData, error) { + // Read request body if any + var body io.Reader + if request.Body != nil { + body = request.Body + } + + // Forward to Schwab + endpoint := request.URL.Path + if request.URL.RawQuery != "" { + endpoint = endpoint + "?" + request.URL.RawQuery + } + + // Copy headers but remove Accept-Encoding to ensure we get uncompressed response + headers := make(http.Header) + + for key, values := range request.Header { + if !strings.EqualFold(key, "Accept-Encoding") { + headers[key] = values + } + } + + resp, err := p.schwabClient.Call(ctx, request.Method, endpoint, body, headers) + if err != nil { + return nil, fmt.Errorf("failed to call Schwab API: %w", err) + } + defer resp.Body.Close() + + // Read response body + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + return &ResponseData{ + StatusCode: resp.StatusCode, + Headers: resp.Header, + Body: respBody, + }, nil +} + +// processStreamingMetadata modifies the streaming URL to point to the proxy. +func (p *APIProxy) processStreamingMetadata(prefs *streaming.UserPreferencesResponse, request *http.Request) { + if len(prefs.StreamerInfo) == 0 { + return + } + + // Get the proxy's hostname from the request + scheme := "wss" + if request.TLS == nil { + scheme = "ws" + } + + // Build the proxy WebSocket URL + proxyURL := fmt.Sprintf("%s://%s/ws/stream", scheme, request.Host) + + // Update the URL in the response + for i := range prefs.StreamerInfo { + prefs.StreamerInfo[i].StreamerSocketURL = proxyURL + } + + log.Debug(request.Context(), "Modified streaming URL", + "original_url", "wss://streamer-api.schwab.com/ws", + "proxy_url", proxyURL, + "host", request.Host, + ) +} + +// modifyAndSendPreferences modifies and sends the user preferences response. +func (p *APIProxy) modifyAndSendPreferences( + writer http.ResponseWriter, + request *http.Request, + resp *ResponseData, +) { + // Parse response + var prefs streaming.UserPreferencesResponse + + if err := json.Unmarshal(resp.Body, &prefs); err != nil { + log.Error(request.Context(), err, "Failed to parse user preferences", + "body_length", len(resp.Body), + ) + + // Forward original response if parse fails + forwardResponse(request.Context(), writer, resp) + + return + } + + // Extract and store metadata if streaming info exists + p.processStreamingMetadata(&prefs, request) + + if len(prefs.StreamerInfo) > 0 { + log.Debug(request.Context(), "After processStreamingMetadata", + "modified_url", prefs.StreamerInfo[0].StreamerSocketURL, + ) + } + + // Re-encode with modifications + modified, err := json.Marshal(prefs) + if err != nil { + log.Error(request.Context(), err, "Failed to encode modified preferences") + // Forward original response if encode fails + forwardResponse(request.Context(), writer, resp) + + return + } + + // Send modified response + for key, values := range resp.Headers { + // Skip Content-Encoding and Content-Length since we're sending uncompressed modified JSON + if strings.EqualFold(key, "Content-Encoding") || strings.EqualFold(key, "Content-Length") { + continue + } + + for _, value := range values { + writer.Header().Add(key, value) + } + } + + writer.Header().Set("Content-Type", "application/json") + writer.WriteHeader(resp.StatusCode) + + if _, err := writer.Write(modified); err != nil { + log.Error(request.Context(), err, "Failed to write modified response") + } +} + +// handleUserPreferences intercepts and modifies the user preferences response. +func (p *APIProxy) handleUserPreferences(writer http.ResponseWriter, request *http.Request) { + log.Debug(request.Context(), "handleUserPreferences called", + "method", request.Method, + "path", request.URL.Path, + ) + + // Only intercept GET requests + if request.Method != http.MethodGet { + p.handleTraderRequest(writer, request) + + return + } + + // Fetch preferences from Schwab + resp, err := p.fetchUserPreferences(request.Context(), request) + if err != nil { + log.Error(request.Context(), err, "Failed to fetch user preferences") + http.Error(writer, "Failed to fetch preferences", http.StatusBadGateway) + + return + } + + // If not successful, forward as-is + if resp.StatusCode != http.StatusOK { + forwardResponse(request.Context(), writer, resp) + + return + } + + // Modify and send the preferences + p.modifyAndSendPreferences(writer, request, resp) +} + +// isUserPreferencesRequest checks if this is a user preferences request. +func isUserPreferencesRequest(path string) bool { + // Remove query parameters for comparison + if idx := strings.Index(path, "?"); idx > 0 { + path = path[:idx] + } + + // Check if it matches the user preferences endpoint + return strings.HasSuffix(path, "/userPreference") +} + +// interceptableTraderRequest wraps handleTraderRequest to intercept specific endpoints. +func (p *APIProxy) interceptableTraderRequest(writer http.ResponseWriter, request *http.Request) { + log.Debug(request.Context(), "interceptableTraderRequest called", + "path", request.URL.Path, + "is_user_preferences", isUserPreferencesRequest(request.URL.Path), + ) + + // Check if this is a user preferences request + if isUserPreferencesRequest(request.URL.Path) { + p.handleUserPreferences(writer, request) + + return + } + + // Otherwise, forward normally + p.handleTraderRequest(writer, request) +} diff --git a/proxy/proxy.go b/proxy/proxy.go index f1f47d1..ac2d991 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io" - "log/slog" "net/http" "strings" "time" @@ -17,6 +16,7 @@ import ( "github.com/jkoelker/schwab-proxy/log" "github.com/jkoelker/schwab-proxy/observability" "github.com/jkoelker/schwab-proxy/storage" + "github.com/jkoelker/schwab-proxy/streaming" ) // contextKey is a custom type for context keys to avoid collisions. @@ -42,34 +42,17 @@ type APIProxy struct { otelProviders *observability.OTelProviders server *auth.Server storage *storage.Store - logger *slog.Logger // Background refresh management refreshCancel context.CancelFunc -} - -// NewAPIProxy creates a new API proxy server. -func NewAPIProxy( - cfg *config.Config, - schwabClient api.ProviderClient, - tokenService auth.TokenServicer, - clientService *auth.ClientService, - store *storage.Store, - otelProviders *observability.OTelProviders, -) (*APIProxy, error) { - // Create health checker - healthChecker := health.NewManager("schwab-proxy-1.0.0") - // Add storage health check - healthChecker.AddChecker(health.NewStorageChecker(store)) - - // Add provider API health check - healthChecker.AddChecker(health.NewProviderChecker(schwabClient)) - - // Create OAuth2 server - storageAdapter := auth.NewStorageAdapter(store) + // Streaming support + streamManager *streaming.Proxy +} - // Get JWT KDF parameters and derive key +// deriveJWTSigningKey derives the JWT signing key from configuration. +func deriveJWTSigningKey(cfg *config.Config) ([]byte, error) { + // Get JWT KDF parameters jwtKDFParams, err := cfg.GetJWTKDFParams() if err != nil { return nil, fmt.Errorf("failed to get JWT KDF parameters: %w", err) @@ -87,7 +70,8 @@ func NewAPIProxy( return nil, fmt.Errorf("failed to get JWT salt: %w", err) } - jwtSigningKey, err := jwtKDFParams.DeriveKey( + // Derive the key + key, err := jwtKDFParams.DeriveKey( []byte(cfg.JWTSeed), jwtSalt, auth.JWTKeySize, @@ -96,6 +80,32 @@ func NewAPIProxy( return nil, fmt.Errorf("failed to derive JWT signing key: %w", err) } + return key, nil +} + +// NewAPIProxy creates a new API proxy server. +func NewAPIProxy( + cfg *config.Config, + schwabClient api.ProviderClient, + tokenService auth.TokenServicer, + clientService *auth.ClientService, + store *storage.Store, + otelProviders *observability.OTelProviders, +) (*APIProxy, error) { + // Create health checker + healthChecker := health.NewManager("schwab-proxy-1.0.0") + healthChecker.AddChecker(health.NewStorageChecker(store)) + healthChecker.AddChecker(health.NewProviderChecker(schwabClient)) + + // Create OAuth2 server + storageAdapter := auth.NewStorageAdapter(store) + + // Derive JWT signing key + jwtSigningKey, err := deriveJWTSigningKey(cfg) + if err != nil { + return nil, err + } + server, err := auth.NewServer(storageAdapter, cfg, jwtSigningKey) if err != nil { return nil, fmt.Errorf("failed to create OAuth2 server: %w", err) @@ -117,7 +127,12 @@ func NewAPIProxy( otelProviders: otelProviders, server: server, storage: store, - logger: slog.Default(), + + streamManager: streaming.NewProxy( + tokenService, + server, + streaming.CreateMetadataFunc(schwabClient), + ), } // Set up routes @@ -126,6 +141,14 @@ func NewAPIProxy( // Start background token refresh proxy.startBackgroundTokenRefresh() + // Start streaming manager if enabled + if proxy.streamManager != nil { + ctx := context.Background() + if err := proxy.streamManager.Start(ctx); err != nil { + return nil, fmt.Errorf("failed to start streaming manager: %w", err) + } + } + return proxy, nil } @@ -140,11 +163,18 @@ func (p *APIProxy) GetServer() *auth.Server { return p.server } -// Shutdown gracefully stops the background token refresh. -func (p *APIProxy) Shutdown() { +// Shutdown gracefully stops the background token refresh and streaming manager. +func (p *APIProxy) Shutdown(ctx context.Context) { if p.refreshCancel != nil { p.refreshCancel() } + + // Shutdown streaming if enabled + if p.streamManager != nil { + if err := p.streamManager.Shutdown(ctx); err != nil { + log.Error(ctx, err, "Failed to shutdown streaming manager") + } + } } // setupRoutes configures all API routes. @@ -176,7 +206,7 @@ func (p *APIProxy) setupRoutes() { // Trader endpoints - following Schwab's structure: /trader/v1/... p.mux.HandleFunc( "/trader/v1/", - p.withTokenValidation(p.withTokenRefresh(p.handleTraderRequest)), + p.withTokenValidation(p.withTokenRefresh(p.interceptableTraderRequest)), ) // Client management endpoints (admin only) @@ -190,6 +220,11 @@ func (p *APIProxy) setupRoutes() { p.mux.HandleFunc("GET /api/approvals", p.withAPIAuth(p.handleListApprovals)) p.mux.HandleFunc("POST /api/approvals/{id}", p.withAPIAuth(p.handleApproveRequest)) p.mux.HandleFunc("DELETE /api/approvals/{id}", p.withAPIAuth(p.handleDenyRequest)) + + // Streaming endpoint + if p.streamManager != nil { + p.mux.HandleFunc("/ws/stream", p.streamManager.HandleWebSocket) + } } // startBackgroundTokenRefresh starts a goroutine that proactively refreshes the Schwab token. diff --git a/python/schwab_monkeypatch.py b/python/schwab_monkeypatch.py index c4509d7..1b23135 100644 --- a/python/schwab_monkeypatch.py +++ b/python/schwab_monkeypatch.py @@ -60,7 +60,7 @@ def _patched_get_request(self, path, params): headers = { "Authorization": "Bearer " + str(self.token_metadata.token["access_token"]) } - # Use the session's internal httpx client which has our SSL settings + return self.session.session.get(dest, params=params, headers=headers) def _patched_post_request(self, path, data): @@ -68,7 +68,7 @@ def _patched_post_request(self, path, data): headers = { "Authorization": "Bearer " + str(self.token_metadata.token["access_token"]), } - # Use the session's internal httpx client which has our SSL settings + return self.session.session.post(dest, json=data, headers=headers) def _patched_put_request(self, path, data): @@ -76,7 +76,7 @@ def _patched_put_request(self, path, data): headers = { "Authorization": "Bearer " + str(self.token_metadata.token["access_token"]), } - # Use the session's internal httpx client which has our SSL settings + return self.session.session.put(dest, json=data, headers=headers) def _patched_delete_request(self, path): @@ -84,10 +84,9 @@ def _patched_delete_request(self, path): headers = { "Authorization": "Bearer " + str(self.token_metadata.token["access_token"]) } - # Use the session's internal httpx client which has our SSL settings + return self.session.session.delete(dest, headers=headers) - # Apply patches Client._get_request = _patched_get_request Client._post_request = _patched_post_request Client._put_request = _patched_put_request @@ -103,7 +102,7 @@ async def _patched_async_get_request(self, path, params): headers = { "Authorization": "Bearer " + str(self.token_metadata.token["access_token"]) } - # Use the session's internal httpx client which has our SSL settings + return await self.session.session.get(dest, params=params, headers=headers) async def _patched_async_post_request(self, path, data): @@ -111,7 +110,7 @@ async def _patched_async_post_request(self, path, data): headers = { "Authorization": "Bearer " + str(self.token_metadata.token["access_token"]), } - # Use the session's internal httpx client which has our SSL settings + return await self.session.session.post(dest, json=data, headers=headers) async def _patched_async_put_request(self, path, data): @@ -119,7 +118,7 @@ async def _patched_async_put_request(self, path, data): headers = { "Authorization": "Bearer " + str(self.token_metadata.token["access_token"]), } - # Use the session's internal httpx client which has our SSL settings + return await self.session.session.put(dest, json=data, headers=headers) async def _patched_async_delete_request(self, path): @@ -127,10 +126,9 @@ async def _patched_async_delete_request(self, path): headers = { "Authorization": "Bearer " + str(self.token_metadata.token["access_token"]) } - # Use the session's internal httpx client which has our SSL settings + return await self.session.session.delete(dest, headers=headers) - # Apply patches AsyncClient._get_request = _patched_async_get_request AsyncClient._post_request = _patched_async_post_request AsyncClient._put_request = _patched_async_put_request @@ -139,30 +137,24 @@ async def _patched_async_delete_request(self, path): def _patch_oauth_endpoints(auth_module, proxy_base_url: str, verify_ssl: bool = True): """Patch OAuth endpoints in the auth module""" - # Patch the token endpoint constant auth_module.TOKEN_ENDPOINT = f"{proxy_base_url}/v1/oauth/token" - # Patch the get_auth_context function which creates the authorization URL if hasattr(auth_module, "get_auth_context"): def make_patched_get_auth_context(ssl_verify): def _patched_get_auth_context(api_key, callback_url, state=None): - # Import OAuth2Client here to avoid import issues from authlib.integrations.httpx_client import OAuth2Client import httpx - # Create httpx client with appropriate SSL verification setting httpx_client = httpx.Client(verify=ssl_verify) - # Create OAuth2Client and then override its session oauth = OAuth2Client(api_key, redirect_uri=callback_url) - oauth.session = httpx_client # Force it to use our custom client + oauth.session = httpx_client authorization_url, state = oauth.create_authorization_url( - f"{proxy_base_url}/v1/oauth/authorize", # Use proxy URL instead of Schwab + f"{proxy_base_url}/v1/oauth/authorize", state=state, ) - # Import AuthContext from the auth module AuthContext = auth_module.collections.namedtuple( "AuthContext", ["callback_url", "authorization_url", "state"] ) @@ -173,7 +165,6 @@ def _patched_get_auth_context(api_key, callback_url, state=None): auth_module.get_auth_context = make_patched_get_auth_context(verify_ssl) - # Also patch client_from_received_url to use an httpx client with SSL verification setting if hasattr(auth_module, "client_from_received_url"): def make_patched_client_from_received_url(ssl_verify): @@ -189,20 +180,17 @@ def _patched_client_from_received_url( from authlib.integrations.httpx_client import OAuth2Client import httpx - # Create httpx client with appropriate SSL verification setting httpx_client = httpx.Client(verify=ssl_verify) - # Create OAuth2Client and then override its session oauth = OAuth2Client(api_key, redirect_uri=auth_context.callback_url) - oauth.session = httpx_client # Force it to use our custom client + oauth.session = httpx_client token = oauth.fetch_token( - auth_module.TOKEN_ENDPOINT, # This now points to our proxy + auth_module.TOKEN_ENDPOINT, authorization_response=received_url, client_secret=app_secret, ) - # Set up token writing and perform the initial token write (like schwab-py does) import time metadata_manager = auth_module.TokenMetadata( @@ -211,15 +199,13 @@ def _patched_client_from_received_url( wrapped_token_write_func = metadata_manager.wrapped_token_write_func() wrapped_token_write_func(token) - # Create httpx client with appropriate SSL verification setting import httpx - httpx_client = httpx.Client(verify=ssl_verify) - - # Create the proper session class (OAuth2Client) as expected by Client constructor if asyncio: from authlib.integrations.httpx_client import AsyncOAuth2Client + httpx_client = httpx.AsyncClient(verify=ssl_verify) + async def oauth_client_update_token(t, *args, **kwargs): wrapped_token_write_func(t, *args, **kwargs) @@ -230,7 +216,9 @@ async def oauth_client_update_token(t, *args, **kwargs): update_token=oauth_client_update_token, leeway=300, ) - session.session = httpx_client # Set our custom httpx client + + session.session = httpx_client + return auth_module.AsyncClient( api_key, session, @@ -240,6 +228,8 @@ async def oauth_client_update_token(t, *args, **kwargs): else: from authlib.integrations.httpx_client import OAuth2Client + httpx_client = httpx.Client(verify=ssl_verify) + session = OAuth2Client( api_key, client_secret=app_secret, @@ -247,7 +237,9 @@ async def oauth_client_update_token(t, *args, **kwargs): update_token=wrapped_token_write_func, leeway=300, ) - session.session = httpx_client # Set our custom httpx client + + session.session = httpx_client + return auth_module.Client( api_key, session, @@ -261,7 +253,6 @@ async def oauth_client_update_token(t, *args, **kwargs): verify_ssl ) - # Also patch client_from_access_functions for token file loading original_client_from_access_functions = auth_module.client_from_access_functions def make_patched_client_from_access_functions(verify_ssl): @@ -273,7 +264,6 @@ def patched_client_from_access_functions( asyncio=False, enforce_enums=True, ): - # Call original function to get the client client = original_client_from_access_functions( api_key, app_secret, @@ -283,10 +273,13 @@ def patched_client_from_access_functions( enforce_enums, ) - # Update the session's httpx client to use our SSL verification setting import httpx - httpx_client = httpx.Client(verify=verify_ssl) + if asyncio: + httpx_client = httpx.AsyncClient(verify=verify_ssl) + else: + httpx_client = httpx.Client(verify=verify_ssl) + client.session.session = httpx_client return client @@ -300,15 +293,13 @@ def patched_client_from_access_functions( def _patch_auth_client_references(auth_module, sync_module, async_module): """Patch any cached client class references in the auth module""" - # Update the Client and AsyncClient references in auth module - # in case they were imported before we patched them if hasattr(auth_module, "Client"): auth_module.Client = sync_module.Client + if hasattr(auth_module, "AsyncClient"): auth_module.AsyncClient = async_module.AsyncClient -# Convenience function for common use case def patch_for_localhost(port: int = 8080, https: bool = True, verify_ssl: bool = True): """ Convenience function to patch for localhost proxy @@ -323,7 +314,6 @@ def patch_for_localhost(port: int = 8080, https: bool = True, verify_ssl: bool = if __name__ == "__main__": - # Example usage print("Example usage:") print("import schwab_monkeypatch") print("schwab_monkeypatch.patch_schwab_client('https://127.0.0.1:8080')") diff --git a/python/test_client.py b/python/test_client.py index 8cf248e..f1aced1 100755 --- a/python/test_client.py +++ b/python/test_client.py @@ -31,11 +31,46 @@ def test_api_calls(client): # Initialize account_hash for later tests account_hash = None + + # Helper function to handle 401 errors + def make_request_with_retry(request_func, *args, **kwargs): + """Make a request and retry once if we get a 401""" + response = request_func(*args, **kwargs) + + if response.status_code == 401: + print(" Got 401, attempting to refresh token...") + try: + # Manually refresh the token + if hasattr(client, 'session') and hasattr(client.session, 'refresh_token'): + # Get the token endpoint from the patched auth module + import schwab + token_endpoint = schwab.auth.TOKEN_ENDPOINT + + # Get current token + if hasattr(client.session, 'token') and client.session.token: + refresh_token = client.session.token.get('refresh_token') + if refresh_token: + print(" Refreshing token...") + client.session.refresh_token(token_endpoint, refresh_token=refresh_token) + print(" Token refreshed successfully") + else: + print(" No refresh token available") + + # Retry the request + response = request_func(*args, **kwargs) + + if response.status_code == 401: + print(" Still got 401 after token refresh.") + print(" Try deleting the token file and re-authenticating.") + except Exception as e: + print(f" Failed to refresh token: {e}") + + return response # Test 1: Account Numbers (no parameters required) print("\n1. Testing get_account_numbers()...") try: - response = client.get_account_numbers() + response = make_request_with_retry(client.get_account_numbers) print(f" ✓ Success: {response.status_code}") if response.status_code == 200: data = response.json() @@ -54,7 +89,7 @@ def test_api_calls(client): # Test 2: User Preferences (no parameters required) print("\n2. Testing get_user_preferences()...") try: - response = client.get_user_preferences() + response = make_request_with_retry(client.get_user_preferences) print(f" ✓ Success: {response.status_code}") if response.status_code == 200: data = response.json() @@ -69,7 +104,7 @@ def test_api_calls(client): # Test 3: All Accounts print("\n3. Testing get_accounts()...") try: - response = client.get_accounts() + response = make_request_with_retry(client.get_accounts) print(f" ✓ Success: {response.status_code}") if response.status_code == 200: data = response.json() @@ -84,7 +119,7 @@ def test_api_calls(client): # Test 4: Single Stock Quote print("\n4. Testing get_quote('AAPL')...") try: - response = client.get_quote("AAPL") + response = make_request_with_retry(client.get_quote, "AAPL") print(f" ✓ Success: {response.status_code}") if response.status_code == 200: data = response.json() @@ -109,7 +144,7 @@ def test_api_calls(client): # Test 5: Multiple Stock Quotes print("\n5. Testing get_quotes(['AAPL', 'GOOGL', 'MSFT'])...") try: - response = client.get_quotes(["AAPL", "GOOGL", "MSFT"]) + response = make_request_with_retry(client.get_quotes, ["AAPL", "GOOGL", "MSFT"]) print(f" ✓ Success: {response.status_code}") if response.status_code == 200: data = response.json() @@ -133,7 +168,7 @@ def test_api_calls(client): print("\n6. Testing get_market_hours(['equity'])...") try: # Use the correct Market enum from client.MarketHours.Market - response = client.get_market_hours([client.MarketHours.Market.EQUITY]) + response = make_request_with_retry(client.get_market_hours, [client.MarketHours.Market.EQUITY]) print(f" ✓ Success: {response.status_code}") if response.status_code == 200: data = response.json() @@ -149,7 +184,7 @@ def test_api_calls(client): if account_hash: print(f"\n7. Testing get_account('{account_hash}')...") try: - response = client.get_account(account_hash) + response = make_request_with_retry(client.get_account, account_hash) print(f" ✓ Success: {response.status_code}") if response.status_code == 200: data = response.json() diff --git a/python/test_streaming.py b/python/test_streaming.py new file mode 100755 index 0000000..1501605 --- /dev/null +++ b/python/test_streaming.py @@ -0,0 +1,400 @@ +#!/usr/bin/env python3 +""" +Test script for schwab-proxy streaming functionality using schwab-py client. + +This script tests: +1. WebSocket connection through the proxy +2. Streaming authentication +3. Subscription management (SUBS, ADD, UNSUBS) +4. Data message routing + +Usage: + python test_streaming.py --app-key YOUR_KEY --app-secret YOUR_SECRET --tickers AAPL,MSFT,GOOGL +""" + +import argparse +import asyncio +import json +import os +import sys +import time +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Set + +# CRITICAL: Patch BEFORE importing any schwab modules +import schwab_monkeypatch + + +class StreamingTestClient: + """Test client for streaming functionality""" + + def __init__(self, schwab_client, account_id: str): + self.schwab_client = schwab_client + self.account_id = account_id + self.stream_client = None + self.messages_received = 0 + self.subscribed_symbols: Set[str] = set() + self.data_by_symbol: Dict[str, List[dict]] = {} + self.errors: List[str] = [] + self.connected = False + + async def setup_streaming(self): + """Initialize the streaming client""" + from schwab.streaming import StreamClient + + self.stream_client = StreamClient( + self.schwab_client, + account_id=self.account_id + ) + + prefs_func = self.stream_client._client.get_user_preferences + + async def get_user_preferences(): + response = await prefs_func() + + # Handle 401 - try to refresh token once + if response.status_code == 401: + print(" Got 401 on user preferences, attempting to refresh token...") + try: + # Manually refresh the token + if hasattr(self.schwab_client, 'session') and hasattr(self.schwab_client.session, 'refresh_token'): + # Get the token endpoint from the patched auth module + import schwab + token_endpoint = schwab.auth.TOKEN_ENDPOINT + + # Get current token + if hasattr(self.schwab_client.session, 'token') and self.schwab_client.session.token: + refresh_token = self.schwab_client.session.token.get('refresh_token') + if refresh_token: + print(" Refreshing token...") + await self.schwab_client.session.refresh_token(token_endpoint, refresh_token=refresh_token) + print(" Token refreshed successfully") + # Retry the request + response = await prefs_func() + except Exception as e: + print(f" Failed to refresh token: {e}") + + print(f"User preferences response: {response.json()}") + return response + + # Patch the client to use the custom get_user_preferences + self.stream_client._client.get_user_preferences = get_user_preferences + + # Add handlers for different data types + self.stream_client.add_level_one_equity_handler( + lambda msg: self._handle_level_one_equity(msg) + ) + + print("StreamClient initialized") + return True + + async def connect_and_login(self): + """Connect to the WebSocket and login""" + print("Attempting to login to stream...") + await self.stream_client.login() + self.connected = True + print("✓ Successfully logged in to stream") + return True + + async def subscribe_to_symbols(self, symbols: List[str]): + """Subscribe to level 1 equity data for given symbols""" + print(f"Subscribing to symbols: {symbols}") + await self.stream_client.level_one_equity_subs(symbols) + self.subscribed_symbols.update(symbols) + print(f"✓ Subscribed to {len(symbols)} symbols") + return True + + async def add_symbols(self, symbols: List[str]): + """Add additional symbols to subscription""" + print(f"Adding symbols: {symbols}") + await self.stream_client.level_one_equity_add(symbols) + self.subscribed_symbols.update(symbols) + print(f"✓ Added {len(symbols)} symbols") + return True + + async def unsubscribe_symbols(self, symbols: List[str]): + """Unsubscribe from specific symbols""" + print(f"Unsubscribing from symbols: {symbols}") + await self.stream_client.level_one_equity_unsubs(symbols) + self.subscribed_symbols.difference_update(symbols) + print(f"✓ Unsubscribed from {len(symbols)} symbols") + return True + + async def handle_messages(self, duration: int = 30): + """Handle incoming messages for specified duration""" + print(f"Starting message handler for {duration} seconds...") + start_time = time.time() + + try: + while time.time() - start_time < duration: + await self.stream_client.handle_message() + await asyncio.sleep(0.01) # Small delay to prevent tight loop + except asyncio.CancelledError: + print("Message handler cancelled") + raise + + def _handle_level_one_equity(self, msg): + """Handler for level 1 equity data""" + self.messages_received += 1 + + # Extract symbol from message + symbol = msg.get('key', 'UNKNOWN') + + # Store message data + if symbol not in self.data_by_symbol: + self.data_by_symbol[symbol] = [] + + self.data_by_symbol[symbol].append({ + 'timestamp': datetime.now().isoformat(), + 'data': msg + }) + + # Print first few messages for debugging + if self.messages_received <= 5: + print(f"Received data for {symbol}: {json.dumps(msg, indent=2)}") + elif self.messages_received % 10 == 0: + print(f"Total messages received: {self.messages_received}") + + def get_summary(self): + """Get summary of test results""" + return { + 'connected': self.connected, + 'messages_received': self.messages_received, + 'symbols_subscribed': len(self.subscribed_symbols), + 'symbols_with_data': len(self.data_by_symbol), + 'errors': len(self.errors), + 'error_details': self.errors + } + + +async def test_streaming(schwab_client, account_id, tickers, duration): + """Test streaming with specified tickers""" + print("\n" + "="*60) + print("STREAMING TEST") + print("="*60) + + client = StreamingTestClient(schwab_client, account_id) + + # Setup streaming + await client.setup_streaming() + + # Connect and login + await client.connect_and_login() + + # Subscribe to specified tickers + print(f"\nSubscribing to tickers: {tickers}") + await client.subscribe_to_symbols(tickers) + + # Handle messages for specified duration + print(f"\nStreaming data for {duration} seconds...") + await client.handle_messages(duration) + + # Print summary + summary = client.get_summary() + print("\n" + "-"*40) + print("STREAMING TEST SUMMARY") + print("-"*40) + print(f"Connected: {summary['connected']}") + print(f"Messages received: {summary['messages_received']}") + print(f"Symbols subscribed: {summary['symbols_subscribed']}") + print(f"Symbols with data: {summary['symbols_with_data']}") + + # Show data summary per symbol + if client.data_by_symbol: + print("\nData received per symbol:") + for symbol, data in client.data_by_symbol.items(): + print(f" {symbol}: {len(data)} messages") + + + + +async def main_async(args): + """Main async function""" + # Get credentials + app_key = args.app_key or os.getenv("SCHWAB_APP_KEY") + app_secret = args.app_secret or os.getenv("SCHWAB_APP_SECRET") + + if not app_key or not app_secret: + print("Error: Must provide app key and secret (via --app-key/--app-secret or SCHWAB_APP_KEY/SCHWAB_APP_SECRET env vars)") + return 1 + + # Parse tickers + tickers = [t.strip().upper() for t in args.tickers.split(",") if t.strip()] + if not tickers: + print("Error: Must provide at least one ticker") + return 1 + + print("Schwab Proxy Streaming Test") + print("="*50) + print(f"Proxy URL: {args.proxy_url}") + print(f"SSL verification: {'Disabled' if args.no_verify_ssl else 'Enabled'}") + print(f"Tickers to monitor: {', '.join(tickers)}") + print(f"Stream duration: {args.stream_duration} seconds") + + # Patch schwab client + print("\n1. Patching schwab-py client to use proxy...") + schwab_monkeypatch.patch_schwab_client( + args.proxy_url, verify_ssl=not args.no_verify_ssl + ) + + # Import schwab after patching + print("\n2. Importing schwab modules...") + import schwab + + # Handle SSL + if args.no_verify_ssl: + print("\n3. Disabling SSL verification...") + import urllib3 + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + import ssl + ssl._create_default_https_context = ssl._create_unverified_context + os.environ["PYTHONHTTPSVERIFY"] = "0" + os.environ["CURL_CA_BUNDLE"] = "" + + # Create client + print("\n4. Creating Schwab client...") + token_path = Path(args.token_file) + + if token_path.exists(): + client = schwab.auth.client_from_token_file( + token_path=token_path, + api_key=app_key, + app_secret=app_secret, + asyncio=True # Enable async support + ) + print(f" ✓ Loaded existing token from {args.token_file}") + else: + print(f" Token file '{args.token_file}' not found.") + print("\n Starting OAuth authentication flow...") + print(f" Callback URL: {args.callback_url}") + # This will use the patched OAuth endpoints (proxy's OAuth server) + client = schwab.auth.client_from_login_flow( + api_key=app_key, + app_secret=app_secret, + callback_url=args.callback_url, + token_path=Path(args.token_file), + asyncio=True # Enable async support + ) + print(" ✓ OAuth flow completed successfully") + + # Get account ID + print("\n5. Getting account ID...") + resp = await client.get_account_numbers() + + # Handle 401 - try to refresh token once + if resp.status_code == 401: + print(" Got 401, attempting to refresh token...") + try: + # Manually refresh the token + if hasattr(client, 'session') and hasattr(client.session, 'refresh_token'): + # Get the token endpoint from the patched auth module + import schwab + token_endpoint = schwab.auth.TOKEN_ENDPOINT + + # Get current token + if hasattr(client.session, 'token') and client.session.token: + refresh_token = client.session.token.get('refresh_token') + if refresh_token: + print(" Refreshing token...") + await client.session.refresh_token(token_endpoint, refresh_token=refresh_token) + print(" Token refreshed successfully") + else: + print(" No refresh token available") + + # Retry the request + resp = await client.get_account_numbers() + except Exception as e: + print(f" Failed to refresh token: {e}") + + if resp.status_code == 200: + accounts = resp.json() + if accounts: + account_id = accounts[0]['hashValue'] + print(f" Account ID: {account_id}") + else: + print(" No accounts found") + return 1 + else: + print(f" Failed to get accounts: {resp.status_code}") + if resp.status_code == 401: + print(" Authentication failed. Try deleting the token file and re-authenticating.") + return 1 + + # Run streaming test + await test_streaming(client, account_id, tickers, args.stream_duration) + + print("\n🎉 Streaming test completed!") + return 0 + + +def main(): + parser = argparse.ArgumentParser( + description="Test schwab-proxy streaming functionality", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Stream with existing token file + python test_streaming.py --app-key YOUR_KEY --app-secret YOUR_SECRET --tickers AAPL,MSFT,GOOGL + + # Stream with custom token file + python test_streaming.py --app-key YOUR_KEY --app-secret YOUR_SECRET --token-file my_token.json --tickers SPY + + # Fresh authentication with custom callback URL + python test_streaming.py --app-key YOUR_KEY --app-secret YOUR_SECRET --callback-url https://localhost:9000 --tickers TSLA + + # Use environment variables for credentials + export SCHWAB_APP_KEY=your_key + export SCHWAB_APP_SECRET=your_secret + python test_streaming.py --tickers SPY,QQQ,IWM + + # Test with custom proxy URL and no SSL verification + python test_streaming.py --proxy-url https://myproxy:8080 --no-verify-ssl --tickers NVDA --stream-duration 60 +""" + ) + parser.add_argument( + "--proxy-url", + default="https://localhost:8080", + help="Proxy server URL (default: https://localhost:8080)", + ) + parser.add_argument( + "--app-key", help="Schwab app key (or set SCHWAB_APP_KEY env var)" + ) + parser.add_argument( + "--app-secret", help="Schwab app secret (or set SCHWAB_APP_SECRET env var)" + ) + parser.add_argument( + "--tickers", + required=True, + help="Comma-separated list of stock tickers to monitor (e.g., AAPL,MSFT,GOOGL)", + ) + parser.add_argument( + "--token-file", + default="schwab_token.json", + help="Token file path to save/load OAuth tokens (default: schwab_token.json)", + ) + parser.add_argument( + "--callback-url", + default="https://127.0.0.1:3000", + help="OAuth callback URL for fresh authentication (default: https://127.0.0.1:3000)", + ) + parser.add_argument( + "--no-verify-ssl", + action="store_true", + help="Disable SSL verification for self-signed certificates", + ) + parser.add_argument( + "--stream-duration", + type=int, + default=30, + help="Duration to stream data in seconds (default: 30)", + ) + + args = parser.parse_args() + + # Run async main + return asyncio.run(main_async(args)) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/streaming/metadata.go b/streaming/metadata.go new file mode 100644 index 0000000..fb15f03 --- /dev/null +++ b/streaming/metadata.go @@ -0,0 +1,136 @@ +package streaming + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "sync" + "time" + + "github.com/jkoelker/schwab-proxy/api" +) + +var ErrNoMetadata = errors.New("no metadata available") + +const defaultMetadataTTL = 24 * time.Hour + +// UserPreferencesResponse represents the user preferences API response. +type UserPreferencesResponse struct { + //nolint:tagliatelle // Schwab API response structure + StreamerInfo []StreamerInfo `json:"streamerInfo"` +} + +// StreamerInfo contains streaming configuration. +type StreamerInfo struct { + //nolint:tagliatelle // Schwab API response structure + StreamerSocketURL string `json:"streamerSocketUrl"` + + //nolint:tagliatelle // Schwab API response structure + SchwabClientCustomerID string `json:"schwabClientCustomerId"` + + //nolint:tagliatelle // Schwab API response structure + SchwabClientCorrelID string `json:"schwabClientCorrelId"` + + //nolint:tagliatelle // Schwab API response structure + SchwabClientChannel string `json:"schwabClientChannel"` + + //nolint:tagliatelle // Schwab API response structure + SchwabClientFunctionID string `json:"schwabClientFunctionId"` +} + +// MetadataManager manages streaming metadata with caching. +type MetadataManager struct { + metadata *Metadata + mu sync.RWMutex + refreshFunc func() (*Metadata, error) + lastRefresh time.Time +} + +// NewMetadataManager creates a metadata manager. +func NewMetadataManager(refreshFunc func() (*Metadata, error)) *MetadataManager { + return &MetadataManager{ + refreshFunc: refreshFunc, + } +} + +// GetMetadata returns current metadata, refreshing if needed. +func (m *MetadataManager) GetMetadata() (*Metadata, error) { + m.mu.RLock() + + if m.metadata != nil && time.Since(m.lastRefresh) < m.metadata.TTL { + metadata := m.metadata + m.mu.RUnlock() + + return metadata, nil + } + + m.mu.RUnlock() + + // Need to refresh + m.mu.Lock() + defer m.mu.Unlock() + + // Double-check after acquiring write lock + if m.metadata != nil && time.Since(m.lastRefresh) < m.metadata.TTL { + return m.metadata, nil + } + + // Refresh metadata + metadata, err := m.refreshFunc() + if err != nil { + return nil, err + } + + m.metadata = metadata + m.lastRefresh = time.Now() + + return metadata, nil +} + +// CreateMetadataFunc creates a metadata refresh function. +func CreateMetadataFunc(schwabClient api.ProviderClient) func() (*Metadata, error) { + return func() (*Metadata, error) { + ctx := context.Background() + + resp, err := schwabClient.Call(ctx, "GET", "/trader/v1/userPreference", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to fetch user preferences: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + + return nil, fmt.Errorf( + "%w: preferences request failed: %d - %s", + ErrNoMetadata, + resp.StatusCode, + string(body), + ) + } + + var prefs UserPreferencesResponse + if err := json.NewDecoder(resp.Body).Decode(&prefs); err != nil { + return nil, fmt.Errorf("failed to decode user preferences: %w", err) + } + + if len(prefs.StreamerInfo) == 0 { + return nil, fmt.Errorf("%w: no streaming info in response", ErrNoMetadata) + } + + metadata := &Metadata{ + CorrelID: prefs.StreamerInfo[0].SchwabClientCorrelID, + CustomerID: prefs.StreamerInfo[0].SchwabClientCustomerID, + Channel: prefs.StreamerInfo[0].SchwabClientChannel, + FunctionID: prefs.StreamerInfo[0].SchwabClientFunctionID, + WSEndpoint: prefs.StreamerInfo[0].StreamerSocketURL, + ExtractedAt: time.Now(), + TTL: defaultMetadataTTL, + } + + return metadata, nil + } +} diff --git a/streaming/proxy.go b/streaming/proxy.go new file mode 100644 index 0000000..a782ab3 --- /dev/null +++ b/streaming/proxy.go @@ -0,0 +1,784 @@ +package streaming + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "sync" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + + "github.com/jkoelker/schwab-proxy/auth" + "github.com/jkoelker/schwab-proxy/log" +) + +const ( + readBufferSize = 1024 + writeBufferSize = 1024 + maxMessageSize = 10 * 1024 * 1024 // 10MB + + authTimeout = 30 * time.Second + readTimeout = 60 * time.Second + writeTimeout = 10 * time.Second + pingInterval = 30 * time.Second + pongWait = 60 * time.Second + + clientChannelSize = 100 + + reconnectStartInterval = 1 * time.Second + // Error codes. + codeInvalidJSON = 1 + codeNotAuthenticated = 3 + codeServiceError = 22 + // Connection timeouts. + handshakeTimeout = 10 * time.Second +) + +var ( + ErrMasterConnectionNotAvailable = errors.New("master connection not available") + ErrLoginFailed = errors.New("LOGIN failed") +) + +// Proxy manages WebSocket connections between clients and Schwab. +type Proxy struct { + // Dependencies + tokenManager auth.TokenServicer + authServer *auth.Server + metadataFunc func() (*Metadata, error) + + // Master connection + masterConn *websocket.Conn + masterMu sync.RWMutex + reconnectDelay time.Duration + + // Client tracking + clients *ClientMap + + // Lifecycle + cancel context.CancelFunc + wg sync.WaitGroup +} + +// Client represents a connected WebSocket client. +type Client struct { + id string + conn *websocket.Conn + info ClientInfo + authed bool + msgChan chan json.RawMessage + done chan struct{} +} + +// ClientInfo represents authenticated client information. +type ClientInfo struct { + ClientID string + Scopes []string +} + +// NewProxy creates a new streaming proxy. +func NewProxy( + tokenManager auth.TokenServicer, + authServer *auth.Server, + metadataFunc func() (*Metadata, error), +) *Proxy { + return &Proxy{ + tokenManager: tokenManager, + authServer: authServer, + metadataFunc: metadataFunc, + clients: NewClientMap(), + reconnectDelay: reconnectStartInterval, + } +} + +// Start begins the streaming proxy operations. +func (sp *Proxy) Start(ctx context.Context) error { + // Don't store context, just store the cancel function + _, cancel := context.WithCancel(ctx) + sp.cancel = cancel + + return nil +} + +// Shutdown gracefully shuts down the streaming proxy. +func (sp *Proxy) Shutdown(ctx context.Context) error { + log.Info(ctx, "Stopping streaming proxy") + + // Cancel context + if sp.cancel != nil { + sp.cancel() + } + + // Close master connection + sp.masterMu.Lock() + + if sp.masterConn != nil { + sp.masterConn.Close() + } + + sp.masterMu.Unlock() + + // Close all client connections + sp.clients.Range(func(_ string, client *Client) bool { + close(client.done) + client.conn.Close() + + return true + }) + + // Wait for goroutines + sp.wg.Wait() + + return nil +} + +// HandleWebSocket handles incoming WebSocket connections from clients. +func (sp *Proxy) HandleWebSocket(writer http.ResponseWriter, req *http.Request) { + // Upgrade connection + upgrader := websocket.Upgrader{ + ReadBufferSize: readBufferSize, + WriteBufferSize: writeBufferSize, + CheckOrigin: func(_ *http.Request) bool { + return true // Configure based on security needs + }, + } + + conn, err := upgrader.Upgrade(writer, req, nil) + if err != nil { + log.Error(req.Context(), err, "Failed to upgrade WebSocket") + + return + } + + // Create client + clientID := "client_" + uuid.New().String() + client := &Client{ + id: clientID, + conn: conn, + msgChan: make(chan json.RawMessage, clientChannelSize), + done: make(chan struct{}), + } + + // Register client + sp.clients.Store(clientID, client) + + // Create a context for the client that isn't tied to the HTTP request + // Use WithoutCancel to inherit values but not cancellation from the HTTP request + clientCtx, clientCancel := context.WithCancel(context.WithoutCancel(req.Context())) + + // Ensure master connection exists + if err := sp.ensureMasterConnection(clientCtx); err != nil { + log.Error(clientCtx, err, "Failed to establish master connection") + + _ = conn.WriteJSON(map[string]string{"error": "Service unavailable"}) + + conn.Close() + clientCancel() + + return + } + + // Handle client connection + sp.wg.Add(1) + + go func() { + defer clientCancel() + + sp.handleClient(clientCtx, client) + }() +} + +// GetConnectionState returns the current master connection state. +func (sp *Proxy) GetConnectionState() string { + sp.masterMu.RLock() + defer sp.masterMu.RUnlock() + + if sp.masterConn != nil { + return "connected" + } + + return "disconnected" +} + +// GetClientCount returns the number of connected clients. +func (sp *Proxy) GetClientCount() int { + return sp.clients.Count() +} + +// GetLastHeartbeat returns the last heartbeat time (not implemented in simplified version). +func (sp *Proxy) GetLastHeartbeat() time.Time { + return time.Now() // Simplified - could track if needed +} + +// handleClient manages a client connection lifecycle. +func (sp *Proxy) handleClient(ctx context.Context, client *Client) { + defer sp.wg.Done() + defer func() { + // Cleanup + close(client.done) + client.conn.Close() + + sp.clients.Delete(client.id) + + log.Info(ctx, "Client disconnected", "client_id", client.id) + }() + + log.Info(ctx, "Client connected", "client_id", client.id) + + // Start write loop + go sp.clientWriteLoop(client) + + // Set auth timeout + authTimer := time.NewTimer(authTimeout) + defer authTimer.Stop() + + // Configure connection + client.conn.SetReadLimit(maxMessageSize) + _ = client.conn.SetReadDeadline(time.Now().Add(readTimeout)) + + client.conn.SetPongHandler(func(string) error { + return client.conn.SetReadDeadline(time.Now().Add(readTimeout)) + }) + + // Read loop + for { + select { + case <-ctx.Done(): + return + case <-authTimer.C: + if !client.authed { + log.Info(ctx, "Client auth timeout", "client_id", client.id) + + return + } + + default: + } + + _, message, err := client.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError( + err, + websocket.CloseGoingAway, + websocket.CloseAbnormalClosure, + ) { + log.Error(ctx, err, "WebSocket read error", "client_id", client.id) + } + + return + } + + // Process message + if err := sp.processClientMessage(ctx, client, message); err != nil { + log.Error(ctx, err, "Failed to process message", "client_id", client.id) + } + } +} + +// clientWriteLoop sends messages to a client. +func (sp *Proxy) clientWriteLoop(client *Client) { + ticker := time.NewTicker(pingInterval) + defer ticker.Stop() + + for { + select { + case <-client.done: + return + + case msg := <-client.msgChan: + _ = client.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + + if err := client.conn.WriteMessage(websocket.TextMessage, msg); err != nil { + return + } + + case <-ticker.C: + _ = client.conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + + if err := client.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } +} + +// processClientMessage handles a message from a client. +func (sp *Proxy) processClientMessage(ctx context.Context, client *Client, message []byte) error { + // Parse request + var req RequestBatch + if err := json.Unmarshal(message, &req); err != nil { + return sp.sendErrorResponse(client, "", "ADMIN", "", codeInvalidJSON, "Invalid JSON") + } + + // Process each command + for _, cmd := range req.Requests { + // Handle LOGIN specially + if cmd.Service == "ADMIN" && cmd.Command == "LOGIN" { + if err := sp.handleClientLogin(ctx, client, cmd); err != nil { + return err + } + + continue + } + + // Require auth for other commands + if !client.authed { + return sp.sendErrorResponse( + client, + cmd.RequestID, + cmd.Service, + cmd.Command, + codeNotAuthenticated, + "Not authenticated", + ) + } + + // Store original request ID before prefixing + originalRequestID := cmd.RequestID + + // Forward to master with client ID prefix + cmd.RequestID = PrefixRequestID(client.id, cmd.RequestID) + if err := sp.forwardToMaster(ctx, cmd); err != nil { + return sp.sendErrorResponse( + client, + originalRequestID, + cmd.Service, + cmd.Command, + codeServiceError, + "Service error", + ) + } + } + + return nil +} + +// handleClientLogin processes a client LOGIN command. +func (sp *Proxy) handleClientLogin(ctx context.Context, client *Client, cmd Request) error { + // Extract token + authToken, ok := cmd.Parameters["Authorization"].(string) + if !ok { + return sp.sendErrorResponse( + client, + cmd.RequestID, + cmd.Service, + cmd.Command, + codeInvalidJSON, + "Missing Authorization", + ) + } + + // Validate proxy JWT + clientID, scopes, err := sp.authServer.ValidateAccessToken(ctx, authToken) + if err != nil { + return sp.sendErrorResponse( + client, + cmd.RequestID, + cmd.Service, + cmd.Command, + codeNotAuthenticated, + "Invalid token", + ) + } + + // Mark as authenticated + client.info = ClientInfo{ClientID: clientID, Scopes: scopes} + client.authed = true + + // Send success response + response := Response{ + Response: []ResponseItem{{ + Service: cmd.Service, + Command: cmd.Command, + RequestID: cmd.RequestID, + SchwabClientCorrelID: cmd.SchwabClientCorrelID, + Timestamp: time.Now().UnixMilli(), + Content: ResponseContent{ + Code: 0, + Msg: "Login successful", + }, + }}, + } + + data, err := json.Marshal(response) + if err != nil { + return fmt.Errorf("failed to marshal response: %w", err) + } + + select { + case client.msgChan <- data: + return nil + + case <-client.done: + return nil + } +} + +// sendErrorResponse sends an error response to a client. +func (sp *Proxy) sendErrorResponse( + client *Client, + reqID string, + service string, + command string, + code int, + msg string, +) error { + response := Response{ + Response: []ResponseItem{{ + Service: service, + Command: command, + RequestID: reqID, + Timestamp: time.Now().UnixMilli(), + Content: ResponseContent{ + Code: code, + Msg: msg, + }, + }}, + } + + data, err := json.Marshal(response) + if err != nil { + return fmt.Errorf("failed to marshal response: %w", err) + } + + select { + case client.msgChan <- data: + return nil + + case <-client.done: + return nil + } +} + +// forwardToMaster sends a command to the master connection. +func (sp *Proxy) forwardToMaster(ctx context.Context, cmd Request) error { + // First attempt with existing connection + sp.masterMu.RLock() + conn := sp.masterConn + sp.masterMu.RUnlock() + + if conn == nil { + // Try to reconnect + if err := sp.ensureMasterConnection(ctx); err != nil { + return fmt.Errorf("master connection unavailable: %w", err) + } + + // Get connection again after reconnect + sp.masterMu.RLock() + conn = sp.masterConn + sp.masterMu.RUnlock() + + if conn == nil { + return ErrMasterConnectionNotAvailable + } + } + + req := RequestBatch{Requests: []Request{cmd}} + + if err := conn.WriteJSON(req); err != nil { + // Connection might be stale, mark it as closed + sp.masterMu.Lock() + + if sp.masterConn == conn { + sp.masterConn = nil + } + + sp.masterMu.Unlock() + + return fmt.Errorf("failed to forward command to master: %w", err) + } + + return nil +} + +// ensureMasterConnection establishes master connection if needed. +func (sp *Proxy) ensureMasterConnection(ctx context.Context) error { + sp.masterMu.Lock() + defer sp.masterMu.Unlock() + + // Already connected + if sp.masterConn != nil { + return nil + } + + // Get metadata + metadata, err := sp.metadataFunc() + if err != nil { + return fmt.Errorf("failed to get metadata: %w", err) + } + + // Connect to Schwab + dialer := websocket.Dialer{ + HandshakeTimeout: handshakeTimeout, + } + + conn, resp, err := dialer.Dial(metadata.WSEndpoint, nil) + if err != nil { + return fmt.Errorf("failed to connect: %w", err) + } + + if resp != nil && resp.Body != nil { + resp.Body.Close() + } + + sp.masterConn = conn + + // Configure connection timeouts and handlers + conn.SetReadLimit(maxMessageSize) + _ = conn.SetReadDeadline(time.Now().Add(pongWait)) + conn.SetPongHandler(func(string) error { + return conn.SetReadDeadline(time.Now().Add(pongWait)) + }) + + // Authenticate master connection + if err := sp.authenticateMaster(ctx, metadata); err != nil { + sp.masterConn.Close() + sp.masterConn = nil + + return err + } + + // Start master read loop + sp.wg.Add(1) + + go sp.masterReadLoop(ctx) + + log.Info(ctx, "Master connection established") + + return nil +} + +// authenticateMaster sends LOGIN to Schwab. +func (sp *Proxy) authenticateMaster(ctx context.Context, metadata *Metadata) error { + token, err := sp.tokenManager.GetProviderToken(ctx) + if err != nil { + return fmt.Errorf("failed to get token: %w", err) + } + + loginCmd := RequestBatch{ + Requests: []Request{{ + Service: "ADMIN", + Command: "LOGIN", + RequestID: "master_login", + SchwabClientCustomerID: metadata.CustomerID, + SchwabClientCorrelID: metadata.CorrelID, + Parameters: map[string]interface{}{ + "Authorization": token.AccessToken, + "SchwabClientChannel": metadata.Channel, + "SchwabClientFunctionId": metadata.FunctionID, + }, + }}, + } + + if err := sp.masterConn.WriteJSON(loginCmd); err != nil { + return fmt.Errorf("failed to send LOGIN: %w", err) + } + + // Wait for response + _ = sp.masterConn.SetReadDeadline(time.Now().Add(authTimeout)) + + var response Response + if err := sp.masterConn.ReadJSON(&response); err != nil { + return fmt.Errorf("failed to read LOGIN response: %w", err) + } + + if len(response.Response) == 0 || response.Response[0].Content.Code != 0 { + return ErrLoginFailed + } + + return nil +} + +// masterReadLoop reads messages from master and routes to appropriate clients. +func (sp *Proxy) masterReadLoop(ctx context.Context) { + defer sp.wg.Done() + defer sp.cleanupMasterConnection(ctx) + + // Get the connection to pass to the reader goroutine + sp.masterMu.RLock() + conn := sp.masterConn + sp.masterMu.RUnlock() + + if conn == nil { + log.Error(ctx, nil, "Master connection is nil in read loop") + + return + } + + // Start ping ticker to keep connection alive + pingTicker := time.NewTicker(pingInterval) + defer pingTicker.Stop() + + // Create a channel for read messages + const readChannelSize = 10 + + msgChan := make(chan []byte, readChannelSize) + errChan := make(chan error, 1) + + // Start goroutine to read messages + go sp.readMasterMessages(conn, msgChan, errChan) + + for { + select { + case <-ctx.Done(): + return + case <-pingTicker.C: + if err := sp.sendMasterPing(ctx); err != nil { + return + } + + case err := <-errChan: + log.Error(ctx, err, "Master connection read error") + + return + case message := <-msgChan: + // Route message to appropriate client(s) + sp.routeMessage(ctx, message) + } + } +} + +// cleanupMasterConnection handles cleanup when master connection is lost. +func (sp *Proxy) cleanupMasterConnection(ctx context.Context) { + sp.masterMu.Lock() + + if sp.masterConn != nil { + sp.masterConn.Close() + sp.masterConn = nil + } + + sp.masterMu.Unlock() + + // Trigger reconnection if context not cancelled + select { + case <-ctx.Done(): + // Shutting down, don't reconnect + default: + log.Info(ctx, "Master connection lost, will attempt reconnection") + } +} + +// readMasterMessages reads messages from the master connection. +func (sp *Proxy) readMasterMessages(conn *websocket.Conn, msgChan chan<- []byte, errChan chan<- error) { + for { + _, message, err := conn.ReadMessage() + if err != nil { + errChan <- err + + return + } + + // Reset read deadline on any message (including heartbeats) + _ = conn.SetReadDeadline(time.Now().Add(pongWait)) + + msgChan <- message + } +} + +// sendMasterPing sends a ping message to the master connection. +func (sp *Proxy) sendMasterPing(ctx context.Context) error { + sp.masterMu.RLock() + conn := sp.masterConn + sp.masterMu.RUnlock() + + if conn != nil { + _ = conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + + if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + log.Error(ctx, err, "Failed to send ping to master") + + return fmt.Errorf("failed to send ping to master: %w", err) + } + } + + return nil +} + +// routeMessage routes a message from master to the appropriate client(s). +func (sp *Proxy) routeMessage(ctx context.Context, message []byte) { + // Try to parse the message to check if it's a response + var resp Response + if err := json.Unmarshal(message, &resp); err != nil { + // If it's not a valid response, broadcast to all clients + sp.broadcastToClients(message) + + return + } + + // Check if this is a routable response message + if !sp.isRoutableResponse(resp) { + // For data/notify messages or messages without routing info, broadcast to all + sp.broadcastToClients(message) + + return + } + + // Route to specific client based on request ID + sp.routeToSpecificClient(ctx, resp, message) +} + +// isRoutableResponse checks if a response can be routed to a specific client. +func (sp *Proxy) isRoutableResponse(resp Response) bool { + return len(resp.Response) > 0 && resp.Response[0].RequestID != "" +} + +// routeToSpecificClient routes a response to a specific client based on request ID. +func (sp *Proxy) routeToSpecificClient(ctx context.Context, resp Response, originalMessage []byte) { + requestID := resp.Response[0].RequestID + + // Extract client ID and original request ID + clientID, originalRequestID, err := UnprefixRequestID(requestID) + if err != nil { + // Not a client-prefixed request ID, broadcast instead + sp.broadcastToClients(originalMessage) + + return + } + + // Restore original request ID + resp.Response[0].RequestID = originalRequestID + + // Re-marshal with original request ID + modifiedMsg, err := json.Marshal(resp) + if err != nil { + log.Error(ctx, err, "Failed to marshal modified response") + + return + } + + // Send to specific client + sp.sendToClient(ctx, clientID, modifiedMsg) +} + +// sendToClient sends a message to a specific client. +func (sp *Proxy) sendToClient(ctx context.Context, clientID string, message []byte) { + client, exists := sp.clients.Load(clientID) + if !exists || !client.authed { + return + } + + select { + case client.msgChan <- message: + // Message sent successfully + default: + log.Warn(ctx, "Client channel full", "client_id", clientID) + } +} + +// broadcastToClients sends a message to all authenticated clients. +func (sp *Proxy) broadcastToClients(message []byte) { + sp.clients.Range(func(_ string, client *Client) bool { + if client.authed { + select { + case client.msgChan <- message: + // Client channel full, skip + default: + } + } + + return true + }) +} diff --git a/streaming/proxy_test.go b/streaming/proxy_test.go new file mode 100644 index 0000000..c28e8b6 --- /dev/null +++ b/streaming/proxy_test.go @@ -0,0 +1,138 @@ +package streaming_test + +import ( + "context" + "testing" + "time" + + "golang.org/x/oauth2" + + "github.com/jkoelker/schwab-proxy/auth" + "github.com/jkoelker/schwab-proxy/config" + "github.com/jkoelker/schwab-proxy/streaming" +) + +func TestProxy(t *testing.T) { + t.Parallel() + // Mock dependencies + tokenManager := &mockTokenManager{} + + // Create a minimal auth server for testing + mockStore := &mockStorage{data: make(map[string]interface{})} + signingKey := []byte("test-key-that-is-at-least-32-bytes-long") + testConfig := &config.Config{ + OAuth2AccessTokenExpiry: 12 * time.Hour, + OAuth2RefreshTokenExpiry: 7 * 24 * time.Hour, + OAuth2AuthCodeExpiry: 10 * time.Minute, + } + + authServer, err := auth.NewServer(mockStore, testConfig, signingKey) + if err != nil { + t.Fatalf("Failed to create auth server: %v", err) + } + + metadataFunc := func() (*streaming.Metadata, error) { + return &streaming.Metadata{ + CorrelID: "test-correl", + CustomerID: "test-customer", + Channel: "test-channel", + FunctionID: "test-function", + WSEndpoint: "wss://test.example.com/stream", + }, nil + } + + // Create proxy + proxy := streaming.NewProxy(tokenManager, authServer, metadataFunc) + + // Test Start + ctx := context.Background() + + err = proxy.Start(ctx) + if err != nil { + t.Fatalf("Start failed: %v", err) + } + + // Test Stop + err = proxy.Shutdown(ctx) + if err != nil { + t.Fatalf("Stop failed: %v", err) + } + + // Test GetConnectionState + state := proxy.GetConnectionState() + if state != "disconnected" { + t.Errorf("Expected disconnected state, got %s", state) + } + + // Test GetClientCount + count := proxy.GetClientCount() + if count != 0 { + t.Errorf("Expected 0 clients, got %d", count) + } +} + +// Mock implementations. +type mockTokenManager struct{} + +func (m *mockTokenManager) GetProviderToken(_ context.Context) (*oauth2.Token, error) { + return &oauth2.Token{ + AccessToken: "test-token", + TokenType: "Bearer", + Expiry: time.Now().Add(time.Hour), + }, nil +} + +func (m *mockTokenManager) RefreshProviderToken(ctx context.Context) (*oauth2.Token, error) { + return m.GetProviderToken(ctx) +} + +func (m *mockTokenManager) SaveProviderToken(_ context.Context, _ *oauth2.Token) error { + return nil +} + +func (m *mockTokenManager) StoreProviderToken(_ context.Context, _, _, _ string, _ int) error { + return nil +} + +func (m *mockTokenManager) NeedsProactiveRefresh(_ context.Context) bool { + return false +} + +// Mock storage for auth server. +type mockStorage struct { + data map[string]interface{} +} + +func (m *mockStorage) Get(key string, value any) error { + if val, ok := m.data[key]; ok { + // Simple type assertion for test + switch targetValue := value.(type) { + case *string: + str, ok := val.(string) + if ok { + *targetValue = str + } + case *[]byte: + b, ok := val.([]byte) + if ok { + *targetValue = b + } + } + + return nil + } + + return nil +} + +func (m *mockStorage) Set(key string, value any, _ time.Duration) error { + m.data[key] = value + + return nil +} + +func (m *mockStorage) Delete(key string) error { + delete(m.data, key) + + return nil +} diff --git a/streaming/types.go b/streaming/types.go new file mode 100644 index 0000000..c131e69 --- /dev/null +++ b/streaming/types.go @@ -0,0 +1,78 @@ +package streaming + +import ( + "encoding/json" + "time" +) + +// Request represents a streaming API request. +type Request struct { + Service string `json:"service"` + Command string `json:"command"` + RequestID string `json:"requestid"` + Parameters map[string]any `json:"parameters,omitempty"` + + //nolint:tagliatelle // Required by Schwab API + SchwabClientCustomerID string `json:"SchwabClientCustomerId,omitempty"` + + //nolint:tagliatelle // Required by Schwab API + SchwabClientCorrelID string `json:"SchwabClientCorrelId,omitempty"` +} + +// RequestBatch wraps multiple requests. +type RequestBatch struct { + Requests []Request `json:"requests"` +} + +// Response represents a streaming API response. +type Response struct { + Response []ResponseItem `json:"response,omitempty"` + Data []DataItem `json:"data,omitempty"` + Notify []NotifyItem `json:"notify,omitempty"` +} + +// ResponseItem represents a single response. +type ResponseItem struct { + Service string `json:"service"` + RequestID string `json:"requestid"` + Command string `json:"command"` + Timestamp int64 `json:"timestamp"` + Content ResponseContent `json:"content"` + + //nolint:tagliatelle // Required by Schwab API + SchwabClientCorrelID string `json:"SchwabClientCorrelId,omitempty"` +} + +// ResponseContent represents response content. +type ResponseContent struct { + Code int `json:"code"` + Msg string `json:"msg,omitempty"` +} + +// DataItem represents streaming data. +type DataItem struct { + Service string `json:"service"` + Timestamp int64 `json:"timestamp"` + Command string `json:"command"` + Content json.RawMessage `json:"content"` +} + +// NotifyItem represents a notification. +type NotifyItem struct { + Service string `json:"service,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` + Content json.RawMessage `json:"content,omitempty"` + Heartbeat string `json:"heartbeat,omitempty"` +} + +// Metadata represents streaming service metadata. +type Metadata struct { + CorrelID string `json:"correl_id"` + CustomerID string `json:"customer_id"` + Channel string `json:"channel"` + FunctionID string `json:"function_id"` + WSEndpoint string `json:"ws_endpoint"` + ExtractedAt time.Time `json:"extracted_at"` + + TTL time.Duration +} diff --git a/streaming/utils.go b/streaming/utils.go new file mode 100644 index 0000000..7842f8e --- /dev/null +++ b/streaming/utils.go @@ -0,0 +1,109 @@ +package streaming + +import ( + "errors" + "fmt" + "strings" + "sync" +) + +// ErrInvalidPrefixedRequestID is returned when a prefixed request ID has an invalid format. +var ErrInvalidPrefixedRequestID = errors.New("invalid prefixed request ID format") + +// PrefixRequestID adds a client ID prefix to a request ID to prevent collisions +// and enable proper routing of responses. +// Format: "client__". +func PrefixRequestID(clientID, requestID string) string { + return fmt.Sprintf("%s_%s", clientID, requestID) +} + +// UnprefixRequestID extracts the client ID and original request ID from a prefixed request ID. +// Returns an error if the format is invalid. +func UnprefixRequestID(prefixedID string) (string, string, error) { + const expectedParts = 3 + + parts := strings.SplitN(prefixedID, "_", expectedParts) + + if len(parts) != expectedParts || parts[0] != "client" { + return "", "", fmt.Errorf("%w: %s", ErrInvalidPrefixedRequestID, prefixedID) + } + + // Reconstruct the client ID + clientID := "client_" + parts[1] + requestID := parts[2] + + return clientID, requestID, nil +} + +// IsPrefixedRequestID checks if a request ID has the expected client prefix format. +func IsPrefixedRequestID(requestID string) bool { + const expectedParts = 3 + + parts := strings.SplitN(requestID, "_", expectedParts) + + return len(parts) == expectedParts && parts[0] == "client" +} + +// ClientMap provides a thread-safe map for managing WebSocket clients. +type ClientMap struct { + m sync.Map +} + +// NewClientMap creates a new ClientMap. +func NewClientMap() *ClientMap { + return &ClientMap{} +} + +// Store stores a client in the map. +func (cm *ClientMap) Store(id string, client *Client) { + cm.m.Store(id, client) +} + +// Load retrieves a client from the map. +func (cm *ClientMap) Load(id string) (*Client, bool) { + val, ok := cm.m.Load(id) + if !ok { + return nil, false + } + + if client, ok := val.(*Client); ok { + return client, true + } + + return nil, false +} + +// Delete removes a client from the map. +func (cm *ClientMap) Delete(id string) { + cm.m.Delete(id) +} + +// Range calls f for each client in the map. +func (cm *ClientMap) Range(function func(id string, client *Client) bool) { + cm.m.Range(func(key any, value any) bool { + ident, ok := key.(string) + if !ok { + return true + } + + client, ok := value.(*Client) + if !ok { + return true + } + + return function(ident, client) + }) +} + +// Count returns the number of clients in the map. +func (cm *ClientMap) Count() int { + count := 0 + + cm.m.Range(func(_ any, _ any) bool { + count++ + + return true + }) + + return count +} diff --git a/streaming/utils_test.go b/streaming/utils_test.go new file mode 100644 index 0000000..b41b617 --- /dev/null +++ b/streaming/utils_test.go @@ -0,0 +1,216 @@ +package streaming_test + +import ( + "errors" + "testing" + + "github.com/jkoelker/schwab-proxy/streaming" +) + +func TestPrefixRequestID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + clientID string + requestID string + want string + }{ + { + name: "standard case", + clientID: "client_12345", + requestID: "req_001", + want: "client_12345_req_001", + }, + { + name: "empty request ID", + clientID: "client_12345", + requestID: "", + want: "client_12345_", + }, + { + name: "request ID with underscores", + clientID: "client_12345", + requestID: "req_with_underscores_123", + want: "client_12345_req_with_underscores_123", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + got := streaming.PrefixRequestID(test.clientID, test.requestID) + if got != test.want { + t.Errorf("PrefixRequestID() = %v, want %v", got, test.want) + } + }) + } +} + +func TestUnprefixRequestID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + prefixedID string + wantClientID string + wantRequestID string + wantErr bool + }{ + { + name: "standard case", + prefixedID: "client_12345_req_001", + wantClientID: "client_12345", + wantRequestID: "req_001", + wantErr: false, + }, + { + name: "request ID with underscores", + prefixedID: "client_12345_req_with_underscores_123", + wantClientID: "client_12345", + wantRequestID: "req_with_underscores_123", + wantErr: false, + }, + { + name: "empty request ID", + prefixedID: "client_12345_", + wantClientID: "client_12345", + wantRequestID: "", + wantErr: false, + }, + { + name: "invalid format - missing parts", + prefixedID: "client_12345", + wantClientID: "", + wantRequestID: "", + wantErr: true, + }, + { + name: "invalid format - no client prefix", + prefixedID: "user_12345_req_001", + wantClientID: "", + wantRequestID: "", + wantErr: true, + }, + { + name: "invalid format - empty string", + prefixedID: "", + wantClientID: "", + wantRequestID: "", + wantErr: true, + }, + { + name: "master login request ID", + prefixedID: "master_login", + wantClientID: "", + wantRequestID: "", + wantErr: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + gotClientID, gotRequestID, err := streaming.UnprefixRequestID(test.prefixedID) + if (err != nil) != test.wantErr { + t.Errorf("UnprefixRequestID() error = %v, wantErr %v", err, test.wantErr) + + return + } + + // Check that error is wrapped correctly + if err != nil && !errors.Is(err, streaming.ErrInvalidPrefixedRequestID) { + t.Errorf("UnprefixRequestID() error type = %T, want wrapped ErrInvalidPrefixedRequestID", err) + } + + if gotClientID != test.wantClientID { + t.Errorf("UnprefixRequestID() clientID = %v, want %v", gotClientID, test.wantClientID) + } + + if gotRequestID != test.wantRequestID { + t.Errorf("UnprefixRequestID() requestID = %v, want %v", gotRequestID, test.wantRequestID) + } + }) + } +} + +func TestIsPrefixedRequestID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + requestID string + want bool + }{ + { + name: "valid prefixed ID", + requestID: "client_12345_req_001", + want: true, + }, + { + name: "valid prefixed ID with underscores in request", + requestID: "client_12345_req_with_underscores", + want: true, + }, + { + name: "invalid - missing parts", + requestID: "client_12345", + want: false, + }, + { + name: "invalid - wrong prefix", + requestID: "user_12345_req_001", + want: false, + }, + { + name: "invalid - empty string", + requestID: "", + want: false, + }, + { + name: "invalid - master login", + requestID: "master_login", + want: false, + }, + { + name: "invalid - single underscore", + requestID: "client_", + want: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + if got := streaming.IsPrefixedRequestID(test.requestID); got != test.want { + t.Errorf("IsPrefixedRequestID() = %v, want %v", got, test.want) + } + }) + } +} + +func TestRoundTrip(t *testing.T) { + t.Parallel() + + // Test that prefix/unprefix are inverse operations + clientID := "client_abc123" + requestID := "original_request_456" + + prefixed := streaming.PrefixRequestID(clientID, requestID) + + gotClientID, gotRequestID, err := streaming.UnprefixRequestID(prefixed) + if err != nil { + t.Errorf("UnprefixRequestID() unexpected error: %v", err) + } + + if gotClientID != clientID { + t.Errorf("Round trip clientID = %v, want %v", gotClientID, clientID) + } + + if gotRequestID != requestID { + t.Errorf("Round trip requestID = %v, want %v", gotRequestID, requestID) + } +}