Skip to content

Commit

Permalink
Implement batch_is_authorized_* APIs in AWS auth manager (#37430)
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck authored Feb 22, 2024
1 parent 45eeff4 commit 0c2d2c6
Show file tree
Hide file tree
Showing 7 changed files with 400 additions and 41 deletions.
1 change: 0 additions & 1 deletion airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse:
{
"method": "GET",
"details": DagDetails(id=id),
"user": g.user,
}
for id in dag_ids
]
Expand Down
47 changes: 20 additions & 27 deletions airflow/auth/managers/base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,24 @@ def is_authorized_custom_view(
"""
raise AirflowException(f"The resource `{fab_resource_name}` does not exist in the environment.")

def batch_is_authorized_connection(
self,
requests: Sequence[IsAuthorizedConnectionRequest],
) -> bool:
"""
Batch version of ``is_authorized_connection``.
By default, calls individually the ``is_authorized_connection`` API on each item in the list of
requests, which can lead to some poor performance. It is recommended to override this method in the auth
manager implementation to provide a more efficient implementation.
:param requests: a list of requests containing the parameters for ``is_authorized_connection``
"""
return all(
self.is_authorized_connection(method=request["method"], details=request.get("details"))
for request in requests
)

def batch_is_authorized_dag(
self,
requests: Sequence[IsAuthorizedDagRequest],
Expand All @@ -272,27 +290,6 @@ def batch_is_authorized_dag(
method=request["method"],
access_entity=request.get("access_entity"),
details=request.get("details"),
user=request.get("user"),
)
for request in requests
)

def batch_is_authorized_connection(
self,
requests: Sequence[IsAuthorizedConnectionRequest],
) -> bool:
"""
Batch version of ``is_authorized_connection``.
By default, calls individually the ``is_authorized_connection`` API on each item in the list of
requests. Can lead to some poor performance. It is recommended to override this method in the auth
manager implementation to provide a more efficient implementation.
:param requests: a list of requests containing the parameters for ``is_authorized_connection``
"""
return all(
self.is_authorized_connection(
method=request["method"], details=request.get("details"), user=request.get("user")
)
for request in requests
)
Expand All @@ -311,9 +308,7 @@ def batch_is_authorized_pool(
:param requests: a list of requests containing the parameters for ``is_authorized_pool``
"""
return all(
self.is_authorized_pool(
method=request["method"], details=request.get("details"), user=request.get("user")
)
self.is_authorized_pool(method=request["method"], details=request.get("details"))
for request in requests
)

Expand All @@ -331,9 +326,7 @@ def batch_is_authorized_variable(
:param requests: a list of requests containing the parameters for ``is_authorized_variable``
"""
return all(
self.is_authorized_variable(
method=request["method"], details=request.get("details"), user=request.get("user")
)
self.is_authorized_variable(method=request["method"], details=request.get("details"))
for request in requests
)

Expand Down
5 changes: 0 additions & 5 deletions airflow/auth/managers/models/batch_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

if TYPE_CHECKING:
from airflow.auth.managers.base_auth_manager import ResourceMethod
from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.resource_details import (
ConnectionDetails,
DagAccessEntity,
Expand All @@ -36,7 +35,6 @@ class IsAuthorizedConnectionRequest(TypedDict, total=False):

method: ResourceMethod
details: ConnectionDetails | None
user: BaseUser | None


class IsAuthorizedDagRequest(TypedDict, total=False):
Expand All @@ -45,20 +43,17 @@ class IsAuthorizedDagRequest(TypedDict, total=False):
method: ResourceMethod
access_entity: DagAccessEntity | None
details: DagDetails | None
user: BaseUser | None


class IsAuthorizedPoolRequest(TypedDict, total=False):
"""Represent the parameters of ``is_authorized_pool`` API in the auth manager."""

method: ResourceMethod
details: PoolDetails | None
user: BaseUser | None


class IsAuthorizedVariableRequest(TypedDict, total=False):
"""Represent the parameters of ``is_authorized_variable`` API in the auth manager."""

method: ResourceMethod
details: VariableDetails | None
user: BaseUser | None
68 changes: 67 additions & 1 deletion airflow/providers/amazon/aws/auth_manager/avp/facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Sequence, TypedDict

from airflow.configuration import conf
from airflow.exceptions import AirflowException
Expand All @@ -37,6 +37,15 @@
from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser


class IsAuthorizedRequest(TypedDict, total=False):
"""Represent the parameters of ``is_authorized`` method in AVP facade."""

method: ResourceMethod
entity_type: AvpEntities
entity_id: str | None
context: dict | None


class AwsAuthManagerAmazonVerifiedPermissionsFacade(LoggingMixin):
"""
Facade for Amazon Verified Permissions.
Expand Down Expand Up @@ -116,6 +125,63 @@ def is_authorized(

return resp["decision"] == "ALLOW"

def batch_is_authorized(
self,
*,
requests: Sequence[IsAuthorizedRequest],
user: AwsAuthManagerUser | None,
) -> bool:
"""
Make a batch authorization decision against Amazon Verified Permissions.
Check whether the user has permissions to access given resources.
:param requests: the list of requests containing the method, the entity_type and the entity ID
:param user: the user
"""
if user is None:
return False

entity_list = self._get_user_role_entities(user)

self.log.debug("Making batch authorization request for user=%s, requests=%s", user.get_id(), requests)

avp_requests = [
prune_dict(
{
"principal": {"entityType": get_entity_type(AvpEntities.USER), "entityId": user.get_id()},
"action": {
"actionType": get_entity_type(AvpEntities.ACTION),
"actionId": get_action_id(request["entity_type"], request["method"]),
},
"resource": {
"entityType": get_entity_type(request["entity_type"]),
"entityId": request.get("entity_id", "*"),
},
"context": self._build_context(request.get("context")),
}
)
for request in requests
]

resp = self.avp_client.batch_is_authorized(
policyStoreId=self.avp_policy_store_id,
requests=avp_requests,
entities={"entityList": entity_list},
)

self.log.debug("Authorization response: %s", resp)

has_errors = any(len(result.get("errors", [])) > 0 for result in resp["results"])

if has_errors:
self.log.error(
"Error occurred while making a batch authorization decision. Result: %s", resp["results"]
)
raise AirflowException("Error occurred while making a batch authorization decision.")

return all(result["decision"] == "ALLOW" for result in resp["results"])

@staticmethod
def _get_user_role_entities(user: AwsAuthManagerUser) -> list[dict]:
user_entity = {
Expand Down
112 changes: 105 additions & 7 deletions airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@

import argparse
from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Sequence, cast

from flask import session, url_for

from airflow.cli.cli_config import CLICommand, DefaultHelpParser, GroupCommand
from airflow.configuration import conf
from airflow.exceptions import AirflowOptionalProviderFeatureException
from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities
from airflow.providers.amazon.aws.auth_manager.avp.facade import AwsAuthManagerAmazonVerifiedPermissionsFacade
from airflow.providers.amazon.aws.auth_manager.avp.facade import (
AwsAuthManagerAmazonVerifiedPermissionsFacade,
IsAuthorizedRequest,
)
from airflow.providers.amazon.aws.auth_manager.cli.definition import (
AWS_AUTH_MANAGER_COMMANDS,
)
Expand All @@ -40,22 +43,30 @@

try:
from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod
from airflow.auth.managers.models.resource_details import (
ConnectionDetails,
DagAccessEntity,
DagDetails,
PoolDetails,
VariableDetails,
)
except ImportError:
raise AirflowOptionalProviderFeatureException(
"Failed to import BaseUser. This feature is only available in Airflow versions >= 2.8.0"
)

if TYPE_CHECKING:
from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.batch_apis import (
IsAuthorizedConnectionRequest,
IsAuthorizedDagRequest,
IsAuthorizedPoolRequest,
IsAuthorizedVariableRequest,
)
from airflow.auth.managers.models.resource_details import (
AccessView,
ConfigurationDetails,
ConnectionDetails,
DagAccessEntity,
DagDetails,
DatasetDetails,
PoolDetails,
VariableDetails,
)
from airflow.providers.amazon.aws.auth_manager.user import AwsAuthManagerUser
from airflow.www.extensions.init_appbuilder import AirflowAppBuilder
Expand Down Expand Up @@ -191,6 +202,93 @@ def is_authorized_view(
entity_id=access_view.value,
)

def batch_is_authorized_connection(
self,
requests: Sequence[IsAuthorizedConnectionRequest],
) -> bool:
"""
Batch version of ``is_authorized_connection``.
:param requests: a list of requests containing the parameters for ``is_authorized_connection``
"""
facade_requests: Sequence[IsAuthorizedRequest] = [
{
"method": request["method"],
"entity_type": AvpEntities.CONNECTION,
"entity_id": cast(ConnectionDetails, request["details"]).conn_id
if request.get("details")
else None,
}
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user())

def batch_is_authorized_dag(
self,
requests: Sequence[IsAuthorizedDagRequest],
) -> bool:
"""
Batch version of ``is_authorized_dag``.
:param requests: a list of requests containing the parameters for ``is_authorized_dag``
"""
facade_requests: Sequence[IsAuthorizedRequest] = [
{
"method": request["method"],
"entity_type": AvpEntities.DAG,
"entity_id": cast(DagDetails, request["details"]).id if request.get("details") else None,
"context": {
"dag_entity": {
"string": cast(DagAccessEntity, request["access_entity"]).value,
},
}
if request.get("access_entity")
else None,
}
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user())

def batch_is_authorized_pool(
self,
requests: Sequence[IsAuthorizedPoolRequest],
) -> bool:
"""
Batch version of ``is_authorized_pool``.
:param requests: a list of requests containing the parameters for ``is_authorized_pool``
"""
facade_requests: Sequence[IsAuthorizedRequest] = [
{
"method": request["method"],
"entity_type": AvpEntities.POOL,
"entity_id": cast(PoolDetails, request["details"]).name if request.get("details") else None,
}
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user())

def batch_is_authorized_variable(
self,
requests: Sequence[IsAuthorizedVariableRequest],
) -> bool:
"""
Batch version of ``is_authorized_variable``.
:param requests: a list of requests containing the parameters for ``is_authorized_variable``
"""
facade_requests: Sequence[IsAuthorizedRequest] = [
{
"method": request["method"],
"entity_type": AvpEntities.VARIABLE,
"entity_id": cast(VariableDetails, request["details"]).key
if request.get("details")
else None,
}
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=self.get_user())

def get_url_login(self, **kwargs) -> str:
return url_for("AwsAuthManagerAuthenticationViews.login")

Expand Down
Loading

0 comments on commit 0c2d2c6

Please sign in to comment.