Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement batch_is_authorized_* APIs in AWS auth manager #37430

Merged
merged 3 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse:
{
"method": "GET",
"details": DagDetails(id=id),
"user": g.user,
}
for id in dag_ids
]
Expand Down
47 changes: 20 additions & 27 deletions airflow/auth/managers/base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,24 @@ def is_authorized_custom_view(
"""
raise AirflowException(f"The resource `{fab_resource_name}` does not exist in the environment.")

def batch_is_authorized_connection(
self,
requests: Sequence[IsAuthorizedConnectionRequest],
) -> bool:
"""
Batch version of ``is_authorized_connection``.

By default, calls individually the ``is_authorized_connection`` API on each item in the list of
requests. Can lead to some poor performance. It is recommended to override this method in the auth
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
manager implementation to provide a more efficient implementation.

:param requests: a list of requests containing the parameters for ``is_authorized_connection``
"""
return all(
self.is_authorized_connection(method=request["method"], details=request.get("details"))
for request in requests
)

def batch_is_authorized_dag(
self,
requests: Sequence[IsAuthorizedDagRequest],
Expand All @@ -272,27 +290,6 @@ def batch_is_authorized_dag(
method=request["method"],
access_entity=request.get("access_entity"),
details=request.get("details"),
user=request.get("user"),
)
for request in requests
)

def batch_is_authorized_connection(
self,
requests: Sequence[IsAuthorizedConnectionRequest],
) -> bool:
"""
Batch version of ``is_authorized_connection``.

By default, calls individually the ``is_authorized_connection`` API on each item in the list of
requests. Can lead to some poor performance. It is recommended to override this method in the auth
manager implementation to provide a more efficient implementation.

:param requests: a list of requests containing the parameters for ``is_authorized_connection``
"""
return all(
self.is_authorized_connection(
method=request["method"], details=request.get("details"), user=request.get("user")
)
for request in requests
)
Expand All @@ -311,9 +308,7 @@ def batch_is_authorized_pool(
:param requests: a list of requests containing the parameters for ``is_authorized_pool``
"""
return all(
self.is_authorized_pool(
method=request["method"], details=request.get("details"), user=request.get("user")
)
self.is_authorized_pool(method=request["method"], details=request.get("details"))
for request in requests
)

Expand All @@ -331,9 +326,7 @@ def batch_is_authorized_variable(
:param requests: a list of requests containing the parameters for ``is_authorized_variable``
"""
return all(
self.is_authorized_variable(
method=request["method"], details=request.get("details"), user=request.get("user")
)
self.is_authorized_variable(method=request["method"], details=request.get("details"))
for request in requests
)

Expand Down
5 changes: 0 additions & 5 deletions airflow/auth/managers/models/batch_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

if TYPE_CHECKING:
from airflow.auth.managers.base_auth_manager import ResourceMethod
from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.resource_details import (
ConnectionDetails,
DagAccessEntity,
Expand All @@ -36,7 +35,6 @@ class IsAuthorizedConnectionRequest(TypedDict, total=False):

method: ResourceMethod
details: ConnectionDetails | None
user: BaseUser | None


class IsAuthorizedDagRequest(TypedDict, total=False):
Expand All @@ -45,20 +43,17 @@ class IsAuthorizedDagRequest(TypedDict, total=False):
method: ResourceMethod
access_entity: DagAccessEntity | None
details: DagDetails | None
user: BaseUser | None


class IsAuthorizedPoolRequest(TypedDict, total=False):
"""Represent the parameters of ``is_authorized_pool`` API in the auth manager."""

method: ResourceMethod
details: PoolDetails | None
user: BaseUser | None


class IsAuthorizedVariableRequest(TypedDict, total=False):
"""Represent the parameters of ``is_authorized_variable`` API in the auth manager."""

method: ResourceMethod
details: VariableDetails | None
user: BaseUser | None
68 changes: 67 additions & 1 deletion airflow/providers/amazon/aws/auth_manager/avp/facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Sequence, TypedDict

from airflow.configuration import conf
from airflow.exceptions import AirflowException
Expand All @@ -37,6 +37,15 @@
from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser


class IsAuthorizedRequest(TypedDict, total=False):
"""Represent the parameters of ``is_authorized`` method in AVP facade."""

method: ResourceMethod
entity_type: AvpEntities
entity_id: str | None
context: dict | None


class AwsAuthManagerAmazonVerifiedPermissionsFacade(LoggingMixin):
"""
Facade for Amazon Verified Permissions.
Expand Down Expand Up @@ -116,6 +125,63 @@ def is_authorized(

return resp["decision"] == "ALLOW"

def batch_is_authorized(
self,
*,
requests: Sequence[IsAuthorizedRequest],
user: AwsAuthManagerUser | None,
) -> bool:
"""
Make a batch authorization decision against Amazon Verified Permissions.

Check whether the user has permissions to access given resources.

:param requests: the list of requests containing the method, the entity_type and the entity ID
:param user: the user
"""
if user is None:
return False

entity_list = self._get_user_role_entities(user)

self.log.debug("Making batch authorization request for user=%s, requests=%s", user.get_id(), requests)

avp_requests = [
prune_dict(
{
"principal": {"entityType": get_entity_type(AvpEntities.USER), "entityId": user.get_id()},
"action": {
"actionType": get_entity_type(AvpEntities.ACTION),
"actionId": get_action_id(request["entity_type"], request["method"]),
},
"resource": {
"entityType": get_entity_type(request["entity_type"]),
"entityId": request.get("entity_id", "*"),
},
"context": self._build_context(request.get("context")),
}
)
for request in requests
]

resp = self.avp_client.batch_is_authorized(
policyStoreId=self.avp_policy_store_id,
requests=avp_requests,
entities={"entityList": entity_list},
)

self.log.debug("Authorization response: %s", resp)

has_errors = any(len(result.get("errors", [])) > 0 for result in resp["results"])

if has_errors:
self.log.error(
"Error occurred while making a batch authorization decision. Result: %s", resp["results"]
)
raise AirflowException("Error occurred while making a batch authorization decision.")
vincbeck marked this conversation as resolved.
Show resolved Hide resolved

return all(result["decision"] == "ALLOW" for result in resp["results"])

@staticmethod
def _get_user_role_entities(user: AwsAuthManagerUser) -> list[dict]:
user_entity = {
Expand Down
112 changes: 105 additions & 7 deletions airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@

import argparse
from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Sequence, cast

from flask import session, url_for

from airflow.cli.cli_config import CLICommand, DefaultHelpParser, GroupCommand
from airflow.configuration import conf
from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities
from airflow.providers.amazon.aws.auth_manager.avp.facade import AwsAuthManagerAmazonVerifiedPermissionsFacade
from airflow.providers.amazon.aws.auth_manager.avp.facade import (
AwsAuthManagerAmazonVerifiedPermissionsFacade,
IsAuthorizedRequest,
)
from airflow.providers.amazon.aws.auth_manager.cli.definition import (
AWS_AUTH_MANAGER_COMMANDS,
)
Expand All @@ -40,22 +43,30 @@

try:
from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod
from airflow.auth.managers.models.resource_details import (
ConnectionDetails,
DagAccessEntity,
DagDetails,
PoolDetails,
VariableDetails,
)
except ImportError:
raise AirflowOptionalProviderFeatureException(
"Failed to import BaseUser. This feature is only available in Airflow versions >= 2.8.0"
)

if TYPE_CHECKING:
from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.batch_apis import (
IsAuthorizedConnectionRequest,
IsAuthorizedDagRequest,
IsAuthorizedPoolRequest,
IsAuthorizedVariableRequest,
)
from airflow.auth.managers.models.resource_details import (
AccessView,
ConfigurationDetails,
ConnectionDetails,
DagAccessEntity,
DagDetails,
DatasetDetails,
PoolDetails,
VariableDetails,
)
from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
from airflow.www.extensions.init_appbuilder import AirflowAppBuilder
Expand Down Expand Up @@ -191,6 +202,93 @@ def is_authorized_view(
entity_id=access_view.value,
)

def batch_is_authorized_connection(
self,
requests: Sequence[IsAuthorizedConnectionRequest],
) -> bool:
"""
Batch version of ``is_authorized_connection``.

:param requests: a list of requests containing the parameters for ``is_authorized_connection``
"""
facade_requests: Sequence[IsAuthorizedRequest] = [
{
"method": request["method"],
"entity_type": AvpEntities.CONNECTION,
"entity_id": cast(ConnectionDetails, request["details"]).conn_id
if request.get("details")
else None,
}
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user())

def batch_is_authorized_dag(
self,
requests: Sequence[IsAuthorizedDagRequest],
) -> bool:
"""
Batch version of ``is_authorized_dag``.

:param requests: a list of requests containing the parameters for ``is_authorized_dag``
"""
facade_requests: Sequence[IsAuthorizedRequest] = [
{
"method": request["method"],
"entity_type": AvpEntities.DAG,
"entity_id": cast(DagDetails, request["details"]).id if request.get("details") else None,
"context": {
"dag_entity": {
"string": cast(DagAccessEntity, request["access_entity"]).value,
},
}
if request.get("access_entity")
else None,
}
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user())

def batch_is_authorized_pool(
self,
requests: Sequence[IsAuthorizedPoolRequest],
) -> bool:
"""
Batch version of ``is_authorized_pool``.

:param requests: a list of requests containing the parameters for ``is_authorized_pool``
"""
facade_requests: Sequence[IsAuthorizedRequest] = [
{
"method": request["method"],
"entity_type": AvpEntities.POOL,
"entity_id": cast(PoolDetails, request["details"]).name if request.get("details") else None,
}
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user())

def batch_is_authorized_variable(
self,
requests: Sequence[IsAuthorizedVariableRequest],
) -> bool:
"""
Batch version of ``is_authorized_variable``.

:param requests: a list of requests containing the parameters for ``is_authorized_variable``
"""
facade_requests: Sequence[IsAuthorizedRequest] = [
{
"method": request["method"],
"entity_type": AvpEntities.VARIABLE,
"entity_id": cast(VariableDetails, request["details"]).key
if request.get("details")
else None,
}
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user())

def get_url_login(self, **kwargs) -> str:
return url_for("AwsAuthManagerAuthenticationViews.login")

Expand Down
Loading