Skip to content

Bmoric/add fetch source schema in oauth #19392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ public static void mergeMaps(final Map<String, Object> originalMap, final String
Entry::getValue)));
}

public static Map<String, String> deserializeToStringMap(JsonNode json) {
return OBJECT_MAPPER.convertValue(json, new TypeReference<>() {});
}

/**
* By the Jackson DefaultPrettyPrinter prints objects with an extra space as follows: {"name" :
* "airbyte"}. We prefer {"name": "airbyte"}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ public static ServerRunnable getServer(final ServerFactory apiFactory,

final HealthCheckHandler healthCheckHandler = new HealthCheckHandler(configRepository);

final OAuthHandler oAuthHandler = new OAuthHandler(configRepository, httpClient, trackingClient);
final OAuthHandler oAuthHandler = new OAuthHandler(configRepository, httpClient, trackingClient, secretsRepositoryReader);

final SourceHandler sourceHandler = new SourceHandler(
configRepository,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

package io.airbyte.server.handlers;

import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Iterables;
import io.airbyte.analytics.TrackingClient;
import io.airbyte.api.model.generated.CompleteDestinationOAuthRequest;
import io.airbyte.api.model.generated.CompleteSourceOauthRequest;
Expand All @@ -12,23 +15,32 @@
import io.airbyte.api.model.generated.SetInstancewideDestinationOauthParamsRequestBody;
import io.airbyte.api.model.generated.SetInstancewideSourceOauthParamsRequestBody;
import io.airbyte.api.model.generated.SourceOauthConsentRequest;
import io.airbyte.commons.constants.AirbyteSecretConstants;
import io.airbyte.commons.json.JsonPaths;
import io.airbyte.commons.json.Jsons;
import io.airbyte.config.DestinationConnection;
import io.airbyte.config.DestinationOAuthParameter;
import io.airbyte.config.SourceConnection;
import io.airbyte.config.SourceOAuthParameter;
import io.airbyte.config.StandardDestinationDefinition;
import io.airbyte.config.StandardSourceDefinition;
import io.airbyte.config.persistence.ConfigNotFoundException;
import io.airbyte.config.persistence.ConfigRepository;
import io.airbyte.config.persistence.SecretsRepositoryReader;
import io.airbyte.oauth.OAuthFlowImplementation;
import io.airbyte.oauth.OAuthImplementationFactory;
import io.airbyte.persistence.job.factory.OAuthConfigSupplier;
import io.airbyte.persistence.job.tracker.TrackingMetadata;
import io.airbyte.protocol.models.ConnectorSpecification;
import io.airbyte.server.handlers.helpers.OAuthPathExtractor;
import io.airbyte.validation.json.JsonValidationException;
import java.io.IOException;
import java.net.http.HttpClient;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -40,67 +52,97 @@ public class OAuthHandler {
private final ConfigRepository configRepository;
private final OAuthImplementationFactory oAuthImplementationFactory;
private final TrackingClient trackingClient;
private final SecretsRepositoryReader secretsRepositoryReader;

public OAuthHandler(final ConfigRepository configRepository,
final HttpClient httpClient,
final TrackingClient trackingClient) {
final TrackingClient trackingClient,
final SecretsRepositoryReader secretsRepositoryReader) {
this.configRepository = configRepository;
this.oAuthImplementationFactory = new OAuthImplementationFactory(configRepository, httpClient);
this.trackingClient = trackingClient;
this.secretsRepositoryReader = secretsRepositoryReader;
}

public OAuthConsentRead getSourceOAuthConsent(final SourceOauthConsentRequest sourceDefinitionIdRequestBody)
public OAuthConsentRead getSourceOAuthConsent(final SourceOauthConsentRequest sourceOauthConsentRequest)
throws JsonValidationException, ConfigNotFoundException, IOException {
final StandardSourceDefinition sourceDefinition =
configRepository.getStandardSourceDefinition(sourceDefinitionIdRequestBody.getSourceDefinitionId());
configRepository.getStandardSourceDefinition(sourceOauthConsentRequest.getSourceDefinitionId());
final OAuthFlowImplementation oAuthFlowImplementation = oAuthImplementationFactory.create(sourceDefinition);
final ConnectorSpecification spec = sourceDefinition.getSpec();
final Map<String, Object> metadata = generateSourceMetadata(sourceDefinitionIdRequestBody.getSourceDefinitionId());
final Map<String, Object> metadata = generateSourceMetadata(sourceOauthConsentRequest.getSourceDefinitionId());
final OAuthConsentRead result;
if (OAuthConfigSupplier.hasOAuthConfigSpecification(spec)) {
final JsonNode oAuthInputConfigurationForConsent;

if (sourceOauthConsentRequest.getSourceId() == null) {
oAuthInputConfigurationForConsent = sourceOauthConsentRequest.getoAuthInputConfiguration();
} else {
final SourceConnection hydratedSourceConnection =
secretsRepositoryReader.getSourceConnectionWithSecrets(sourceOauthConsentRequest.getSourceId());

oAuthInputConfigurationForConsent = getOAuthInputConfigurationForConsent(spec,
hydratedSourceConnection.getConfiguration(),
sourceOauthConsentRequest.getoAuthInputConfiguration() );
}

result = new OAuthConsentRead().consentUrl(oAuthFlowImplementation.getSourceConsentUrl(
sourceDefinitionIdRequestBody.getWorkspaceId(),
sourceDefinitionIdRequestBody.getSourceDefinitionId(),
sourceDefinitionIdRequestBody.getRedirectUrl(),
sourceDefinitionIdRequestBody.getoAuthInputConfiguration(),
sourceOauthConsentRequest.getWorkspaceId(),
sourceOauthConsentRequest.getSourceDefinitionId(),
sourceOauthConsentRequest.getRedirectUrl(),
oAuthInputConfigurationForConsent,
spec.getAdvancedAuth().getOauthConfigSpecification()));
} else {
result = new OAuthConsentRead().consentUrl(oAuthFlowImplementation.getSourceConsentUrl(
sourceDefinitionIdRequestBody.getWorkspaceId(),
sourceDefinitionIdRequestBody.getSourceDefinitionId(),
sourceDefinitionIdRequestBody.getRedirectUrl(), Jsons.emptyObject(), null));
sourceOauthConsentRequest.getWorkspaceId(),
sourceOauthConsentRequest.getSourceDefinitionId(),
sourceOauthConsentRequest.getRedirectUrl(), Jsons.emptyObject(), null));
}
try {
trackingClient.track(sourceDefinitionIdRequestBody.getWorkspaceId(), "Get Oauth Consent URL - Backend", metadata);
trackingClient.track(sourceOauthConsentRequest.getWorkspaceId(), "Get Oauth Consent URL - Backend", metadata);
} catch (final Exception e) {
LOGGER.error(ERROR_MESSAGE, e);
}
return result;
}

public OAuthConsentRead getDestinationOAuthConsent(final DestinationOauthConsentRequest destinationDefinitionIdRequestBody)
public OAuthConsentRead getDestinationOAuthConsent(final DestinationOauthConsentRequest destinationOauthConsentRequest)
throws JsonValidationException, ConfigNotFoundException, IOException {
final StandardDestinationDefinition destinationDefinition =
configRepository.getStandardDestinationDefinition(destinationDefinitionIdRequestBody.getDestinationDefinitionId());
configRepository.getStandardDestinationDefinition(destinationOauthConsentRequest.getDestinationDefinitionId());
final OAuthFlowImplementation oAuthFlowImplementation = oAuthImplementationFactory.create(destinationDefinition);
final ConnectorSpecification spec = destinationDefinition.getSpec();
final Map<String, Object> metadata = generateDestinationMetadata(destinationDefinitionIdRequestBody.getDestinationDefinitionId());
final Map<String, Object> metadata = generateDestinationMetadata(destinationOauthConsentRequest.getDestinationDefinitionId());
final OAuthConsentRead result;
if (OAuthConfigSupplier.hasOAuthConfigSpecification(spec)) {
final JsonNode oAuthInputConfigurationForConsent;

if (destinationOauthConsentRequest.getDestinationId() == null) {
oAuthInputConfigurationForConsent = destinationOauthConsentRequest.getoAuthInputConfiguration();
} else {
final DestinationConnection hydratedSourceConnection =
secretsRepositoryReader.getDestinationConnectionWithSecrets(destinationOauthConsentRequest.getDestinationId());

oAuthInputConfigurationForConsent = getOAuthInputConfigurationForConsent(spec,
hydratedSourceConnection.getConfiguration(),
destinationOauthConsentRequest.getoAuthInputConfiguration() );

}

result = new OAuthConsentRead().consentUrl(oAuthFlowImplementation.getDestinationConsentUrl(
destinationDefinitionIdRequestBody.getWorkspaceId(),
destinationDefinitionIdRequestBody.getDestinationDefinitionId(),
destinationDefinitionIdRequestBody.getRedirectUrl(),
destinationDefinitionIdRequestBody.getoAuthInputConfiguration(),
destinationOauthConsentRequest.getWorkspaceId(),
destinationOauthConsentRequest.getDestinationDefinitionId(),
destinationOauthConsentRequest.getRedirectUrl(),
oAuthInputConfigurationForConsent,
spec.getAdvancedAuth().getOauthConfigSpecification()));
} else {
result = new OAuthConsentRead().consentUrl(oAuthFlowImplementation.getDestinationConsentUrl(
destinationDefinitionIdRequestBody.getWorkspaceId(),
destinationDefinitionIdRequestBody.getDestinationDefinitionId(),
destinationDefinitionIdRequestBody.getRedirectUrl(), Jsons.emptyObject(), null));
destinationOauthConsentRequest.getWorkspaceId(),
destinationOauthConsentRequest.getDestinationDefinitionId(),
destinationOauthConsentRequest.getRedirectUrl(), Jsons.emptyObject(), null));
}
try {
trackingClient.track(destinationDefinitionIdRequestBody.getWorkspaceId(), "Get Oauth Consent URL - Backend", metadata);
trackingClient.track(destinationOauthConsentRequest.getWorkspaceId(), "Get Oauth Consent URL - Backend", metadata);
} catch (final Exception e) {
LOGGER.error(ERROR_MESSAGE, e);
}
Expand Down Expand Up @@ -195,6 +237,19 @@ public void setDestinationInstancewideOauthParams(final SetInstancewideDestinati
configRepository.writeDestinationOAuthParam(param);
}

private JsonNode getOAuthInputConfigurationForConsent(final ConnectorSpecification spec,
final JsonNode hydratedSourceConnectionConfiguration,
final JsonNode destinationDefinitionIdRequestBody) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rename this last parameter to oAuthInputConfiguration, since this current parameter name is not really quite right

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

final List<String> fieldsToGet =
buildJsonPathFromOAuthFlowInitParameters(OAuthPathExtractor.extractOauthConfigurationPaths(
spec.getAdvancedAuth().getOauthConfigSpecification().getOauthUserInputFromConnectorConfigSpecification()));

final JsonNode oAuthInputConfigurationFromDB = getOAuthInputConfiguration(hydratedSourceConnectionConfiguration, fieldsToGet);

return getOauthFromDBIfNeeded(oAuthInputConfigurationFromDB,
destinationDefinitionIdRequestBody);
}

private Map<String, Object> generateSourceMetadata(final UUID sourceDefinitionId)
throws JsonValidationException, ConfigNotFoundException, IOException {
final StandardSourceDefinition sourceDefinition = configRepository.getStandardSourceDefinition(sourceDefinitionId);
Expand All @@ -207,4 +262,41 @@ private Map<String, Object> generateDestinationMetadata(final UUID destinationDe
return TrackingMetadata.generateDestinationDefinitionMetadata(destinationDefinition);
}

@VisibleForTesting
List<String> buildJsonPathFromOAuthFlowInitParameters(final List<List<String>> oAuthFlowInitParameters) {
return oAuthFlowInitParameters.stream()
.map(path -> "$." + String.join(".", path))
.toList();
}

@VisibleForTesting
JsonNode getOauthFromDBIfNeeded(final JsonNode oAuthInputConfigurationFromDB, final JsonNode oAuthInputConfigurationFromInput) {
final Map<String, String> result = new HashMap<>();

Jsons.deserializeToStringMap(oAuthInputConfigurationFromInput)
.forEach((k, v) -> {
if (AirbyteSecretConstants.SECRETS_MASK.equals(v)) {
if (oAuthInputConfigurationFromDB.has(k)) {
result.put(k, oAuthInputConfigurationFromDB.get(k).textValue());
} else {
LOGGER.warn("Missing the k {} in the config store in DB", k);
}

} else {
result.put(k, v);
}
});

return Jsons.jsonNode(result);
}

@VisibleForTesting
JsonNode getOAuthInputConfiguration(final JsonNode hydratedSourceConnectionConfiguration, final List<String> pathsToGet) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the prototype of this function, looks like we could run into conflicts in the Collectors.toMap.
Is that a valid concern? Should we have some type of error handling here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done,

return Jsons.jsonNode(pathsToGet.stream().map(path -> Map.entry(path,
JsonPaths.getSingleValue(hydratedSourceConnectionConfiguration, path)))
.collect(Collectors.toMap(
entry -> Iterables.getLast(List.of(entry.getKey().split("\\."))),
entry -> entry.getValue().get())));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) 2022 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.server.handlers.helpers;

import com.fasterxml.jackson.databind.JsonNode;
import java.util.ArrayList;
import java.util.List;

public class OAuthPathExtractor {

private static final String PROPERTIES = "properties";
private static final String PATH_IN_CONNECTOR_CONFIG = "path_in_connector_config";

public static List<List<String>> extractOauthConfigurationPaths(final JsonNode configuration) {

if (configuration.has(PROPERTIES) && configuration.get(PROPERTIES).isObject()) {
final List<List<String>> result = new ArrayList<>();

configuration.get(PROPERTIES).fields().forEachRemaining(entry -> {
final JsonNode value = entry.getValue();
if (value.isObject() && value.has(PATH_IN_CONNECTOR_CONFIG) && value.get(PATH_IN_CONNECTOR_CONFIG).isArray()) {
final List<String> path = new ArrayList<>();
for (final JsonNode pathPart : value.get(PATH_IN_CONNECTOR_CONFIG)) {
path.add(pathPart.textValue());
}
result.add(path);
}
});

return result;
} else {
return new ArrayList<>();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import com.fasterxml.jackson.databind.JsonNode;
import io.airbyte.analytics.TrackingClient;
import io.airbyte.api.model.generated.SetInstancewideDestinationOauthParamsRequestBody;
import io.airbyte.api.model.generated.SetInstancewideSourceOauthParamsRequestBody;
import io.airbyte.commons.json.Jsons;
import io.airbyte.config.DestinationOAuthParameter;
import io.airbyte.config.SourceOAuthParameter;
import io.airbyte.config.persistence.ConfigRepository;
import io.airbyte.config.persistence.SecretsRepositoryReader;
import io.airbyte.validation.json.JsonValidationException;
import java.io.IOException;
import java.net.http.HttpClient;
Expand All @@ -34,6 +36,7 @@ class OAuthHandlerTest {
private OAuthHandler handler;
private TrackingClient trackingClient;
private HttpClient httpClient;
private SecretsRepositoryReader secretsRepositoryReader;
private static final String CLIENT_ID = "123";
private static final String CLIENT_ID_KEY = "client_id";
private static final String CLIENT_SECRET_KEY = "client_secret";
Expand All @@ -44,7 +47,8 @@ public void init() {
configRepository = Mockito.mock(ConfigRepository.class);
trackingClient = mock(TrackingClient.class);
httpClient = Mockito.mock(HttpClient.class);
handler = new OAuthHandler(configRepository, httpClient, trackingClient);
secretsRepositoryReader = mock(SecretsRepositoryReader.class);
handler = new OAuthHandler(configRepository, httpClient, trackingClient, secretsRepositoryReader);
}

@Test
Expand Down Expand Up @@ -151,4 +155,75 @@ void resetDestinationInstancewideOauthParams() throws JsonValidationException, I
assertEquals(oauthParameterId, capturedValues.get(1).getOauthParameterId());
}

@Test
void testBuildJsonPathFromOAuthFlowInitParameters() {
List<List<String>> input = List.of(
List.of("1"),
List.of("2", "3"));

List<String> expected = List.of("$.1", "$.2.3");

assertEquals(expected, handler.buildJsonPathFromOAuthFlowInitParameters(input));
}

@Test
void testGetOAuthInputConfiguration() {
JsonNode hydratedConfig = Jsons.deserialize(
"""
{
"field1": "1",
"field2": "2",
"field3": {
"field3_1": "3_1",
"field3_2": "3_2"
}
}
""");

List<String> pathsToGet = List.of(
"$.field1",
"$.field3.field3_1",
"$.field3.field3_2");

JsonNode expected = Jsons.deserialize(
"""
{
"field1": "1",
"field3_1": "3_1",
"field3_2": "3_2"
}
""");

assertEquals(expected, handler.getOAuthInputConfiguration(hydratedConfig, pathsToGet));
}

@Test
void testGetOauthFromDBIfNeeded() {
JsonNode fromInput = Jsons.deserialize(
"""
{
"testMask": "**********",
"testNotMask": "this"
}
""");

JsonNode fromDb = Jsons.deserialize(
"""
{
"testMask": "mask",
"testNotMask": "notThis"
}
""");

JsonNode expected = Jsons.deserialize(
"""
{
"testMask": "mask",
"testNotMask": "this"
}
""");

assertEquals(expected, handler.getOauthFromDBIfNeeded(fromDb, fromInput));
}

}
Loading