Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Context access settings #415

Merged
merged 5 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ where = ["src"]

[project]
name = "aiogram_dialog"
version = "2.2.0a4"
version = "2.2.0a5"
readme = "README.md"
authors = [
{ name = "Andrey Tikhonov", email = "[email protected]" },
Expand Down
3 changes: 2 additions & 1 deletion src/aiogram_dialog/api/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
"DialogStartEvent", "DialogSwitchEvent", "DialogUpdate",
]

from .access import AccessSettings
from .context import Context, Data
from .events import ChatEvent, EVENT_CONTEXT_KEY, EventContext
from .launch_mode import LaunchMode
from .media import MediaAttachment, MediaId
from .modes import ShowMode, StartMode
from .new_message import MarkupVariant, NewMessage, OldMessage, UnknownText
from .stack import AccessSettings, DEFAULT_STACK_ID, GROUP_STACK_ID, Stack
from .stack import DEFAULT_STACK_ID, GROUP_STACK_ID, Stack
from .update_event import (
DIALOG_EVENT_NAME, DialogAction, DialogStartEvent, DialogSwitchEvent,
DialogUpdate, DialogUpdateEvent,
Expand Down
11 changes: 11 additions & 0 deletions src/aiogram_dialog/api/entities/access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from dataclasses import dataclass
from typing import Any, List, Optional

from aiogram.enums import ChatMemberStatus


@dataclass
class AccessSettings:
user_ids: List[int]
member_status: Optional[ChatMemberStatus] = None
custom: Any = None
6 changes: 5 additions & 1 deletion src/aiogram_dialog/api/entities/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dataclasses import dataclass, field
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union

from aiogram.fsm.state import State
from .access import AccessSettings

Data = Union[Dict, List, int, str, float, None]
DataDict = Dict[str, Data]
Expand All @@ -15,6 +16,9 @@ class Context:
start_data: Data = field(compare=False)
dialog_data: DataDict = field(compare=False, default_factory=dict)
widget_data: DataDict = field(compare=False, default_factory=dict)
access_settings: Optional[AccessSettings] = field(
compare=False, default=None,
)

@property
def id(self) -> str:
Expand Down
11 changes: 2 additions & 9 deletions src/aiogram_dialog/api/entities/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import string
import time
from dataclasses import dataclass, field
from typing import Any, List, Optional
from typing import List, Optional

from aiogram.enums import ChatMemberStatus
from aiogram.fsm.state import State

from aiogram_dialog.api.exceptions import DialogStackOverflow
from .access import AccessSettings
from .context import Context, Data

DEFAULT_STACK_ID = ""
Expand Down Expand Up @@ -35,13 +35,6 @@ def new_id():
return id_to_str(new_int_id())


@dataclass
class AccessSettings:
user_ids: List[int]
member_status: Optional[ChatMemberStatus] = None
custom: Any = None


@dataclass(unsafe_hash=True)
class Stack:
_id: str = field(compare=True, default_factory=new_id)
Expand Down
10 changes: 7 additions & 3 deletions src/aiogram_dialog/api/protocols/stack_access.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from abc import abstractmethod
from typing import Protocol
from typing import Protocol, Optional

from aiogram_dialog import ChatEvent
from aiogram_dialog.api.entities import Stack
from aiogram_dialog.api.entities import Stack, Context


class StackAccessValidator(Protocol):
@abstractmethod
async def is_allowed(
self, stack: Stack, event: ChatEvent, data: dict,
self,
stack: Stack,
context: Optional[Context],
event: ChatEvent,
data: dict,
) -> bool:
raise NotImplementedError
20 changes: 15 additions & 5 deletions src/aiogram_dialog/context/access_validator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from logging import getLogger
from typing import Optional

from aiogram.enums import ChatType

from aiogram_dialog import ChatEvent
from aiogram_dialog.api.entities import (
Stack,
Stack, Context,
)
from aiogram_dialog.api.protocols import StackAccessValidator

Expand All @@ -13,15 +14,24 @@

class DefaultAccessValidator(StackAccessValidator):
async def is_allowed(
self, stack: Stack, event: ChatEvent, data: dict,
self,
stack: Stack,
context: Optional[Context],
event: ChatEvent,
data: dict,
) -> bool:
if not stack.access_settings:
if context:
access_settings = context.access_settings
else:
access_settings = stack.access_settings

if not access_settings:
return True
chat = data["event_chat"]
if chat.type is ChatType.PRIVATE:
return True
if stack.access_settings.user_ids:
if access_settings.user_ids:
user = data["event_from_user"]
if user.id not in stack.access_settings.user_ids:
if user.id not in access_settings.user_ids:
return False
return True
46 changes: 36 additions & 10 deletions src/aiogram_dialog/context/intent_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,6 @@ async def _load_stack(
if stack_id is None:
raise InvalidStackIdError("Both stack id and intent id are None")
stack = await proxy.load_stack(stack_id)
if not await self.access_validator.is_allowed(stack, event, data):
logger.debug(
"Stack %s is not allowed for user %s",
stack.id, proxy.user_id,
)
data[FORBIDDEN_STACK_KEY] = True
await proxy.unlock()
return
return stack

async def _load_context_by_stack(
Expand All @@ -198,6 +190,17 @@ async def _load_context_by_stack(
except: # noqa: B001,B901,E722
await proxy.unlock()
raise

if not await self.access_validator.is_allowed(
stack, context, event, data,
):
logger.debug(
"Stack %s is not allowed for user %s",
stack.id, proxy.user_id,
)
data[FORBIDDEN_STACK_KEY] = True
await proxy.unlock()
return
data[STORAGE_KEY] = proxy
data[STACK_KEY] = stack
data[CONTEXT_KEY] = context
Expand All @@ -223,6 +226,16 @@ async def _load_context_by_intent(
await proxy.unlock()
raise

if not await self.access_validator.is_allowed(
stack, context, event, data,
):
logger.debug(
"Stack %s is not allowed for user %s",
stack.id, proxy.user_id,
)
data[FORBIDDEN_STACK_KEY] = True
await proxy.unlock()
return
data[STORAGE_KEY] = proxy
data[STACK_KEY] = stack
data[CONTEXT_KEY] = context
Expand Down Expand Up @@ -408,11 +421,13 @@ class IntentErrorMiddleware(BaseMiddleware):
def __init__(
self,
registry: DialogRegistryProtocol,
access_validator: StackAccessValidator,
events_isolation: BaseEventIsolation,
):
super().__init__()
self.registry = registry
self.events_isolation = events_isolation
self.access_validator = access_validator

def _is_error_supported(
self, event: ErrorEvent, data: Dict[str, Any],
Expand Down Expand Up @@ -484,8 +499,19 @@ async def __call__(
storage=proxy,
stack=stack,
)
data[STACK_KEY] = stack
data[CONTEXT_KEY] = context

if await self.access_validator.is_allowed(
stack, context, event.update.event, data,
):
data[STACK_KEY] = stack
data[CONTEXT_KEY] = context
else:
logger.debug(
"Stack %s is not allowed for user %s",
stack.id, proxy.user_id,
)
data[FORBIDDEN_STACK_KEY] = True
await proxy.unlock()
return await handler(event, data)
finally:
proxy: StorageProxy = data.pop(STORAGE_KEY, None)
Expand Down
18 changes: 9 additions & 9 deletions src/aiogram_dialog/context/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,14 @@ async def load_context(self, intent_id: str) -> Context:
raise UnknownIntent(
f"Context not found for intent id: {intent_id}",
)
data["access_settings"] = self._parse_access_settings(
data.pop("access_settings", None),
)
data["state"] = self._state(data["state"])
return Context(**data)

def _default_access_settings(self, stack_id: str) -> AccessSettings:
if stack_id == DEFAULT_STACK_ID:
if stack_id == DEFAULT_STACK_ID and self.user_id:
return AccessSettings(user_ids=[self.user_id])
else:
return AccessSettings(user_ids=[])
Expand All @@ -67,20 +70,20 @@ async def load_stack(self, stack_id: str = DEFAULT_STACK_ID) -> Stack:
key = self._stack_key(fixed_stack_id)
await self.lock(key)
data = await self.storage.get_data(key)
data.pop("access_settings", None) # compat with 2.2a5
access_settings = self._default_access_settings(stack_id)
if not data:
access_settings = self._default_access_settings(stack_id)
return Stack(_id=fixed_stack_id, access_settings=access_settings)

access_settings = self._parse_access_settings(
data.pop("access_settings", None),
)
return Stack(access_settings=access_settings, **data)

async def save_context(self, context: Optional[Context]) -> None:
if not context:
return
data = copy(vars(context))
data["state"] = data["state"].state
data["access_settings"] = self._dump_access_settings(
context.access_settings,
)
await self.storage.set_data(
key=self._context_key(context.id),
data=data,
Expand Down Expand Up @@ -108,9 +111,6 @@ async def save_stack(self, stack: Optional[Stack]) -> None:
)
else:
data = copy(vars(stack))
data["access_settings"] = self._dump_access_settings(
stack.access_settings,
)
await self.storage.set_data(
key=self._stack_key(stack.id),
data=data,
Expand Down
9 changes: 7 additions & 2 deletions src/aiogram_dialog/manager/manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from logging import getLogger
from typing import Any, cast, Dict, Optional, Union

Expand Down Expand Up @@ -255,8 +256,6 @@ async def _start_normal(
access_settings: Optional[AccessSettings],
) -> None:
stack = self.current_stack()
if access_settings is not None:
stack.access_settings = access_settings
old_dialog: Optional[DialogProtocol] = None
if not stack.empty():
old_dialog = self.dialog()
Expand All @@ -270,7 +269,13 @@ async def _start_normal(
await self._process_launch_mode(old_dialog, new_dialog)
if self.has_context():
await self.storage().save_context(self.current_context())
if access_settings is None:
access_settings = self.current_context().access_settings
if access_settings is None:
access_settings = stack.access_settings

context = stack.push(state, data)
context.access_settings = deepcopy(access_settings)
self._data[CONTEXT_KEY] = context
await self.dialog().process_start(self, data, state)
new_context = self._current_context_unsafe()
Expand Down
1 change: 1 addition & 0 deletions src/aiogram_dialog/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _register_middleware(
router.errors.middleware(IntentErrorMiddleware(
registry=registry,
events_isolation=events_isolation,
access_validator=stack_access_validator,
))

router.message.middleware(manager_middleware)
Expand Down
39 changes: 37 additions & 2 deletions tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

import pytest
from aiogram import Dispatcher
from aiogram.filters import CommandStart
from aiogram.filters import CommandStart, Command
from aiogram.fsm.state import State, StatesGroup

from aiogram_dialog import (
Dialog, DialogManager, setup_dialogs, StartMode, Window,
)
from aiogram_dialog.api.entities import GROUP_STACK_ID
from aiogram_dialog.api.entities import GROUP_STACK_ID, AccessSettings
from aiogram_dialog.test_tools import BotClient, MockMessageManager
from aiogram_dialog.test_tools.keyboard import InlineButtonTextLocator
from aiogram_dialog.widgets.kbd import Button
Expand All @@ -36,6 +36,12 @@ async def start_shared(event: Any, dialog_manager: DialogManager):
await dialog_manager.start(MainSG.start, mode=StartMode.RESET_STACK)


async def add_shared(event: Any, dialog_manager: DialogManager):
await dialog_manager.start(MainSG.start, access_settings=AccessSettings(
user_ids=[1, 2],
))


@pytest.fixture()
def message_manager():
return MockMessageManager()
Expand Down Expand Up @@ -74,6 +80,35 @@ async def test_second_user(dp, client, second_client, message_manager):
assert not message_manager.sent_messages


@pytest.mark.asyncio
async def test_change_seettings(dp, client, second_client, message_manager):
dp.message.register(start, CommandStart())
dp.message.register(add_shared, Command("add"))

await client.send("/start")
message_manager.reset_history()

await client.send("/add")
window_message = message_manager.one_message()
message_manager.reset_history()

await second_client.click(
window_message, InlineButtonTextLocator("Button"),
)
window_message = message_manager.one_message()
message_manager.reset_history()
assert window_message.text == "stub"

await client.send("/start")
window_message = message_manager.one_message()
message_manager.reset_history()

await second_client.click(
window_message, InlineButtonTextLocator("Button"),
)
assert not message_manager.sent_messages


@pytest.mark.asyncio
async def test_same_user(dp, client, message_manager):
dp.message.register(start, CommandStart())
Expand Down