3
3
#
4
4
5
5
from os import getenv
6
- from typing import Any , Dict , List , Mapping , MutableMapping , Tuple
6
+ from typing import Any , List , Mapping , MutableMapping , Optional , Tuple
7
7
from urllib .parse import urlparse
8
8
9
9
from airbyte_cdk import AirbyteLogger
@@ -65,7 +65,9 @@ class SourceGithub(AbstractSource):
65
65
continue_sync_on_stream_failure = True
66
66
67
67
@staticmethod
68
- def _get_org_repositories (config : Mapping [str , Any ], authenticator : MultipleTokenAuthenticator ) -> Tuple [List [str ], List [str ]]:
68
+ def _get_org_repositories (
69
+ config : Mapping [str , Any ], authenticator : MultipleTokenAuthenticator
70
+ ) -> Tuple [List [str ], List [str ], Optional [str ]]:
69
71
"""
70
72
Parse config/repositories and produce two lists: organizations, repositories.
71
73
Args:
@@ -78,16 +80,19 @@ def _get_org_repositories(config: Mapping[str, Any], authenticator: MultipleToke
78
80
organizations = set ()
79
81
unchecked_repos = set ()
80
82
unchecked_orgs = set ()
83
+ pattern = None
81
84
82
85
for org_repos in config_repositories :
83
- org , _ , repos = org_repos .partition ("/" )
84
- if repos == "*" :
85
- unchecked_orgs .add (org )
86
+ _ , _ , repos = org_repos .partition ("/" )
87
+ if "*" in repos :
88
+ unchecked_orgs .add (org_repos )
86
89
else :
87
90
unchecked_repos .add (org_repos )
88
91
89
92
if unchecked_orgs :
90
- stream = Repositories (authenticator = authenticator , organizations = unchecked_orgs , api_url = config .get ("api_url" ))
93
+ org_names = [org .split ("/" )[0 ] for org in unchecked_orgs ]
94
+ pattern = "|" .join ([f"({ org .replace ('*' , '.*' )} )" for org in unchecked_orgs ])
95
+ stream = Repositories (authenticator = authenticator , organizations = org_names , api_url = config .get ("api_url" ), pattern = pattern )
91
96
for record in read_full_refresh (stream ):
92
97
repositories .add (record ["full_name" ])
93
98
organizations .add (record ["organization" ])
@@ -96,7 +101,7 @@ def _get_org_repositories(config: Mapping[str, Any], authenticator: MultipleToke
96
101
if unchecked_repos :
97
102
stream = RepositoryStats (
98
103
authenticator = authenticator ,
99
- repositories = unchecked_repos ,
104
+ repositories = list ( unchecked_repos ) ,
100
105
api_url = config .get ("api_url" ),
101
106
# This parameter is deprecated and in future will be used sane default, page_size: 10
102
107
page_size_for_large_streams = config .get ("page_size_for_large_streams" , constants .DEFAULT_PAGE_SIZE_FOR_LARGE_STREAM ),
@@ -107,7 +112,7 @@ def _get_org_repositories(config: Mapping[str, Any], authenticator: MultipleToke
107
112
if organization :
108
113
organizations .add (organization )
109
114
110
- return list (organizations ), list (repositories )
115
+ return list (organizations ), list (repositories ), pattern
111
116
112
117
@staticmethod
113
118
def get_access_token (config : Mapping [str , Any ]):
@@ -169,45 +174,6 @@ def _validate_branches(self, config: MutableMapping[str, Any]) -> MutableMapping
169
174
def _is_http_allowed () -> bool :
170
175
return getenv ("DEPLOYMENT_MODE" , "" ).upper () != "CLOUD"
171
176
172
- @staticmethod
173
- def _get_branches_data (
174
- selected_branches : List , full_refresh_args : Dict [str , Any ] = None
175
- ) -> Tuple [Dict [str , str ], Dict [str , List [str ]]]:
176
- selected_branches = set (selected_branches )
177
-
178
- # Get the default branch for each repository
179
- default_branches = {}
180
- repository_stats_stream = RepositoryStats (** full_refresh_args )
181
- for stream_slice in repository_stats_stream .stream_slices (sync_mode = SyncMode .full_refresh ):
182
- default_branches .update (
183
- {
184
- repo_stats ["full_name" ]: repo_stats ["default_branch" ]
185
- for repo_stats in repository_stats_stream .read_records (sync_mode = SyncMode .full_refresh , stream_slice = stream_slice )
186
- }
187
- )
188
-
189
- all_branches = []
190
- branches_stream = Branches (** full_refresh_args )
191
- for stream_slice in branches_stream .stream_slices (sync_mode = SyncMode .full_refresh ):
192
- for branch in branches_stream .read_records (sync_mode = SyncMode .full_refresh , stream_slice = stream_slice ):
193
- all_branches .append (f"{ branch ['repository' ]} /{ branch ['name' ]} " )
194
-
195
- # Create mapping of repository to list of branches to pull commits for
196
- # If no branches are specified for a repo, use its default branch
197
- branches_to_pull : Dict [str , List [str ]] = {}
198
- for repo in full_refresh_args ["repositories" ]:
199
- repo_branches = []
200
- for branch in selected_branches :
201
- branch_parts = branch .split ("/" , 2 )
202
- if "/" .join (branch_parts [:2 ]) == repo and branch in all_branches :
203
- repo_branches .append (branch_parts [- 1 ])
204
- if not repo_branches :
205
- repo_branches = [default_branches [repo ]]
206
-
207
- branches_to_pull [repo ] = repo_branches
208
-
209
- return default_branches , branches_to_pull
210
-
211
177
def user_friendly_error_message (self , message : str ) -> str :
212
178
user_message = ""
213
179
if "404 Client Error: Not Found for url: https://api.github.com/repos/" in message :
@@ -229,7 +195,7 @@ def check_connection(self, logger: AirbyteLogger, config: Mapping[str, Any]) ->
229
195
config = self ._validate_and_transform_config (config )
230
196
try :
231
197
authenticator = self ._get_authenticator (config )
232
- _ , repositories = self ._get_org_repositories (config = config , authenticator = authenticator )
198
+ _ , repositories , _ = self ._get_org_repositories (config = config , authenticator = authenticator )
233
199
if not repositories :
234
200
return (
235
201
False ,
@@ -246,7 +212,7 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
246
212
authenticator = self ._get_authenticator (config )
247
213
config = self ._validate_and_transform_config (config )
248
214
try :
249
- organizations , repositories = self ._get_org_repositories (config = config , authenticator = authenticator )
215
+ organizations , repositories , pattern = self ._get_org_repositories (config = config , authenticator = authenticator )
250
216
except Exception as e :
251
217
message = repr (e )
252
218
user_message = self .user_friendly_error_message (message )
@@ -291,7 +257,6 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
291
257
}
292
258
repository_args_with_start_date = {** repository_args , "start_date" : start_date }
293
259
294
- default_branches , branches_to_pull = self ._get_branches_data (config .get ("branch" , []), repository_args )
295
260
pull_requests_stream = PullRequests (** repository_args_with_start_date )
296
261
projects_stream = Projects (** repository_args_with_start_date )
297
262
project_columns_stream = ProjectColumns (projects_stream , ** repository_args_with_start_date )
@@ -307,7 +272,7 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
307
272
Comments (** repository_args_with_start_date ),
308
273
CommitCommentReactions (** repository_args_with_start_date ),
309
274
CommitComments (** repository_args_with_start_date ),
310
- Commits (** repository_args_with_start_date , branches_to_pull = branches_to_pull , default_branches = default_branches ),
275
+ Commits (** repository_args_with_start_date , branches_to_pull = config . get ( "branches" , []) ),
311
276
ContributorActivity (** repository_args ),
312
277
Deployments (** repository_args_with_start_date ),
313
278
Events (** repository_args_with_start_date ),
@@ -327,7 +292,7 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
327
292
ProjectsV2 (** repository_args_with_start_date ),
328
293
pull_requests_stream ,
329
294
Releases (** repository_args_with_start_date ),
330
- Repositories (** organization_args_with_start_date ),
295
+ Repositories (** organization_args_with_start_date , pattern = pattern ),
331
296
ReviewComments (** repository_args_with_start_date ),
332
297
Reviews (** repository_args_with_start_date ),
333
298
Stargazers (** repository_args_with_start_date ),
0 commit comments