From 67e6e492642d13d2561ff313e84849dd21a298cb Mon Sep 17 00:00:00 2001 From: sellth Date: Fri, 8 Dec 2023 17:43:16 +0100 Subject: [PATCH] feat: Add common functions for interfacing with python-irodsclient (#202) --- cubi_tk/irods_common.py | 250 +++++++++++++++++++++++++++++++++++++ tests/test_irods_common.py | 169 +++++++++++++++++++++++++ 2 files changed, 419 insertions(+) create mode 100644 cubi_tk/irods_common.py create mode 100644 tests/test_irods_common.py diff --git a/cubi_tk/irods_common.py b/cubi_tk/irods_common.py new file mode 100644 index 00000000..ee998b8e --- /dev/null +++ b/cubi_tk/irods_common.py @@ -0,0 +1,250 @@ +import getpass +import os.path +from pathlib import Path +from typing import Iterable + +import attrs +from irods.exception import ( + CAT_INVALID_AUTHENTICATION, + CAT_INVALID_USER, + CAT_PASSWORD_EXPIRED, + PAM_AUTH_PASSWORD_FAILED, +) +from irods.password_obfuscation import encode +from irods.session import NonAnonymousLoginWithoutPassword, iRODSSession +import logzero +from logzero import logger +from tqdm import tqdm + +# no-frills logger +formatter = logzero.LogFormatter(fmt="%(message)s") +output_logger = logzero.setup_logger(formatter=formatter) + + +@attrs.frozen(auto_attribs=True) +class TransferJob: + """ + Encodes a transfer job between the local file system + and a remote iRODS collection. + """ + + #: Source path. + path_local: str + + #: Destination path. + path_remote: str + + #: Number of bytes to transfer (optional). + bytes: str = attrs.field() + + @bytes.default + def _get_file_size(self): + try: + return Path(self.path_local).stat().st_size + except FileNotFoundError: + return -1 + + +class iRODSCommon: + """ + Implementation of common iRODS utility functions. + + :param ask: Confirm with user before certain actions. + :type ask: bool, optional + :param irods_env_path: Path to irods_environment.json + :type irods_env_path: pathlib.Path, optional + """ + + def __init__(self, ask: bool = False, irods_env_path: Path = None): + # Path to iRODS environment file + if irods_env_path is None: + self.irods_env_path = Path.home().joinpath(".irods", "irods_environment.json") + else: + self.irods_env_path = irods_env_path + self.ask = ask + + @staticmethod + def get_irods_error(e: Exception): + """Return logger friendly iRODS exception.""" + es = str(e) + return es if es and es != "None" else e.__class__.__name__ + + def _init_irods(self) -> iRODSSession: + """Connect to iRODS. Login if needed.""" + while True: + try: + session = iRODSSession(irods_env_file=self.irods_env_path) + session.connection_timeout = 600 + session.server_version + return session + except NonAnonymousLoginWithoutPassword as e: # pragma: no cover + logger.info(self.get_irods_error(e)) + self._irods_login() + except ( + CAT_INVALID_AUTHENTICATION, + CAT_INVALID_USER, + CAT_PASSWORD_EXPIRED, + ): # pragma: no cover + logger.warning("Problem with your session token.") + self.irods_env_path.parent.joinpath(".irodsA").unlink() + self._irods_login() + except Exception as e: # pragma: no cover + logger.error(f"iRODS connection failed: {self.get_irods_error(e)}") + raise + + def _irods_login(self): + """Ask user to log into iRODS.""" + # No valid .irodsA file. Query user for password. + attempts = 0 + while attempts < 3: + try: + session = iRODSSession( + irods_env_file=self.irods_env_path, + password=getpass.getpass(prompt="Please enter SODAR password:"), + ) + token = session.pam_pw_negotiated + session.cleanup() + break + except PAM_AUTH_PASSWORD_FAILED as e: # pragma: no cover + if attempts < 2: + logger.warning("Wrong password. Please try again.") + attempts += 1 + continue + else: + logger.error("iRODS connection failed.") + raise e + except Exception as e: # pragma: no cover + logger.error(f"iRODS connection failed: {self.get_irods_error(e)}") + raise RuntimeError + + if self.ask and input( + "Save iRODS session for passwordless operation? [y/N] " + ).lower().startswith("y"): + self._save_irods_token(token) # pragma: no cover + elif not self.ask: + self._save_irods_token(token) + + def _save_irods_token(self, token: str): + """Retrieve PAM temp auth token 'obfuscate' it and save to disk.""" + irods_auth_path = self.irods_env_path.parent.joinpath(".irodsA") + irods_auth_path.parent.mkdir(parents=True, exist_ok=True) + + if isinstance(token, list) and token: + irods_auth_path.write_text(encode(token[0])) + irods_auth_path.chmod(0o600) + else: + logger.warning("No token found to be saved.") + + @property + def session(self): + return self._init_irods() + + +class iRODSTransfer(iRODSCommon): + """ + Transfer files to iRODS. + + :param jobs: Iterable of TransferJob objects + :type jobs: Union[list,tuple,dict,set] + """ + + def __init__(self, jobs: Iterable[TransferJob], **kwargs): + super().__init__(**kwargs) + self.__jobs = jobs + self.__total_bytes = sum([job.bytes for job in self.__jobs]) + self.__destinations = [job.path_remote for job in self.__jobs] + + @property + def jobs(self): + return self.__jobs + + @property + def size(self): + return self.__total_bytes + + @property + def destinations(self): + return self.__destinations + + def _create_collections(self, job: TransferJob): + collection = str(Path(job.path_remote).parent) + with self.session as session: + session.collections.create(collection) + + def put(self, recursive: bool = False, sync: bool = False): + # Double tqdm for currently transferred file info + # TODO: add more parenthesis after python 3.10 + with tqdm( + total=self.__total_bytes, + unit="B", + unit_scale=True, + unit_divisor=1024, + position=1, + ) as t, tqdm(total=0, position=0, bar_format="{desc}", leave=False) as file_log: + for n, job in enumerate(self.__jobs): + file_log.set_description_str( + f"File [{n + 1}/{len(self.__jobs)}]: {Path(job.path_local).name}" + ) + try: + with self.session as session: + if recursive: + self._create_collections(job) + if sync and session.data_objects.exists(job.path_remote): + t.update(job.bytes) + continue + session.data_objects.put(job.path_local, job.path_remote) + t.update(job.bytes) + except Exception as e: # pragma: no cover + logger.error(f"Problem during transfer of {job.path_local}") + logger.error(self.get_irods_error(e)) + t.clear() + + def chksum(self): + """Compute remote md5 checksums for all jobs.""" + common_prefix = os.path.commonpath(self.__destinations) + checkjobs = tuple(job for job in self.__jobs if not job.path_remote.endswith(".md5")) + logger.info(f"Triggering remote checksum computation for {len(checkjobs)} files.") + for n, job in enumerate(checkjobs): + output_logger.info( + f"[{n + 1}/{len(checkjobs)}]: {Path(job.path_remote).relative_to(common_prefix)}" + ) + try: + with self.session as session: + data_object = session.data_objects.get(job.path_remote) + if not data_object.checksum: + data_object.chksum() + except Exception as e: # pragma: no cover + logger.error("Problem during iRODS checksumming.") + logger.error(self.get_irods_error(e)) + + def get(self): + """Download files from SODAR.""" + with self.session as session: + self.__jobs = [ + attrs.evolve(job, bytes=session.data_objects.get(job.path_remote).size) + for job in self.__jobs + ] + self.__total_bytes = sum([job.bytes for job in self.__jobs]) + # Double tqdm for currently transferred file info + # TODO: add more parenthesis after python 3.10 + with tqdm( + total=self.__total_bytes, + unit="B", + unit_scale=True, + unit_divisor=1024, + position=1, + ) as t, tqdm(total=0, position=0, bar_format="{desc}", leave=False) as file_log: + for n, job in enumerate(self.__jobs): + file_log.set_description_str( + f"File [{n + 1}/{len(self.__jobs)}]: {Path(job.path_local).name}" + ) + try: + with self.session as session: + session.data_objects.get(job.path_remote, job.path_local) + t.update(job.bytes) + except FileNotFoundError: # pragma: no cover + raise + except Exception as e: # pragma: no cover + logger.error(f"Problem during transfer of {job.path_remote}") + logger.error(self.get_irods_error(e)) + t.clear() diff --git a/tests/test_irods_common.py b/tests/test_irods_common.py new file mode 100644 index 00000000..156ae241 --- /dev/null +++ b/tests/test_irods_common.py @@ -0,0 +1,169 @@ +from pathlib import Path +from unittest.mock import ANY, MagicMock, call, patch + +import irods.exception +import pytest + +from cubi_tk.irods_common import TransferJob, iRODSCommon, iRODSTransfer + + +def test_transfer_job_bytes(fs): + fs.create_file("test_file", st_size=123) + assert TransferJob("test_file", "remote/path").bytes == 123 + assert TransferJob("no_file.no", "remote/path").bytes == -1 + + +@patch("cubi_tk.irods_common.iRODSSession") +def test_common_init(mocksession): + assert iRODSCommon().irods_env_path is not None + icommon = iRODSCommon(irods_env_path="a/b/c.json") + assert icommon.irods_env_path == "a/b/c.json" + assert type(iRODSCommon().ask) is bool + assert iRODSCommon().session is mocksession.return_value + + +@patch("cubi_tk.irods_common.iRODSSession") +def test_get_irods_error(mocksession): + e = irods.exception.NetworkException() + assert iRODSCommon().get_irods_error(e) == "NetworkException" + e = irods.exception.NetworkException("Connection reset") + assert iRODSCommon().get_irods_error(e) == "Connection reset" + + +@patch("cubi_tk.irods_common.iRODSSession") +def test_init_irods(mocksession, fs): + fs.create_file(".irods/irods_environment.json") + fs.create_file(".irods/.irodsA") + + iRODSCommon()._init_irods() + mocksession.assert_called() + + +@patch("getpass.getpass") +@patch("cubi_tk.irods_common.iRODSSession") +def test_irods_login(mocksession, mockpass, fs): + fs.create_file(".irods/irods_environment.json") + password = "1234" + icommon = iRODSCommon() + mockpass.return_value = password + + icommon._irods_login() + mockpass.assert_called() + mocksession.assert_any_call(irods_env_file=ANY, password=password) + + +@patch("cubi_tk.irods_common.encode", return_value="it works") +@patch("cubi_tk.irods_common.iRODSSession") +def test_save_irods_token(mocksession, mockencode, fs): + token = [ + "secure", + ] + icommon = iRODSCommon() + icommon.irods_env_path = Path("testdir/env.json") + icommon._save_irods_token(token=token) + + assert icommon.irods_env_path.parent.joinpath(".irodsA").exists() + mockencode.assert_called_with("secure") + + +# Test iRODSTransfer ######### +@pytest.fixture +def jobs(): + return ( + TransferJob(path_local="myfile.csv", path_remote="dest_dir/myfile.csv", bytes=123), + TransferJob( + path_local="folder/file.csv", path_remote="dest_dir/folder/file.csv", bytes=1024 + ), + ) + + +def test_irods_transfer_init(jobs): + with patch("cubi_tk.irods_common.iRODSSession"): + itransfer = iRODSTransfer(jobs=jobs, irods_env_path="a/b/c", ask=True) + assert itransfer.irods_env_path == "a/b/c" + assert itransfer.ask is True + assert itransfer.jobs == jobs + assert itransfer.size == sum([job.bytes for job in jobs]) + assert itransfer.destinations == [job.path_remote for job in jobs] + + +@patch("cubi_tk.irods_common.iRODSTransfer._init_irods") +@patch("cubi_tk.irods_common.iRODSTransfer._create_collections") +def test_irods_transfer_put(mockrecursive, mocksession, jobs): + mockput = MagicMock() + mockexists = MagicMock(return_value=True) + mockobj = MagicMock() + mockobj.put = mockput + mockobj.exists = mockexists + + # fit for context management + mocksession.return_value.__enter__.return_value.data_objects = mockobj + itransfer = iRODSTransfer(jobs) + + # put + itransfer.put() + calls = [call(j.path_local, j.path_remote) for j in jobs] + mockput.assert_has_calls(calls) + + # recursive + itransfer.put(recursive=True) + calls = [call(j) for j in jobs] + mockrecursive.assert_has_calls(calls) + + # sync + mockput.reset_mock() + itransfer.put(sync=True) + mockput.assert_not_called() + mockexists.assert_called() + + +@patch("cubi_tk.irods_common.iRODSTransfer._init_irods") +def test_create_collections(mocksession, jobs): + mockcreate = MagicMock() + mockcoll = MagicMock() + mockcoll.create = mockcreate + mocksession.return_value.__enter__.return_value.collections = mockcoll + itransfer = iRODSTransfer(jobs) + + itransfer._create_collections(itransfer.jobs[1]) + coll_path = str(Path(itransfer.jobs[1].path_remote).parent) + mockcreate.assert_called_with(coll_path) + + +@patch("cubi_tk.irods_common.iRODSTransfer._init_irods") +def test_irods_transfer_chksum(mocksession, jobs): + mockget = MagicMock() + mockobj = MagicMock() + mockobj.get = mockget + mocksession.return_value.__enter__.return_value.data_objects = mockobj + + mock_data_object = MagicMock() + mock_data_object.checksum = None + mock_data_object.chksum = MagicMock() + mockget.return_value = mock_data_object + + itransfer = iRODSTransfer(jobs) + itransfer.chksum() + + assert mock_data_object.chksum.call_count == len(itransfer.destinations) + for path in itransfer.destinations: + mockget.assert_any_call(path) + + +@patch("cubi_tk.irods_common.iRODSTransfer._init_irods") +def test_irods_transfer_get(mocksession, jobs): + mockget = MagicMock() + mockobj = MagicMock() + mockobj.get = mockget + mocksession.return_value.__enter__.return_value.data_objects = mockobj + itransfer = iRODSTransfer(jobs) + + mockget.return_value.size = 111 + itransfer.get() + + for job in jobs: + # size check + mockget.assert_any_call(job.path_remote) + # download + mockget.assert_any_call(job.path_remote, job.path_local) + assert itransfer.size == 222