Skip to content

Commit

Permalink
Trigger FAB provider tests on API change (#39010)
Browse files Browse the repository at this point in the history
Follow up after #38924 which was not triggered when API changed
  • Loading branch information
potiuk authored Apr 14, 2024
1 parent 7fc2169 commit ac1f744
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 78 deletions.
134 changes: 66 additions & 68 deletions dev/breeze/src/airflow_breeze/utils/selective_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,60 +304,6 @@ def find_provider_affected(changed_file: str, include_docs: bool) -> str | None:
return "Providers"


def find_all_providers_affected(
changed_files: tuple[str, ...], include_docs: bool, fail_if_suspended_providers_affected: bool
) -> list[str] | str | None:
all_providers: set[str] = set()

all_providers_affected = False
suspended_providers: set[str] = set()
for changed_file in changed_files:
provider = find_provider_affected(changed_file, include_docs=include_docs)
if provider == "Providers":
all_providers_affected = True
elif provider is not None:
if provider not in DEPENDENCIES:
suspended_providers.add(provider)
else:
all_providers.add(provider)
if all_providers_affected:
return "ALL_PROVIDERS"
if suspended_providers:
# We check for suspended providers only after we have checked if all providers are affected.
# No matter if we found that we are modifying a suspended provider individually, if all providers are
# affected, then it means that we are ok to proceed because likely we are running some kind of
# global refactoring that affects multiple providers including the suspended one. This is a
# potential escape hatch if someone would like to modify suspended provider,
# but it can be found at the review time and is anyway harmless as the provider will not be
# released nor tested nor used in CI anyway.
get_console().print("[yellow]You are modifying suspended providers.\n")
get_console().print(
"[info]Some providers modified by this change have been suspended, "
"and before attempting such changes you should fix the reason for suspension."
)
get_console().print(
"[info]When fixing it, you should set suspended = false in provider.yaml "
"to make changes to the provider."
)
get_console().print(f"Suspended providers: {suspended_providers}")
if fail_if_suspended_providers_affected:
get_console().print(
"[error]This PR did not have `allow suspended provider changes` label set so it will fail."
)
sys.exit(1)
else:
get_console().print(
"[info]This PR had `allow suspended provider changes` label set so it will continue"
)
if not all_providers:
return None
for provider in list(all_providers):
all_providers.update(
get_related_providers(provider, upstream_dependencies=True, downstream_dependencies=True)
)
return sorted(all_providers)


def _match_files_with_regexps(files: tuple[str, ...], matched_files, matching_regexps):
for file in files:
if any(re.match(regexp, file) for regexp in matching_regexps):
Expand Down Expand Up @@ -747,7 +693,7 @@ def _are_all_providers_affected(self) -> bool:
# prepare all providers packages and build all providers documentation
return "Providers" in self._get_test_types_to_run()

def _fail_if_suspended_providers_affected(self):
def _fail_if_suspended_providers_affected(self) -> bool:
return "allow suspended provider changes" not in self._pr_labels

def _get_test_types_to_run(self) -> list[str]:
Expand Down Expand Up @@ -800,14 +746,17 @@ def _get_test_types_to_run(self) -> list[str]:
get_console().print(remaining_files)
candidate_test_types.update(all_selective_test_types())
else:
if "Providers" in candidate_test_types:
affected_providers = find_all_providers_affected(
changed_files=self._files,
if "Providers" in candidate_test_types or "API" in candidate_test_types:
affected_providers = self.find_all_providers_affected(
include_docs=False,
fail_if_suspended_providers_affected=self._fail_if_suspended_providers_affected(),
)
if affected_providers != "ALL_PROVIDERS" and affected_providers is not None:
candidate_test_types.remove("Providers")
try:
candidate_test_types.remove("Providers")
except KeyError:
# In case of API tests Providers could not be in the list originally so we can ignore
# Providers missing in the list.
pass
candidate_test_types.add(f"Providers[{','.join(sorted(affected_providers))}]")
get_console().print(
"[warning]There are no core/other files. Only tests relevant to the changed files are run.[/]"
Expand Down Expand Up @@ -988,10 +937,8 @@ def docs_list_as_string(self) -> str | None:
return "apache-airflow docker-stack"
if self.full_tests_needed:
return _ALL_DOCS_LIST
providers_affected = find_all_providers_affected(
changed_files=self._files,
providers_affected = self.find_all_providers_affected(
include_docs=True,
fail_if_suspended_providers_affected=self._fail_if_suspended_providers_affected(),
)
if (
providers_affected == "ALL_PROVIDERS"
Expand Down Expand Up @@ -1100,11 +1047,7 @@ def affected_providers_list_as_string(self) -> str | None:
return _ALL_PROVIDERS_LIST
if self._are_all_providers_affected():
return _ALL_PROVIDERS_LIST
affected_providers = find_all_providers_affected(
changed_files=self._files,
include_docs=True,
fail_if_suspended_providers_affected=self._fail_if_suspended_providers_affected(),
)
affected_providers = self.find_all_providers_affected(include_docs=True)
if not affected_providers:
return None
if affected_providers == "ALL_PROVIDERS":
Expand Down Expand Up @@ -1259,3 +1202,58 @@ def is_committer_build(self):
if NON_COMMITTER_BUILD_LABEL in self._pr_labels:
return False
return self._github_actor in COMMITTERS

def find_all_providers_affected(self, include_docs: bool) -> list[str] | str | None:
all_providers: set[str] = set()

all_providers_affected = False
suspended_providers: set[str] = set()
for changed_file in self._files:
provider = find_provider_affected(changed_file, include_docs=include_docs)
if provider == "Providers":
all_providers_affected = True
elif provider is not None:
if provider not in DEPENDENCIES:
suspended_providers.add(provider)
else:
all_providers.add(provider)
if self.needs_api_tests:
all_providers.add("fab")
if all_providers_affected:
return "ALL_PROVIDERS"
if suspended_providers:
# We check for suspended providers only after we have checked if all providers are affected.
# No matter if we found that we are modifying a suspended provider individually,
# if all providers are
# affected, then it means that we are ok to proceed because likely we are running some kind of
# global refactoring that affects multiple providers including the suspended one. This is a
# potential escape hatch if someone would like to modify suspended provider,
# but it can be found at the review time and is anyway harmless as the provider will not be
# released nor tested nor used in CI anyway.
get_console().print("[yellow]You are modifying suspended providers.\n")
get_console().print(
"[info]Some providers modified by this change have been suspended, "
"and before attempting such changes you should fix the reason for suspension."
)
get_console().print(
"[info]When fixing it, you should set suspended = false in provider.yaml "
"to make changes to the provider."
)
get_console().print(f"Suspended providers: {suspended_providers}")
if self._fail_if_suspended_providers_affected():
get_console().print(
"[error]This PR did not have `allow suspended provider changes`"
" label set so it will fail."
)
sys.exit(1)
else:
get_console().print(
"[info]This PR had `allow suspended provider changes` label set so it will continue"
)
if not all_providers:
return None
for provider in list(all_providers):
all_providers.update(
get_related_providers(provider, upstream_dependencies=True, downstream_dependencies=True)
)
return sorted(all_providers)
20 changes: 10 additions & 10 deletions dev/breeze/tests/test_selective_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
pytest.param(
("airflow/api/file.py",),
{
"affected-providers-list-as-string": None,
"affected-providers-list-as-string": "fab",
"all-python-versions": "['3.8']",
"all-python-versions-list-as-string": "3.8",
"python-versions": "['3.8']",
Expand All @@ -138,11 +138,11 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
"skip-pre-commits": "check-provider-yaml-valid,identity,lint-helm-chart,mypy-airflow,mypy-dev,"
"mypy-docs,mypy-providers,ts-compile-format-lint-www",
"upgrade-to-newer-dependencies": "false",
"parallel-test-types-list-as-string": "API Always",
"parallel-test-types-list-as-string": "API Always Providers[fab]",
"needs-mypy": "true",
"mypy-folders": "['airflow']",
},
id="Only API tests and DOCS should run",
id="Only API tests and DOCS and FAB provider should run",
)
),
(
Expand Down Expand Up @@ -228,7 +228,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
"tests/providers/postgres/file.py",
),
{
"affected-providers-list-as-string": "amazon common.sql google openlineage "
"affected-providers-list-as-string": "amazon common.sql fab google openlineage "
"pgvector postgres",
"all-python-versions": "['3.8']",
"all-python-versions-list-as-string": "3.8",
Expand All @@ -244,7 +244,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
"ts-compile-format-lint-www",
"upgrade-to-newer-dependencies": "false",
"parallel-test-types-list-as-string": "API Always Providers[amazon] "
"Providers[common.sql,openlineage,pgvector,postgres] Providers[google]",
"Providers[common.sql,fab,openlineage,pgvector,postgres] Providers[google]",
"needs-mypy": "true",
"mypy-folders": "['airflow', 'providers']",
},
Expand Down Expand Up @@ -1211,24 +1211,24 @@ def test_expected_output_pull_request_v2_7(
"airflow/api/file.py",
),
{
"affected-providers-list-as-string": None,
"affected-providers-list-as-string": "fab",
"all-python-versions": "['3.8']",
"all-python-versions-list-as-string": "3.8",
"ci-image-build": "true",
"prod-image-build": "false",
"needs-helm-tests": "false",
"run-tests": "true",
"docs-build": "true",
"docs-list-as-string": "apache-airflow",
"docs-list-as-string": "apache-airflow fab",
"skip-pre-commits": "check-provider-yaml-valid,identity,lint-helm-chart,mypy-airflow,mypy-dev,mypy-docs,mypy-providers,ts-compile-format-lint-www",
"run-kubernetes-tests": "false",
"upgrade-to-newer-dependencies": "false",
"skip-provider-tests": "true",
"parallel-test-types-list-as-string": "API Always CLI Operators WWW",
"skip-provider-tests": "false",
"parallel-test-types-list-as-string": "API Always CLI Operators Providers[fab] WWW",
"needs-mypy": "true",
"mypy-folders": "['airflow']",
},
id="No providers tests should run if only CLI/API/Operators/WWW file changed",
id="No providers tests except fab should run if only CLI/API/Operators/WWW file changed",
),
pytest.param(
("airflow/models/test.py",),
Expand Down

0 comments on commit ac1f744

Please sign in to comment.