Skip to content

Commit

Permalink
Reimplement SSHRemoteIO with datalad_next.shell
Browse files Browse the repository at this point in the history
This takes out all of the old remote shell implementation, and uses the
new one. It does not touch the get/put implementations (yet). They
can also be done with the new shell feature, but that is a different
problem.
  • Loading branch information
mih committed Apr 18, 2024
1 parent 69b0885 commit ecb516e
Showing 1 changed file with 78 additions and 113 deletions.
191 changes: 78 additions & 113 deletions datalad_next/patches/replace_sshremoteio.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,30 @@
import logging

from urllib.parse import urlparse
from urllib.request import unquote

from datalad.distributed.ora_remote import (
DEFAULT_BUFFER_SIZE,
IOBase,
RemoteError,
RemoteCommandFailedError,
RIARemoteError,
contextmanager,
functools,
on_osx,
sh_quote,
ssh_manager,
stat,
subprocess,
)

from datalad_next.exceptions import CapturedException
from datalad_next.patches import apply_patch
from datalad_next.runners import CommandError
from datalad_next.shell import shell


class SSHRemoteIO(IOBase):
"""IO operation if the object tree is SSH-accessible
It doesn't even think about a windows server.
"""

# output markers to detect possible command failure as well as end of output
# from a particular command:
REMOTE_CMD_FAIL = "ora-remote: end - fail"
REMOTE_CMD_OK = "ora-remote: end - ok"

def __init__(self, ssh_url, buffer_size=DEFAULT_BUFFER_SIZE):
"""
Parameters
Expand All @@ -43,6 +35,7 @@ def __init__(self, ssh_url, buffer_size=DEFAULT_BUFFER_SIZE):
"""
parsed_url = urlparse(ssh_url)

self.url = ssh_url
# the connection to the remote
# we don't open it yet, not yet clear if needed
self.ssh = ssh_manager.get_connection(
Expand All @@ -52,29 +45,22 @@ def __init__(self, ssh_url, buffer_size=DEFAULT_BUFFER_SIZE):
self.ssh.open()
# open a remote shell
cmd = ['ssh'] + self.ssh._ssh_args + [self.ssh.sshri.as_str()]
self.shell = subprocess.Popen(cmd,
stderr=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stdin=subprocess.PIPE)
# swallow login message(s):
self.shell.stdin.write(b"echo RIA-REMOTE-LOGIN-END\n")
self.shell.stdin.flush()
while True:
line = self.shell.stdout.readline()
if line == b"RIA-REMOTE-LOGIN-END\n":
break
# TODO: Same for stderr?

# make sure default is used when None was passed, too.
self.buffer_size = buffer_size if buffer_size else DEFAULT_BUFFER_SIZE
# we settle on `bash` as a shell. It should be around and then we
# can count on it
cmd.append('bash')
self.servershell_context = shell(
cmd,
chunk_size=buffer_size,
)
self.servershell = self.servershell_context.__enter__()

# if the URL had a path, we try to 'cd' into it to make operations on
# relative paths intelligible
if parsed_url.path:
# unquote path
real_path = unquote(parsed_url.path)
try:
self._run(
self.servershell(
f'cd {sh_quote(real_path)}',
check=True,
)
Expand All @@ -84,22 +70,7 @@ def __init__(self, ssh_url, buffer_size=DEFAULT_BUFFER_SIZE):
CapturedException(e)

def close(self):
# try exiting shell clean first
self.shell.stdin.write(b"exit\n")
self.shell.stdin.flush()
exitcode = self.shell.wait(timeout=0.5)
# be more brutal if it doesn't work
if exitcode is None: # timed out
# TODO: Theoretically terminate() can raise if not successful.
# How to deal with that?
self.shell.terminate()

def _append_end_markers(self, cmd):
"""Append end markers to remote command"""

return cmd + " && printf '%s\\n' {} || printf '%s\\n' {}\n".format(
sh_quote(self.REMOTE_CMD_OK),
sh_quote(self.REMOTE_CMD_FAIL))
self.servershell_context.__exit__(None, None, None)

def _get_download_size_from_key(self, key):
"""Get the size of an annex object file from it's key
Expand All @@ -110,7 +81,7 @@ def _get_download_size_from_key(self, key):
Parameter
---------
key: str
annex key of the file
annex key of the filte
Returns
-------
Expand Down Expand Up @@ -154,38 +125,6 @@ def _get_download_size_from_key(self, key):
else:
raise RIARemoteError("invalid key: {}".format(key))

def _run(self, cmd, no_output=True, check=False):

# TODO: we might want to redirect stderr to stdout here (or have
# additional end marker in stderr) otherwise we can't empty stderr
# to be ready for next command. We also can't read stderr for
# better error messages (RemoteError) without making sure there's
# something to read in any case (it's blocking!).
# However, if we are sure stderr can only ever happen if we would
# raise RemoteError anyway, it might be okay.
call = self._append_end_markers(cmd)
self.shell.stdin.write(call.encode())
self.shell.stdin.flush()

lines = []
while True:
line = self.shell.stdout.readline().decode()
lines.append(line)
if line == self.REMOTE_CMD_OK + '\n':
# end reading
break
elif line == self.REMOTE_CMD_FAIL + '\n':
if check:
raise RemoteCommandFailedError(
"{cmd} failed: {msg}".format(cmd=cmd,
msg="".join(lines[:-1]))
)
else:
break
if no_output and len(lines) > 1:
raise RIARemoteError("{}: {}".format(call, "".join(lines)))
return "".join(lines[:-1])

@contextmanager
def ensure_writeable(self, path):
"""Context manager to get write permission on `path` and restore
Expand Down Expand Up @@ -215,12 +154,14 @@ def ensure_writeable(self, path):
# needed.
conversion = functools.partial(int, base=16)

output = self._run(f"stat {format_option} {path}",
no_output=False, check=True)
output = self.servershell(
f"stat {format_option} {path}",
check=True,
).stdout.decode()
mode = conversion(output)
if not mode & stat.S_IWRITE:
new_mode = oct(mode | stat.S_IWRITE)[-3:]
self._run(f"chmod {new_mode} {path}")
self.servershell(f"chmod {new_mode} {path}", check=True)
changed = True
else:
changed = False
Expand All @@ -229,16 +170,23 @@ def ensure_writeable(self, path):
finally:
if changed:
# restore original mode
self._run("chmod {mode} {file}".format(mode=oct(mode)[-3:],
file=path),
check=False) # don't fail if path doesn't exist
# anymore
self.servershell(
f"chmod {oct(mode)[-3:]} {path}",
# don't fail if path doesn't exist anymore
check=False,
)

def mkdir(self, path):
self._run('mkdir -p {}'.format(sh_quote(str(path))))
self.servershell(
f'mkdir -p {sh_quote(str(path))}',
check=True,
)

def symlink(self, target, link_name):
self._run('ln -s {} {}'.format(sh_quote(str(target)), sh_quote(str(link_name))))
self.servershell(
f'ln -s {sh_quote(str(target))} {sh_quote(str(link_name))}',
check=True,
)

def put(self, src, dst, progress_cb):
self.ssh.put(str(src), str(dst))
Expand Down Expand Up @@ -295,25 +243,38 @@ def get(self, src, dst, progress_cb):

def rename(self, src, dst):
with self.ensure_writeable(dst.parent):
self._run('mv {} {}'.format(sh_quote(str(src)), sh_quote(str(dst))))
self.servershell(
f'mv {sh_quote(str(src))} {sh_quote(str(dst))}',
check=True,
)

def remove(self, path):
try:
with self.ensure_writeable(path.parent):
self._run('rm {}'.format(sh_quote(str(path))), check=True)
except RemoteCommandFailedError as e:
raise RIARemoteError(f"Unable to remove {path} "
"or to obtain write permission in parent directory.") from e
self.servershell(
f'rm {sh_quote(str(path))}',
check=True,
)
except CommandError as e:
raise RIARemoteError(
f"Unable to remove {path} "
"or to obtain write permission in parent directory.") from e

def remove_dir(self, path):
with self.ensure_writeable(path.parent):
self._run('rmdir {}'.format(sh_quote(str(path))))
self.servershell(
f'rmdir {sh_quote(str(path))}',
check=True,
)

def exists(self, path):
try:
self._run('test -e {}'.format(sh_quote(str(path))), check=True)
self.servershell(
f'test -e {sh_quote(str(path))}',
check=True,
)
return True
except RemoteCommandFailedError:
except CommandError:
return False

def in_archive(self, archive_path, file_path):
Expand All @@ -324,14 +285,15 @@ def in_archive(self, archive_path, file_path):
loc = str(file_path)
# query 7z for the specific object location, keeps the output
# lean, even for big archives
cmd = '7z l {} {}'.format(
sh_quote(str(archive_path)),
sh_quote(loc))
cmd = f'7z l {sh_quote(str(archive_path))} {sh_quote(loc)}'

# Note: Currently relies on file_path not showing up in case of failure
# including non-existent archive. If need be could be more sophisticated
# and called with check=True + catch RemoteCommandFailedError
out = self._run(cmd, no_output=False, check=False)
out = self.servershell(
cmd,
check=False,
).stdout.decode()

return loc in out

Expand Down Expand Up @@ -373,10 +335,13 @@ def get_from_archive(self, archive, src, dst, progress_cb):

def read_file(self, file_path):

cmd = "cat {}".format(sh_quote(str(file_path)))
cmd = f"cat {sh_quote(str(file_path))}"
try:
out = self._run(cmd, no_output=False, check=True)
except RemoteCommandFailedError as e:
out = self.servershell(
cmd,
check=True,
).stdout.decode()
except CommandError as e:
# Currently we don't read stderr. All we know is, we couldn't read.
# Try narrowing it down by calling a subsequent exists()
if not self.exists(file_path):
Expand All @@ -397,13 +362,16 @@ def write_file(self, file_path, content, mode='w'):
if not content.endswith('\n'):
content += '\n'

cmd = "printf '%s' {} {} {}".format(
sh_quote(content),
mode,
sh_quote(str(file_path)))
# it really should read from stdin, but MIH cannot make it happen
stdin = content.encode()
cmd = f"head -c {len(stdin)} | cat {mode} {sh_quote(str(file_path))}"
try:
self._run(cmd, check=True)
except RemoteCommandFailedError as e:
self.servershell(
cmd,
check=True,
stdin=[stdin],
)
except CommandError as e:
raise RIARemoteError(f"Could not write to {file_path}") from e

def get_7z(self):
Expand All @@ -416,17 +384,14 @@ def get_7z(self):
# to just call 7z and see whether it returns zero.

try:
self._run("7z", check=True, no_output=False)
self.servershell(
"7z",
check=True,
)
return True
except RemoteCommandFailedError:
except CommandError:
return False

# try:
# out = self._run("which 7z", check=True, no_output=False)
# return out
# except RemoteCommandFailedError:
# return None


# replace the whole class
apply_patch('datalad.distributed.ora_remote', None, 'SSHRemoteIO', SSHRemoteIO)

0 comments on commit ecb516e

Please sign in to comment.