Skip to content

Commit

Permalink
Merge pull request #660 from christian-monch/download-progress
Browse files Browse the repository at this point in the history
Add progress-callback to `datalad_next.shell.operations.posix.download`
  • Loading branch information
mih authored Apr 18, 2024
2 parents 5a12b73 + e727df2 commit 10b6e04
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 17 deletions.
27 changes: 17 additions & 10 deletions datalad_next/shell/operations/posix.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def download(
shell: ShellCommandExecutor,
remote_path: PurePosixPath,
local_path: Path,
progress_callback: Callable[[int, int], None] | None = None,
*,
response_generator_class: type[
DownloadResponseGenerator
Expand All @@ -166,10 +167,7 @@ def download(
) -> ExecutionResult:
"""Download a file from the connected shell
This method downloads a file from the connected shell. It uses ``base64`` in
the shell to encode the file. The encoding is mainly done to ensure that the
end-marker is significant, i.e. not contained in the transferred file
content, and to ensure that no control-sequences are sent.
This method downloads a file from the connected shell.
The requirements for download via instances of class
:class:`DownloadResponseGeneratorPosix` are:
Expand All @@ -187,6 +185,9 @@ def download(
downloaded.
local_path : Path
The name of the local file that will contain the downloaded content.
progress_callback : callable[[int, int], None], optional, default: None
If given, the callback is called with the number of bytes that have
been received and the total number of bytes that should be received.
response_generator_class : type[DownloadResponseGenerator], optional, default: DownloadResponseGeneratorPosix
The response generator that should be used to handle the download
output. It must be a subclass of :class:`DownloadResponseGenerator`.
Expand All @@ -211,17 +212,23 @@ def download(
output.
"""
command = remote_path.as_posix().encode()
response_generator = shell.start(
response_generator = response_generator_class(shell.stdout)
result_generator = shell.start(
command,
response_generator=response_generator_class(shell.stdout),
response_generator=response_generator,
)
with local_path.open("wb") as local_file:
for chunk in response_generator:
processed = 0
for chunk in result_generator:
local_file.write(chunk)
stderr = b''.join(response_generator.stderr_deque)
response_generator.stderr_deque.clear()
processed += len(chunk)
if progress_callback is not None:
progress_callback(processed, response_generator.length)

stderr = b''.join(result_generator.stderr_deque)
result_generator.stderr_deque.clear()
return create_result(
response_generator,
result_generator,
command,
stdout=b'',
stderr=stderr,
Expand Down
40 changes: 33 additions & 7 deletions datalad_next/shell/tests/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def test_upload(sshserver, tmp_path):
test_file_name = 'upload_123'
upload_file = tmp_path / test_file_name
upload_file.write_text(content)
progress = []
with shell(ssh_args) as ssh_executor:
# perform an operation on the remote shell
_check_ls_result(ssh_executor, common_files[0])
Expand All @@ -156,10 +157,12 @@ def test_upload(sshserver, tmp_path):
result = posix.upload(
ssh_executor,
upload_file,
PurePosixPath(ssh_path + '/' + test_file_name)
PurePosixPath(ssh_path + '/' + test_file_name),
progress_callback=lambda a,b: progress.append((a,b))
)
assert result.returncode == 0
assert (local_path / test_file_name).read_text() == content
assert len(progress) > 0

# perform another operation on the remote shell to ensure functionality
_check_ls_result(ssh_executor, common_files[0])
Expand All @@ -173,6 +176,7 @@ def test_download_ssh(sshserver, tmp_path):
server_file = local_path / test_file_name
server_file.write_text(content)
download_file = tmp_path / test_file_name
progress = []
with shell(ssh_args) as ssh_executor:
# perform an operation on the remote shell
_check_ls_result(ssh_executor, common_files[0])
Expand All @@ -182,9 +186,11 @@ def test_download_ssh(sshserver, tmp_path):
ssh_executor,
PurePosixPath(ssh_path + '/' + test_file_name),
download_file,
progress_callback=lambda a,b: progress.append((a,b))
)
assert result.returncode == 0
assert download_file.read_text() == content
assert len(progress) > 0

# perform another operation on the remote shell to ensure functionality
_check_ls_result(ssh_executor, common_files[0])
Expand All @@ -198,16 +204,19 @@ def test_download_local_bash(tmp_path):
download_file = tmp_path / 'download_123'
download_file.write_text(content)
result_file = tmp_path / 'result_123'
progress = []
with shell(['bash']) as bash:
_check_ls_result(bash, common_files[0])

# download file from server and verify its content
posix.download(
bash,
PurePosixPath(download_file),
result_file
result_file,
progress_callback=lambda a,b: progress.append((a,b)),
)
assert result_file.read_text() == content
assert len(progress) > 0

# perform another operation on the remote shell to ensure functionality
_check_ls_result(bash, common_files[0])
Expand All @@ -220,12 +229,19 @@ def test_upload_local_bash(tmp_path):
upload_file = tmp_path / 'upload_123'
upload_file.write_text(content)
result_file = tmp_path / 'result_123'
progress = []
with shell(['bash']) as bash:
_check_ls_result(bash, common_files[0])

# upload file to server and verify its content
posix.upload(bash, upload_file, PurePosixPath(result_file))
posix.upload(
bash,
upload_file,
PurePosixPath(result_file),
progress_callback=lambda a,b: progress.append((a,b)),
)
assert result_file.read_text() == content
assert len(progress) > 0

# perform another operation on the remote shell to ensure functionality
_check_ls_result(bash, common_files[0])
Expand All @@ -238,12 +254,19 @@ def test_upload_local_bash_error(tmp_path):
source_file = tmp_path / 'upload_123'
source_file.write_text(content)
destination_file = PurePosixPath('/result_123')
progress = []
with shell(['bash']) as bash:
_check_ls_result(bash, common_files[0])

# upload file to a root on the server
result = posix.upload(bash, source_file, destination_file)
result = posix.upload(
bash,
source_file,
destination_file,
progress_callback=lambda a,b: progress.append((a,b)),
)
assert result.returncode != 0
assert len(progress) > 0

# perform another operation on the remote shell to ensure functionality
_check_ls_result(bash, common_files[0])
Expand Down Expand Up @@ -477,21 +500,24 @@ def test_download_length_error():
# This test does not work on Windows systems because it executes a local bash.
@skip_if(on_windows)
def test_download_error(tmp_path):
progress = []
with shell(['bash']) as bash:
with pytest.raises(CommandError):
posix.download(
bash,
PurePosixPath('/thisdoesnotexist'),
tmp_path / 'downloaded_file',
check=True
progress_callback=lambda a,b: progress.append((a,b)),
check=True,
)
_check_ls_result(bash, common_files[0])

result = posix.download(
bash,
PurePosixPath('/thisdoesnotexist'),
tmp_path / 'downloaded_file',
check=False
tmp_path / 'downloaded_file',
progress_callback=lambda a,b: progress.append((a,b)),
check=False,
)
assert result.returncode not in (0, None)
_check_ls_result(bash, common_files[0])

0 comments on commit 10b6e04

Please sign in to comment.