Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for subcommands in the API #254

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions src/zino/api/legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def requires_authentication(func: Callable) -> Callable:
class Zino1BaseServerProtocol(asyncio.Protocol):
"""Base implementation of the Zino 1 protocol, with a basic command dispatcher for subclasses to utilize."""

HAS_SUBCOMMANDS: tuple[str] = tuple()

def __init__(
self,
server: Optional["ZinoServer"] = None,
Expand Down Expand Up @@ -120,9 +122,14 @@ def message_received(self, message: str):
return self._dispatch_command(*args)

def _dispatch_command(self, command, *args):
responder = self._get_responder(command)
commands = [command]
if command.lower() in self.HAS_SUBCOMMANDS and args:
commands.append(args[0])
# Remove subcommand from args
args = args[1:]
responder = self._get_responder(*commands)
if not responder:
return self._respond_error(f'unknown command: "{command}"')
return self._respond_error(f'unknown command: "{" ".join(commands)}"')

if getattr(responder, "requires_authentication", False) and not self.is_authenticated:
return self._respond_error("Not authenticated")
Expand All @@ -136,7 +143,7 @@ def _dispatch_command(self, command, *args):
_logger.debug("client %s sent %r, ignoring garbage args at end: %r", self.peer_name, args, garbage_args)
args = args[: len(required_args)]

self._current_task = asyncio.create_task(self._run_async_responder(command, responder, *args))
self._current_task = asyncio.create_task(self._run_async_responder(" ".join(commands), responder, *args))
return self._current_task

async def _run_async_responder(self, command: str, responder: Callable, *args):
Expand All @@ -149,11 +156,13 @@ async def _run_async_responder(self, command: str, responder: Callable, *args):
finally:
self._current_task = None

def _get_responder(self, command: str):
if not command.isalpha():
return

func = getattr(self, f"do_{command.lower()}", None)
def _get_responder(self, *commands):
command_string = "do"
for command in commands:
if not command.isalpha():
return
command_string += f"_{command.lower()}"
func = getattr(self, command_string, None)
if callable(func):
return func

Expand Down
17 changes: 17 additions & 0 deletions tests/api/legacy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,23 @@ def test_when_disconnected_it_should_deregister_instance_from_server(self, event

assert protocol not in server.active_clients

@pytest.mark.asyncio
async def test_subcommand_should_be_dispatched_if_command_with_support_of_subcommand_is_used(self):
args = []

class TestProtocol(Zino1BaseServerProtocol):
HAS_SUBCOMMANDS = ("testcommand",)

async def do_testcommand_subcommand(self, one, two):
args.extend((one, two))

protocol = TestProtocol()
fake_transport = Mock()
protocol.connection_made(fake_transport)
await protocol.message_received("testcommand subcommand foo bar")

assert args == ["foo", "bar"], "do_testcommand_subcommand() was apparently not called"


class TestZino1ServerProtocolTranslateCaseIdToEvent:
@pytest.mark.asyncio
Expand Down
Loading