diff --git a/CHANGELOG.md b/CHANGELOG.md index d34cf803f27..d8f64e004c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ## [Unreleased] +### Added + +- Consistent probability sampler implementation. (#1379) + ### Fixed - Fix the `otelmux` middleware by using `SpanKindServer` when deciding the `SpanStatus`. diff --git a/samplers/probability/consistent/base2.go b/samplers/probability/consistent/base2.go new file mode 100644 index 00000000000..62f559c35da --- /dev/null +++ b/samplers/probability/consistent/base2.go @@ -0,0 +1,69 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consistent // import "go.opentelemetry.io/contrib/samplers/probability/consistent" + +import "math" + +// These are IEEE 754 double-width floating point constants used with +// math.Float64bits. +const ( + offsetExponentMask = 0x7ff0000000000000 + offsetExponentBias = 1023 + significandBits = 52 +) + +// expFromFloat64 returns floor(log2(x)). +func expFromFloat64(x float64) int { + return int((math.Float64bits(x)&offsetExponentMask)>>significandBits) - offsetExponentBias +} + +// expToFloat64 returns 2^x. +func expToFloat64(x int) float64 { + return math.Float64frombits(uint64(offsetExponentBias+x) << significandBits) +} + +// splitProb returns the two values of log-adjusted-count nearest to p +// Example: +// +// splitProb(0.375) => (2, 1, 0.5) +// +// indicates to sample with probability (2^-2) 50% of the time +// and (2^-1) 50% of the time. +func splitProb(p float64) (uint8, uint8, float64) { + if p < 2e-62 { + // Note: spec. + return pZeroValue, pZeroValue, 1 + } + // Take the exponent and drop the significand to locate the + // smaller of two powers of two. + exp := expFromFloat64(p) + + // Low is the smaller of two log-adjusted counts, the negative + // of the exponent computed above. + low := -exp + // High is the greater of two log-adjusted counts (i.e., one + // less than low, a smaller adjusted count means a larger + // probability). + high := low - 1 + + // Return these to probability values and use linear + // interpolation to compute the required probability of + // choosing the low-probability Sampler. + lowP := expToFloat64(-low) + highP := expToFloat64(-high) + lowProb := (highP - p) / (highP - lowP) + + return uint8(low), uint8(high), lowProb +} diff --git a/samplers/probability/consistent/base2_test.go b/samplers/probability/consistent/base2_test.go new file mode 100644 index 00000000000..9dad5c46f22 --- /dev/null +++ b/samplers/probability/consistent/base2_test.go @@ -0,0 +1,61 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consistent + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSplitProb(t *testing.T) { + require.Equal(t, -1, expFromFloat64(0.6)) + require.Equal(t, -2, expFromFloat64(0.4)) + require.Equal(t, 0.5, expToFloat64(-1)) + require.Equal(t, 0.25, expToFloat64(-2)) + + for _, tc := range []struct { + in float64 + low uint8 + lowProb float64 + }{ + // Probability 0.75 corresponds with choosing S=1 (the + // "low" probability) 50% of the time and S=0 (the + // "high" probability) 50% of the time. + {0.75, 1, 0.5}, + {0.6, 1, 0.8}, + {0.9, 1, 0.2}, + + // Powers of 2 exactly + {1, 0, 1}, + {0.5, 1, 1}, + {0.25, 2, 1}, + + // Smaller numbers + {0.05, 5, 0.4}, + {0.1, 4, 0.4}, // 0.1 == 0.4 * 1/16 + 0.6 * 1/8 + {0.003, 9, 0.464}, + + // Special cases: + {0, 63, 1}, + } { + low, high, lowProb := splitProb(tc.in) + require.Equal(t, tc.low, low, "got %v want %v", low, tc.low) + if lowProb != 1 { + require.Equal(t, tc.low-1, high, "got %v want %v", high, tc.low-1) + } + require.InEpsilon(t, tc.lowProb, lowProb, 1e-6, "got %v want %v", lowProb, tc.lowProb) + } +} diff --git a/samplers/probability/consistent/go.mod b/samplers/probability/consistent/go.mod new file mode 100644 index 00000000000..3f21dde3d7e --- /dev/null +++ b/samplers/probability/consistent/go.mod @@ -0,0 +1,10 @@ +module go.opentelemetry.io/contrib/samplers/probability/consistent + +go 1.16 + +require ( + github.com/stretchr/testify v1.7.1 + go.opentelemetry.io/otel v1.6.1 + go.opentelemetry.io/otel/sdk v1.6.1 + go.opentelemetry.io/otel/trace v1.6.1 +) diff --git a/samplers/probability/consistent/go.sum b/samplers/probability/consistent/go.sum new file mode 100644 index 00000000000..c2909f67f14 --- /dev/null +++ b/samplers/probability/consistent/go.sum @@ -0,0 +1,27 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= +github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= +github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +go.opentelemetry.io/otel v1.6.1 h1:6r1YrcTenBvYa1x491d0GGpTVBsNECmrc/K6b+zDeis= +go.opentelemetry.io/otel v1.6.1/go.mod h1:blzUabWHkX6LJewxvadmzafgh/wnvBSDBdOuwkAtrWQ= +go.opentelemetry.io/otel/sdk v1.6.1 h1:ZmcNyMhcuAYIb/Nr6QhBPTMopMTbov/47wHt1gibkoY= +go.opentelemetry.io/otel/sdk v1.6.1/go.mod h1:IVYrddmFZ+eJqu2k38qD3WezFR2pymCzm8tdxyh3R4E= +go.opentelemetry.io/otel/trace v1.6.1 h1:f8c93l5tboBYZna1nWk0W9DYyMzJXDWdZcJZ0Kb400U= +go.opentelemetry.io/otel/trace v1.6.1/go.mod h1:RkFRM1m0puWIq10oxImnGEduNBzxiN7TXluRBtE+5j0= +golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 h1:iGu644GcxtEcrInvDsQRCwJjtCIOlT2V7IRt6ah2Whw= +golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/samplers/probability/consistent/parent.go b/samplers/probability/consistent/parent.go new file mode 100644 index 00000000000..f21052f77b6 --- /dev/null +++ b/samplers/probability/consistent/parent.go @@ -0,0 +1,74 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consistent // import "go.opentelemetry.io/contrib/samplers/probability/consistent" + +import ( + "strings" + + "go.opentelemetry.io/otel" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/trace" +) + +type ( + parentProbabilitySampler struct { + delegate sdktrace.Sampler + } +) + +// ParentProbabilityBased is an implementation of the OpenTelemetry +// Trace Sampler interface that provides additional checks for tracestate +// Probability Sampling fields. +func ParentProbabilityBased(root sdktrace.Sampler, samplers ...sdktrace.ParentBasedSamplerOption) sdktrace.Sampler { + return &parentProbabilitySampler{ + delegate: sdktrace.ParentBased(root, samplers...), + } +} + +// ShouldSample implements "go.opentelemetry.io/otel/sdk/trace".Sampler. +func (p *parentProbabilitySampler) ShouldSample(params sdktrace.SamplingParameters) sdktrace.SamplingResult { + psc := trace.SpanContextFromContext(params.ParentContext) + + // Note: We do not check psc.IsValid(), i.e., we repair the tracestate + // with or without a parent TraceId and SpanId. + state := psc.TraceState() + + otts, err := parseOTelTraceState(state.Get(traceStateKey), psc.IsSampled()) + + if err != nil { + otel.Handle(err) + value := otts.serialize() + if len(value) > 0 { + // Note: see the note in + // "go.opentelemetry.io/otel/trace".TraceState.Insert(). The + // error below is not a condition we're supposed to handle. + state, _ = state.Insert(traceStateKey, value) + } else { + state = state.Delete(traceStateKey) + } + + // Fix the broken tracestate before calling the delegate. + params.ParentContext = trace.ContextWithSpanContext(params.ParentContext, psc.WithTraceState(state)) + } + + return p.delegate.ShouldSample(params) +} + +// Description returns the same description as the built-in +// ParentBased sampler, with "ParentBased" replaced by +// "ParentProbabilityBased". +func (p *parentProbabilitySampler) Description() string { + return "ParentProbabilityBased" + strings.TrimPrefix(p.delegate.Description(), "ParentBased") +} diff --git a/samplers/probability/consistent/parent_test.go b/samplers/probability/consistent/parent_test.go new file mode 100644 index 00000000000..056179d7d50 --- /dev/null +++ b/samplers/probability/consistent/parent_test.go @@ -0,0 +1,187 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consistent + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "go.opentelemetry.io/otel/attribute" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/trace" +) + +func TestParentSamplerDescription(t *testing.T) { + opts := []sdktrace.ParentBasedSamplerOption{ + sdktrace.WithRemoteParentNotSampled(sdktrace.AlwaysSample()), + } + root := ProbabilityBased(1) + compare := sdktrace.ParentBased(root, opts...) + parent := ParentProbabilityBased(root, opts...) + require.Equal(t, + strings.Replace( + compare.Description(), + "ParentBased", + "ParentProbabilityBased", + 1, + ), + parent.Description(), + ) +} + +func TestParentSamplerValidContext(t *testing.T) { + parent := ParentProbabilityBased(sdktrace.NeverSample()) + type testCase struct { + in string + sampled bool + } + for _, valid := range []testCase{ + // sampled tests + {"r:10", true}, + {"r:10;a:b", true}, + {"r:10;p:1", true}, + {"r:10;p:10", true}, + {"r:10;p:10;a:b", true}, + {"r:10;p:63", true}, + {"r:10;p:63;a:b", true}, + {"p:0", true}, + {"p:10;a:b", true}, + {"p:63", true}, + {"p:63;a:b", true}, + + // unsampled tests + {"r:10", false}, + {"r:10;a:b", false}, + } { + t.Run(testName(valid.in), func(t *testing.T) { + traceID, _ := trace.TraceIDFromHex("4bf92f3577b34da6a3ce929d0e0e4736") + spanID, _ := trace.SpanIDFromHex("00f067aa0ba902b7") + traceState, err := trace.TraceState{}.Insert(traceStateKey, valid.in) + require.NoError(t, err) + + sccfg := trace.SpanContextConfig{ + TraceID: traceID, + SpanID: spanID, + TraceState: traceState, + } + + if valid.sampled { + sccfg.TraceFlags = trace.FlagsSampled + } + + parentCtx := trace.ContextWithSpanContext( + context.Background(), + trace.NewSpanContext(sccfg), + ) + + result := parent.ShouldSample( + sdktrace.SamplingParameters{ + ParentContext: parentCtx, + TraceID: traceID, + Name: "test", + Kind: trace.SpanKindServer, + }, + ) + + if valid.sampled { + require.Equal(t, sdktrace.RecordAndSample, result.Decision) + } else { + require.Equal(t, sdktrace.Drop, result.Decision) + } + require.Equal(t, []attribute.KeyValue(nil), result.Attributes) + require.Equal(t, valid.in, result.Tracestate.Get(traceStateKey)) + }) + } +} + +func TestParentSamplerInvalidContext(t *testing.T) { + parent := ParentProbabilityBased(sdktrace.NeverSample()) + type testCase struct { + in string + sampled bool + expect string + } + for _, invalid := range []testCase{ + // sampled + {"r:100", true, ""}, + {"r:100;p:1", true, ""}, + {"r:100;p:1;a:b", true, "a:b"}, + {"r:10;p:100", true, "r:10"}, + {"r:10;p:100;a:b", true, "r:10;a:b"}, + + // unsampled + {"r:63;p:1", false, ""}, + {"r:10;p:1", false, "r:10"}, + {"r:10;p:1;a:b", false, "r:10;a:b"}, + } { + testInvalid := func(t *testing.T, isChildContext bool) { + + traceID, _ := trace.TraceIDFromHex("4bf92f3577b34da6a3ce929d0e0e4736") + traceState, err := trace.TraceState{}.Insert(traceStateKey, invalid.in) + require.NoError(t, err) + + sccfg := trace.SpanContextConfig{ + TraceState: traceState, + } + if isChildContext { + spanID, _ := trace.SpanIDFromHex("00f067aa0ba902b7") + + sccfg.TraceID = traceID + sccfg.SpanID = spanID + + // Note: the other branch is testing a fabricated + // situation where the context has a tracestate and + // no TraceID. + } + if invalid.sampled { + sccfg.TraceFlags = trace.FlagsSampled + } + + parentCtx := trace.ContextWithSpanContext( + context.Background(), + trace.NewSpanContext(sccfg), + ) + + result := parent.ShouldSample( + sdktrace.SamplingParameters{ + ParentContext: parentCtx, + TraceID: sccfg.TraceID, + Name: "test", + Kind: trace.SpanKindServer, + }, + ) + + if isChildContext && invalid.sampled { + require.Equal(t, sdktrace.RecordAndSample, result.Decision) + } else { + // if we're not a child context, ShouldSample + // falls through to the delegate, which is NeverSample. + require.Equal(t, sdktrace.Drop, result.Decision) + } + require.Equal(t, []attribute.KeyValue(nil), result.Attributes) + require.Equal(t, invalid.expect, result.Tracestate.Get(traceStateKey)) + } + + t.Run(testName(invalid.in)+"_with_parent", func(t *testing.T) { + testInvalid(t, true) + }) + t.Run(testName(invalid.in)+"_no_parent", func(t *testing.T) { + testInvalid(t, false) + }) + } +} diff --git a/samplers/probability/consistent/sampler.go b/samplers/probability/consistent/sampler.go new file mode 100644 index 00000000000..814f52c503b --- /dev/null +++ b/samplers/probability/consistent/sampler.go @@ -0,0 +1,179 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package consistent provides a consistent probability based sampler. +package consistent // import "go.opentelemetry.io/contrib/samplers/probability/consistent" + +import ( + "fmt" + "math/bits" + "math/rand" + "sync" + + "go.opentelemetry.io/otel" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/trace" +) + +type ( + // ProbabilityBasedOption is an option to the + // ConssitentProbabilityBased sampler. + ProbabilityBasedOption interface { + apply(*consistentProbabilityBasedConfig) + } + + consistentProbabilityBasedConfig struct { + source rand.Source + } + + consistentProbabilityBasedRandomSource struct { + rand.Source + } + + consistentProbabilityBased struct { + // "LAC" is an abbreviation for the logarithm of + // adjusted count. Greater values have greater + // representivity, therefore lesser sampling + // probability. + + // lowLAC is the lower-probability log-adjusted count + lowLAC uint8 + // highLAC is the higher-probability log-adjusted + // count. except for the zero probability special + // case, highLAC == lowLAC - 1. + highLAC uint8 + // lowProb is the probability that lowLAC should be used, + // in the interval (0, 1]. For exact powers of two and the + // special case of 0 probability, lowProb == 1. + lowProb float64 + + // lock protects rnd + lock sync.Mutex + rnd *rand.Rand + } +) + +// WithRandomSource sets the source of the randomness used by the Sampler. +func WithRandomSource(source rand.Source) ProbabilityBasedOption { + return consistentProbabilityBasedRandomSource{source} +} + +func (s consistentProbabilityBasedRandomSource) apply(cfg *consistentProbabilityBasedConfig) { + cfg.source = s.Source +} + +// ProbabilityBased samples a given fraction of traces. Based on the +// OpenTelemetry specification, this Sampler supports only power-of-two +// fractions. When the input fraction is not a power of two, it will +// be rounded down. +// - Fractions >= 1 will always sample. +// - Fractions < 2^-62 are treated as zero. +// +// This Sampler sets the OpenTelemetry tracestate p-value and/or r-value. +// +// To respect the parent trace's `SampledFlag`, this sampler should be +// used as the root delegate of a `Parent` sampler. +func ProbabilityBased(fraction float64, opts ...ProbabilityBasedOption) sdktrace.Sampler { + cfg := consistentProbabilityBasedConfig{ + source: rand.NewSource(rand.Int63()), + } + for _, opt := range opts { + opt.apply(&cfg) + } + + if fraction < 0 { + fraction = 0 + } else if fraction > 1 { + fraction = 1 + } + + lowLAC, highLAC, lowProb := splitProb(fraction) + + return &consistentProbabilityBased{ + lowLAC: lowLAC, + highLAC: highLAC, + lowProb: lowProb, + rnd: rand.New(cfg.source), + } +} + +func (cs *consistentProbabilityBased) newR() uint8 { + cs.lock.Lock() + defer cs.lock.Unlock() + return uint8(bits.LeadingZeros64(uint64(cs.rnd.Int63())) - 1) +} + +func (cs *consistentProbabilityBased) lowChoice() bool { + cs.lock.Lock() + defer cs.lock.Unlock() + return cs.rnd.Float64() < cs.lowProb +} + +// ShouldSample implements "go.opentelemetry.io/otel/sdk/trace".Sampler. +func (cs *consistentProbabilityBased) ShouldSample(p sdktrace.SamplingParameters) sdktrace.SamplingResult { + psc := trace.SpanContextFromContext(p.ParentContext) + + // Note: this ignores whether psc.IsValid() because this + // allows other otel trace state keys to pass through even + // for root decisions. + state := psc.TraceState() + + otts, err := parseOTelTraceState(state.Get(traceStateKey), psc.IsSampled()) + if err != nil { + // Note: a state.Insert(traceStateKey) + // follows, nothing else needs to be done here. + otel.Handle(err) + } + + if !otts.hasRValue() { + otts.rvalue = cs.newR() + } + + var decision sdktrace.SamplingDecision + var lac uint8 + + if cs.lowProb == 1 || cs.lowChoice() { + lac = cs.lowLAC + } else { + lac = cs.highLAC + } + + if lac <= otts.rvalue { + decision = sdktrace.RecordAndSample + otts.pvalue = lac + } else { + decision = sdktrace.Drop + otts.pvalue = invalidValue + } + + // Note: see the note in + // "go.opentelemetry.io/otel/trace".TraceState.Insert(). The + // error below is not a condition we're supposed to handle. + state, _ = state.Insert(traceStateKey, otts.serialize()) + + return sdktrace.SamplingResult{ + Decision: decision, + Tracestate: state, + } +} + +// Description returns "ProbabilityBased{%g}" with the configured probability. +func (cs *consistentProbabilityBased) Description() string { + var prob float64 + if cs.lowLAC != pZeroValue { + prob = cs.lowProb * expToFloat64(-int(cs.lowLAC)) + prob += (1 - cs.lowProb) * expToFloat64(-int(cs.highLAC)) + } + return fmt.Sprintf("ProbabilityBased{%g}", prob) +} diff --git a/samplers/probability/consistent/sampler_test.go b/samplers/probability/consistent/sampler_test.go new file mode 100644 index 00000000000..a7d74a77191 --- /dev/null +++ b/samplers/probability/consistent/sampler_test.go @@ -0,0 +1,235 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consistent + +import ( + "context" + "fmt" + "math/rand" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/require" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/trace" +) + +type ( + testDegrees int + pValue int + + testErrorHandler struct { + lock sync.Mutex + errors []error + } +) + +func parsePR(s string) (p, r string) { + for _, kvf := range strings.Split(s, ";") { + kv := strings.SplitN(kvf, ":", 2) + switch kv[0] { + case "p": + p = kv[1] + case "r": + r = kv[1] + } + } + return +} + +func (eh *testErrorHandler) Handle(err error) { + eh.lock.Lock() + defer eh.lock.Unlock() + eh.errors = append(eh.errors, err) +} + +func (eh *testErrorHandler) Errors() []error { + eh.lock.Lock() + defer eh.lock.Unlock() + return eh.errors +} + +func TestSamplerDescription(t *testing.T) { + const minProb = 0x1p-62 // 2.168404344971009e-19 + + for _, tc := range []struct { + prob float64 + expect string + }{ + {1, "ProbabilityBased{1}"}, + {0, "ProbabilityBased{0}"}, + {0.75, "ProbabilityBased{0.75}"}, + {0.05, "ProbabilityBased{0.05}"}, + {0.003, "ProbabilityBased{0.003}"}, + {0.99999999, "ProbabilityBased{0.99999999}"}, + {0.00000001, "ProbabilityBased{1e-08}"}, + {minProb, "ProbabilityBased{2.168404344971009e-19}"}, + {minProb * 1.5, "ProbabilityBased{3.2526065174565133e-19}"}, + {3e-19, "ProbabilityBased{3e-19}"}, + + // out-of-range > 1 + {1.01, "ProbabilityBased{1}"}, + {101.1, "ProbabilityBased{1}"}, + + // out-of-range < 2^-62 + {-1, "ProbabilityBased{0}"}, + {-0.001, "ProbabilityBased{0}"}, + {minProb * 0.999, "ProbabilityBased{0}"}, + } { + s := ProbabilityBased(tc.prob) + require.Equal(t, tc.expect, s.Description(), "%#v", tc.prob) + } +} + +func getUnknowns(otts otelTraceState) string { + otts.pvalue = invalidValue + otts.rvalue = invalidValue + return otts.serialize() +} + +func TestSamplerBehavior(t *testing.T) { + type testGroup struct { + probability float64 + minP uint8 + maxP uint8 + } + type testCase struct { + isRoot bool + parentSampled bool + ctxTracestate string + hasErrors bool + } + + for _, group := range []testGroup{ + {1.0, 0, 0}, + {0.75, 0, 1}, + {0.5, 1, 1}, + {0, 63, 63}, + } { + t.Run(fmt.Sprint(group.probability), func(t *testing.T) { + for _, test := range []testCase{ + // roots do not care if the context is + // sampled, however preserve other + // otel tracestate keys + {true, false, "", false}, + {true, false, "a:b", false}, + + // non-roots insert r + {false, true, "", false}, + {false, true, "a:b", false}, + {false, false, "", false}, + {false, false, "a:b", false}, + + // error cases: r-p inconsistency + {false, true, "r:10;p:20", true}, + {false, true, "r:10;p:20;a:b", true}, + {false, false, "r:10;p:5", true}, + {false, false, "r:10;p:5;a:b", true}, + + // error cases: out-of-range + {false, false, "r:100", true}, + {false, false, "r:100;a:b", true}, + {false, true, "r:100;p:100", true}, + {false, true, "r:100;p:100;a:b", true}, + {false, true, "r:10;p:100", true}, + {false, true, "r:10;p:100;a:b", true}, + } { + t.Run(testName(test.ctxTracestate), func(t *testing.T) { + handler := &testErrorHandler{} + otel.SetErrorHandler(handler) + + src := rand.NewSource(99999199999) + sampler := ProbabilityBased(group.probability, WithRandomSource(src)) + + traceID, _ := trace.TraceIDFromHex("4bf92f3577b34da6a3ce929d0e0e4736") + spanID, _ := trace.SpanIDFromHex("00f067aa0ba902b7") + + traceState := trace.TraceState{} + if test.ctxTracestate != "" { + var err error + traceState, err = traceState.Insert(traceStateKey, test.ctxTracestate) + require.NoError(t, err) + } + + sccfg := trace.SpanContextConfig{ + TraceState: traceState, + } + + if !test.isRoot { + sccfg.TraceID = traceID + sccfg.SpanID = spanID + } + + if test.parentSampled { + sccfg.TraceFlags = trace.FlagsSampled + } + + parentCtx := trace.ContextWithSpanContext( + context.Background(), + trace.NewSpanContext(sccfg), + ) + + // Note: the error below is sometimes expected + testState, _ := parseOTelTraceState(test.ctxTracestate, test.parentSampled) + hasRValue := testState.hasRValue() + + const repeats = 10 + for i := 0; i < repeats; i++ { + result := sampler.ShouldSample( + sdktrace.SamplingParameters{ + ParentContext: parentCtx, + TraceID: traceID, + Name: "test", + Kind: trace.SpanKindServer, + }, + ) + sampled := result.Decision == sdktrace.RecordAndSample + + // The result is deterministically random. Parse the tracestate + // to see that it is consistent. + otts, err := parseOTelTraceState(result.Tracestate.Get(traceStateKey), sampled) + require.NoError(t, err) + require.True(t, otts.hasRValue()) + require.Equal(t, []attribute.KeyValue(nil), result.Attributes) + + if otts.hasPValue() { + require.LessOrEqual(t, group.minP, otts.pvalue) + require.LessOrEqual(t, otts.pvalue, group.maxP) + require.Equal(t, sdktrace.RecordAndSample, result.Decision) + } else { + require.Equal(t, sdktrace.Drop, result.Decision) + } + + require.Equal(t, getUnknowns(testState), getUnknowns(otts)) + + if hasRValue { + require.Equal(t, testState.rvalue, otts.rvalue) + } + + if test.hasErrors { + require.Less(t, 0, len(handler.Errors())) + } else { + require.Equal(t, 0, len(handler.Errors())) + } + } + }) + } + }) + } +} diff --git a/samplers/probability/consistent/statistical_test.go b/samplers/probability/consistent/statistical_test.go new file mode 100644 index 00000000000..8da51ca366b --- /dev/null +++ b/samplers/probability/consistent/statistical_test.go @@ -0,0 +1,319 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !race +// +build !race + +package consistent + +import ( + "context" + "fmt" + "math" + "math/rand" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/require" + + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" +) + +const ( + oneDegree testDegrees = 1 + twoDegrees testDegrees = 2 +) + +var ( + trials = 20 + populationSize = 1e5 + + // These may be computed using Gonum, e.g., + // import "gonum.org/v1/gonum/stat/distuv" + // with significance = 1 / float64(trials) = 0.05 + // chiSquaredDF1 = distuv.ChiSquared{K: 1}.Quantile(significance) + // chiSquaredDF2 = distuv.ChiSquared{K: 2}.Quantile(significance) + // + // These have been specified using significance = 0.05: + chiSquaredDF1 = 0.003932140000019522 + chiSquaredDF2 = 0.1025865887751011 + + chiSquaredByDF = [3]float64{ + 0, + chiSquaredDF1, + chiSquaredDF2, + } +) + +func TestSamplerStatistics(t *testing.T) { + + seedBankRng := rand.New(rand.NewSource(77777677777)) + seedBank := make([]int64, 7) // N.B. Max=6 below. + for i := range seedBank { + seedBank[i] = seedBankRng.Int63() + } + type ( + testCase struct { + // prob is the sampling probability under test. + prob float64 + + // upperP reflects the larger of the one or two + // distinct adjusted counts represented in the test. + // + // For power-of-two tests, there is one distinct p-value, + // and each span counts as 2**upperP representative spans. + // + // For non-power-of-two tests, there are two distinct + // p-values expected, the test is specified using the + // larger of these values corresponding with the + // smaller sampling probability. The sampling + // probability under test rounded down to the nearest + // power of two is expected to equal 2**(-upperP). + upperP pValue + + // degrees is 1 for power-of-two tests and 2 for + // non-power-of-two tests. + degrees testDegrees + + // seedIndex is the index into seedBank of the test seed. + // If this is -1 the code below will search for the smallest + // seed index that passes the test. + seedIndex int + } + testResult struct { + test testCase + expected []float64 + } + ) + var ( + testSummary []testResult + + allTests = []testCase{ + // Non-powers of two + {0.90000, 1, twoDegrees, 3}, + {0.60000, 1, twoDegrees, 2}, + {0.33000, 2, twoDegrees, 2}, + {0.13000, 3, twoDegrees, 1}, + {0.10000, 4, twoDegrees, 0}, + {0.05000, 5, twoDegrees, 0}, + {0.01700, 6, twoDegrees, 2}, + {0.01000, 7, twoDegrees, 2}, + {0.00500, 8, twoDegrees, 2}, + {0.00290, 9, twoDegrees, 4}, + {0.00100, 10, twoDegrees, 6}, + {0.00050, 11, twoDegrees, 0}, + + // Powers of two + {0x1p-1, 1, oneDegree, 0}, + {0x1p-4, 4, oneDegree, 0}, + {0x1p-7, 7, oneDegree, 1}, + } + ) + + // Limit the test runtime by choosing 3 of the above + // non-deterministically + rand.New(rand.NewSource(time.Now().UnixNano())).Shuffle(len(allTests), func(i, j int) { + allTests[i], allTests[j] = allTests[j], allTests[i] + }) + allTests = allTests[0:3] + + for _, test := range allTests { + t.Run(fmt.Sprint(test.prob), func(t *testing.T) { + var expected []float64 + trySeedIndex := 0 + + for { + var seed int64 + seedIndex := test.seedIndex + if seedIndex >= 0 { + seed = seedBank[seedIndex] + } else { + seedIndex = trySeedIndex + seed = seedBank[trySeedIndex] + trySeedIndex++ + } + + countFailures := func(src rand.Source) int { + failed := 0 + + for j := 0; j < trials; j++ { + var x float64 + x, expected = sampleTrials(t, test.prob, test.degrees, test.upperP, src) + + if x < chiSquaredByDF[test.degrees] { + failed++ + } + } + return failed + } + + failed := countFailures(rand.NewSource(seed)) + + if failed != 1 && test.seedIndex < 0 { + t.Logf("%d probabilistic failures, trying a new seed for %g was 0x%x", failed, test.prob, seed) + continue + } else if failed != 1 { + t.Errorf("wrong number of probabilistic failures for %g, should be 1 was %d for seed 0x%x", test.prob, failed, seed) + } else if test.seedIndex < 0 { + t.Logf("update the test for %g to use seed index %d", test.prob, seedIndex) + t.Fail() + return + } else { + // Note: this can be uncommented to verify that the preceding seed failed the test, + // however this just doubles runtime and adds little evidence. For example: + // if seedIndex != 0 && countFailures(rand.NewSource(seedBank[seedIndex-1])) == 1 { + // t.Logf("update the test for %g to use seed index < %d", test.prob, seedIndex) + // t.Fail() + // } + break + } + } + testSummary = append(testSummary, testResult{ + test: test, + expected: expected, + }) + }) + } + + // Note: This produces a table that should match what is in + // the specification if it's the same test. + for idx, res := range testSummary { + var probability, pvalues, expectLower, expectUpper, expectUnsampled string + if res.test.degrees == twoDegrees { + probability = fmt.Sprintf("%.6f", res.test.prob) + pvalues = fmt.Sprint(res.test.upperP-1, ", ", res.test.upperP) + expectUnsampled = fmt.Sprintf("%.10g", res.expected[0]) + expectLower = fmt.Sprintf("%.10g", res.expected[1]) + expectUpper = fmt.Sprintf("%.10g", res.expected[2]) + } else { + probability = fmt.Sprintf("%x (%.6f)", res.test.prob, res.test.prob) + pvalues = fmt.Sprint(res.test.upperP) + expectUnsampled = fmt.Sprintf("%.10g", res.expected[0]) + expectLower = fmt.Sprintf("%.10g", res.expected[1]) + expectUpper = "n/a" + } + t.Logf("| %d | %s | %s | %s | %s | %s |\n", idx+1, probability, pvalues, expectLower, expectUpper, expectUnsampled) + } +} + +func sampleTrials(t *testing.T, prob float64, degrees testDegrees, upperP pValue, source rand.Source) (float64, []float64) { + ctx := context.Background() + + sampler := ProbabilityBased( + prob, + WithRandomSource(source), + ) + + recorder := &tracetest.InMemoryExporter{} + provider := sdktrace.NewTracerProvider( + sdktrace.WithSyncer(recorder), + sdktrace.WithSampler(sampler), + ) + + tracer := provider.Tracer("test") + + for i := 0; i < int(populationSize); i++ { + _, span := tracer.Start(ctx, "span") + span.End() + } + + var minP, maxP pValue + + counts := map[pValue]int64{} + + for idx, r := range recorder.GetSpans() { + ts := r.SpanContext.TraceState() + p, _ := parsePR(ts.Get("ot")) + + pi, err := strconv.ParseUint(p, 10, 64) + require.NoError(t, err) + + if idx == 0 { + maxP = pValue(pi) + minP = maxP + } else { + if pValue(pi) < minP { + minP = pValue(pi) + } + if pValue(pi) > maxP { + maxP = pValue(pi) + } + } + counts[pValue(pi)]++ + } + + require.Less(t, maxP, minP+pValue(degrees), "%v %v %v", minP, maxP, degrees) + require.Less(t, maxP, pValue(63)) + require.LessOrEqual(t, len(counts), 2) + + var ceilingProb, floorProb, floorChoice float64 + + // Note: we have to test len(counts) == 0 because this outcome + // is actually possible, just very unlikely. If this happens + // during development, a new initial seed must be used for + // this test. + // + // The test specification ensures the test ensures there are + // at least 20 expected items per category in these tests. + require.NotEqual(t, 0, len(counts)) + + if degrees == 2 { + // Note: because the test is probabilistic, we can't be + // sure that both the min and max P values happen. We + // can only assert that one of these is true. + require.GreaterOrEqual(t, maxP, upperP-1) + require.GreaterOrEqual(t, minP, upperP-1) + require.LessOrEqual(t, maxP, upperP) + require.LessOrEqual(t, minP, upperP) + require.LessOrEqual(t, maxP-minP, 1) + + ceilingProb = 1 / float64(int64(1)<<(upperP-1)) + floorProb = 1 / float64(int64(1)< !hasRValue() + pvalue: invalidValue, // out-of-range => !hasPValue() + } +} + +func (otts otelTraceState) serialize() string { + var sb strings.Builder + semi := func() { + if sb.Len() != 0 { + _, _ = sb.WriteString(";") + } + } + + if otts.hasPValue() { + _, _ = sb.WriteString(fmt.Sprintf("p:%d", otts.pvalue)) + } + if otts.hasRValue() { + semi() + _, _ = sb.WriteString(fmt.Sprintf("r:%d", otts.rvalue)) + } + for _, unk := range otts.unknown { + ex := 0 + if sb.Len() != 0 { + ex = 1 + } + if sb.Len()+ex+len(unk) > traceStateSizeLimit { + // Note: should this generate an explicit error? + break + } + semi() + _, _ = sb.WriteString(unk) + } + return sb.String() +} + +func isValueByte(r byte) bool { + if isLCAlphaNum(r) { + return true + } + if isUCAlpha(r) { + return true + } + return r == '.' || r == '_' || r == '-' +} + +func isLCAlphaNum(r byte) bool { + if isLCAlpha(r) { + return true + } + return r >= '0' && r <= '9' +} + +func isLCAlpha(r byte) bool { + return r >= 'a' && r <= 'z' +} + +func isUCAlpha(r byte) bool { + return r >= 'A' && r <= 'Z' +} + +func parseOTelTraceState(ts string, isSampled bool) (otelTraceState, error) { + var pval, rval string + var unknown []string + + if len(ts) == 0 { + return newTraceState(), nil + } + + if len(ts) > traceStateSizeLimit { + return newTraceState(), errTraceStateSyntax + } + + for len(ts) > 0 { + eqPos := 0 + for ; eqPos < len(ts); eqPos++ { + if eqPos == 0 { + if isLCAlpha(ts[eqPos]) { + continue + } + } else if isLCAlphaNum(ts[eqPos]) { + continue + } + break + } + if eqPos == 0 || eqPos == len(ts) || ts[eqPos] != ':' { + return newTraceState(), errTraceStateSyntax + } + + key := ts[0:eqPos] + tail := ts[eqPos+1:] + + sepPos := 0 + + for ; sepPos < len(tail); sepPos++ { + if isValueByte(tail[sepPos]) { + continue + } + break + } + + if key == pValueSubkey { + // Note: does the spec say how to handle duplicates? + pval = tail[0:sepPos] + } else if key == rValueSubkey { + rval = tail[0:sepPos] + } else { + unknown = append(unknown, ts[0:sepPos+eqPos+1]) + } + + if sepPos < len(tail) && tail[sepPos] != ';' { + return newTraceState(), errTraceStateSyntax + } + + if sepPos == len(tail) { + break + } + + ts = tail[sepPos+1:] + + // test for a trailing ; + if ts == "" { + return newTraceState(), errTraceStateSyntax + } + } + + otts := newTraceState() + otts.unknown = unknown + + // Note: set R before P, so that P won't propagate if R has an error. + value, err := parseNumber(rValueSubkey, rval, pZeroValue-1) + if err != nil { + return otts, err + } + otts.rvalue = value + + value, err = parseNumber(pValueSubkey, pval, pZeroValue) + if err != nil { + return otts, err + } + otts.pvalue = value + + // Invariant checking: unset P when the values are inconsistent. + if otts.hasPValue() && otts.hasRValue() { + implied := otts.pvalue <= otts.rvalue || otts.pvalue == pZeroValue + + if !isSampled || !implied { + // Note: the error ensures the parent-based + // sampler repairs the broken tracestate entry. + otts.pvalue = invalidValue + return otts, parseError(pValueSubkey, errTraceStateInconsistent) + } + } + + return otts, nil +} + +func parseNumber(key string, input string, maximum uint8) (uint8, error) { + if input == "" { + return maximum + 1, nil + } + value, err := strconv.ParseUint(input, 10, 64) + if err != nil { + return maximum + 1, parseError(key, err) + } + if value > uint64(maximum) { + return maximum + 1, parseError(key, strconv.ErrRange) + } + return uint8(value), nil +} + +func parseError(key string, err error) error { + return fmt.Errorf("otel tracestate: %s-value %w", key, err) +} + +func (otts otelTraceState) hasRValue() bool { + return otts.rvalue < pZeroValue +} + +func (otts otelTraceState) hasPValue() bool { + return otts.pvalue <= pZeroValue +} diff --git a/samplers/probability/consistent/tracestate_test.go b/samplers/probability/consistent/tracestate_test.go new file mode 100644 index 00000000000..671b13e6a15 --- /dev/null +++ b/samplers/probability/consistent/tracestate_test.go @@ -0,0 +1,290 @@ +// Copyright The OpenTelemetry Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consistent + +import ( + "errors" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func testName(in string) string { + x := strings.NewReplacer(":", "_", ";", "_").Replace(in) + if len(x) > 32 { + return "" + } + return x +} + +func TestNewTraceState(t *testing.T) { + otts := newTraceState() + require.False(t, otts.hasPValue()) + require.False(t, otts.hasRValue()) + require.Equal(t, "", otts.serialize()) +} + +func TestTraceStatePRValueSerialize(t *testing.T) { + otts := newTraceState() + otts.pvalue = 3 + otts.rvalue = 4 + otts.unknown = []string{"a:b", "c:d"} + require.True(t, otts.hasPValue()) + require.True(t, otts.hasRValue()) + require.Equal(t, "p:3;r:4;a:b;c:d", otts.serialize()) +} + +func TestTraceStateSerializeOverflow(t *testing.T) { + long := "x:" + strings.Repeat(".", 254) + otts := newTraceState() + otts.unknown = []string{long} + // this drops the extra key, sorry! + require.Equal(t, long, otts.serialize()) + otts.pvalue = 1 + require.Equal(t, "p:1", otts.serialize()) +} + +func TestParseTraceStateUnsampled(t *testing.T) { + type testCase struct { + in string + rval uint8 + expectErr error + } + const notset = 255 + for _, test := range []testCase{ + // All are unsampled tests, i.e., `sampled` is not set in traceparent. + {"r:2", 2, nil}, + {"r:1;", notset, strconv.ErrSyntax}, + {"r:1", 1, nil}, + {"r:1=p:2", notset, strconv.ErrSyntax}, + {"r:1;p:2=s:3", notset, strconv.ErrSyntax}, + {":1;p:2=s:3", notset, strconv.ErrSyntax}, + {":;p:2=s:3", notset, strconv.ErrSyntax}, + {":;:", notset, strconv.ErrSyntax}, + {":", notset, strconv.ErrSyntax}, + {"", notset, nil}, + {"r:;p=1", notset, strconv.ErrSyntax}, + {"r:1", 1, nil}, + {"r:10", 10, nil}, + {"r:33", 33, nil}, + {"r:61", 61, nil}, + {"r:62", 62, nil}, // max r-value + {"r:63", notset, strconv.ErrRange}, // out-of-range + {"r:100", notset, strconv.ErrRange}, // out-of-range + {"r:100001", notset, strconv.ErrRange}, // out-of-range + {"p:64", notset, strconv.ErrRange}, + {"p:100", notset, strconv.ErrRange}, + {"r:1a", notset, strconv.ErrSyntax}, // not-hexadecimal + {"p:-1", notset, strconv.ErrSyntax}, // non-negative + + // Inconsistent trace state: any p-value when unsampled + {"p:4;r:2", 2, errTraceStateInconsistent}, + {"p:1;r:2", 2, errTraceStateInconsistent}, + } { + t.Run(testName(test.in), func(t *testing.T) { + // Note: passing isSampled=false as stated above. + otts, err := parseOTelTraceState(test.in, false) + + require.False(t, otts.hasPValue(), "should have no p-value") + + if test.expectErr != nil { + require.True(t, errors.Is(err, test.expectErr), "not expecting %v", err) + } + if test.rval != notset { + require.True(t, otts.hasRValue()) + require.Equal(t, test.rval, otts.rvalue) + } else { + require.False(t, otts.hasRValue(), "should have no r-value") + } + require.EqualValues(t, []string(nil), otts.unknown) + + if test.expectErr == nil { + // Require serialize to round-trip + otts2, err := parseOTelTraceState(otts.serialize(), false) + require.NoError(t, err) + require.Equal(t, otts, otts2) + } + }) + } +} + +func TestParseTraceStateSampled(t *testing.T) { + type testCase struct { + in string + rval, pval uint8 + expectErr error + } + const notset = 255 + for _, test := range []testCase{ + // All are sampled tests, i.e., `sampled` is set in traceparent. + {"r:2;p:2", 2, 2, nil}, + {"r:2;p:1", 2, 1, nil}, + {"r:2;p:0", 2, 0, nil}, + + {"r:1;p:1", 1, 1, nil}, + {"r:1;p:0", 1, 0, nil}, + + {"r:0;p:0", 0, 0, nil}, + + {"r:62;p:0", 62, 0, nil}, + {"r:62;p:62", 62, 62, nil}, + + // The important special case: + {"r:0;p:63", 0, 63, nil}, + {"r:2;p:63", 2, 63, nil}, + {"r:62;p:63", 62, 63, nil}, + + // Inconsistent p causes unset p-value. + {"r:2;p:3", 2, notset, errTraceStateInconsistent}, + {"r:2;p:4", 2, notset, errTraceStateInconsistent}, + {"r:2;p:62", 2, notset, errTraceStateInconsistent}, + {"r:0;p:1", 0, notset, errTraceStateInconsistent}, + {"r:1;p:2", 1, notset, errTraceStateInconsistent}, + {"r:61;p:62", 61, notset, errTraceStateInconsistent}, + + // Inconsistent r causes unset p-value and r-value. + {"r:63;p:2", notset, notset, strconv.ErrRange}, + {"r:120;p:2", notset, notset, strconv.ErrRange}, + {"r:ab;p:2", notset, notset, strconv.ErrSyntax}, + + // Syntax is tested before range errors + {"r:ab;p:77", notset, notset, strconv.ErrSyntax}, + + // p without r (when sampled) + {"p:1", notset, 1, nil}, + {"p:62", notset, 62, nil}, + {"p:63", notset, 63, nil}, + + // r without p (when sampled) + {"r:2", 2, notset, nil}, + {"r:62", 62, notset, nil}, + {"r:0", 0, notset, nil}, + } { + t.Run(testName(test.in), func(t *testing.T) { + // Note: passing isSampled=true as stated above. + otts, err := parseOTelTraceState(test.in, true) + + if test.expectErr != nil { + require.True(t, errors.Is(err, test.expectErr), "not expecting %v", err) + } else { + require.NoError(t, err) + } + if test.pval != notset { + require.True(t, otts.hasPValue()) + require.Equal(t, test.pval, otts.pvalue) + } else { + require.False(t, otts.hasPValue(), "should have no p-value") + } + if test.rval != notset { + require.True(t, otts.hasRValue()) + require.Equal(t, test.rval, otts.rvalue) + } else { + require.False(t, otts.hasRValue(), "should have no r-value") + } + require.EqualValues(t, []string(nil), otts.unknown) + + if test.expectErr == nil { + // Require serialize to round-trip + otts2, err := parseOTelTraceState(otts.serialize(), true) + require.NoError(t, err) + require.Equal(t, otts, otts2) + } + }) + } +} + +func TestParseTraceStateExtra(t *testing.T) { + type testCase struct { + in string + rval, pval uint8 + sampled bool + extra []string + expectErr error + } + const notset = 255 + for _, test := range []testCase{ + // one field + {"e100:1", notset, notset, false, []string{"e100:1"}, nil}, + + // two fields + {"e1:1;e2:2", notset, notset, false, []string{"e1:1", "e2:2"}, nil}, + {"e1:1;e2:2", notset, notset, false, []string{"e1:1", "e2:2"}, nil}, + + // one extra key, three ways + {"r:2;p:2;extra:stuff", 2, 2, true, []string{"extra:stuff"}, nil}, + {"extra:stuff;r:2;p:2", 2, 2, true, []string{"extra:stuff"}, nil}, + {"p:2;extra:stuff;r:2", 2, 2, true, []string{"extra:stuff"}, nil}, + + // extra with inconsistent p with and without sampling + {"r:3;extra:stuff;p:4", 3, notset, true, []string{"extra:stuff"}, errTraceStateInconsistent}, + {"extra:stuff;r:3;p:2", 3, notset, false, []string{"extra:stuff"}, errTraceStateInconsistent}, + + // two extra fields + {"e100:100;r:2;p:1;e101:101", 2, 1, true, []string{"e100:100", "e101:101"}, nil}, + {"r:2;p:1;e100:100;e101:101", 2, 1, true, []string{"e100:100", "e101:101"}, nil}, + {"e100:100;e101:101;r:2;p:1", 2, 1, true, []string{"e100:100", "e101:101"}, nil}, + + // parse error prevents capturing unrecognized keys + {"1:1;u:V", notset, notset, true, nil, strconv.ErrSyntax}, + {"X:1;u:V", notset, notset, true, nil, strconv.ErrSyntax}, + {"x:1;u:V", notset, notset, true, []string{"x:1", "u:V"}, nil}, + + // no trailing ; + {"x:1;", notset, notset, true, nil, strconv.ErrSyntax}, + + // empty key + {"x:", notset, notset, true, []string{"x:"}, nil}, + + // charset test + {"x:0X1FFF;y:.-_-.;z:", notset, notset, true, []string{"x:0X1FFF", "y:.-_-.", "z:"}, nil}, + {"x1y2z3:1-2-3;y1:y_1;xy:-;r:50", 50, notset, true, []string{"x1y2z3:1-2-3", "y1:y_1", "xy:-"}, nil}, + + // size exceeded + {"x:" + strings.Repeat("_", 255), notset, notset, false, nil, strconv.ErrSyntax}, + {"x:" + strings.Repeat("_", 254), notset, notset, false, []string{"x:" + strings.Repeat("_", 254)}, nil}, + } { + t.Run(testName(test.in), func(t *testing.T) { + // Note: These tests are independent of sampling state, + // so both are tested. + otts, err := parseOTelTraceState(test.in, test.sampled) + + if test.expectErr != nil { + require.True(t, errors.Is(err, test.expectErr), "not expecting %v", err) + } else { + require.NoError(t, err) + } + if test.pval != notset { + require.True(t, otts.hasPValue()) + require.Equal(t, test.pval, otts.pvalue) + } else { + require.False(t, otts.hasPValue(), "should have no p-value") + } + if test.rval != notset { + require.True(t, otts.hasRValue()) + require.Equal(t, test.rval, otts.rvalue) + } else { + require.False(t, otts.hasRValue(), "should have no r-value") + } + require.EqualValues(t, test.extra, otts.unknown) + + // on success w/o r-value or p-value, serialize() should not modify + if !otts.hasRValue() && !otts.hasPValue() && test.expectErr == nil { + require.Equal(t, test.in, otts.serialize()) + } + }) + } +}