Skip to content

Improve provider verification pre-commit #33640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
*get_extra_docker_flags(MOUNT_SELECTED),
"-e",
"SKIP_ENVIRONMENT_INITIALIZATION=true",
"-e",
"PYTHONWARNINGS=default",
"--pull",
"never",
airflow_image,
Expand Down
138 changes: 114 additions & 24 deletions scripts/in_container/run_provider_yaml_files_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import platform
import sys
import textwrap
import warnings
from collections import Counter
from enum import Enum
from typing import Any, Iterable
Expand All @@ -37,18 +38,22 @@
from tabulate import tabulate

from airflow.cli.commands.info_command import Architecture
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers_manager import ProvidersManager

# Those are deprecated modules that contain removed Hooks/Sensors/Operators that we left in the code
# so that users can get a very specific error message when they try to use them.

EXCLUDED_MODULES = [
DEPRECATED_MODULES = [
"airflow.providers.apache.hdfs.sensors.hdfs",
"airflow.providers.apache.hdfs.hooks.hdfs",
"airflow.providers.cncf.kubernetes.triggers.kubernetes_pod",
"airflow.providers.cncf.kubernetes.operators.kubernetes_pod",
]

KNOWN_DEPRECATED_CLASSES = [
"airflow.providers.google.cloud.links.dataproc.DataprocLink",
]

try:
from yaml import CSafeLoader as SafeLoader
Expand All @@ -71,6 +76,13 @@
errors: list[str] = []

console = Console(width=400, color_system="standard")
# you need to enable warnings for all deprecations - needed by importlib library to show deprecations
if os.environ.get("PYTHONWARNINGS") != "default":
console.print(
"[red]Error: PYTHONWARNINGS not set[/]\n"
"You must set `PYTHONWARNINGS=default` environment variable to run this script"
)
sys.exit(1)

suspended_providers: set[str] = set()
suspended_logos: set[str] = set()
Expand Down Expand Up @@ -136,7 +148,14 @@ def check_integration_duplicates(yaml_files: dict[str, dict]):
sys.exit(3)


def assert_sets_equal(set1, set2, allow_extra_in_set2=False):
def assert_sets_equal(
set1: set[str],
set_name_1: str,
set2: set[str],
set_name_2: str,
allow_extra_in_set2=False,
extra_message: str = "",
):
try:
difference1 = set1.difference(set2)
except TypeError as e:
Expand All @@ -153,6 +172,8 @@ def assert_sets_equal(set1, set2, allow_extra_in_set2=False):

if difference1 or (difference2 and not allow_extra_in_set2):
lines = []
lines.append(f" Left set:{set_name_1}")
lines.append(f" Right set:{set_name_2}")
if difference1:
lines.append(" -- Items in the left set but not the right:")
for item in sorted(difference1):
Expand All @@ -163,6 +184,8 @@ def assert_sets_equal(set1, set2, allow_extra_in_set2=False):
lines.append(f" {item!r}")

standard_msg = "\n".join(lines)
if extra_message:
standard_msg += f"\n{extra_message}"
raise AssertionError(standard_msg)


Expand All @@ -174,12 +197,37 @@ class ObjectType(Enum):
def check_if_object_exist(object_name: str, resource_type: str, yaml_file_path: str, object_type: ObjectType):
try:
if object_type == ObjectType.CLASS:
module_name, object_name = object_name.rsplit(".", maxsplit=1)
the_class = getattr(importlib.import_module(module_name), object_name)
module_name, class_name = object_name.rsplit(".", maxsplit=1)
with warnings.catch_warnings(record=True) as w:
the_class = getattr(importlib.import_module(module_name), class_name)
for warn in w:
if warn.category == AirflowProviderDeprecationWarning:
if object_name in KNOWN_DEPRECATED_CLASSES:
console.print(
f"[yellow]The {object_name} class is deprecated and we know about it. "
f"It should be removed in the future."
)
continue
errors.append(
f"The `{class_name}` class in {resource_type} list in {yaml_file_path} "
f"is deprecated with this message: '{warn.message}'.\n"
f"[yellow]How to fix it[/]: Please remove it from provider.yaml and replace with "
f"the new class."
)
if the_class and inspect.isclass(the_class):
return
elif object_type == ObjectType.MODULE:
module = importlib.import_module(object_name)
with warnings.catch_warnings(record=True) as w:
module = importlib.import_module(object_name)
for warn in w:
if warn.category == AirflowProviderDeprecationWarning:
errors.append(
f"The `{object_name}` module in {resource_type} list in {yaml_file_path} "
f"is deprecated with this message: '{warn.message}'.\n"
f"[yellow]How to fix it[/]: Please remove it from provider.yaml and replace it "
f"with the new module. If you see warnings in classes - fix the classes so that "
f"they are not raising Deprecation Warnings when module is imported."
)
if inspect.ismodule(module):
return
else:
Expand Down Expand Up @@ -231,23 +279,32 @@ def parse_module_data(provider_data, resource_type, yaml_file_path):
return expected_modules, provider_package, resource_data


def check_correctness_of_list_of_sensors_operators_hook_modules(yaml_files: dict[str, dict]):
print("Checking completeness of list of {sensors, hooks, operators, triggers}")
print(" -- {sensors, hooks, operators, triggers} - Expected modules (left) : Current modules (right)")
def check_correctness_of_list_of_sensors_operators_hook_trigger_modules(yaml_files: dict[str, dict]):
print(" -- Checking completeness of list of {sensors, hooks, operators, triggers}")
for (yaml_file_path, provider_data), resource_type in itertools.product(
yaml_files.items(), ["sensors", "operators", "hooks", "triggers"]
):
expected_modules, provider_package, resource_data = parse_module_data(
provider_data, resource_type, yaml_file_path
)
expected_modules = {module for module in expected_modules if module not in EXCLUDED_MODULES}
expected_modules = {module for module in expected_modules if module not in DEPRECATED_MODULES}
current_modules = {str(i) for r in resource_data for i in r.get("python-modules", [])}

check_if_objects_exist_and_belong_to_package(
current_modules, provider_package, yaml_file_path, resource_type, ObjectType.MODULE
)
try:
assert_sets_equal(set(expected_modules), set(current_modules))
package_name = os.fspath(ROOT_DIR.joinpath(yaml_file_path).parent.relative_to(ROOT_DIR)).replace(
"/", "."
)
assert_sets_equal(
set(expected_modules),
f"Found list of {resource_type} modules in provider package: {package_name}",
set(current_modules),
f"Currently configured list of {resource_type} modules in {yaml_file_path}",
extra_message="[yellow]If there are deprecated modules in the list, please add them to "
f"DEPRECATED_MODULES in {pathlib.Path(__file__).relative_to(ROOT_DIR)}[/]",
)
except AssertionError as ex:
nested_error = textwrap.indent(str(ex), " ")
errors.append(
Expand Down Expand Up @@ -276,19 +333,27 @@ def check_completeness_of_list_of_transfers(yaml_files: dict[str, dict]):
print("Checking completeness of list of transfers")
resource_type = "transfers"

print(" -- Expected transfers modules(Left): Current transfers Modules(Right)")
print(" -- Checking transfers modules")
for yaml_file_path, provider_data in yaml_files.items():
expected_modules, provider_package, resource_data = parse_module_data(
provider_data, resource_type, yaml_file_path
)
expected_modules = {module for module in expected_modules if module not in EXCLUDED_MODULES}
expected_modules = {module for module in expected_modules if module not in DEPRECATED_MODULES}
current_modules = {r.get("python-module") for r in resource_data}

check_if_objects_exist_and_belong_to_package(
current_modules, provider_package, yaml_file_path, resource_type, ObjectType.MODULE
)
try:
assert_sets_equal(set(expected_modules), set(current_modules))
package_name = os.fspath(ROOT_DIR.joinpath(yaml_file_path).parent.relative_to(ROOT_DIR)).replace(
"/", "."
)
assert_sets_equal(
set(expected_modules),
f"Found list of transfer modules in provider package: {package_name}",
set(current_modules),
f"Currently configured list of transfer modules in {yaml_file_path}",
)
except AssertionError as ex:
nested_error = textwrap.indent(str(ex), " ")
errors.append(
Expand Down Expand Up @@ -337,6 +402,18 @@ def check_extra_link_classes(yaml_files: dict[str, dict]):
)


def check_notification_classes(yaml_files: dict[str, dict]):
print("Checking notifications belong to package, exist and are classes")
resource_type = "notifications"
for yaml_file_path, provider_data in yaml_files.items():
provider_package = pathlib.Path(yaml_file_path).parent.as_posix().replace("/", ".")
notifications = provider_data.get(resource_type)
if notifications:
check_if_objects_exist_and_belong_to_package(
notifications, provider_package, yaml_file_path, resource_type, ObjectType.CLASS
)


def check_duplicates_in_list_of_transfers(yaml_files: dict[str, dict]):
print("Checking for duplicates in list of transfers")
errors = []
Expand Down Expand Up @@ -435,11 +512,20 @@ def check_doc_files(yaml_files: dict[str, dict]):
}

try:
print(" -- Checking document urls: expected (left), current (right)")
assert_sets_equal(set(expected_doc_urls), set(current_doc_urls))

print(" -- Checking logo urls: expected (left), current (right)")
assert_sets_equal(set(expected_logo_urls), set(current_logo_urls))
print(" -- Checking document urls")
assert_sets_equal(
set(expected_doc_urls),
"Document urls found in airflow/docs",
set(current_doc_urls),
"Document urls configured in provider.yaml files",
)
print(" -- Checking logo urls")
assert_sets_equal(
set(expected_logo_urls),
"Logo urls found in airflow/docs/integration-logos",
set(current_logo_urls),
"Logo urls configured in provider.yaml files",
)
except AssertionError as ex:
print(ex)
sys.exit(1)
Expand All @@ -465,12 +551,15 @@ def check_providers_are_mentioned_in_issue_template(yaml_files: dict[str, dict])
issue_template = yaml.safe_load(issue_file)
all_mentioned_providers = [match.value for match in jsonpath_expr.find(issue_template)]
try:
print(
f" -- Checking providers: present in code (left), "
f"mentioned in {PROVIDER_ISSUE_TEMPLATE_PATH} (right)"
)
print(f" -- Checking providers are mentioned in {PROVIDER_ISSUE_TEMPLATE_PATH}")
# in case of suspended providers, we still want to have them in the issue template
assert_sets_equal(set(short_provider_names), set(all_mentioned_providers), allow_extra_in_set2=True)
assert_sets_equal(
set(short_provider_names),
"Provider names found in provider.yaml files",
set(all_mentioned_providers),
f"Provider names mentioned in {PROVIDER_ISSUE_TEMPLATE_PATH}",
allow_extra_in_set2=True,
)
except AssertionError as ex:
print(ex)
sys.exit(1)
Expand Down Expand Up @@ -512,7 +601,8 @@ def check_providers_have_all_documentation_files(yaml_files: dict[str, dict]):
check_hook_connection_classes(all_parsed_yaml_files)
check_plugin_classes(all_parsed_yaml_files)
check_extra_link_classes(all_parsed_yaml_files)
check_correctness_of_list_of_sensors_operators_hook_modules(all_parsed_yaml_files)
check_correctness_of_list_of_sensors_operators_hook_trigger_modules(all_parsed_yaml_files)
check_notification_classes(all_parsed_yaml_files)
check_unique_provider_name(all_parsed_yaml_files)
check_providers_have_all_documentation_files(all_parsed_yaml_files)

Expand Down