Skip to content

Use aws-lambda-java-serialization library, which is available by default, while deserializing input and serializing output #11868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestStreamHandler;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.context.Scope;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.internal.ApiGatewayProxyRequest;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.internal.AwsLambdaFunctionInstrumenter;
Expand Down Expand Up @@ -67,49 +68,55 @@ protected TracingRequestStreamHandler(
@Override
public void handleRequest(InputStream input, OutputStream output, Context context)
throws IOException {

ApiGatewayProxyRequest proxyRequest = ApiGatewayProxyRequest.forStream(input);
AwsLambdaRequest request =
AwsLambdaRequest.create(context, proxyRequest, proxyRequest.getHeaders());
AwsLambdaRequest request = createRequest(input, context, proxyRequest);
io.opentelemetry.context.Context parentContext = instrumenter.extract(request);

if (!instrumenter.shouldStart(parentContext, request)) {
doHandleRequest(proxyRequest.freshStream(), output, context);
doHandleRequest(proxyRequest.freshStream(), output, context, request);
return;
}

io.opentelemetry.context.Context otelContext = instrumenter.start(parentContext, request);
Throwable error = null;
try (Scope ignored = otelContext.makeCurrent()) {
doHandleRequest(
proxyRequest.freshStream(),
new OutputStreamWrapper(output, otelContext, request, openTelemetrySdk),
context);
new OutputStreamWrapper(output, otelContext),
context,
request);
} catch (Throwable t) {
instrumenter.end(otelContext, request, null, t);
LambdaUtils.forceFlush(openTelemetrySdk, flushTimeoutNanos, TimeUnit.NANOSECONDS);
error = t;
throw t;
} finally {
instrumenter.end(otelContext, request, null, error);
LambdaUtils.forceFlush(openTelemetrySdk, flushTimeoutNanos, TimeUnit.NANOSECONDS);
}
}

protected AwsLambdaRequest createRequest(
InputStream input, Context context, ApiGatewayProxyRequest proxyRequest) throws IOException {
return AwsLambdaRequest.create(context, proxyRequest, proxyRequest.getHeaders());
}

protected void doHandleRequest(
InputStream input, OutputStream output, Context context, AwsLambdaRequest request)
throws IOException {
doHandleRequest(input, output, context);
}

protected abstract void doHandleRequest(InputStream input, OutputStream output, Context context)
throws IOException;

private class OutputStreamWrapper extends OutputStream {
private static class OutputStreamWrapper extends OutputStream {

private final OutputStream delegate;
private final io.opentelemetry.context.Context otelContext;
private final AwsLambdaRequest request;
private final OpenTelemetrySdk openTelemetrySdk;

private OutputStreamWrapper(
OutputStream delegate,
io.opentelemetry.context.Context otelContext,
AwsLambdaRequest request,
OpenTelemetrySdk openTelemetrySdk) {
OutputStream delegate, io.opentelemetry.context.Context otelContext) {
this.delegate = delegate;
this.otelContext = otelContext;
this.request = request;
this.openTelemetrySdk = openTelemetrySdk;
}

@Override
Expand All @@ -135,8 +142,8 @@ public void flush() throws IOException {
@Override
public void close() throws IOException {
delegate.close();
instrumenter.end(otelContext, request, null, null);
LambdaUtils.forceFlush(openTelemetrySdk, flushTimeoutNanos, TimeUnit.NANOSECONDS);
Span span = Span.fromContext(otelContext);
span.addEvent("Output stream closed");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
*/
public class TracingRequestStreamWrapper extends TracingRequestStreamHandler {

private final WrappedLambda wrappedLambda;
protected final WrappedLambda wrappedLambda;

public TracingRequestStreamWrapper() {
this(
Expand All @@ -32,7 +32,8 @@ public TracingRequestStreamWrapper() {
}

// Visible for testing
TracingRequestStreamWrapper(OpenTelemetrySdk openTelemetrySdk, WrappedLambda wrappedLambda) {
protected TracingRequestStreamWrapper(
OpenTelemetrySdk openTelemetrySdk, WrappedLambda wrappedLambda) {
super(openTelemetrySdk, WrapperConfiguration.flushTimeout());
this.wrappedLambda = wrappedLambda;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ dependencies {
// in public API.
library("com.amazonaws:aws-lambda-java-events:2.2.1")

// By default, "aws-lambda-java-serialization" library is enabled in the classpath
// at the AWS Lambda environment except "java8" runtime which is deprecated.
// But it is available at "java8.al2" runtime, so it is still can be used
// by Java 8 based Lambda functions.
// So that is the reason that why we add it as compile only dependency.
compileOnly("com.amazonaws:aws-lambda-java-serialization:1.1.5")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment here about this library being available during runtime natively in lambda?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rapphil done


// We need Jackson for wrappers to reproduce the serialization does when Lambda invokes a RequestHandler with event
// since Lambda will only be able to invoke the wrapper itself with a generic Object.
// Note that Lambda itself uses Jackson, but does not expose it to the function so we need to include it here.
Expand All @@ -33,6 +40,7 @@ dependencies {
testImplementation("io.opentelemetry:opentelemetry-sdk-extension-autoconfigure")
testImplementation("io.opentelemetry:opentelemetry-extension-trace-propagators")
testImplementation("com.google.guava:guava")
testImplementation("com.amazonaws:aws-lambda-java-serialization:1.1.5")

testImplementation(project(":instrumentation:aws-lambda:aws-lambda-events-2.2:testing"))
testImplementation("uk.org.webcompere:system-stubs-jupiter")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package io.opentelemetry.instrumentation.awslambdaevents.v2_2;

import com.amazonaws.services.lambda.runtime.Context;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.util.function.BiFunction;

Expand All @@ -27,5 +28,35 @@ static <T> Object[] toArray(
return parameters;
}

static <T> Object[] toParameters(Method targetMethod, T input, Context context) {
Class<?>[] parameterTypes = targetMethod.getParameterTypes();
Object[] parameters = new Object[parameterTypes.length];
for (int i = 0; i < parameterTypes.length; i++) {
Class<?> clazz = parameterTypes[i];
boolean isContext = clazz.equals(Context.class);
if (isContext) {
parameters[i] = context;
} else if (i == 0) {
parameters[0] = input;
}
}
return parameters;
}

static Object toInput(
Method targetMethod,
InputStream inputStream,
BiFunction<InputStream, Class<?>, Object> mapper) {
Class<?>[] parameterTypes = targetMethod.getParameterTypes();
for (int i = 0; i < parameterTypes.length; i++) {
Class<?> clazz = parameterTypes[i];
boolean isContext = clazz.equals(Context.class);
if (i == 0 && !isContext) {
return mapper.apply(inputStream, clazz);
}
}
return null;
}

private LambdaParameters() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.internal.WrappedLambda;
import io.opentelemetry.instrumentation.awslambdaevents.v2_2.internal.SerializationUtil;
import io.opentelemetry.sdk.OpenTelemetrySdk;
import java.util.function.BiFunction;

Expand All @@ -35,12 +35,7 @@ public TracingRequestApiGatewayWrapper() {

// Visible for testing
static <T> T map(APIGatewayProxyRequestEvent event, Class<T> clazz) {
try {
return OBJECT_MAPPER.readValue(event.getBody(), clazz);
} catch (JsonProcessingException e) {
throw new IllegalStateException(
"Could not map API Gateway event body to requested parameter type: " + clazz, e);
}
return SerializationUtil.fromJson(event.getBody(), clazz);
}

@Override
Expand All @@ -52,12 +47,8 @@ protected APIGatewayProxyResponseEvent doHandleRequest(
if (result instanceof APIGatewayProxyResponseEvent) {
event = (APIGatewayProxyResponseEvent) result;
} else {
try {
event = new APIGatewayProxyResponseEvent();
event.setBody(OBJECT_MAPPER.writeValueAsString(result));
} catch (JsonProcessingException e) {
throw new IllegalStateException("Could not serialize return value.", e);
}
event = new APIGatewayProxyResponseEvent();
event.setBody(SerializationUtil.toJson(result));
}
return event;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,90 @@

package io.opentelemetry.instrumentation.awslambdaevents.v2_2;

import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.AwsLambdaRequest;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.TracingRequestStreamWrapper;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.internal.ApiGatewayProxyRequest;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.internal.MapUtils;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.internal.WrappedLambda;
import io.opentelemetry.instrumentation.awslambdaevents.v2_2.internal.SerializationUtil;
import io.opentelemetry.sdk.OpenTelemetrySdk;
import java.util.function.BiFunction;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.Map;

/**
* Wrapper for {@link io.opentelemetry.instrumentation.awslambdacore.v1_0.TracingRequestHandler}.
* Allows for wrapping a regular lambda, not proxied through API Gateway. Therefore, HTTP headers
* propagation is not supported.
* Wrapper for {@link com.amazonaws.services.lambda.runtime.RequestHandler} based Lambda handlers.
*/
public class TracingRequestWrapper extends TracingRequestWrapperBase<Object, Object> {
public class TracingRequestWrapper extends TracingRequestStreamWrapper {
public TracingRequestWrapper() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this constructor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tylerbenson default constructors are used by the AWS Lambda runtime while instantiating our wrapper handlers. Other constructors with arguments are used by the tests.

super(TracingRequestWrapper::map);
super();
}

// Visible for testing
TracingRequestWrapper(
OpenTelemetrySdk openTelemetrySdk,
WrappedLambda wrappedLambda,
BiFunction<Object, Class<?>, Object> mapper) {
super(openTelemetrySdk, wrappedLambda, mapper);
TracingRequestWrapper(OpenTelemetrySdk openTelemetrySdk, WrappedLambda wrappedLambda) {
super(openTelemetrySdk, wrappedLambda);
}

@Override
protected final AwsLambdaRequest createRequest(
InputStream inputStream, Context context, ApiGatewayProxyRequest proxyRequest) {
Method targetMethod = wrappedLambda.getRequestTargetMethod();
Object input = LambdaParameters.toInput(targetMethod, inputStream, TracingRequestWrapper::map);
return AwsLambdaRequest.create(context, input, extractHeaders(input));
}

protected Map<String, String> extractHeaders(Object input) {
if (input instanceof APIGatewayProxyRequestEvent) {
return MapUtils.emptyIfNull(((APIGatewayProxyRequestEvent) input).getHeaders());
}
return Collections.emptyMap();
}

@Override
protected final void doHandleRequest(
InputStream input, OutputStream output, Context context, AwsLambdaRequest request) {
Method targetMethod = wrappedLambda.getRequestTargetMethod();
Object[] parameters = LambdaParameters.toParameters(targetMethod, request.getInput(), context);
try {
Object result = targetMethod.invoke(wrappedLambda.getTargetObject(), parameters);
SerializationUtil.toJson(output, result);
} catch (IllegalAccessException e) {
throw new IllegalStateException("Method is inaccessible", e);
} catch (InvocationTargetException e) {
throw (e.getCause() instanceof RuntimeException
? (RuntimeException) e.getCause()
: new IllegalStateException(e.getTargetException()));
Comment on lines +66 to +68
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we stripping off the top layer of the exception chain? Is it to remove the instrumentation itself or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is the same exception handling logic borrowed from TracingRequestWrapperBase.doHandleRequest which was parent class of TracingRequestWrapper before these changes. Maybe @tylerbenson or @rapphil can enlighten us on the reason behind this logic.

}
}

@SuppressWarnings({"unchecked", "TypeParameterUnusedInFormals"})
// Used for testing
<INPUT, OUTPUT> OUTPUT handleRequest(INPUT input, Context context) throws IOException {
byte[] inputJsonData = SerializationUtil.toJsonData(input);
ByteArrayInputStream inputStream = new ByteArrayInputStream(inputJsonData);
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();

super.handleRequest(inputStream, outputStream, context);

byte[] outputJsonData = outputStream.toByteArray();
return (OUTPUT)
SerializationUtil.fromJson(
new ByteArrayInputStream(outputJsonData),
wrappedLambda.getRequestTargetMethod().getReturnType());
}

// Visible for testing
static <T> T map(Object jsonMap, Class<T> clazz) {
static <T> T map(InputStream inputStream, Class<T> clazz) {
try {
return OBJECT_MAPPER.convertValue(jsonMap, clazz);
return SerializationUtil.fromJson(inputStream, clazz);
} catch (IllegalArgumentException e) {
throw new IllegalStateException(
"Could not map input to requested parameter type: " + clazz, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.opentelemetry.instrumentation.api.internal.HttpConstants;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.TracingRequestHandler;
import io.opentelemetry.instrumentation.awslambdacore.v1_0.internal.MapUtils;
Expand All @@ -29,10 +27,6 @@
*/
abstract class TracingRequestWrapperBase<I, O> extends TracingRequestHandler<I, O> {

protected static final ObjectMapper OBJECT_MAPPER =
new ObjectMapper()
.registerModule(new CustomJodaModule())
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
private final WrappedLambda wrappedLambda;
private final Method targetMethod;
private final BiFunction<I, Class<?>, Object> parameterMapper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.instrumentation.awslambdaevents.v2_2;
package io.opentelemetry.instrumentation.awslambdaevents.v2_2.internal;

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.DeserializationContext;
Expand Down
Loading
Loading