Skip to content

Don't cache sanitization results for large sql statements #13353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ default String getDbSystem(REQUEST request) {

@Deprecated
@Nullable
String getUser(REQUEST request);
default String getUser(REQUEST request) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since these are removed in the stable semconv we don't need to force users to implement them.

return null;
}

/**
* @deprecated Use {@link #getDbNamespace(Object)} instead.
Expand All @@ -43,5 +45,7 @@ default String getDbNamespace(REQUEST request) {

@Deprecated
@Nullable
String getConnectionString(REQUEST request);
default String getConnectionString(REQUEST request) {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ public String extract(REQUEST request) {
private static final class SqlClientSpanNameExtractor<REQUEST>
extends DbClientSpanNameExtractor<REQUEST> {

// a dedicated sanitizer just for extracting the operation and identifier name
private static final SqlStatementSanitizer sanitizer = SqlStatementSanitizer.create(true);

private final SqlClientAttributesGetter<REQUEST> getter;

private SqlClientSpanNameExtractor(SqlClientAttributesGetter<REQUEST> getter) {
Expand All @@ -106,13 +103,15 @@ public String extract(REQUEST request) {
if (rawQueryTexts.size() > 1) { // for backcompat(?)
return computeSpanName(namespace, null, null);
}
SqlStatementInfo sanitizedStatement = sanitizer.sanitize(rawQueryTexts.iterator().next());
SqlStatementInfo sanitizedStatement =
SqlStatementSanitizerUtil.sanitize(rawQueryTexts.iterator().next());
return computeSpanName(
namespace, sanitizedStatement.getOperation(), sanitizedStatement.getMainIdentifier());
}

if (rawQueryTexts.size() == 1) {
SqlStatementInfo sanitizedStatement = sanitizer.sanitize(rawQueryTexts.iterator().next());
SqlStatementInfo sanitizedStatement =
SqlStatementSanitizerUtil.sanitize(rawQueryTexts.iterator().next());
String operation = sanitizedStatement.getOperation();
if (isBatch(request)) {
operation = "BATCH " + operation;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import java.util.Set;

class MultiQuery {
private static final SqlStatementSanitizer sanitizer = SqlStatementSanitizer.create(true);

private final String mainIdentifier;
private final String operation;
Expand All @@ -28,7 +27,7 @@ static MultiQuery analyze(
UniqueValue uniqueOperation = new UniqueValue();
Set<String> uniqueStatements = new LinkedHashSet<>();
for (String rawQueryText : rawQueryTexts) {
SqlStatementInfo sanitizedStatement = sanitizer.sanitize(rawQueryText);
SqlStatementInfo sanitizedStatement = SqlStatementSanitizerUtil.sanitize(rawQueryText);
String mainIdentifier = sanitizedStatement.getMainIdentifier();
uniqueMainIdentifier.set(mainIdentifier);
String operation = sanitizedStatement.getOperation();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ public static <REQUEST, RESPONSE> SqlClientAttributesExtractorBuilder<REQUEST, R
}

private static final String SQL_CALL = "CALL";
// sanitizer is also used to extract operation and table name, so we have it always enabled here
private static final SqlStatementSanitizer sanitizer = SqlStatementSanitizer.create(true);

private final AttributeKey<String> oldSemconvTableAttribute;
private final boolean statementSanitizationEnabled;
Expand All @@ -83,7 +81,7 @@ public void onStart(AttributesBuilder attributes, Context parentContext, REQUEST
if (SemconvStability.emitOldDatabaseSemconv()) {
if (rawQueryTexts.size() == 1) { // for backcompat(?)
String rawQueryText = rawQueryTexts.iterator().next();
SqlStatementInfo sanitizedStatement = sanitizer.sanitize(rawQueryText);
SqlStatementInfo sanitizedStatement = SqlStatementSanitizerUtil.sanitize(rawQueryText);
String operation = sanitizedStatement.getOperation();
internalSet(
attributes,
Expand All @@ -104,7 +102,7 @@ public void onStart(AttributesBuilder attributes, Context parentContext, REQUEST
}
if (rawQueryTexts.size() == 1) {
String rawQueryText = rawQueryTexts.iterator().next();
SqlStatementInfo sanitizedStatement = sanitizer.sanitize(rawQueryText);
SqlStatementInfo sanitizedStatement = SqlStatementSanitizerUtil.sanitize(rawQueryText);
String operation = sanitizedStatement.getOperation();
internalSet(
attributes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public final class SqlStatementSanitizer {

private static final Cache<CacheKey, SqlStatementInfo> sqlToStatementInfoCache =
Cache.bounded(1000);
private static final int LARGE_STATEMENT_THRESHOLD = 10 * 1024;

public static SqlStatementSanitizer create(boolean statementSanitizationEnabled) {
return new SqlStatementSanitizer(statementSanitizationEnabled);
Expand All @@ -40,12 +41,24 @@ public SqlStatementInfo sanitize(@Nullable String statement, SqlDialect dialect)
if (!statementSanitizationEnabled || statement == null) {
return SqlStatementInfo.create(statement, null, null);
}
// sanitization result will not be cached for statements larger than the threshold to avoid
// cache growing too large
// https://github.com/open-telemetry/opentelemetry-java-instrumentation/issues/13180
if (statement.length() > LARGE_STATEMENT_THRESHOLD) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was thinking we could hash these larger statements instead of using the whole statement as the key, but that might be more computationally expensive, so this seems reasonable to me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually my first attempt was to use hashing. Computing a hash for a very large statement can be more expensive than applying the sanitizer as the sanitizer also applies a size limit. My guess is that many of these super large statements could be dynamically generated so it is likely that the statement is executed only once and would not benefit from caching anyway.

return sanitizeImpl(statement, dialect);
}
return sqlToStatementInfoCache.computeIfAbsent(
CacheKey.create(statement, dialect),
k -> {
supportability.incrementCounter(SQL_STATEMENT_SANITIZER_CACHE_MISS);
return AutoSqlSanitizer.sanitize(statement, dialect);
});
CacheKey.create(statement, dialect), k -> sanitizeImpl(statement, dialect));
}

private static SqlStatementInfo sanitizeImpl(@Nullable String statement, SqlDialect dialect) {
supportability.incrementCounter(SQL_STATEMENT_SANITIZER_CACHE_MISS);
return AutoSqlSanitizer.sanitize(statement, dialect);
}

// visible for tests
static boolean isCached(String statement) {
return sqlToStatementInfoCache.get(CacheKey.create(statement, SqlDialect.DEFAULT)) != null;
}

@AutoValue
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.instrumentation.api.incubator.semconv.db;

import io.opentelemetry.instrumentation.api.instrumenter.Instrumenter;
import io.opentelemetry.instrumentation.api.internal.InstrumenterContext;
import java.util.HashMap;
import java.util.Map;

/**
* Helper class for sanitizing sql that keeps sanitization results in {@link InstrumenterContext} so
* that each statement would be sanitized only once for given {@link Instrumenter} call.
*/
class SqlStatementSanitizerUtil {
private static final SqlStatementSanitizer sanitizer = SqlStatementSanitizer.create(true);

static SqlStatementInfo sanitize(String queryText) {
if (!InstrumenterContext.isActive()) {
return sanitizer.sanitize(queryText);
}

Map<String, SqlStatementInfo> map =
InstrumenterContext.computeIfAbsent("sanitized-sql-map", unused -> new HashMap<>());
return map.computeIfAbsent(queryText, sanitizer::sanitize);
}

private SqlStatementSanitizerUtil() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,27 @@ public void longInStatementDoesntCauseStackOverflow() {
assertThat(sanitized).isEqualTo("select col from table where col in (?)");
}

@Test
public void largeStatementCached() {
// test that short statement is cached
String shortStatement = "SELECT * FROM TABLE WHERE FIELD = 1234";
String sanitizedShort =
SqlStatementSanitizer.create(true).sanitize(shortStatement).getFullStatement();
assertThat(sanitizedShort).doesNotContain("1234");
assertThat(SqlStatementSanitizer.isCached(shortStatement)).isTrue();

// test that large statement is not cached
StringBuffer s = new StringBuffer();
for (int i = 0; i < 10000; i++) {
s.append("SELECT * FROM TABLE WHERE FIELD = 1234 AND ");
}
String largeStatement = s.toString();
String sanitizedLarge =
SqlStatementSanitizer.create(true).sanitize(largeStatement).getFullStatement();
assertThat(sanitizedLarge).doesNotContain("1234");
assertThat(SqlStatementSanitizer.isCached(largeStatement)).isFalse();
}

static class SqlArgs implements ArgumentsProvider {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import io.opentelemetry.context.ContextKey;
import io.opentelemetry.instrumentation.api.internal.HttpRouteState;
import io.opentelemetry.instrumentation.api.internal.InstrumenterAccess;
import io.opentelemetry.instrumentation.api.internal.InstrumenterContext;
import io.opentelemetry.instrumentation.api.internal.InstrumenterUtil;
import io.opentelemetry.instrumentation.api.internal.SupportabilityMetrics;
import java.time.Instant;
Expand Down Expand Up @@ -164,6 +165,10 @@ Context startAndEnd(
}

private Context doStart(Context parentContext, REQUEST request, @Nullable Instant startTime) {
return InstrumenterContext.withContext(() -> doStartImpl(parentContext, request, startTime));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the delayed feedback, I wish this (relatively small) overhead didn't affect all instrumentations just for this edge case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reworked this to only create the thread local instrumentation context when needed. The downside is that now if these classes are not used with the Instrumenter there may be a leak.

}

private Context doStartImpl(Context parentContext, REQUEST request, @Nullable Instant startTime) {
SpanKind spanKind = spanKindExtractor.extract(request);
SpanBuilder spanBuilder =
tracer.spanBuilder(spanNameExtractor.extract(request)).setSpanKind(spanKind);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.instrumentation.api.internal;

import io.opentelemetry.instrumentation.api.instrumenter.AttributesExtractor;
import io.opentelemetry.instrumentation.api.instrumenter.Instrumenter;
import io.opentelemetry.instrumentation.api.instrumenter.SpanNameExtractor;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;

/**
* Helper class for sharing computed values between different {@link AttributesExtractor}s and
* {@link SpanNameExtractor} called in the start phase of the {@link Instrumenter}.
*
* <p>This class is internal and is hence not for public use. Its APIs are unstable and can change
* at any time.
*/
public final class InstrumenterContext {
private static final ThreadLocal<InstrumenterContext> instrumenterContext = new ThreadLocal<>();

private final Map<String, Object> map = new HashMap<>();
private int useCount;

private InstrumenterContext() {}

@SuppressWarnings("unchecked")
public static <T> T computeIfAbsent(String key, Function<String, T> function) {
InstrumenterContext context = instrumenterContext.get();
if (context == null) {
return function.apply(key);
}
return (T) context.map.computeIfAbsent(key, function);
}

// visible for testing
static Map<String, Object> get() {
return instrumenterContext.get().map;
}

public static boolean isActive() {
return instrumenterContext.get() != null;
}

public static <T> T withContext(Supplier<T> action) {
InstrumenterContext context = instrumenterContext.get();
if (context == null) {
context = new InstrumenterContext();
instrumenterContext.set(context);
}
context.useCount++;
try {
return action.get();
} finally {
context.useCount--;
if (context.useCount == 0) {
instrumenterContext.remove();
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.instrumentation.api.internal;

import static io.opentelemetry.instrumentation.testing.junit.db.SemconvStabilityUtil.maybeStable;
import static org.assertj.core.api.Assertions.assertThat;

import io.opentelemetry.api.common.Attributes;
import io.opentelemetry.api.common.AttributesBuilder;
import io.opentelemetry.context.Context;
import io.opentelemetry.instrumentation.api.incubator.semconv.db.DbClientSpanNameExtractor;
import io.opentelemetry.instrumentation.api.incubator.semconv.db.SqlClientAttributesExtractor;
import io.opentelemetry.instrumentation.api.incubator.semconv.db.SqlClientAttributesGetter;
import io.opentelemetry.instrumentation.api.incubator.semconv.db.SqlStatementInfo;
import io.opentelemetry.instrumentation.api.instrumenter.AttributesExtractor;
import io.opentelemetry.instrumentation.api.instrumenter.SpanNameExtractor;
import io.opentelemetry.semconv.incubating.DbIncubatingAttributes;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import org.junit.jupiter.api.Test;

class InstrumenterContextTest {

@SuppressWarnings({"unchecked", "deprecation"}) // using deprecated DB_SQL_TABLE
@Test
void testSqlSanitizer() {
String testQuery = "SELECT name FROM test WHERE id = 1";
SqlClientAttributesGetter<Object> getter =
new SqlClientAttributesGetter<Object>() {

@Override
public Collection<String> getRawQueryTexts(Object request) {
return Collections.singletonList(testQuery);
}
};
SpanNameExtractor<Object> spanNameExtractor = DbClientSpanNameExtractor.create(getter);
AttributesExtractor<Object, Void> attributesExtractor =
SqlClientAttributesExtractor.create(getter);

assertThat(InstrumenterContext.isActive()).isFalse();
InstrumenterContext.withContext(
() -> {
assertThat(InstrumenterContext.isActive()).isTrue();
assertThat(InstrumenterContext.get()).isEmpty();
assertThat(spanNameExtractor.extract(null)).isEqualTo("SELECT test");
// verify that sanitized statement was cached, see SqlStatementSanitizerUtil
assertThat(InstrumenterContext.get()).containsKey("sanitized-sql-map");
Map<String, SqlStatementInfo> sanitizedMap =
(Map<String, SqlStatementInfo>) InstrumenterContext.get().get("sanitized-sql-map");
assertThat(sanitizedMap).containsKey(testQuery);

// replace cached sanitization result to verify it is used
sanitizedMap.put(
testQuery,
SqlStatementInfo.create("SELECT name2 FROM test2 WHERE id = ?", "SELECT", "test2"));
{
AttributesBuilder builder = Attributes.builder();
attributesExtractor.onStart(builder, Context.root(), null);
assertThat(builder.build().get(maybeStable(DbIncubatingAttributes.DB_SQL_TABLE)))
.isEqualTo("test2");
}

// clear cached value to see whether it gets recomputed correctly
sanitizedMap.clear();
{
AttributesBuilder builder = Attributes.builder();
attributesExtractor.onStart(builder, Context.root(), null);
assertThat(builder.build().get(maybeStable(DbIncubatingAttributes.DB_SQL_TABLE)))
.isEqualTo("test");
}

return null;
});
}
}
Loading