Skip to content

Commit c993677

Browse files
committed
Consolidate import and usage of itertools
1 parent abef61f commit c993677

24 files changed

+66
-70
lines changed

airflow/configuration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import datetime
2020
import functools
2121
import io
22-
import itertools as it
22+
import itertools
2323
import json
2424
import logging
2525
import multiprocessing
@@ -473,7 +473,7 @@ def get_sections_including_defaults(self) -> list[str]:
473473
474474
:return: list of section names
475475
"""
476-
return list(dict.fromkeys(it.chain(self.configuration_description, self.sections())))
476+
return list(dict.fromkeys(itertools.chain(self.configuration_description, self.sections())))
477477

478478
def get_options_including_defaults(self, section: str) -> list[str]:
479479
"""
@@ -485,7 +485,7 @@ def get_options_including_defaults(self, section: str) -> list[str]:
485485
"""
486486
my_own_options = self.options(section) if self.has_section(section) else []
487487
all_options_from_defaults = self.configuration_description.get(section, {}).get("options", {})
488-
return list(dict.fromkeys(it.chain(all_options_from_defaults, my_own_options)))
488+
return list(dict.fromkeys(itertools.chain(all_options_from_defaults, my_own_options)))
489489

490490
def optionxform(self, optionstr: str) -> str:
491491
"""

airflow/decorators/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from __future__ import annotations
1818

1919
import inspect
20+
import itertools
2021
import warnings
2122
from functools import cached_property
22-
from itertools import chain
2323
from textwrap import dedent
2424
from typing import (
2525
Any,
@@ -226,7 +226,7 @@ def __init__(
226226
def execute(self, context: Context):
227227
# todo make this more generic (move to prepare_lineage) so it deals with non taskflow operators
228228
# as well
229-
for arg in chain(self.op_args, self.op_kwargs.values()):
229+
for arg in itertools.chain(self.op_args, self.op_kwargs.values()):
230230
if isinstance(arg, Dataset):
231231
self.inlets.append(arg)
232232
return_value = super().execute(context)

airflow/lineage/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
"""Provides lineage support functions."""
1919
from __future__ import annotations
2020

21-
import itertools
2221
import logging
2322
from functools import wraps
2423
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
@@ -142,7 +141,7 @@ def wrapper(self, context, *args, **kwargs):
142141
_inlets = self.xcom_pull(
143142
context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS, session=session
144143
)
145-
self.inlets.extend(itertools.chain.from_iterable(_inlets))
144+
self.inlets.extend(i for it in _inlets for i in it)
146145

147146
elif self.inlets:
148147
raise AttributeError("inlets is not a list, operator, string or attr annotated object")

airflow/providers/amazon/aws/hooks/batch_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"""
2727
from __future__ import annotations
2828

29-
import itertools as it
29+
import itertools
3030
from random import uniform
3131
from time import sleep
3232
from typing import Callable
@@ -488,7 +488,7 @@ def get_job_all_awslogs_info(self, job_id: str) -> list[dict[str, str]]:
488488

489489
# cross stream names with options (i.e. attempts X nodes) to generate all log infos
490490
result = []
491-
for stream, option in it.product(stream_names, log_options):
491+
for stream, option in itertools.product(stream_names, log_options):
492492
result.append(
493493
{
494494
"awslogs_stream_name": stream,

airflow/providers/amazon/aws/triggers/batch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from __future__ import annotations
1818

1919
import asyncio
20-
import itertools as it
20+
import itertools
2121
from functools import cached_property
2222
from typing import Any
2323

@@ -162,7 +162,7 @@ async def run(self):
162162
"""
163163
async with self.hook.async_conn as client:
164164
waiter = self.hook.get_waiter("batch_job_complete", deferrable=True, client=client)
165-
for attempt in it.count(1):
165+
for attempt in itertools.count(1):
166166
try:
167167
await waiter.wait(
168168
jobs=[self.job_id],

airflow/providers/cncf/kubernetes/utils/pod_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from __future__ import annotations
1919

2020
import enum
21-
import itertools as it
21+
import itertools
2222
import json
2323
import logging
2424
import math
@@ -628,7 +628,7 @@ def read_pod(self, pod: V1Pod) -> V1Pod:
628628

629629
def await_xcom_sidecar_container_start(self, pod: V1Pod) -> None:
630630
self.log.info("Checking if xcom sidecar container is started.")
631-
for attempt in it.count():
631+
for attempt in itertools.count():
632632
if self.container_is_running(pod, PodDefaults.SIDECAR_CONTAINER_NAME):
633633
self.log.info("The xcom sidecar container is started.")
634634
break

airflow/utils/helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
from __future__ import annotations
1919

2020
import copy
21+
import itertools
2122
import re
2223
import signal
2324
import warnings
2425
from datetime import datetime
2526
from functools import reduce
26-
from itertools import filterfalse, tee
2727
from typing import TYPE_CHECKING, Any, Callable, Generator, Iterable, Mapping, MutableMapping, TypeVar, cast
2828

2929
from lazy_object_proxy import Proxy
@@ -216,8 +216,8 @@ def merge_dicts(dict1: dict, dict2: dict) -> dict:
216216

217217
def partition(pred: Callable[[T], bool], iterable: Iterable[T]) -> tuple[Iterable[T], Iterable[T]]:
218218
"""Use a predicate to partition entries into false entries and true entries."""
219-
iter_1, iter_2 = tee(iterable)
220-
return filterfalse(pred, iter_1), filter(pred, iter_2)
219+
iter_1, iter_2 = itertools.tee(iterable)
220+
return itertools.filterfalse(pred, iter_1), filter(pred, iter_2)
221221

222222

223223
def chain(*args, **kwargs):

airflow/www/decorators.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919

2020
import functools
2121
import gzip
22+
import itertools
2223
import json
2324
import logging
2425
from io import BytesIO as IO
25-
from itertools import chain
2626
from typing import Callable, TypeVar, cast
2727

2828
import pendulum
@@ -94,15 +94,15 @@ def wrapper(*args, **kwargs):
9494
fields_skip_logging = {"csrf_token", "_csrf_token"}
9595
extra_fields = [
9696
(k, secrets_masker.redact(v, k))
97-
for k, v in chain(request.values.items(multi=True), request.view_args.items())
97+
for k, v in itertools.chain(request.values.items(multi=True), request.view_args.items())
9898
if k not in fields_skip_logging
9999
]
100100
if event and event.startswith("variable."):
101101
extra_fields = _mask_variable_fields(extra_fields)
102102
if event and event.startswith("connection."):
103103
extra_fields = _mask_connection_fields(extra_fields)
104104

105-
params = {k: v for k, v in chain(request.values.items(), request.view_args.items())}
105+
params = {k: v for k, v in itertools.chain(request.values.items(), request.view_args.items())}
106106

107107
log = Log(
108108
event=event or f.__name__,

dev/check_files.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
import itertools
1920
import os
2021
import re
21-
from itertools import product
2222

2323
import rich_click as click
2424
from rich import print
@@ -141,7 +141,7 @@ def check_release(files: list[str], version: str):
141141

142142

143143
def expand_name_variations(files):
144-
return sorted(base + suffix for base, suffix in product(files, ["", ".asc", ".sha512"]))
144+
return sorted(base + suffix for base, suffix in itertools.product(files, ["", ".asc", ".sha512"]))
145145

146146

147147
def check_upgrade_check(files: list[str], version: str):

docs/build_docs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
from __future__ import annotations
2424

2525
import argparse
26+
import itertools
2627
import multiprocessing
2728
import os
2829
import sys
2930
from collections import defaultdict
30-
from itertools import filterfalse, tee
3131
from typing import Callable, Iterable, NamedTuple, TypeVar
3232

3333
from rich.console import Console
@@ -74,8 +74,8 @@
7474

7575
def partition(pred: Callable[[T], bool], iterable: Iterable[T]) -> tuple[Iterable[T], Iterable[T]]:
7676
"""Use a predicate to partition entries into false entries and true entries"""
77-
iter_1, iter_2 = tee(iterable)
78-
return filterfalse(pred, iter_1), filter(pred, iter_2)
77+
iter_1, iter_2 = itertools.tee(iterable)
78+
return itertools.filterfalse(pred, iter_1), filter(pred, iter_2)
7979

8080

8181
def _promote_new_flags():

docs/exts/docs_build/fetch_inventories.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
import concurrent
2020
import concurrent.futures
2121
import datetime
22+
import itertools
2223
import os
2324
import shutil
2425
import sys
2526
import traceback
26-
from itertools import repeat
2727
from tempfile import NamedTemporaryFile
2828
from typing import Iterator
2929

@@ -142,7 +142,7 @@ def fetch_inventories():
142142
with requests.Session() as session, concurrent.futures.ThreadPoolExecutor(DEFAULT_POOLSIZE) as pool:
143143
download_results: Iterator[tuple[str, bool]] = pool.map(
144144
_fetch_file,
145-
repeat(session, len(to_download)),
145+
itertools.repeat(session, len(to_download)),
146146
(pkg_name for pkg_name, _, _ in to_download),
147147
(url for _, url, _ in to_download),
148148
(path for _, _, path in to_download),

docs/exts/docs_build/lint_checks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from __future__ import annotations
1818

1919
import ast
20+
import itertools
2021
import os
2122
import re
2223
from glob import glob
23-
from itertools import chain
2424
from typing import Iterable
2525

2626
from docs.exts.docs_build.docs_builder import ALL_PROVIDER_YAMLS
@@ -87,7 +87,7 @@ def check_guide_links_in_operator_descriptions() -> list[DocBuildError]:
8787
operator_names=find_existing_guide_operator_names(
8888
f"{DOCS_DIR}/apache-airflow/howto/operator/**/*.rst"
8989
),
90-
python_module_paths=chain(
90+
python_module_paths=itertools.chain(
9191
glob(f"{ROOT_PACKAGE_DIR}/operators/*.py"),
9292
glob(f"{ROOT_PACKAGE_DIR}/sensors/*.py"),
9393
),
@@ -101,7 +101,7 @@ def check_guide_links_in_operator_descriptions() -> list[DocBuildError]:
101101
}
102102

103103
# Extract all potential python modules that can contain operators
104-
python_module_paths = chain(
104+
python_module_paths = itertools.chain(
105105
glob(f"{provider['package-dir']}/**/operators/*.py", recursive=True),
106106
glob(f"{provider['package-dir']}/**/sensors/*.py", recursive=True),
107107
glob(f"{provider['package-dir']}/**/transfers/*.py", recursive=True),

scripts/ci/pre_commit/pre_commit_check_deferrable_default.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def iter_check_deferrable_default_errors(module_filename: str) -> Iterator[str]:
7676
args = node.args
7777
arguments = reversed([*args.args, *args.kwonlyargs])
7878
defaults = reversed([*args.defaults, *args.kw_defaults])
79-
for argument, default in itertools.zip_longest(arguments, defaults, fillvalue=None):
79+
for argument, default in zip(arguments, defaults):
8080
if argument is None or default is None:
8181
continue
8282
if argument.arg != "deferrable" or _is_valid_deferrable_default(default):

scripts/ci/pre_commit/pre_commit_sort_installed_providers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
# under the License.
1818
from __future__ import annotations
1919

20-
import itertools
2120
from pathlib import Path
2221

2322
if __name__ not in ("__main__", "__mp_main__"):
@@ -35,7 +34,7 @@ def stable_sort(x):
3534

3635

3736
def sort_uniq(sequence):
38-
return (x[0] for x in itertools.groupby(sorted(sequence, key=stable_sort)))
37+
return sorted(set(sequence), key=stable_sort)
3938

4039

4140
if __name__ == "__main__":

scripts/ci/pre_commit/pre_commit_sort_spelling_wordlist.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
# under the License.
1818
from __future__ import annotations
1919

20-
import itertools
2120
from pathlib import Path
2221

2322
if __name__ not in ("__main__", "__mp_main__"):
@@ -35,7 +34,7 @@ def stable_sort(x):
3534

3635

3736
def sort_uniq(sequence):
38-
return (x[0] for x in itertools.groupby(sorted(sequence, key=stable_sort)))
37+
return sorted(set(sequence), key=stable_sort)
3938

4039

4140
if __name__ == "__main__":

scripts/in_container/run_provider_yaml_files_check.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import importlib
2121
import inspect
22+
import itertools
2223
import json
2324
import os
2425
import pathlib
@@ -27,7 +28,6 @@
2728
import textwrap
2829
from collections import Counter
2930
from enum import Enum
30-
from itertools import chain, product
3131
from typing import Any, Iterable
3232

3333
import jsonschema
@@ -219,7 +219,7 @@ def check_if_objects_exist_and_belong_to_package(
219219
def parse_module_data(provider_data, resource_type, yaml_file_path):
220220
package_dir = ROOT_DIR.joinpath(yaml_file_path).parent
221221
provider_package = pathlib.Path(yaml_file_path).parent.as_posix().replace("/", ".")
222-
py_files = chain(
222+
py_files = itertools.chain(
223223
package_dir.glob(f"**/{resource_type}/*.py"),
224224
package_dir.glob(f"{resource_type}/*.py"),
225225
package_dir.glob(f"**/{resource_type}/**/*.py"),
@@ -233,7 +233,7 @@ def parse_module_data(provider_data, resource_type, yaml_file_path):
233233
def check_correctness_of_list_of_sensors_operators_hook_modules(yaml_files: dict[str, dict]):
234234
print("Checking completeness of list of {sensors, hooks, operators, triggers}")
235235
print(" -- {sensors, hooks, operators, triggers} - Expected modules (left) : Current modules (right)")
236-
for (yaml_file_path, provider_data), resource_type in product(
236+
for (yaml_file_path, provider_data), resource_type in itertools.product(
237237
yaml_files.items(), ["sensors", "operators", "hooks", "triggers"]
238238
):
239239
expected_modules, provider_package, resource_data = parse_module_data(
@@ -257,7 +257,7 @@ def check_correctness_of_list_of_sensors_operators_hook_modules(yaml_files: dict
257257

258258
def check_duplicates_in_integrations_names_of_hooks_sensors_operators(yaml_files: dict[str, dict]):
259259
print("Checking for duplicates in list of {sensors, hooks, operators, triggers}")
260-
for (yaml_file_path, provider_data), resource_type in product(
260+
for (yaml_file_path, provider_data), resource_type in itertools.product(
261261
yaml_files.items(), ["sensors", "operators", "hooks", "triggers"]
262262
):
263263
resource_data = provider_data.get(resource_type, [])
@@ -362,7 +362,7 @@ def check_invalid_integration(yaml_files: dict[str, dict]):
362362
print("Detect unregistered integrations")
363363
all_integration_names = set(get_all_integration_names(yaml_files))
364364

365-
for (yaml_file_path, provider_data), resource_type in product(
365+
for (yaml_file_path, provider_data), resource_type in itertools.product(
366366
yaml_files.items(), ["sensors", "operators", "hooks", "triggers"]
367367
):
368368
resource_data = provider_data.get(resource_type, [])
@@ -374,7 +374,7 @@ def check_invalid_integration(yaml_files: dict[str, dict]):
374374
f"Invalid values: {invalid_names}"
375375
)
376376

377-
for (yaml_file_path, provider_data), key in product(
377+
for (yaml_file_path, provider_data), key in itertools.product(
378378
yaml_files.items(), ["source-integration-name", "target-integration-name"]
379379
):
380380
resource_data = provider_data.get("transfers", [])
@@ -409,7 +409,7 @@ def check_doc_files(yaml_files: dict[str, dict]):
409409
console.print("[yellow]Suspended providers:[/]")
410410
console.print(suspended_providers)
411411

412-
expected_doc_files = chain(
412+
expected_doc_files = itertools.chain(
413413
DOCS_DIR.glob("apache-airflow-providers-*/operators/**/*.rst"),
414414
DOCS_DIR.glob("apache-airflow-providers-*/transfer/**/*.rst"),
415415
)

tests/always/test_project_structure.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -455,12 +455,10 @@ class TestDockerProviderProjectStructure(ExampleCoverageTest):
455455
class TestOperatorsHooks:
456456
def test_no_illegal_suffixes(self):
457457
illegal_suffixes = ["_operator.py", "_hook.py", "_sensor.py"]
458-
files = itertools.chain(
459-
*(
460-
glob.glob(f"{ROOT_FOLDER}/{part}/providers/**/{resource_type}/*.py", recursive=True)
461-
for resource_type in ["operators", "hooks", "sensors", "example_dags"]
462-
for part in ["airflow", "tests"]
463-
)
458+
files = itertools.chain.from_iterable(
459+
glob.glob(f"{ROOT_FOLDER}/{part}/providers/**/{resource_type}/*.py", recursive=True)
460+
for resource_type in ["operators", "hooks", "sensors", "example_dags"]
461+
for part in ["airflow", "tests"]
464462
)
465463

466464
invalid_files = [f for f in files if f.endswith(tuple(illegal_suffixes))]

0 commit comments

Comments
 (0)