Skip to content

Commit ae2c628

Browse files
kjschnei001MrAlias
andauthored
otelhttp: handle nil base http transport (#713)
* handle nil base http transport * update godoc Co-authored-by: Tyler Yahn <[email protected]>
1 parent e8c2192 commit ae2c628

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

instrumentation/net/http/otelhttp/transport.go

+7
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,14 @@ var _ http.RoundTripper = &Transport{}
4040

4141
// NewTransport wraps the provided http.RoundTripper with one that
4242
// starts a span and injects the span context into the outbound request headers.
43+
//
44+
// If the provided http.RoundTripper is nil, http.DefaultTransport will be used
45+
// as the base http.RoundTripper
4346
func NewTransport(base http.RoundTripper, opts ...Option) *Transport {
47+
if base == nil {
48+
base = http.DefaultTransport
49+
}
50+
4451
t := Transport{
4552
rt: base,
4653
}

instrumentation/net/http/otelhttp/transport_test.go

+48
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,51 @@ func TestTransportBasics(t *testing.T) {
7474
t.Fatalf("unexpected content: got %s, expected %s", body, content)
7575
}
7676
}
77+
78+
func TestNilTransport(t *testing.T) {
79+
prop := propagation.TraceContext{}
80+
provider := oteltest.NewTracerProvider()
81+
content := []byte("Hello, world!")
82+
83+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
84+
ctx := prop.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
85+
span := trace.RemoteSpanContextFromContext(ctx)
86+
tgtID, err := trace.SpanIDFromHex(fmt.Sprintf("%016x", uint(2)))
87+
if err != nil {
88+
t.Fatalf("Error converting id to SpanID: %s", err.Error())
89+
}
90+
if span.SpanID() != tgtID {
91+
t.Fatalf("testing remote SpanID: got %s, expected %s", span.SpanID(), tgtID)
92+
}
93+
if _, err := w.Write(content); err != nil {
94+
t.Fatal(err)
95+
}
96+
}))
97+
defer ts.Close()
98+
99+
r, err := http.NewRequest(http.MethodGet, ts.URL, nil)
100+
if err != nil {
101+
t.Fatal(err)
102+
}
103+
104+
tr := NewTransport(
105+
nil,
106+
WithTracerProvider(provider),
107+
WithPropagators(prop),
108+
)
109+
110+
c := http.Client{Transport: tr}
111+
res, err := c.Do(r)
112+
if err != nil {
113+
t.Fatal(err)
114+
}
115+
116+
body, err := ioutil.ReadAll(res.Body)
117+
if err != nil {
118+
t.Fatal(err)
119+
}
120+
121+
if !bytes.Equal(body, content) {
122+
t.Fatalf("unexpected content: got %s, expected %s", body, content)
123+
}
124+
}

0 commit comments

Comments
 (0)