Skip to content

Replace --auth option with --bearer-auth / --basic-auth / --raw-auth options #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 40 additions & 53 deletions cmd/emcee/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ import (
"syscall"
"time"

"github.com/hashicorp/go-retryablehttp"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"

"encoding/base64"

"github.com/loopwork-ai/emcee/internal"
"github.com/loopwork-ai/emcee/mcp"
)
Expand All @@ -24,13 +25,16 @@ var rootCmd = &cobra.Command{
Use: "emcee [spec-path-or-url]",
Short: "Creates an MCP server for an OpenAPI specification",
Long: `emcee is a CLI tool that provides an Model Context Protocol (MCP) stdio transport for a given OpenAPI specification.
It takes an OpenAPI specification path or URL as input and processes JSON-RPC requests
from stdin, making corresponding API calls and returning JSON-RPC responses to stdout.
It takes an OpenAPI specification path or URL as input and processes JSON-RPC requests from stdin, making corresponding API calls and returning JSON-RPC responses to stdout.

The spec-path-or-url argument can be:
- A local file path
- An HTTP(S) URL
- "-" to read from stdin`,
- A local file path (e.g. ./openapi.json)
- An HTTP(S) URL (e.g. https://api.example.com/openapi.json)
- "-" to read from stdin

By default, a GET request with no additional headers is made to the spec URL to download the OpenAPI specification.
If additional authentication is required to download the specification, you can first download it to a local file using your preferred HTTP client with the necessary authentication headers, and then provide the local file path to emcee.
`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
// Set up context and signal handling
Expand All @@ -57,53 +61,32 @@ The spec-path-or-url argument can be:
// Set logger
opts = append(opts, mcp.WithLogger(logger))

// Configure HTTP client
retryClient := retryablehttp.NewClient()
retryClient.RetryMax = retries
retryClient.RetryWaitMin = 1 * time.Second
retryClient.RetryWaitMax = 30 * time.Second
retryClient.HTTPClient.Timeout = timeout
retryClient.Logger = logger
if rps > 0 {
retryClient.Backoff = func(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration {
// Ensure we wait at least 1/rps between requests
minWait := time.Second / time.Duration(rps)
if min < minWait {
min = minWait
}
return retryablehttp.DefaultBackoff(min, max, attemptNum, resp)
}
}

// Set default headers if auth is provided
if auth != "" {
parts := strings.SplitN(auth, " ", 2)
if len(parts) == 1 {
// Only token provided, add Bearer prefix
logger.Warn("no auth scheme provided, automatically adding 'Bearer' prefix")
auth = "Bearer " + parts[0]
} else if len(parts) == 2 {
// Scheme and token provided, use as-is
auth = fmt.Sprintf("%s %s", parts[0], parts[1])
}

headers := http.Header{}
headers.Add("Authorization", auth)

retryClient.HTTPClient.Transport = &internal.HeaderTransport{
Base: retryClient.HTTPClient.Transport,
Headers: headers,
if bearerAuth != "" {
opts = append(opts, mcp.WithAuth("Bearer "+bearerAuth))
} else if basicAuth != "" {
// Check if already base64 encoded
if strings.Contains(basicAuth, ":") {
encoded := base64.StdEncoding.EncodeToString([]byte(basicAuth))
opts = append(opts, mcp.WithAuth("Basic "+encoded))
} else {
// Assume it's already base64 encoded
opts = append(opts, mcp.WithAuth("Basic "+basicAuth))
}
} else if rawAuth != "" {
opts = append(opts, mcp.WithAuth(rawAuth))
}

client := retryClient.StandardClient()
// Set HTTP client
client, err := internal.RetryableClient(retries, timeout, rps, logger)
if err != nil {
return fmt.Errorf("error creating client: %w", err)
}
opts = append(opts, mcp.WithClient(client))

// Read OpenAPI specification data
var rpcInput io.Reader = os.Stdin
var specData []byte
var err error

if args[0] == "-" {
logger.Info("reading spec from stdin")

Expand All @@ -130,11 +113,6 @@ The spec-path-or-url argument can be:
return fmt.Errorf("error creating request: %w", err)
}

// Apply auth header if provided
if auth != "" {
req.Header.Set("Authorization", auth)
}

// Make HTTP request
resp, err := client.Do(req)
if err != nil {
Expand Down Expand Up @@ -201,24 +179,33 @@ The spec-path-or-url argument can be:
}

var (
auth string
verbose bool
bearerAuth string
basicAuth string
rawAuth string

retries int
timeout time.Duration
rps int

verbose bool

version = "dev"
commit = "none"
date = "unknown"
)

func init() {
rootCmd.Flags().StringVar(&auth, "auth", "", "Authorization header value (e.g. 'Bearer token123' or 'Basic dXNlcjpwYXNz')")
rootCmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Enable verbose logging to stderr")
rootCmd.Flags().StringVar(&bearerAuth, "bearer-auth", "", "Bearer token value (will be prefixed with 'Bearer ')")
rootCmd.Flags().StringVar(&basicAuth, "basic-auth", "", "Basic auth value (either user:pass or base64 encoded, will be prefixed with 'Basic ')")
rootCmd.Flags().StringVar(&rawAuth, "raw-auth", "", "Raw value for Authorization header")
rootCmd.MarkFlagsMutuallyExclusive("bearer-auth", "basic-auth", "raw-auth")

rootCmd.Flags().IntVar(&retries, "retries", 3, "Maximum number of retries for failed requests")
rootCmd.Flags().DurationVar(&timeout, "timeout", 60*time.Second, "HTTP request timeout")
rootCmd.Flags().IntVarP(&rps, "rps", "r", 0, "Maximum requests per second (0 for no limit)")

rootCmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Enable verbose logging to stderr")

rootCmd.Version = fmt.Sprintf("%s (commit: %s, built at: %s)", version, commit, date)
}

Expand Down
42 changes: 41 additions & 1 deletion internal/http.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
package internal

import "net/http"
import (
"fmt"
"net/http"
"time"

"github.com/hashicorp/go-retryablehttp"
)

// HeaderTransport is a custom RoundTripper that adds default headers to requests
type HeaderTransport struct {
Base http.RoundTripper
Headers http.Header
}

// RoundTrip adds the default headers to the request
func (t *HeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
for key, values := range t.Headers {
for _, value := range values {
Expand All @@ -20,3 +27,36 @@ func (t *HeaderTransport) RoundTrip(req *http.Request) (*http.Response, error) {
}
return base.RoundTrip(req)
}

// RetryableClient returns a new http.Client with a retryablehttp.Client
// configured with the provided parameters.
func RetryableClient(retries int, timeout time.Duration, rps int, logger interface{}) (*http.Client, error) {
if retries < 0 {
return nil, fmt.Errorf("retries must be greater than 0")
}
if timeout < 0 {
return nil, fmt.Errorf("timeout must be greater than 0")
}
if rps < 0 {
return nil, fmt.Errorf("rps must be greater than 0")
}

retryClient := retryablehttp.NewClient()
retryClient.RetryMax = retries
retryClient.RetryWaitMin = 1 * time.Second
retryClient.RetryWaitMax = 30 * time.Second
retryClient.HTTPClient.Timeout = timeout
retryClient.Logger = logger
if rps > 0 {
retryClient.Backoff = func(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration {
// Ensure we wait at least 1/rps between requests
minWait := time.Second / time.Duration(rps)
if min < minWait {
min = minWait
}
return retryablehttp.DefaultBackoff(min, max, attemptNum, resp)
}
}

return retryClient.StandardClient(), nil
}
62 changes: 44 additions & 18 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,37 @@ import (
"github.com/pb33f/libopenapi"
v3 "github.com/pb33f/libopenapi/datamodel/high/v3"

"github.com/loopwork-ai/emcee/internal"
"github.com/loopwork-ai/emcee/jsonrpc"
)

// Server represents an MCP server that processes JSON-RPC requests
type Server struct {
auth string
doc libopenapi.Document
model *v3.Document
baseURL string
client *http.Client
info ServerInfo
logger *slog.Logger
}

// ServerOption configures a Server
type ServerOption func(*Server) error

// WithAuth sets the authentication header for the server
func WithAuth(auth string) ServerOption {
return func(s *Server) error {
auth = strings.TrimSpace(auth)
parts := strings.SplitN(auth, " ", 2)
if len(parts) != 2 {
return fmt.Errorf("invalid auth header format: %s", auth)
}
s.auth = fmt.Sprintf("%s %s", parts[0], parts[1])
return nil
}
}

// WithClient sets the HTTP client
func WithClient(client *http.Client) ServerOption {
return func(s *Server) error {
Expand All @@ -29,6 +54,14 @@ func WithClient(client *http.Client) ServerOption {
}
}

// WithLogger sets the logger for the server
func WithLogger(logger *slog.Logger) ServerOption {
return func(s *Server) error {
s.logger = logger
return nil
}
}

// WithServerInfo sets server info
func WithServerInfo(name, version string) ServerOption {
return func(s *Server) error {
Expand Down Expand Up @@ -70,24 +103,6 @@ func WithSpecData(data []byte) ServerOption {
}
}

// WithLogger sets the logger for the server
func WithLogger(logger *slog.Logger) ServerOption {
return func(s *Server) error {
s.logger = logger
return nil
}
}

// Server represents an MCP server that processes JSON-RPC requests
type Server struct {
doc libopenapi.Document
model *v3.Document
baseURL string
client *http.Client
info ServerInfo
logger *slog.Logger
}

// NewServer creates a new MCP server instance
func NewServer(opts ...ServerOption) (*Server, error) {
s := &Server{
Expand All @@ -103,6 +118,17 @@ func NewServer(opts ...ServerOption) (*Server, error) {
}
}

// Apply custom transport to inject auth header, if provided
if s.auth != "" {
headers := http.Header{}
headers.Add("Authorization", s.auth)

s.client.Transport = &internal.HeaderTransport{
Base: s.client.Transport,
Headers: headers,
}
}

// Validate required fields
if s.doc == nil {
return nil, fmt.Errorf("OpenAPI spec URL is required")
Expand Down
73 changes: 73 additions & 0 deletions mcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -685,3 +685,76 @@ func TestWithSpecData(t *testing.T) {
})
}
}

func TestWithAuth(t *testing.T) {
tests := []struct {
name string
auth string
wantErr bool
assert func(*testing.T, *Server)
}{
{
name: "valid bearer token",
auth: "Bearer mytoken123",
assert: func(t *testing.T, s *Server) {
assert.Equal(t, "Bearer mytoken123", s.auth)
},
},
{
name: "valid basic auth",
auth: "Basic dXNlcjpwYXNz",
assert: func(t *testing.T, s *Server) {
assert.Equal(t, "Basic dXNlcjpwYXNz", s.auth)
},
},
{
name: "missing auth type",
auth: "mytoken123",
wantErr: true,
},
{
name: "empty auth",
auth: "",
wantErr: true,
},
{
name: "whitespace only",
auth: " ",
wantErr: true,
},
}

// Create a minimal valid spec for server initialization
validSpec := `{
"openapi": "3.0.0",
"servers": [{"url": "https://api.example.com"}],
"paths": {}
}`

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server, err := NewServer(
WithSpecData([]byte(validSpec)),
WithAuth(tt.auth),
)

if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, server)
return
}

assert.NoError(t, err)
assert.NotNil(t, server)

if tt.assert != nil {
tt.assert(t, server)
}

// Verify the auth header is properly set in the client transport
transport, ok := server.client.Transport.(*internal.HeaderTransport)
assert.True(t, ok)
assert.Equal(t, tt.auth, transport.Headers.Get("Authorization"))
})
}
}
Loading