2
2
3
3
import json
4
4
import logging
5
+ import os
5
6
from pathlib import Path
6
7
from typing import Dict , Optional , Set
7
8
9
+ import rich
8
10
from connection_retriever import ConnectionObject , retrieve_objects # type: ignore
9
11
from connection_retriever .errors import NotPermittedError # type: ignore
10
12
11
13
from .models import AirbyteCatalog , Command , ConfiguredAirbyteCatalog , ConnectionObjects , SecretDict
12
14
13
15
LOGGER = logging .getLogger (__name__ )
16
+ console = rich .get_console ()
14
17
15
18
16
19
def parse_config (config : Dict | str | None ) -> Optional [SecretDict ]:
@@ -32,14 +35,17 @@ def parse_catalog(catalog: Dict | str | None) -> Optional[AirbyteCatalog]:
32
35
33
36
34
37
def parse_configured_catalog (
35
- configured_catalog : Dict | str | None ,
38
+ configured_catalog : Dict | str | None , selected_streams : Set [ str ] | None = None
36
39
) -> Optional [ConfiguredAirbyteCatalog ]:
37
40
if not configured_catalog :
38
41
return None
39
42
if isinstance (configured_catalog , str ):
40
- return ConfiguredAirbyteCatalog .parse_obj (json .loads (configured_catalog ))
43
+ catalog = ConfiguredAirbyteCatalog .parse_obj (json .loads (configured_catalog ))
41
44
else :
42
- return ConfiguredAirbyteCatalog .parse_obj (configured_catalog )
45
+ catalog = ConfiguredAirbyteCatalog .parse_obj (configured_catalog )
46
+ if selected_streams :
47
+ return ConfiguredAirbyteCatalog (streams = [stream for stream in catalog .streams if stream .stream .name in selected_streams ])
48
+ return catalog
43
49
44
50
45
51
def parse_state (state : Dict | str | None ) -> Optional [Dict ]:
@@ -59,8 +65,8 @@ def get_state_from_path(state_path: Path) -> Optional[Dict]:
59
65
return parse_state (state_path .read_text ())
60
66
61
67
62
- def get_configured_catalog_from_path (path : Path ) -> Optional [ConfiguredAirbyteCatalog ]:
63
- return parse_configured_catalog (path .read_text ())
68
+ def get_configured_catalog_from_path (path : Path , selected_streams : Optional [ Set [ str ]] = None ) -> Optional [ConfiguredAirbyteCatalog ]:
69
+ return parse_configured_catalog (path .read_text (), selected_streams )
64
70
65
71
66
72
COMMAND_TO_REQUIRED_OBJECT_TYPES = {
@@ -85,6 +91,8 @@ def get_connection_objects(
85
91
retrieval_reason : Optional [str ],
86
92
fail_if_missing_objects : bool = True ,
87
93
connector_image : Optional [str ] = None ,
94
+ auto_select_connection : bool = False ,
95
+ selected_streams : Optional [Set [str ]] = None ,
88
96
) -> ConnectionObjects :
89
97
"""This function retrieves the connection objects values.
90
98
It checks that the required objects are available and raises a UsageError if they are not.
@@ -100,18 +108,26 @@ def get_connection_objects(
100
108
retrieval_reason (Optional[str]): The reason to access the connection objects.
101
109
fail_if_missing_objects (bool, optional): Whether to raise a ValueError if a required object is missing. Defaults to True.
102
110
connector_image (Optional[str]): The image name for the connector under test.
111
+ auto_select_connection (bool, optional): Whether to automatically select a connection if no connection id is passed. Defaults to False.
112
+ selected_streams (Optional[Set[str]]): The set of selected streams to use when auto selecting a connection.
103
113
Raises:
104
114
click.UsageError: If a required object is missing for the command.
105
115
click.UsageError: If a retrieval reason is missing when passing a connection id.
106
116
Returns:
107
117
ConnectionObjects: The connection objects values.
108
118
"""
119
+ if connection_id is None and not auto_select_connection :
120
+ raise ValueError ("A connection id or auto_select_connection must be provided to retrieve the connection objects." )
121
+ if auto_select_connection and not connector_image :
122
+ raise ValueError ("A connector image must be provided when using auto_select_connection." )
109
123
110
124
custom_config = get_connector_config_from_path (custom_config_path ) if custom_config_path else None
111
- custom_configured_catalog = get_configured_catalog_from_path (custom_configured_catalog_path ) if custom_configured_catalog_path else None
125
+ custom_configured_catalog = (
126
+ get_configured_catalog_from_path (custom_configured_catalog_path , selected_streams ) if custom_configured_catalog_path else None
127
+ )
112
128
custom_state = get_state_from_path (custom_state_path ) if custom_state_path else None
113
129
114
- if not connection_id :
130
+ if not connection_id and not auto_select_connection :
115
131
connection_object = ConnectionObjects (
116
132
source_config = custom_config ,
117
133
destination_config = custom_config ,
@@ -121,15 +137,35 @@ def get_connection_objects(
121
137
workspace_id = None ,
122
138
source_id = None ,
123
139
destination_id = None ,
140
+ connection_id = None ,
141
+ source_docker_image = None ,
124
142
)
125
143
else :
126
144
if not retrieval_reason :
127
145
raise ValueError ("A retrieval reason is required to access the connection objects when passing a connection id." )
128
- retrieved_objects = retrieve_objects (connection_id , requested_objects , retrieval_reason = retrieval_reason )
146
+ LOGGER .info ("Retrieving connection objects from the database..." )
147
+ if auto_select_connection :
148
+ is_ci = os .getenv ("CI" , False )
149
+ connection_id , retrieved_objects = retrieve_objects (
150
+ requested_objects ,
151
+ retrieval_reason = retrieval_reason ,
152
+ source_docker_repository = connector_image ,
153
+ prompt_for_connection_selection = not is_ci ,
154
+ with_streams = selected_streams ,
155
+ )
156
+ else :
157
+ connection_id , retrieved_objects = retrieve_objects (
158
+ requested_objects ,
159
+ retrieval_reason = retrieval_reason ,
160
+ connection_id = connection_id ,
161
+ with_streams = selected_streams ,
162
+ )
129
163
retrieved_source_config = parse_config (retrieved_objects .get (ConnectionObject .SOURCE_CONFIG ))
130
164
rerieved_destination_config = parse_config (retrieved_objects .get (ConnectionObject .DESTINATION_CONFIG ))
131
165
retrieved_catalog = parse_catalog (retrieved_objects .get (ConnectionObject .CATALOG ))
132
- retrieved_configured_catalog = parse_configured_catalog (retrieved_objects .get (ConnectionObject .CONFIGURED_CATALOG ))
166
+ retrieved_configured_catalog = parse_configured_catalog (
167
+ retrieved_objects .get (ConnectionObject .CONFIGURED_CATALOG ), selected_streams
168
+ )
133
169
retrieved_state = parse_state (retrieved_objects .get (ConnectionObject .STATE ))
134
170
135
171
retrieved_source_docker_image = retrieved_objects .get (ConnectionObject .SOURCE_DOCKER_IMAGE )
@@ -149,6 +185,8 @@ def get_connection_objects(
149
185
workspace_id = retrieved_objects .get (ConnectionObject .WORKSPACE_ID ),
150
186
source_id = retrieved_objects .get (ConnectionObject .SOURCE_ID ),
151
187
destination_id = retrieved_objects .get (ConnectionObject .DESTINATION_ID ),
188
+ source_docker_image = retrieved_source_docker_image ,
189
+ connection_id = connection_id ,
152
190
)
153
191
if fail_if_missing_objects :
154
192
if not connection_object .source_config and ConnectionObject .SOURCE_CONFIG in requested_objects :
0 commit comments