diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c0eb929792..275d1f4be14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Fixed - Fix `StreamClientInterceptor` in `go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc` to end the spans synchronously. (#4537) +- Fix data race in stats handlers when processing messages received and sent metrics in `go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc`. (#4577) ## [1.21.0/0.46.0/0.15.0/0.1.0] - 2023-11-10 diff --git a/instrumentation/google.golang.org/grpc/otelgrpc/stats_handler.go b/instrumentation/google.golang.org/grpc/otelgrpc/stats_handler.go index 0211e55e003..8e13bb54e87 100644 --- a/instrumentation/google.golang.org/grpc/otelgrpc/stats_handler.go +++ b/instrumentation/google.golang.org/grpc/otelgrpc/stats_handler.go @@ -195,8 +195,8 @@ func (c *config) handleRPC(ctx context.Context, rs stats.RPCStats) { metricAttrs = append(metricAttrs, rpcStatusAttr) c.rpcDuration.Record(wctx, float64(rs.EndTime.Sub(rs.BeginTime)), metric.WithAttributes(metricAttrs...)) - c.rpcRequestsPerRPC.Record(wctx, gctx.messagesReceived, metric.WithAttributes(metricAttrs...)) - c.rpcResponsesPerRPC.Record(wctx, gctx.messagesSent, metric.WithAttributes(metricAttrs...)) + c.rpcRequestsPerRPC.Record(wctx, atomic.LoadInt64(&gctx.messagesReceived), metric.WithAttributes(metricAttrs...)) + c.rpcResponsesPerRPC.Record(wctx, atomic.LoadInt64(&gctx.messagesSent), metric.WithAttributes(metricAttrs...)) default: return diff --git a/instrumentation/google.golang.org/grpc/otelgrpc/test/grpc_stats_handler_test.go b/instrumentation/google.golang.org/grpc/otelgrpc/test/grpc_stats_handler_test.go index e6fd212f904..f8dd8871072 100644 --- a/instrumentation/google.golang.org/grpc/otelgrpc/test/grpc_stats_handler_test.go +++ b/instrumentation/google.golang.org/grpc/otelgrpc/test/grpc_stats_handler_test.go @@ -16,13 +16,17 @@ package test import ( "context" + "io" "net" + "sync" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/interop" + "google.golang.org/grpc/status" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.opentelemetry.io/otel/attribute" @@ -31,6 +35,8 @@ import ( "go.opentelemetry.io/otel/sdk/metric/metricdata" "go.opentelemetry.io/otel/sdk/metric/metricdata/metricdatatest" + testpb "google.golang.org/grpc/interop/grpc_testing" + "go.opentelemetry.io/otel/sdk/trace" "go.opentelemetry.io/otel/sdk/trace/tracetest" semconv "go.opentelemetry.io/otel/semconv/v1.17.0" @@ -1316,3 +1322,72 @@ func checkServerMetrics(t *testing.T, reader metric.Reader) { metricdatatest.AssertEqual(t, expectedScopeMetric, rm.ScopeMetrics[0], metricdatatest.IgnoreTimestamp(), metricdatatest.IgnoreValue()) } + +// Ensure there is no data race for the following scenario: +// Bidirectional streaming + client cancels context in the middle of streaming. +func TestStatsHandlerConcurrentSafeContextCancellation(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err, "failed to open port") + client := newGrpcTest(t, listener, + []grpc.DialOption{ + grpc.WithStatsHandler(otelgrpc.NewClientHandler()), + }, + []grpc.ServerOption{ + grpc.StatsHandler(otelgrpc.NewServerHandler()), + }, + ) + + const n = 10 + for i := 0; i < n; i++ { + ctx, cancel := context.WithCancel(context.Background()) + stream, err := client.FullDuplexCall(ctx) + require.NoError(t, err) + + const messageCount = 10 + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < messageCount; i++ { + const reqSize = 1 + pl := interop.ClientNewPayload(testpb.PayloadType_COMPRESSABLE, reqSize) + respParam := []*testpb.ResponseParameters{ + { + Size: reqSize, + }, + } + req := &testpb.StreamingOutputCallRequest{ + ResponseType: testpb.PayloadType_COMPRESSABLE, + ResponseParameters: respParam, + Payload: pl, + } + err := stream.Send(req) + if err == io.EOF { // possible due to context cancellation + require.ErrorIs(t, ctx.Err(), context.Canceled) + } else { + require.NoError(t, err) + } + } + require.NoError(t, stream.CloseSend()) + }() + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < messageCount; i++ { + _, err := stream.Recv() + if i > messageCount/2 { + cancel() + } + // must continue to receive messages until server acknowledges the cancellation, to ensure no data race happens there too + if status.Code(err) == codes.Canceled { + return + } + require.NoError(t, err) + } + }() + + wg.Wait() + } +}