diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e85c077..377be64 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,10 +21,10 @@ jobs: - '3.11' steps: - name: 'Set up Python ${{ matrix.python-version }}' - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: '${{ matrix.python-version }}' - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - run: pip install -e . -r requirements-test.txt - run: py.test -vvv --cov . - uses: codecov/codecov-action@v3 @@ -36,23 +36,23 @@ jobs: Lint: runs-on: ubuntu-latest steps: - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v4 with: python-version: '3.11' - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - run: pip install -e . pre-commit - run: pre-commit run --all-files Build: runs-on: ubuntu-latest steps: - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v4 with: python-version: '3.11' - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - run: pip install build - run: python -m build . - - uses: actions/upload-artifact@v2 + - uses: actions/upload-artifact@v3 with: name: dist path: dist/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9b8c719..c8c53f7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,8 +7,8 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.254 + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.278 hooks: - id: ruff args: @@ -16,7 +16,7 @@ repos: - --exit-non-zero-on-fix - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.0.1 + rev: v1.4.1 hooks: - id: mypy exclude: hai_tests/test_.* @@ -24,3 +24,8 @@ repos: - --install-types - --non-interactive - --scripts-are-modules + + - repo: https://github.com/psf/black + rev: 23.7.0 + hooks: + - id: black diff --git a/LICENSE b/LICENSE index 49a87db..6723980 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,3 @@ - The MIT License (MIT) Copyright (c) 2018 Valohai diff --git a/hai/__init__.py b/hai/__init__.py index fc79d63..3ced358 100644 --- a/hai/__init__.py +++ b/hai/__init__.py @@ -1 +1 @@ -__version__ = '0.2.1' +__version__ = "0.2.1" diff --git a/hai/boto3_multipart_upload.py b/hai/boto3_multipart_upload.py index f90ff94..f92f395 100644 --- a/hai/boto3_multipart_upload.py +++ b/hai/boto3_multipart_upload.py @@ -13,8 +13,8 @@ class MultipartUploader(EventEmitter): event_types = { - 'progress', - 'part-error', + "progress", + "part-error", } part_retry_attempts = 10 minimum_file_size = S3_MINIMUM_MULTIPART_FILE_SIZE @@ -28,7 +28,7 @@ def __init__(self, s3: BaseClient, log: Optional[logging.Logger] = None) -> None :param log: A logger, if desired. """ self.s3 = s3 - self.log = (log or logging.getLogger(self.__class__.__name__)) + self.log = log or logging.getLogger(self.__class__.__name__) def upload_parts( self, @@ -63,44 +63,55 @@ def upload_parts( Bucket=bucket, Key=key, PartNumber=part_number, - UploadId=mpu['UploadId'], + UploadId=mpu["UploadId"], Body=chunk, ) except Exception as exc: - self.log.error(f'Error uploading part {part_number} (attempt {attempt})', exc_info=True) - self.emit('part-error', { - 'chunk': part_number, - 'attempt': attempt, - 'attempts_left': self.part_retry_attempts - attempt, - 'exception': exc, - }) + self.log.error( + f"Error uploading part {part_number} (attempt {attempt})", + exc_info=True, + ) + self.emit( + "part-error", + { + "chunk": part_number, + "attempt": attempt, + "attempts_left": self.part_retry_attempts - attempt, + "exception": exc, + }, + ) if attempt == self.part_retry_attempts - 1: raise else: bytes += len(chunk) - part_infos.append({'PartNumber': part_number, 'ETag': part['ETag']}) - self.emit('progress', { - 'part_number': part_number, - 'part': part, - 'bytes_uploaded': bytes, - }) + part_infos.append( + {"PartNumber": part_number, "ETag": part["ETag"]}, + ) + self.emit( + "progress", + { + "part_number": part_number, + "part": part, + "bytes_uploaded": bytes, + }, + ) break except: # noqa - self.log.debug('Aborting multipart upload') + self.log.debug("Aborting multipart upload") self.s3.abort_multipart_upload( Bucket=bucket, Key=key, - UploadId=mpu['UploadId'], + UploadId=mpu["UploadId"], ) raise - self.log.info('Completing multipart upload') + self.log.info("Completing multipart upload") return self.s3.complete_multipart_upload( # type: ignore[no-any-return] Bucket=bucket, Key=key, - UploadId=mpu['UploadId'], - MultipartUpload={'Parts': part_infos}, + UploadId=mpu["UploadId"], + MultipartUpload={"Parts": part_infos}, ) def read_chunk(self, fp: IO[bytes], size: int) -> bytes: @@ -130,8 +141,8 @@ def upload_file( These roughly correspond to what one might be able to pass to `put_object`. :return: The return value of the `complete_multipart_upload` call. """ - if not hasattr(fp, 'read'): # pragma: no cover - raise TypeError('`fp` must have a `read()` method') + if not hasattr(fp, "read"): # pragma: no cover + raise TypeError("`fp` must have a `read()` method") try: size = os.stat(fp.fileno()).st_size @@ -140,8 +151,8 @@ def upload_file( if size and size <= self.minimum_file_size: raise ValueError( - f'File is too small to upload as multipart {size} bytes ' - f'(must be at least {self.minimum_file_size} bytes)' + f"File is too small to upload as multipart {size} bytes " + f"(must be at least {self.minimum_file_size} bytes)", ) if not chunk_size: @@ -150,20 +161,29 @@ def upload_file( maximum = min(S3_MAXIMUM_MULTIPART_CHUNK_SIZE, self.maximum_chunk_size) chunk_size = int(max(minimum, min(chunk_size, maximum))) - if not S3_MINIMUM_MULTIPART_CHUNK_SIZE <= chunk_size < S3_MAXIMUM_MULTIPART_CHUNK_SIZE: + if ( + not S3_MINIMUM_MULTIPART_CHUNK_SIZE + <= chunk_size + < S3_MAXIMUM_MULTIPART_CHUNK_SIZE + ): raise ValueError( - f'Chunk size {chunk_size} is outside the protocol limits ' - f'({S3_MINIMUM_MULTIPART_CHUNK_SIZE}..{S3_MAXIMUM_MULTIPART_CHUNK_SIZE})' + f"Chunk size {chunk_size} is outside the protocol limits " + f"({S3_MINIMUM_MULTIPART_CHUNK_SIZE}..{S3_MAXIMUM_MULTIPART_CHUNK_SIZE})", ) def chunk_generator() -> Generator[bytes, None, None]: while True: - chunk = self.read_chunk(fp, chunk_size) # type: ignore[arg-type] + chunk = self.read_chunk(fp, chunk_size) if not chunk: break yield chunk - return self.upload_parts(bucket, key, parts=chunk_generator(), create_params=create_params) + return self.upload_parts( + bucket, + key, + parts=chunk_generator(), + create_params=create_params, + ) def determine_chunk_size_from_file_size(self, file_size: Optional[int]) -> int: if file_size: diff --git a/hai/event_emitter.py b/hai/event_emitter.py index 5e2c63a..3e61088 100644 --- a/hai/event_emitter.py +++ b/hai/event_emitter.py @@ -1,6 +1,6 @@ from typing import Any, Callable, Dict, Optional, Set -DICT_NAME = '_event_emitter_dict' +DICT_NAME = "_event_emitter_dict" Handler = Callable[..., Any] @@ -13,24 +13,27 @@ class EventEmitter: event_types: Set[str] = set() def on(self, event: str, handler: Handler) -> None: - if event != '*' and event not in self.event_types: - raise ValueError(f'event type {event} is not known') + if event != "*" and event not in self.event_types: + raise ValueError(f"event type {event} is not known") _get_event_emitter_dict(self).setdefault(event, set()).add(handler) def off(self, event: str, handler: Handler) -> None: _get_event_emitter_dict(self).get(event, set()).discard(handler) - def emit(self, event: str, args: Optional[Dict[str, Any]] = None, quiet: bool = True) -> None: + def emit( + self, + event: str, + args: Optional[Dict[str, Any]] = None, + quiet: bool = True, + ) -> None: if event not in self.event_types: - raise ValueError(f'event type {event} is not known') + raise ValueError(f"event type {event} is not known") emitter_dict = _get_event_emitter_dict(self) - handlers = ( - emitter_dict.get(event, set()) | emitter_dict.get('*', set()) - ) - args = (args or {}) - args.setdefault('sender', self) - args.setdefault('event', event) + handlers = emitter_dict.get(event, set()) | emitter_dict.get("*", set()) + args = args or {} + args.setdefault("sender", self) + args.setdefault("event", event) for handler in handlers: try: handler(**args) diff --git a/hai/parallel.py b/hai/parallel.py index 613bb59..f20294f 100644 --- a/hai/parallel.py +++ b/hai/parallel.py @@ -11,7 +11,12 @@ class ParallelException(Exception): class TaskFailed(ParallelException): - def __init__(self, message: str, task: "ApplyResult[Any]", exception: Exception) -> None: + def __init__( + self, + message: str, + task: "ApplyResult[Any]", + exception: Exception, + ) -> None: super().__init__(message) self.task = task self.task_name = str(getattr(task, "name", None)) @@ -43,7 +48,9 @@ class ParallelRun: """ def __init__(self, parallelism: Optional[int] = None) -> None: - self.pool = ThreadPool(processes=(parallelism or (int(os.cpu_count() or 1) * 2))) + self.pool = ThreadPool( + processes=(parallelism or (int(os.cpu_count() or 1) * 2)), + ) self.task_complete_event = threading.Event() self.tasks: List[ApplyResult[Any]] = [] self.completed_tasks: WeakSet[ApplyResult[Any]] = WeakSet() @@ -79,7 +86,7 @@ def add_task( :param kwargs: Keyword arguments, if any. """ if not name: - name = (getattr(task, '__name__' or None) or str(task)) # type: ignore[arg-type] + name = getattr(task, "__name__" or None) or str(task) # type: ignore[arg-type] p_task = self.pool.apply_async( task, args=args, @@ -102,7 +109,7 @@ def wait( fail_fast: bool = True, interval: float = 0.5, callback: Optional[Callable[["ParallelRun"], None]] = None, - max_wait: Optional[float] = None + max_wait: Optional[float] = None, ) -> List["ApplyResult[Any]"]: """ Wait until all of the current tasks have finished, @@ -136,7 +143,7 @@ def wait( while True: if max_wait: - waited_for = (time.time() - start_time) + waited_for = time.time() - start_time if waited_for > max_wait: raise TimeoutError(f"Waited for {waited_for}/{max_wait} seconds.") @@ -159,7 +166,9 @@ def wait( # Reset the flag in case it had been set self.task_complete_event.clear() - return list(self.completed_tasks) # We can just as well return the completed tasks. + return list( + self.completed_tasks, + ) # We can just as well return the completed tasks. def _wait_tick(self, fail_fast: bool) -> bool: # Keep track of whether there were any incomplete tasks this loop. @@ -189,7 +198,7 @@ def _wait_tick(self, fail_fast: bool) -> bool: # raising the exception directly. if fail_fast and not task._success: # type: ignore[attr-defined] exc = task._value # type: ignore[attr-defined] - message = f'[{task.name}] {str(exc)}' # type: ignore[attr-defined] + message = f"[{task.name}] {str(exc)}" # type: ignore[attr-defined] raise TaskFailed( message, task=task, @@ -205,7 +214,7 @@ def maybe_raise(self) -> None: exceptions = self.exceptions if exceptions: raise TasksFailed( - '%d exceptions occurred' % len(exceptions), + f"{len(exceptions)} exceptions occurred", exception_map=exceptions, ) diff --git a/hai/pipe_pump.py b/hai/pipe_pump.py index 8a68799..809342f 100644 --- a/hai/pipe_pump.py +++ b/hai/pipe_pump.py @@ -8,6 +8,7 @@ class BasePipePump: """ Pump file objects into buffers. """ + read_size = 1024 def __init__(self) -> None: @@ -23,7 +24,7 @@ def register(self, key: str, fileobj: Optional[IO[bytes]]) -> None: :param fileobj: File object to poll. """ key = str(key) - self.buffers[key] = b'' + self.buffers[key] = b"" if fileobj: self.selector.register(fileobj, selectors.EVENT_READ, data=key) @@ -45,7 +46,7 @@ def pump(self, timeout: float = 0, max_reads: int = 1) -> int: while read_num < max_reads: read_num += 1 should_repeat = False - for (key, _event) in self.selector.select(timeout=timeout): + for key, _event in self.selector.select(timeout=timeout): fileobj: IO[bytes] = key.fileobj # type: ignore[assignment] data = fileobj.read(self.read_size) self.feed(key.data, data) @@ -106,7 +107,7 @@ def pumper() -> None: while self.selector is not None: self.pump(timeout=interval) - return threading.Thread(target=pumper, name=f'Thread for {self!r}') + return threading.Thread(target=pumper, name=f"Thread for {self!r}") LineHandler = Callable[[str, List[bytes]], None] @@ -118,7 +119,7 @@ class LinePipePump(BasePipePump): separated by a given bytestring. """ - def __init__(self, separator: bytes = b'\n') -> None: + def __init__(self, separator: bytes = b"\n") -> None: """ :param separator: Line separator byte sequence. """ @@ -155,7 +156,7 @@ def add_line(self, key: str, line: bytes) -> None: """ key = str(key) if not isinstance(line, bytes): - line = line.encode('utf-8') + line = line.encode("utf-8") line_list = self.lines.setdefault(key, []) line_list.append(line) @@ -210,7 +211,7 @@ def add_chunk_handler(self, handler: ChunkHandler) -> None: def _process_buffer(self, key: str, buffer: bytes) -> bytes: while len(buffer) >= self.chunk_size: - chunk, buffer = buffer[:self.chunk_size], buffer[self.chunk_size:] + chunk, buffer = buffer[: self.chunk_size], buffer[self.chunk_size :] self._handle_chunk(key, chunk) return buffer @@ -237,7 +238,7 @@ class CRLFPipePump(BasePipePump): Unlike LinePipePump, this does not buffer any history in its own state, only the last line. """ - CRLF_SEP_RE = re.compile(br'^(.*?)([\r\n])') + CRLF_SEP_RE = re.compile(rb"^(.*?)([\r\n])") def __init__(self) -> None: super().__init__() @@ -263,8 +264,8 @@ def _process_buffer(self, key: str, buffer: bytes) -> bytes: m = self.CRLF_SEP_RE.match(buffer) if not m: break - self._process_line(key, m.group(1), is_replace=(m.group(2) == b'\r')) - buffer = buffer[m.end():] + self._process_line(key, m.group(1), is_replace=(m.group(2) == b"\r")) + buffer = buffer[m.end() :] return buffer def _process_line(self, key: str, new_content: bytes, is_replace: bool) -> None: diff --git a/hai/rate_limiter.py b/hai/rate_limiter.py index 45e1d3b..24bfdca 100644 --- a/hai/rate_limiter.py +++ b/hai/rate_limiter.py @@ -4,10 +4,10 @@ class StateChange(Enum): - BECAME_OPEN = 'became_open' - BECAME_THROTTLED = 'became_throttled' - STILL_THROTTLED = 'still_throttled' - STILL_OPEN = 'still_open' + BECAME_OPEN = "became_open" + BECAME_THROTTLED = "became_throttled" + STILL_THROTTLED = "still_throttled" + STILL_OPEN = "still_open" STATE_CHANGE_MAP = { @@ -20,19 +20,19 @@ class StateChange(Enum): class Rate: - __slots__ = ('rate', 'period', 'rate_per_period') + __slots__ = ("rate", "period", "rate_per_period") def __init__(self, rate: int, period: Union[float, int] = 1) -> None: self.rate = float(rate) self.period = float(period) if self.rate < 0: - raise ValueError(f'`rate` must be >= 0 (not {self.rate!r})') + raise ValueError(f"`rate` must be >= 0 (not {self.rate!r})") if self.period <= 0: - raise ValueError(f'`period` must be > 0 (not {self.period!r})') - self.rate_per_period = (self.rate / self.period) + raise ValueError(f"`period` must be > 0 (not {self.period!r})") + self.rate_per_period = self.rate / self.period def __repr__(self) -> str: - return f'' + return f"" class TickResult: @@ -48,7 +48,7 @@ class TickResult: change state (as a `StateChange` value), it's available as `.state_change`. """ - __slots__ = ('state', 'did_change') + __slots__ = ("state", "did_change") def __init__(self, state: bool, did_change: bool) -> None: self.state = bool(state) @@ -62,8 +62,8 @@ def __bool__(self) -> bool: return self.state def __repr__(self) -> str: - state_text = ('throttled' if not self.state else 'open') - return f'' + state_text = "throttled" if not self.state else "open" + return f"" class RateLimiter: @@ -78,9 +78,9 @@ class RateLimiter: #: the `period` of the RateLimiter) as a floating-point number. #: By default, the high-resolution performance counter is used. #: This can be overwritten, or overridden in subclasses. - clock = (time.perf_counter if hasattr(time, 'perf_counter') else time.time) + clock = time.perf_counter if hasattr(time, "perf_counter") else time.time - __slots__ = ('rate', 'allow_underflow', 'last_check', 'allowance', 'current_state') + __slots__ = ("rate", "allow_underflow", "last_check", "allowance", "current_state") def __init__(self, rate: Rate, allow_underflow: bool = False) -> None: """ @@ -96,7 +96,11 @@ def __init__(self, rate: Rate, allow_underflow: bool = False) -> None: self.current_state: Optional[bool] = None @classmethod - def from_per_second(cls, per_second: int, allow_underflow: bool = False) -> "RateLimiter": + def from_per_second( + cls, + per_second: int, + allow_underflow: bool = False, + ) -> "RateLimiter": return cls(rate=Rate(rate=per_second), allow_underflow=allow_underflow) def _tick(self) -> bool: @@ -112,12 +116,15 @@ def _tick(self) -> bool: time_passed = current - last_check self.last_check = current self.allowance += time_passed * self.rate.rate_per_period # type: ignore[operator] - self.allowance = min(self.allowance, self.rate.rate) # Do not allow allowance to grow unbounded - throttled = (self.allowance < 1) + self.allowance = min( + self.allowance, + self.rate.rate, + ) # Do not allow allowance to grow unbounded + throttled = self.allowance < 1 if self.allow_underflow or not throttled: self.allowance -= 1 - return (not throttled) + return not throttled def reset(self) -> None: """ @@ -136,14 +143,16 @@ def tick(self) -> TickResult: if self.current_state is None: self.current_state = new_state - did_change = (new_state is not self.current_state) + did_change = new_state is not self.current_state self.current_state = new_state return TickResult(new_state, did_change) def __repr__(self) -> str: - state_text = ('throttled' if not self.current_state else 'open') - return f'' + state_text = "throttled" if not self.current_state else "open" + return ( + f"" + ) class MultiRateLimiter: @@ -154,7 +163,11 @@ class MultiRateLimiter: rate_limiter_class = RateLimiter allow_underflow = False - def __init__(self, default_limit: Rate, per_name_limits: Optional[Dict[str, Rate]] = None) -> None: + def __init__( + self, + default_limit: Rate, + per_name_limits: Optional[Dict[str, Rate]] = None, + ) -> None: self.limiters: Dict[str, RateLimiter] = {} self.default_limit = default_limit self.per_name_limits = dict(per_name_limits or {}) diff --git a/hai_tests/test_crlf_pipe_pump.py b/hai_tests/test_crlf_pipe_pump.py index 867493d..b2d0590 100644 --- a/hai_tests/test_crlf_pipe_pump.py +++ b/hai_tests/test_crlf_pipe_pump.py @@ -12,11 +12,11 @@ def __init__(self): def handle_crlf_input(self, key, old_content, new_content, is_replace): if is_replace: if new_content: - self.log.append(f'Replace {old_content} with {new_content}') + self.log.append(f"Replace {old_content} with {new_content}") self.lines[-1] = new_content self.raw_lines.append(new_content) else: - self.log.append(f'Print {new_content}') + self.log.append(f"Print {new_content}") self.lines.append(new_content) self.raw_lines.append(new_content) @@ -25,34 +25,41 @@ def do_crlf_test(input, chunk_size=64): handler = CrlfTestHandler() cpp = CRLFPipePump() cpp.add_handler(handler.handle_crlf_input) - cpp.register('test', None) + cpp.register("test", None) input_io = io.BytesIO(input) while True: chunk = input_io.read(chunk_size) if not chunk: break - cpp.feed('test', chunk) + cpp.feed("test", chunk) cpp.close() return handler def test_crlf_pipe_pump(): - input = b'''first\rreplaced first\nsecond\r\rreplaced second\r\n\r\r\rthird\n\n\nfourth''' + input = b"""first\rreplaced first\nsecond\r\rreplaced second\r\n\r\r\rthird\n\n\nfourth""" handler = do_crlf_test(input) - assert handler.lines == [b'replaced first', b'replaced second', b'third', b'', b'', b'fourth'] + assert handler.lines == [ + b"replaced first", + b"replaced second", + b"third", + b"", + b"", + b"fourth", + ] assert handler.raw_lines == [ - b'first', - b'replaced first', - b'second', - b'replaced second', - b'', - b'third', - b'', - b'', - b'fourth', + b"first", + b"replaced first", + b"second", + b"replaced second", + b"", + b"third", + b"", + b"", + b"fourth", ] def test_crlf_pipe_pump_rn(): - handler = do_crlf_test(b'''oispa\r\nkaljaa''') - assert handler.lines == handler.raw_lines == [b'oispa', b'kaljaa'] + handler = do_crlf_test(b"""oispa\r\nkaljaa""") + assert handler.lines == handler.raw_lines == [b"oispa", b"kaljaa"] diff --git a/hai_tests/test_event_emitter.py b/hai_tests/test_event_emitter.py index 6501e43..6c5ba9f 100644 --- a/hai_tests/test_event_emitter.py +++ b/hai_tests/test_event_emitter.py @@ -4,10 +4,10 @@ class Thing(EventEmitter): - event_types = {'one', 'two'} + event_types = {"one", "two"} -@pytest.mark.parametrize('omni', (False, True)) +@pytest.mark.parametrize("omni", (False, True)) def test_event_emitter(omni): t = Thing() events = [] @@ -17,23 +17,23 @@ def handle(sender, **args): events.append(args) if omni: - t.on('*', handle) + t.on("*", handle) else: - t.on('one', handle) - t.emit('one') - t.emit('two') - t.off('one', handle) - t.emit('one', {'oh': 'no'}) + t.on("one", handle) + t.emit("one") + t.emit("two") + t.off("one", handle) + t.emit("one", {"oh": "no"}) if omni: assert events == [ - {'event': 'one'}, - {'event': 'two'}, - {'event': 'one', 'oh': 'no'}, + {"event": "one"}, + {"event": "two"}, + {"event": "one", "oh": "no"}, ] else: assert events == [ - {'event': 'one'}, + {"event": "one"}, ] @@ -41,19 +41,19 @@ def test_event_emitter_exceptions(): t = Thing() def handle(**args): - raise OSError('oh no') + raise OSError("oh no") - t.on('*', handle) - t.emit('one') + t.on("*", handle) + t.emit("one") with pytest.raises(IOError): - t.emit('one', quiet=False) + t.emit("one", quiet=False) def test_event_emitter_unknown_event_types(): t = Thing() with pytest.raises(ValueError): - t.on('hullo', None) + t.on("hullo", None) with pytest.raises(ValueError): - t.emit('hello') + t.emit("hello") diff --git a/hai_tests/test_multipart_upload.py b/hai_tests/test_multipart_upload.py index 5d391ff..a2b1cb5 100644 --- a/hai_tests/test_multipart_upload.py +++ b/hai_tests/test_multipart_upload.py @@ -18,28 +18,30 @@ def read_chunk(self, fp, size): @mock_s3 -@pytest.mark.parametrize('file_type', ('real', 'imaginary')) +@pytest.mark.parametrize("file_type", ("real", "imaginary")) @pytest.mark.parametrize( - 'mpu_class', (MultipartUploader, ChunkCallbackMultipartUploader), ids=('no-func', 'chunk-func') + "mpu_class", + (MultipartUploader, ChunkCallbackMultipartUploader), + ids=("no-func", "chunk-func"), ) def test_multipart_upload(tmpdir, file_type, mpu_class): - if file_type == 'real': - temp_path = tmpdir.join('temp.dat') - with temp_path.open('wb') as outf: + if file_type == "real": + temp_path = tmpdir.join("temp.dat") + with temp_path.open("wb") as outf: for chunk in range(17): outf.write(bytes((chunk,)) * 1024 * 1024) expected_size = outf.tell() - file = temp_path.open('rb') - elif file_type == 'imaginary': + file = temp_path.open("rb") + elif file_type == "imaginary": expected_size = S3_MINIMUM_MULTIPART_FILE_SIZE * 2 - file = BytesIO(b'\xC0' * expected_size) + file = BytesIO(b"\xC0" * expected_size) file.seek(0) else: # pragma: no cover - raise NotImplementedError('...') + raise NotImplementedError("...") - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key_name = 'hello/world' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key_name = "hello/world" s3.create_bucket(Bucket=bucket_name) mpu = mpu_class(s3) events = [] @@ -47,7 +49,7 @@ def test_multipart_upload(tmpdir, file_type, mpu_class): def event_handler(**args): events.append(args) - mpu.on('*', event_handler) + mpu.on("*", event_handler) if mpu_class is ChunkCallbackMultipartUploader: mpu.chunk_sizes = [] @@ -56,8 +58,8 @@ def event_handler(**args): mpu.upload_file(bucket_name, key_name, file) obj = s3.get_object(Bucket=bucket_name, Key=key_name) - assert obj['ContentLength'] == expected_size - assert any(e['event'] == 'progress' for e in events) + assert obj["ContentLength"] == expected_size + assert any(e["event"] == "progress" for e in events) if mpu_class is ChunkCallbackMultipartUploader: assert sum(mpu.chunk_sizes) == expected_size @@ -65,31 +67,31 @@ def event_handler(**args): @mock_s3 def test_invalid_chunk_size(): - s3 = boto3.client('s3', region_name='us-east-1') + s3 = boto3.client("s3", region_name="us-east-1") mpu = MultipartUploader(s3) with pytest.raises(ValueError): - mpu.upload_file('foo', 'foo', BytesIO(), chunk_size=300) + mpu.upload_file("foo", "foo", BytesIO(), chunk_size=300) @mock_s3 def test_invalid_file_size(tmpdir): - s3 = boto3.client('s3', region_name='us-east-1') - pth = tmpdir.join('temp.dat') - pth.write('foofoo') + s3 = boto3.client("s3", region_name="us-east-1") + pth = tmpdir.join("temp.dat") + pth.write("foofoo") mpu = MultipartUploader(s3) with pytest.raises(ValueError): - mpu.upload_file('foo', 'foo', pth.open()) + mpu.upload_file("foo", "foo", pth.open()) @mock_s3 def test_error_handling(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='foo') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="foo") def upload_fn(**args): - raise OSError('the internet is dead') + raise OSError("the internet is dead") s3.upload_part = upload_fn mpu = MultipartUploader(s3) with pytest.raises(IOError): - mpu.upload_parts('foo', 'foo', [b'\x00' * S3_MINIMUM_MULTIPART_FILE_SIZE]) + mpu.upload_parts("foo", "foo", [b"\x00" * S3_MINIMUM_MULTIPART_FILE_SIZE]) diff --git a/hai_tests/test_parallel.py b/hai_tests/test_parallel.py index 2360197..4e3c443 100644 --- a/hai_tests/test_parallel.py +++ b/hai_tests/test_parallel.py @@ -8,7 +8,7 @@ def agh(): time.sleep(0.1) - raise RuntimeError('agh!') + raise RuntimeError("agh!") def return_true(): @@ -21,43 +21,43 @@ def test_parallel_crash(): failing_task = parallel.add_task(agh) with pytest.raises(TaskFailed) as ei: parallel.wait(fail_fast=True) - assert str(ei.value.__cause__) == str(ei.value.exception) == 'agh!' + assert str(ei.value.__cause__) == str(ei.value.exception) == "agh!" assert ei.value.task == failing_task def test_parallel_retval(): with ParallelRun() as parallel: - parallel.add_task(return_true, name='blerg') + parallel.add_task(return_true, name="blerg") parallel.add_task(return_true) parallel.wait() - assert parallel.return_values == {'blerg': True, 'return_true': True} + assert parallel.return_values == {"blerg": True, "return_true": True} def test_parallel_wait_without_fail_fast(): with ParallelRun() as parallel: - parallel.add_task(return_true, name='true') - parallel.add_task(agh, name='agh') + parallel.add_task(return_true, name="true") + parallel.add_task(agh, name="agh") parallel.wait(fail_fast=False) - assert parallel.exceptions['agh'].args[0] == 'agh!' - assert parallel.return_values['true'] is True + assert parallel.exceptions["agh"].args[0] == "agh!" + assert parallel.return_values["true"] is True with pytest.raises(TasksFailed) as ei: parallel.maybe_raise() assert len(ei.value.exception_map) == 1 - assert isinstance(ei.value.exception_map['agh'], RuntimeError) - assert ei.value.failed_task_names == {'agh'} + assert isinstance(ei.value.exception_map["agh"], RuntimeError) + assert ei.value.failed_task_names == {"agh"} -@pytest.mark.parametrize('is_empty_run', (False, True)) +@pytest.mark.parametrize("is_empty_run", (False, True)) def test_parallel_callback_is_called_at_least_once_on_wait(is_empty_run): with ParallelRun() as parallel: stub = MagicMock() if not is_empty_run: - parallel.add_task(return_true, name='true') + parallel.add_task(return_true, name="true") parallel.wait(callback=stub) stub.assert_called_with(parallel) -@pytest.mark.parametrize('fail', (False, True)) +@pytest.mark.parametrize("fail", (False, True)) def test_parallel_limit(fail): """ Test that parallelism limits work. @@ -70,7 +70,7 @@ def tick(): nonlocal count assert count <= 3 count += 1 - time.sleep(.1) + time.sleep(0.1) count -= 1 with ParallelRun(parallelism=(5 if fail else 3)) as parallel: @@ -89,7 +89,7 @@ def test_parallel_long_interval_interruptible(): Test that even with long poll intervals, completion events interrupt the sleep """ with ParallelRun() as parallel: - parallel.add_task(time.sleep, args=(.5,)) # will only wait for half a second + parallel.add_task(time.sleep, args=(0.5,)) # will only wait for half a second t0 = time.time() parallel.wait(interval=10) # would wait for 10 t1 = time.time() @@ -100,4 +100,4 @@ def test_parallel_max_wait(): with ParallelRun() as parallel: parallel.add_task(time.sleep, args=(1,)) with pytest.raises(TimeoutError): - parallel.wait(interval=.1, max_wait=.5) + parallel.wait(interval=0.1, max_wait=0.5) diff --git a/hai_tests/test_pipe_pump.py b/hai_tests/test_pipe_pump.py index 2425ffd..b6c7e05 100644 --- a/hai_tests/test_pipe_pump.py +++ b/hai_tests/test_pipe_pump.py @@ -20,21 +20,23 @@ def add_handler(key, lines): with contextlib.closing(LinePipePump()) as pp: pp.add_line_handler(add_handler) - pp.register('stdout', proc.stdout) + pp.register("stdout", proc.stdout) pp.as_thread().start() proc.wait() assert line_lists == [ - [b'hello'], - [b'olleh', b'world'], # olleh due to mutation + [b"hello"], + [b"olleh", b"world"], # olleh due to mutation ] - assert pp.get_value('stdout') == b'hello\ndlrow' # and again re-reversed for the final list + assert ( + pp.get_value("stdout") == b"hello\ndlrow" + ) # and again re-reversed for the final list def test_chunk_pipe_pump(): proc = subprocess.Popen( - args='dd if=/dev/urandom bs=100 count=10', + args="dd if=/dev/urandom bs=100 count=10", stdout=subprocess.PIPE, shell=True, bufsize=0, @@ -47,7 +49,7 @@ def add_handler(key, chunk): with contextlib.closing(ChunkPipePump()) as pp: pp.add_chunk_handler(add_handler) - pp.register('stdout', proc.stdout) + pp.register("stdout", proc.stdout) while proc.poll() is None: # demonstrate hand-pumping pp.pump(0.05) diff --git a/hai_tests/test_rate_limiter.py b/hai_tests/test_rate_limiter.py index 89908cd..1d2170c 100644 --- a/hai_tests/test_rate_limiter.py +++ b/hai_tests/test_rate_limiter.py @@ -25,13 +25,15 @@ def test_rate_limiter(): assert not r.did_change assert r.state_change == StateChange.STILL_THROTTLED - time.sleep(.02) # Wait for a very short time... + time.sleep(0.02) # Wait for a very short time... assert not l.tick() # Still throttled... - time.sleep(.2) # Wait for enough to get 2 tokens. + time.sleep(0.2) # Wait for enough to get 2 tokens. assert l.allowance < 1 # No tokens? r = l.tick() - assert 1 <= l.allowance <= 2 # We used one token in the tick, but one and some should be left + assert ( + 1 <= l.allowance <= 2 + ) # We used one token in the tick, but one and some should be left assert r assert r.did_change assert r.state_change == StateChange.BECAME_OPEN @@ -98,17 +100,17 @@ def test_rate_construction_validation(): def test_multi_limiter(): ml = MultiRateLimiter(default_limit=Rate(1, 0.1)) # Tick two limiters: - assert ml.tick('foo') - assert ml.tick('bar') - assert not ml.tick('foo') - assert not ml.tick('bar') + assert ml.tick("foo") + assert ml.tick("bar") + assert not ml.tick("foo") + assert not ml.tick("bar") # Reset one of them and assert it is now open, yet the other is not - assert ml.reset('foo') # Check that it got reset - assert ml.tick('foo') - assert not ml.tick('bar') + assert ml.reset("foo") # Check that it got reset + assert ml.tick("foo") + assert not ml.tick("bar") # Wait for the other to open up again time.sleep(0.11) - assert ml.tick('bar') + assert ml.tick("bar") def test_smoke_reprs(): diff --git a/pyproject.toml b/pyproject.toml index 74bb88b..ebebf3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,28 @@ [build-system] -requires = [ - "setuptools>=42", - "wheel" +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "hai" +dynamic = ["version"] +description = "Toolbelt library" +readme = "README.md" +license = "MIT" +requires-python = ">=3.7" +authors = [ + { name = "Valohai", email = "dev@valohai.com" }, +] + +[project.urls] +Homepage = "https://github.com/valohai/hai" + +[tool.hatch.version] +path = "hai/__init__.py" + +[tool.hatch.build.targets.sdist] +include = [ + "/hai", ] -build-backend = "setuptools.build_meta" [tool.mypy] strict = true @@ -28,12 +47,17 @@ mccabe.max-complexity = 10 select = [ "B", # bugbear "C90", # mccabe + "COM", # trailing commas "E", # pycodestyle "F", # pyflakes "I", # isort "T", # debugger and print + "UP", # upgrade "W", # pycodestyle ] ignore = [ "E741", # Ambiguous variable name ] + +[tool.pytest.ini_options] +norecursedirs = [".git", ".tox"] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index ca676ff..0000000 --- a/setup.cfg +++ /dev/null @@ -1,21 +0,0 @@ -[metadata] -name = hai -version = attr: hai.__version__ -author = Valohai -author_email = dev@valohai.com -maintainer = Aarni Koskela -maintainer_email = akx@iki.fi -description = Toolbelt library -url = https://github.com/valohai/hai - -[options] -packages = find: -python_requires = >=3.7 -include_package_data = true - -[options.packages.find] -where = . -exclude = hai_tests - -[tool:pytest] -norecursedirs = .git .tox