Skip to content

Commit 08d4e01

Browse files
eumiroephraimbuddy
authored andcommitted
Simplify 'X for X in Y' to 'Y' where applicable (#33453)
(cherry picked from commit 7700fb1)
1 parent c3580fc commit 08d4e01

File tree

20 files changed

+26
-34
lines changed

20 files changed

+26
-34
lines changed

airflow/lineage/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def wrapper(self, context, *args, **kwargs):
142142
_inlets = self.xcom_pull(
143143
context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS, session=session
144144
)
145-
self.inlets.extend(i for i in itertools.chain.from_iterable(_inlets))
145+
self.inlets.extend(itertools.chain.from_iterable(_inlets))
146146

147147
elif self.inlets:
148148
raise AttributeError("inlets is not a list, operator, string or attr annotated object")

airflow/models/dagrun.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1238,7 +1238,7 @@ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) ->
12381238
TI.run_id == self.run_id,
12391239
)
12401240
)
1241-
existing_indexes = {i for i in query}
1241+
existing_indexes = set(query)
12421242

12431243
removed_indexes = existing_indexes.difference(range(total_length))
12441244
if removed_indexes:

airflow/providers/apache/hive/hooks/hive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def run_cli(
235235

236236
invalid_chars_list = re.findall(r"[^a-z0-9_]", schema)
237237
if invalid_chars_list:
238-
invalid_chars = "".join(char for char in invalid_chars_list)
238+
invalid_chars = "".join(invalid_chars_list)
239239
raise RuntimeError(f"The schema `{schema}` contains invalid characters: {invalid_chars}")
240240

241241
if schema:

airflow/providers/microsoft/azure/operators/batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def _check_inputs(self) -> Any:
189189
)
190190

191191
if self.use_latest_image:
192-
if not all(elem for elem in [self.vm_publisher, self.vm_offer]):
192+
if not self.vm_publisher or not self.vm_offer:
193193
raise AirflowException(
194194
f"If use_latest_image_and_sku is set to True then the parameters vm_publisher, "
195195
f"vm_offer, must all be set. "

airflow/providers/smtp/hooks/smtp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def _get_email_list_from_str(self, addresses: str) -> list[str]:
329329
:return: A list of email addresses.
330330
"""
331331
pattern = r"\s*[,;]\s*"
332-
return [address for address in re.split(pattern, addresses)]
332+
return re.split(pattern, addresses)
333333

334334
@property
335335
def conn(self) -> Connection:

airflow/utils/email.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,4 +340,4 @@ def _get_email_list_from_str(addresses: str) -> list[str]:
340340
:return: A list of email addresses.
341341
"""
342342
pattern = r"\s*[,;]\s*"
343-
return [address for address in re2.split(pattern, addresses)]
343+
return re2.split(pattern, addresses)

airflow/www/views.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2707,7 +2707,7 @@ def confirm(self):
27072707
return redirect_or_json(
27082708
origin, msg=f"TaskGroup {group_id} could not be found", status="error", status_code=404
27092709
)
2710-
tasks = [task for task in task_group.iter_tasks()]
2710+
tasks = list(task_group.iter_tasks())
27112711
elif task_id:
27122712
try:
27132713
task = dag.get_task(task_id)

dev/airflow-license

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,5 @@ if __name__ == "__main__":
7979
license = parse_license_file(notice[1])
8080
print(f"{notice[1]:<30}|{notice[2][:50]:<50}||{notice[0]:<20}||{license:<10}")
8181

82-
file_count = len([name for name in os.listdir("../licenses")])
82+
file_count = len(os.listdir("../licenses"))
8383
print(f"Defined licenses: {len(notices)} Files found: {file_count}")

docker_tests/test_prod_image.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def test_required_providers_are_installed(self):
8585
lines = PREINSTALLED_PROVIDERS
8686
else:
8787
lines = (d.strip() for d in INSTALLED_PROVIDER_PATH.read_text().splitlines())
88-
lines = (d for d in lines)
8988
packages_to_install = {f"apache-airflow-providers-{d.replace('.', '-')}" for d in lines}
9089
assert len(packages_to_install) != 0
9190

scripts/ci/pre_commit/common_precommit_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ def insert_documentation(file_path: Path, content: list[str], header: str, foote
6464

6565

6666
def get_directory_hash(directory: Path, skip_path_regexp: str | None = None) -> str:
67-
files = [file for file in directory.rglob("*")]
68-
files.sort()
67+
files = sorted(directory.rglob("*"))
6968
if skip_path_regexp:
7069
matcher = re.compile(skip_path_regexp)
7170
files = [file for file in files if not matcher.match(os.fspath(file.resolve()))]

tests/always/test_connection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
347347
),
348348
]
349349

350-
@pytest.mark.parametrize("test_config", [x for x in test_from_uri_params])
350+
@pytest.mark.parametrize("test_config", test_from_uri_params)
351351
def test_connection_from_uri(self, test_config: UriTestCaseConfig):
352352

353353
connection = Connection(uri=test_config.test_uri)
@@ -369,7 +369,7 @@ def test_connection_from_uri(self, test_config: UriTestCaseConfig):
369369

370370
self.mask_secret.assert_has_calls(expected_calls)
371371

372-
@pytest.mark.parametrize("test_config", [x for x in test_from_uri_params])
372+
@pytest.mark.parametrize("test_config", test_from_uri_params)
373373
def test_connection_get_uri_from_uri(self, test_config: UriTestCaseConfig):
374374
"""
375375
This test verifies that when we create a conn_1 from URI, and we generate a URI from that conn, that
@@ -390,7 +390,7 @@ def test_connection_get_uri_from_uri(self, test_config: UriTestCaseConfig):
390390
assert connection.schema == new_conn.schema
391391
assert connection.extra_dejson == new_conn.extra_dejson
392392

393-
@pytest.mark.parametrize("test_config", [x for x in test_from_uri_params])
393+
@pytest.mark.parametrize("test_config", test_from_uri_params)
394394
def test_connection_get_uri_from_conn(self, test_config: UriTestCaseConfig):
395395
"""
396396
This test verifies that if we create conn_1 from attributes (rather than from URI), and we generate a

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def pytest_print(text):
134134
# It is very unlikely that the user wants to display only numbers, but probably
135135
# the user just wants to count the queries.
136136
exit_stack.enter_context(count_queries(print_fn=pytest_print))
137-
elif any(c for c in ["time", "trace", "sql", "parameters"]):
137+
elif any(c in columns for c in ["time", "trace", "sql", "parameters"]):
138138
exit_stack.enter_context(
139139
trace_queries(
140140
display_num="num" in columns,

tests/models/test_mappedoperator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ def execute(self, context):
668668

669669
class ConsumeXcomOperator(PushXcomOperator):
670670
def execute(self, context):
671-
assert {i for i in self.arg1} == {1, 2, 3}
671+
assert set(self.arg1) == {1, 2, 3}
672672

673673
with dag_maker("test_all_xcomargs_from_mapped_tasks_are_consumable"):
674674
op1 = PushXcomOperator.partial(task_id="op1").expand(arg1=[1, 2, 3])

tests/models/test_skipmixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def task_group_op(k):
147147
branch_b = EmptyOperator(task_id="branch_b")
148148
branch_op(k) >> [branch_a, branch_b]
149149

150-
task_group_op.expand(k=[i for i in range(2)])
150+
task_group_op.expand(k=[0, 1])
151151

152152
dag_maker.create_dagrun()
153153
branch_op_ti_0 = TI(dag.get_task("task_group_op.branch_op"), execution_date=DEFAULT_DATE, map_index=0)

tests/providers/amazon/aws/sensors/test_eks.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,9 @@
4242
NODEGROUP_NAME = "test_nodegroup"
4343
TASK_ID = "test_eks_sensor"
4444

45-
CLUSTER_PENDING_STATES = frozenset(frozenset({state for state in ClusterStates}) - CLUSTER_TERMINAL_STATES)
46-
FARGATE_PENDING_STATES = frozenset(
47-
frozenset({state for state in FargateProfileStates}) - FARGATE_TERMINAL_STATES
48-
)
49-
NODEGROUP_PENDING_STATES = frozenset(
50-
frozenset({state for state in NodegroupStates}) - NODEGROUP_TERMINAL_STATES
51-
)
45+
CLUSTER_PENDING_STATES = frozenset(ClusterStates) - frozenset(CLUSTER_TERMINAL_STATES)
46+
FARGATE_PENDING_STATES = frozenset(FargateProfileStates) - frozenset(FARGATE_TERMINAL_STATES)
47+
NODEGROUP_PENDING_STATES = frozenset(NodegroupStates) - frozenset(NODEGROUP_TERMINAL_STATES)
5248

5349

5450
class TestEksClusterStateSensor:

tests/providers/google/cloud/log/test_stackdriver_task_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def test_should_read_logs_with_custom_resources(self, mock_client, mock_get_cred
311311

312312
entry = mock.MagicMock(json_payload={"message": "TEXT"})
313313
page = mock.MagicMock(entries=[entry, entry], next_page_token=None)
314-
mock_client.return_value.list_log_entries.return_value.pages = (n for n in [page])
314+
mock_client.return_value.list_log_entries.return_value.pages = iter([page])
315315

316316
logs, metadata = stackdriver_task_handler.read(self.ti)
317317
mock_client.return_value.list_log_entries.assert_called_once_with(

tests/providers/google/cloud/transfers/test_sql_to_gcs.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -555,9 +555,7 @@ def test__write_local_data_files_csv_does_not_write_on_empty_rows(self):
555555
files = op._write_local_data_files(cursor)
556556
# Raises StopIteration when next is called because generator returns no files
557557
with pytest.raises(StopIteration):
558-
next(files)["file_handle"]
559-
560-
assert len([f for f in files]) == 0
558+
next(files)
561559

562560
def test__write_local_data_files_csv_writes_empty_file_with_write_on_empty(self):
563561
op = DummySQLToGCSOperator(

tests/providers/google/cloud/triggers/test_cloud_storage_transfer_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def mock_jobs(names: list[str], latest_operation_names: list[str | None]):
6969
for job, name in zip(jobs, names):
7070
job.name = name
7171
mock_obj = mock.MagicMock()
72-
mock_obj.__aiter__.return_value = (job for job in jobs)
72+
mock_obj.__aiter__.return_value = iter(jobs)
7373
return mock_obj
7474

7575

tests/sensors/test_external_task_sensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def dummy_mapped_task(x: int):
135135
return x
136136

137137
dummy_task()
138-
dummy_mapped_task.expand(x=[i for i in map_indexes])
138+
dummy_mapped_task.expand(x=list(map_indexes))
139139

140140
SerializedDagModel.write_dag(dag)
141141

@@ -1089,7 +1089,7 @@ def run_tasks(dag_bag, execution_date=DEFAULT_DATE, session=None):
10891089
# this is equivalent to topological sort. It would not work in general case
10901090
# but it works for our case because we specifically constructed test DAGS
10911091
# in the way that those two sort methods are equivalent
1092-
tasks = sorted((ti for ti in dagrun.task_instances), key=lambda ti: ti.task_id)
1092+
tasks = sorted(dagrun.task_instances, key=lambda ti: ti.task_id)
10931093
for ti in tasks:
10941094
ti.refresh_from_task(dag.get_task(ti.task_id))
10951095
tis[ti.task_id] = ti
@@ -1478,7 +1478,7 @@ def dummy_task(x: int):
14781478
mode="reschedule",
14791479
)
14801480

1481-
body = dummy_task.expand(x=[i for i in range(5)])
1481+
body = dummy_task.expand(x=range(5))
14821482
tail = ExternalTaskMarker(
14831483
task_id="tail",
14841484
external_dag_id=dag.dag_id,
@@ -1524,7 +1524,7 @@ def test_clear_overlapping_external_task_marker_mapped_tasks(dag_bag_head_tail_m
15241524
include_downstream=True,
15251525
include_upstream=False,
15261526
)
1527-
task_ids = [tid for tid in dag.task_dict]
1527+
task_ids = list(dag.task_dict)
15281528
assert (
15291529
dag.clear(
15301530
start_date=DEFAULT_DATE,

tests/system/providers/amazon/aws/example_s3_to_sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def parse_csv_to_list(filepath):
177177
import csv
178178

179179
with open(filepath, newline="") as file:
180-
return [row for row in csv.reader(file)]
180+
return list(csv.reader(file))
181181

182182
transfer_s3_to_sql = S3ToSqlOperator(
183183
task_id="transfer_s3_to_sql",

0 commit comments

Comments
 (0)