Skip to content

Commit

Permalink
rename execute_command to __call__
Browse files Browse the repository at this point in the history
This commit addresses reviewer comment:
<datalad#596 (comment)>
It renames `ShellCommandExecutor.execute_command` to
`ShellCommandExecutor.__call__`. With that change the
following code becomes possible:

>>> with shell_connection(['bash']) as shell:
...     r = shell(b'uname')
  • Loading branch information
christian-monch committed Jan 22, 2024
1 parent 1026ae6 commit a869e0e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
18 changes: 9 additions & 9 deletions datalad_next/utils/shell_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def train(queue: Queue):
cmd_executor = ShellCommandExecutor(subprocess_inputs, shell_output)
try:
# Skip initial login messages
result_0 = cmd_executor.execute_command(b'test')
result_0 = cmd_executor(b'test')
for line in result_0:
lgr.debug('skipped login message line: %s', line)
if result_0.returncode != 1:
Expand Down Expand Up @@ -118,11 +118,11 @@ def __init__(self,
self.stdout = stdout
self.end_marker, self.command_postfix = self._get_marker()

def execute_command(self,
command: bytes,
*,
stdin: Iterable[bytes] | None = None,
) -> ShellCommandResponseGenerator:
def __call__(self,
command: bytes,
*,
stdin: Iterable[bytes] | None = None,
) -> ShellCommandResponseGenerator:
"""Execute a command in the connected shell
Parameters
Expand Down Expand Up @@ -181,7 +181,7 @@ def upload(self,
file_size = local_path.stat().st_size
cmd_line = f'dd bs=1 of="{remote_path.as_posix()}" count={file_size}'
with local_path.open('rb') as local_file:
result = self.execute_command(
result = self(
cmd_line.encode(),
stdin=iter(local_file.read, b''))
consume(result)
Expand Down Expand Up @@ -213,7 +213,7 @@ def download(self,
# wire. That ensures that the end-marker is not in the downloaded data,
# even if it is contained in the remote file.
cmd_line = f'7z a dummy -tgzip -so "{remote_path.as_posix()}"'
result = self.execute_command(cmd_line.encode())
result = self(cmd_line.encode())
dco = decompressobj(wbits=31)
with local_path.open('wb') as local_file:
for chunk in result:
Expand Down Expand Up @@ -247,7 +247,7 @@ def delete(self,
'rm ' \
+ ('-f ' if force else '') \
+ ' '.join(f'"{f.as_posix()}"' for f in files)
result = self.execute_command(cmd_line.encode())
result = self(cmd_line.encode())
consume(result)
return result.returncode, b''.join(result.stderr_deque)

Expand Down
10 changes: 5 additions & 5 deletions datalad_next/utils/tests/test_shell_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ def test_basic_functionality_multi(sshserver):


def _check_ls_result(ssh_executor, file_name: bytes):
results = ssh_executor.execute_command(b'ls ' + file_name)
results = ssh_executor(b'ls ' + file_name)
assert b''.join(results) == file_name + b'\n'
assert results.returncode == 0


def test_return_code_functionality(sshserver):
ssh_url = sshserver[0]
with shell_connection(_get_cmdline(ssh_url)[0]) as ssh:
results = ssh.execute_command(b'bash -c "exit 123"')
results = ssh(b'bash -c "exit 123"')
consume(results)
assert results.returncode == 123

Expand All @@ -77,7 +77,7 @@ def test_stdout_forwarding_multi(sshserver):


def _check_echo_result(ssh: ShellCommandExecutor, cmd: bytes, expected: bytes):
results = ssh.execute_command(cmd)
results = ssh(cmd)
assert b''.join(results) == expected
assert results.returncode == 0

Expand All @@ -94,7 +94,7 @@ def test_exit_if_unlimited_stdin_is_closed(sshserver):
shell_connection(ssh_args) as ssh_executor, \
iter_subproc([sys.executable, '-c', 'print("0123456789")']) as cat_feed:

results = ssh_executor.execute_command(b'cat >' + ssh_path, stdin=cat_feed)
results = ssh_executor(b'cat >' + ssh_path, stdin=cat_feed)
ssh_executor.close()
consume(results)
assert results.returncode == 0
Expand All @@ -115,7 +115,7 @@ def test_continuation_after_stdin_reading(sshserver):

for file_name, feed in (('dd-123', dd_feed_1), ('dd-456', dd_feed_2)):
server_path = (ssh_path + '/' + file_name).encode()
results = ssh_executor.execute_command(
results = ssh_executor(
b'dd bs=1 count=10 of=' + server_path,
stdin=feed)
consume(results)
Expand Down

0 comments on commit a869e0e

Please sign in to comment.