diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 7335cc831..9869beccd 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -2,38 +2,62 @@ name: CI on: push: - branches: '*' + branches: "*" pull_request: - branches: '*' + branches: "*" jobs: + # Run "pre-commit run --all-files" + pre-commit: + runs-on: ubuntu-20.04 + timeout-minutes: 2 + + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: 3.8 + + # ref: https://github.com/pre-commit/action + - uses: pre-commit/action@v2.0.0 + - name: Help message if pre-commit fail + if: ${{ failure() }} + run: | + echo "You can install pre-commit hooks to automatically run formatting" + echo "on each commit with:" + echo " pre-commit install" + echo "or you can run by hand on staged files with" + echo " pre-commit run" + echo "or after-the-fact on already committed files with" + echo " pre-commit run --all-files" + build-n-test-n-coverage: name: Build, test and code coverage runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: - os: [ubuntu-latest, macos-latest, windows-latest] - python-version: [ 3.6, 3.7, 3.8, 3.9 ] + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: [3.6, 3.7, 3.8, 3.9] env: - OS: ${{ matrix.os }} - PYTHON: '3.9' + OS: ${{ matrix.os }} + PYTHON: "3.9" steps: - - name: Checkout - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - pip install --upgrade setuptools pip - pip install --upgrade --upgrade-strategy eager -e .[test] pytest-cov codecov 'coverage<5' - pip freeze - - name: Check types - run: mypy jupyter_client/manager.py jupyter_client/multikernelmanager.py jupyter_client/client.py jupyter_client/blocking/client.py jupyter_client/asynchronous/client.py jupyter_client/channels.py jupyter_client/session.py jupyter_client/adapter.py jupyter_client/connect.py jupyter_client/consoleapp.py jupyter_client/jsonutil.py jupyter_client/kernelapp.py jupyter_client/launcher.py - - name: Run the tests - run: py.test --cov jupyter_client -v jupyter_client - - name: Code coverage - run: codecov + - name: Checkout + uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install --upgrade setuptools pip + pip install --upgrade --upgrade-strategy eager -e .[test] pytest-cov codecov 'coverage<5' + pip freeze + - name: Check types + run: mypy jupyter_client/manager.py jupyter_client/multikernelmanager.py jupyter_client/client.py jupyter_client/blocking/client.py jupyter_client/asynchronous/client.py jupyter_client/channels.py jupyter_client/session.py jupyter_client/adapter.py jupyter_client/connect.py jupyter_client/consoleapp.py jupyter_client/jsonutil.py jupyter_client/kernelapp.py jupyter_client/launcher.py + - name: Run the tests + run: py.test --cov jupyter_client -v jupyter_client + - name: Code coverage + run: codecov diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..200c35923 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,25 @@ +repos: + - repo: https://github.com/asottile/reorder_python_imports + rev: v1.9.0 + hooks: + - id: reorder-python-imports + - repo: https://github.com/psf/black + rev: 20.8b1 + hooks: + - id: black + args: ["--line-length", "100"] + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v2.2.1 + hooks: + - id: prettier + - repo: https://gitlab.com/pycqa/flake8 + rev: "3.8.4" + hooks: + - id: flake8 + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.4.0 + hooks: + - id: end-of-file-fixer + - id: check-case-conflict + - id: check-executables-have-shebangs + - id: requirements-txt-fixer diff --git a/COPYING.md b/COPYING.md index bd6397d45..7cfb970db 100644 --- a/COPYING.md +++ b/COPYING.md @@ -25,7 +25,7 @@ software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER @@ -47,8 +47,8 @@ Jupyter uses a shared copyright model. Each contributor maintains copyright over their contributions to Jupyter. But, it is important to note that these contributions are typically only changes to the repositories. Thus, the Jupyter source code, in its entirety is not the copyright of any single person or -institution. Instead, it is the collective copyright of the entire Jupyter -Development Team. If individual contributors want to maintain a record of what +institution. Instead, it is the collective copyright of the entire Jupyter +Development Team. If individual contributors want to maintain a record of what changes/contributions they have specific copyright on, they should indicate their copyright in the commit message of the change, when they commit the change to one of the Jupyter repositories. diff --git a/README.md b/README.md index e588f2022..34fb5a97c 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![Build Status](https://github.com/jupyter/jupyter_client/workflows/CI/badge.svg)](https://github.com/jupyter/jupyter_client/actions) [![Code Health](https://landscape.io/github/jupyter/jupyter_client/master/landscape.svg?style=flat)](https://landscape.io/github/jupyter/jupyter_client/master) - +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) `jupyter_client` contains the reference implementation of the [Jupyter protocol][]. It also provides client and kernel management APIs for working with kernels. @@ -10,8 +10,7 @@ It also provides client and kernel management APIs for working with kernels. It also provides the `jupyter kernelspec` entrypoint for installing kernelspecs for use with Jupyter frontends. -[Jupyter protocol]: https://jupyter-client.readthedocs.io/en/latest/messaging.html - +[jupyter protocol]: https://jupyter-client.readthedocs.io/en/latest/messaging.html # Development Setup @@ -43,5 +42,4 @@ The following commands build the documentation in HTML format and check for brok Point your browser to the following URL to access the generated documentation: -_file:///my/projects/jupyter\_client/docs/\_build/html/index.html_ - +_file:///my/projects/jupyter_client/docs/\_build/html/index.html_ diff --git a/RELEASING.md b/RELEASING.md index d362acec1..927923008 100644 --- a/RELEASING.md +++ b/RELEASING.md @@ -25,7 +25,6 @@ twine upload dist/* - Load `jupyter_client/_version.py` and bump the patch version and add the 'dev' tag back to the end of the version tuple. - ## Push to GitHub ```bash diff --git a/docs/api/manager.rst b/docs/api/manager.rst index f1d991383..659f95db0 100644 --- a/docs/api/manager.rst +++ b/docs/api/manager.rst @@ -55,4 +55,3 @@ Utility functions ----------------- .. autofunction:: run_kernel - diff --git a/docs/conf.py b/docs/conf.py index a19150705..6f9d47492 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -11,20 +11,17 @@ # # All configuration values have a default; values that are commented out # serve to show the default. - -import sys import os -import shlex # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -#sys.path.insert(0, os.path.abspath('.')) +# sys.path.insert(0, os.path.abspath('.')) # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom @@ -45,7 +42,7 @@ source_suffix = '.rst' # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. master_doc = 'index' @@ -81,9 +78,9 @@ # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -91,27 +88,27 @@ # The reST default role (used for this markup: `text`) to use for all # documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -126,91 +123,91 @@ # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -#html_static_path = ['_static'] +# html_static_path = ['_static'] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. -#html_extra_path = [] +# html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Language to be used for generating the HTML full-text search index. # Sphinx supports the following languages: # 'da', 'de', 'en', 'es', 'fi', 'fr', 'h', 'it', 'ja' # 'nl', 'no', 'pt', 'ro', 'r', 'sv', 'tr' -#html_search_language = 'en' +# html_search_language = 'en' # A dictionary with options for the search language support, empty by default. # Now only 'ja' uses this config value -#html_search_options = {'type': 'default'} +# html_search_options = {'type': 'default'} # The name of a javascript file (relative to the configuration directory) that # implements a search results scorer. If empty, the default will be used. -#html_search_scorer = 'scorer.js' +# html_search_scorer = 'scorer.js' # Output file base name for HTML help builder. htmlhelp_basename = 'jupyter_clientdoc' @@ -218,59 +215,58 @@ # -- Options for LaTeX output --------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', - -# Latex figure (float) alignment -#'figure_align': 'htbp', + # The paper size ('letterpaper' or 'a4paper'). + # 'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + # 'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + # 'preamble': '', + # Latex figure (float) alignment + # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'jupyter_client.tex', 'jupyter\\_client Documentation', - 'Jupyter Development Team', 'manual'), + ( + master_doc, + 'jupyter_client.tex', + 'jupyter\\_client Documentation', + 'Jupyter Development Team', + 'manual', + ), ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'jupyter_client', 'jupyter_client Documentation', - [author], 1) -] +man_pages = [(master_doc, 'jupyter_client', 'jupyter_client Documentation', [author], 1)] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------- @@ -279,22 +275,28 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'jupyter_client', 'jupyter_client Documentation', - author, 'jupyter_client', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + 'jupyter_client', + 'jupyter_client Documentation', + author, + 'jupyter_client', + 'One line description of project.', + 'Miscellaneous', + ), ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False # Example configuration for intersphinx: refer to the Python standard library. @@ -306,6 +308,7 @@ if not on_rtd: # only import and set the theme if we're building docs locally import sphinx_rtd_theme + html_theme = 'sphinx_rtd_theme' html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] # otherwise, readthedocs.org uses their theme by default, so no need to specify it diff --git a/docs/environment.yml b/docs/environment.yml index b7ed943c1..93259abdb 100644 --- a/docs/environment.yml +++ b/docs/environment.yml @@ -2,11 +2,11 @@ name: jup_client channels: - conda-forge dependencies: -- pyzmq -- python==3.7 -- traitlets>=4.1 -- jupyter_core -- sphinx>=1.3.6 -- sphinx_rtd_theme -- pip: - - sphinxcontrib_github_alt + - pyzmq + - python==3.7 + - traitlets>=4.1 + - jupyter_core + - sphinx>=1.3.6 + - sphinx_rtd_theme + - pip: + - sphinxcontrib_github_alt diff --git a/docs/index.rst b/docs/index.rst index a0b8855cc..a238aba40 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -43,4 +43,3 @@ Indices and tables * :ref:`genindex` * :ref:`modindex` * :ref:`search` - diff --git a/jupyter_client/__init__.py b/jupyter_client/__init__.py index f72c516d3..400bdea63 100644 --- a/jupyter_client/__init__.py +++ b/jupyter_client/__init__.py @@ -1,10 +1,15 @@ """Client-side implementations of the Jupyter protocol""" - -from ._version import version_info, __version__, protocol_version_info, protocol_version -from .connect import * -from .launcher import * -from .client import KernelClient -from .manager import KernelManager, AsyncKernelManager, run_kernel -from .blocking import BlockingKernelClient -from .asynchronous import AsyncKernelClient -from .multikernelmanager import MultiKernelManager, AsyncMultiKernelManager +from ._version import __version__ # noqa +from ._version import protocol_version # noqa +from ._version import protocol_version_info # noqa +from ._version import version_info # noqa +from .asynchronous import AsyncKernelClient # noqa +from .blocking import BlockingKernelClient # noqa +from .client import KernelClient # noqa +from .connect import * # noqa +from .launcher import * # noqa +from .manager import AsyncKernelManager # noqa +from .manager import KernelManager # noqa +from .manager import run_kernel # noqa +from .multikernelmanager import AsyncMultiKernelManager # noqa +from .multikernelmanager import MultiKernelManager # noqa diff --git a/jupyter_client/adapter.py b/jupyter_client/adapter.py index e4e09a54c..838f7150d 100644 --- a/jupyter_client/adapter.py +++ b/jupyter_client/adapter.py @@ -1,18 +1,17 @@ """Adapters for Jupyter msg spec versions.""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - -import re import json -from typing import List, Tuple, Dict, Any +import re +from typing import Any +from typing import Dict +from typing import List +from typing import Tuple from jupyter_client import protocol_version_info -def code_to_line( - code: str, - cursor_pos: int -) -> Tuple[str, int]: + +def code_to_line(code: str, cursor_pos: int) -> Tuple[str, int]: """Turn a multiline code block and cursor position into a single line and new cursor position. @@ -29,14 +28,12 @@ def code_to_line( return line, cursor_pos -_match_bracket = re.compile(r'\([^\(\)]+\)', re.UNICODE) -_end_bracket = re.compile(r'\([^\(]*$', re.UNICODE) -_identifier = re.compile(r'[a-z_][0-9a-z._]*', re.I|re.UNICODE) +_match_bracket = re.compile(r"\([^\(\)]+\)", re.UNICODE) +_end_bracket = re.compile(r"\([^\(]*$", re.UNICODE) +_identifier = re.compile(r"[a-z_][0-9a-z._]*", re.I | re.UNICODE) + -def extract_oname_v4( - code: str, - cursor_pos: int -) -> str: +def extract_oname_v4(code: str, cursor_pos: int) -> str: """Reimplement token-finding logic from IPython 2.x javascript for adapting object_info_request from v5 to v4 @@ -45,18 +42,18 @@ def extract_oname_v4( line, _ = code_to_line(code, cursor_pos) oldline = line - line = _match_bracket.sub('', line) + line = _match_bracket.sub("", line) while oldline != line: oldline = line - line = _match_bracket.sub('', line) + line = _match_bracket.sub("", line) # remove everything after last open bracket - line = _end_bracket.sub('', line) + line = _end_bracket.sub("", line) matches = _identifier.findall(line) if matches: return matches[-1] else: - return '' + return "" class Adapter(object): @@ -67,201 +64,155 @@ class Adapter(object): msg_type_map: Dict[str, str] = {} - def update_header( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: + def update_header(self, msg: Dict[str, Any]) -> Dict[str, Any]: return msg - def update_metadata( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: + def update_metadata(self, msg: Dict[str, Any]) -> Dict[str, Any]: return msg - def update_msg_type( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - header = msg['header'] - msg_type = header['msg_type'] + def update_msg_type(self, msg: Dict[str, Any]) -> Dict[str, Any]: + header = msg["header"] + msg_type = header["msg_type"] if msg_type in self.msg_type_map: - msg['msg_type'] = header['msg_type'] = self.msg_type_map[msg_type] + msg["msg_type"] = header["msg_type"] = self.msg_type_map[msg_type] return msg - def handle_reply_status_error( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: + def handle_reply_status_error(self, msg: Dict[str, Any]) -> Dict[str, Any]: """This will be called *instead of* the regular handler on any reply with status != ok """ return msg - def __call__( - self, - msg: Dict[str, Any] - ): + def __call__(self, msg: Dict[str, Any]): msg = self.update_header(msg) msg = self.update_metadata(msg) msg = self.update_msg_type(msg) - header = msg['header'] + header = msg["header"] - handler = getattr(self, header['msg_type'], None) + handler = getattr(self, header["msg_type"], None) if handler is None: return msg # handle status=error replies separately (no change, at present) - if msg['content'].get('status', None) in {'error', 'aborted'}: + if msg["content"].get("status", None) in {"error", "aborted"}: return self.handle_reply_status_error(msg) return handler(msg) -def _version_str_to_list( - version: str -) -> List[int]: + +def _version_str_to_list(version: str) -> List[int]: """convert a version string to a list of ints non-int segments are excluded """ v = [] - for part in version.split('.'): + for part in version.split("."): try: v.append(int(part)) except ValueError: pass return v + class V5toV4(Adapter): """Adapt msg protocol v5 to v4""" - version = '4.1' + version = "4.1" msg_type_map = { - 'execute_result' : 'pyout', - 'execute_input' : 'pyin', - 'error' : 'pyerr', - 'inspect_request' : 'object_info_request', - 'inspect_reply' : 'object_info_reply', + "execute_result": "pyout", + "execute_input": "pyin", + "error": "pyerr", + "inspect_request": "object_info_request", + "inspect_reply": "object_info_reply", } - def update_header( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - msg['header'].pop('version', None) - msg['parent_header'].pop('version', None) + def update_header(self, msg: Dict[str, Any]) -> Dict[str, Any]: + msg["header"].pop("version", None) + msg["parent_header"].pop("version", None) return msg # shell channel - def kernel_info_reply( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: + def kernel_info_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]: v4c = {} - content = msg['content'] - for key in ('language_version', 'protocol_version'): + content = msg["content"] + for key in ("language_version", "protocol_version"): if key in content: v4c[key] = _version_str_to_list(content[key]) - if content.get('implementation', '') == 'ipython' \ - and 'implementation_version' in content: - v4c['ipython_version'] = _version_str_to_list(content['implementation_version']) - language_info = content.get('language_info', {}) - language = language_info.get('name', '') - v4c.setdefault('language', language) - if 'version' in language_info: - v4c.setdefault('language_version', _version_str_to_list(language_info['version'])) - msg['content'] = v4c + if content.get("implementation", "") == "ipython" and "implementation_version" in content: + v4c["ipython_version"] = _version_str_to_list(content["implementation_version"]) + language_info = content.get("language_info", {}) + language = language_info.get("name", "") + v4c.setdefault("language", language) + if "version" in language_info: + v4c.setdefault("language_version", _version_str_to_list(language_info["version"])) + msg["content"] = v4c return msg - def execute_request( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - content = msg['content'] - content.setdefault('user_variables', []) + def execute_request(self, msg: Dict[str, Any]) -> Dict[str, Any]: + content = msg["content"] + content.setdefault("user_variables", []) return msg - def execute_reply( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - content = msg['content'] - content.setdefault('user_variables', {}) + def execute_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]: + content = msg["content"] + content.setdefault("user_variables", {}) # TODO: handle payloads return msg - def complete_request( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - content = msg['content'] - code = content['code'] - cursor_pos = content['cursor_pos'] + def complete_request(self, msg: Dict[str, Any]) -> Dict[str, Any]: + content = msg["content"] + code = content["code"] + cursor_pos = content["cursor_pos"] line, cursor_pos = code_to_line(code, cursor_pos) - new_content = msg['content'] = {} - new_content['text'] = '' - new_content['line'] = line - new_content['block'] = None - new_content['cursor_pos'] = cursor_pos + new_content = msg["content"] = {} + new_content["text"] = "" + new_content["line"] = line + new_content["block"] = None + new_content["cursor_pos"] = cursor_pos return msg - def complete_reply( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - content = msg['content'] - cursor_start = content.pop('cursor_start') - cursor_end = content.pop('cursor_end') + def complete_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]: + content = msg["content"] + cursor_start = content.pop("cursor_start") + cursor_end = content.pop("cursor_end") match_len = cursor_end - cursor_start - content['matched_text'] = content['matches'][0][:match_len] - content.pop('metadata', None) + content["matched_text"] = content["matches"][0][:match_len] + content.pop("metadata", None) return msg - def object_info_request( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - content = msg['content'] - code = content['code'] - cursor_pos = content['cursor_pos'] + def object_info_request(self, msg: Dict[str, Any]) -> Dict[str, Any]: + content = msg["content"] + code = content["code"] + cursor_pos = content["cursor_pos"] line, _ = code_to_line(code, cursor_pos) - new_content = msg['content'] = {} - new_content['oname'] = extract_oname_v4(code, cursor_pos) - new_content['detail_level'] = content['detail_level'] + new_content = msg["content"] = {} + new_content["oname"] = extract_oname_v4(code, cursor_pos) + new_content["detail_level"] = content["detail_level"] return msg - def object_info_reply( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: + def object_info_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]: """inspect_reply can't be easily backward compatible""" - msg['content'] = {'found' : False, 'oname' : 'unknown'} + msg["content"] = {"found": False, "oname": "unknown"} return msg # iopub channel - def stream( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - content = msg['content'] - content['data'] = content.pop('text') + def stream(self, msg: Dict[str, Any]) -> Dict[str, Any]: + content = msg["content"] + content["data"] = content.pop("text") return msg - def display_data( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - content = msg['content'] + def display_data(self, msg: Dict[str, Any]) -> Dict[str, Any]: + content = msg["content"] content.setdefault("source", "display") - data = content['data'] - if 'application/json' in data: + data = content["data"] + if "application/json" in data: try: - data['application/json'] = json.dumps(data['application/json']) + data["application/json"] = json.dumps(data["application/json"]) except Exception: # warn? pass @@ -269,176 +220,144 @@ def display_data( # stdin channel - def input_request( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - msg['content'].pop('password', None) + def input_request(self, msg: Dict[str, Any]) -> Dict[str, Any]: + msg["content"].pop("password", None) return msg class V4toV5(Adapter): """Convert msg spec V4 to V5""" - version = '5.0' + + version = "5.0" # invert message renames above - msg_type_map = {v:k for k,v in V5toV4.msg_type_map.items()} - - def update_header( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - msg['header']['version'] = self.version - if msg['parent_header']: - msg['parent_header']['version'] = self.version + msg_type_map = {v: k for k, v in V5toV4.msg_type_map.items()} + + def update_header(self, msg: Dict[str, Any]) -> Dict[str, Any]: + msg["header"]["version"] = self.version + if msg["parent_header"]: + msg["parent_header"]["version"] = self.version return msg # shell channel - def kernel_info_reply( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - content = msg['content'] - for key in ('protocol_version', 'ipython_version'): + def kernel_info_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]: + content = msg["content"] + for key in ("protocol_version", "ipython_version"): if key in content: - content[key] = '.'.join(map(str, content[key])) + content[key] = ".".join(map(str, content[key])) - content.setdefault('protocol_version', '4.1') + content.setdefault("protocol_version", "4.1") - if content['language'].startswith('python') and 'ipython_version' in content: - content['implementation'] = 'ipython' - content['implementation_version'] = content.pop('ipython_version') + if content["language"].startswith("python") and "ipython_version" in content: + content["implementation"] = "ipython" + content["implementation_version"] = content.pop("ipython_version") - language = content.pop('language') - language_info = content.setdefault('language_info', {}) - language_info.setdefault('name', language) - if 'language_version' in content: - language_version = '.'.join(map(str, content.pop('language_version'))) - language_info.setdefault('version', language_version) + language = content.pop("language") + language_info = content.setdefault("language_info", {}) + language_info.setdefault("name", language) + if "language_version" in content: + language_version = ".".join(map(str, content.pop("language_version"))) + language_info.setdefault("version", language_version) - content['banner'] = '' + content["banner"] = "" return msg - def execute_request( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - content = msg['content'] - user_variables = content.pop('user_variables', []) - user_expressions = content.setdefault('user_expressions', {}) + def execute_request(self, msg: Dict[str, Any]) -> Dict[str, Any]: + content = msg["content"] + user_variables = content.pop("user_variables", []) + user_expressions = content.setdefault("user_expressions", {}) for v in user_variables: user_expressions[v] = v return msg - def execute_reply( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - content = msg['content'] - user_expressions = content.setdefault('user_expressions', {}) - user_variables = content.pop('user_variables', {}) + def execute_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]: + content = msg["content"] + user_expressions = content.setdefault("user_expressions", {}) + user_variables = content.pop("user_variables", {}) if user_variables: user_expressions.update(user_variables) # Pager payloads became a mime bundle - for payload in content.get('payload', []): - if payload.get('source', None) == 'page' and ('text' in payload): - if 'data' not in payload: - payload['data'] = {} - payload['data']['text/plain'] = payload.pop('text') + for payload in content.get("payload", []): + if payload.get("source", None) == "page" and ("text" in payload): + if "data" not in payload: + payload["data"] = {} + payload["data"]["text/plain"] = payload.pop("text") return msg - def complete_request( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - old_content = msg['content'] + def complete_request(self, msg: Dict[str, Any]) -> Dict[str, Any]: + old_content = msg["content"] - new_content = msg['content'] = {} - new_content['code'] = old_content['line'] - new_content['cursor_pos'] = old_content['cursor_pos'] + new_content = msg["content"] = {} + new_content["code"] = old_content["line"] + new_content["cursor_pos"] = old_content["cursor_pos"] return msg - def complete_reply( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: + def complete_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]: # complete_reply needs more context than we have to get cursor_start and end. # use special end=null to indicate current cursor position and negative offset # for start relative to the cursor. # start=None indicates that start == end (accounts for no -0). - content = msg['content'] - new_content = msg['content'] = {'status' : 'ok'} - new_content['matches'] = content['matches'] - if content['matched_text']: - new_content['cursor_start'] = -len(content['matched_text']) + content = msg["content"] + new_content = msg["content"] = {"status": "ok"} + new_content["matches"] = content["matches"] + if content["matched_text"]: + new_content["cursor_start"] = -len(content["matched_text"]) else: # no -0, use None to indicate that start == end - new_content['cursor_start'] = None - new_content['cursor_end'] = None - new_content['metadata'] = {} + new_content["cursor_start"] = None + new_content["cursor_end"] = None + new_content["metadata"] = {} return msg - def inspect_request( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - content = msg['content'] - name = content['oname'] - - new_content = msg['content'] = {} - new_content['code'] = name - new_content['cursor_pos'] = len(name) - new_content['detail_level'] = content['detail_level'] + def inspect_request(self, msg: Dict[str, Any]) -> Dict[str, Any]: + content = msg["content"] + name = content["oname"] + + new_content = msg["content"] = {} + new_content["code"] = name + new_content["cursor_pos"] = len(name) + new_content["detail_level"] = content["detail_level"] return msg - def inspect_reply( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: + def inspect_reply(self, msg: Dict[str, Any]) -> Dict[str, Any]: """inspect_reply can't be easily backward compatible""" - content = msg['content'] - new_content = msg['content'] = {'status' : 'ok'} - found = new_content['found'] = content['found'] - new_content['data'] = data = {} - new_content['metadata'] = {} + content = msg["content"] + new_content = msg["content"] = {"status": "ok"} + found = new_content["found"] = content["found"] + new_content["data"] = data = {} + new_content["metadata"] = {} if found: lines = [] - for key in ('call_def', 'init_definition', 'definition'): + for key in ("call_def", "init_definition", "definition"): if content.get(key, False): lines.append(content[key]) break - for key in ('call_docstring', 'init_docstring', 'docstring'): + for key in ("call_docstring", "init_docstring", "docstring"): if content.get(key, False): lines.append(content[key]) break if not lines: lines.append("") - data['text/plain'] = '\n'.join(lines) + data["text/plain"] = "\n".join(lines) return msg # iopub channel - def stream( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - content = msg['content'] - content['text'] = content.pop('data') + def stream(self, msg: Dict[str, Any]) -> Dict[str, Any]: + content = msg["content"] + content["text"] = content.pop("data") return msg - def display_data( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - content = msg['content'] + def display_data(self, msg: Dict[str, Any]) -> Dict[str, Any]: + content = msg["content"] content.pop("source", None) - data = content['data'] - if 'application/json' in data: + data = content["data"] + if "application/json" in data: try: - data['application/json'] = json.loads(data['application/json']) + data["application/json"] = json.loads(data["application/json"]) except Exception: # warn? pass @@ -446,19 +365,12 @@ def display_data( # stdin channel - def input_request( - self, - msg: Dict[str, Any] - ) -> Dict[str, Any]: - msg['content'].setdefault('password', False) + def input_request(self, msg: Dict[str, Any]) -> Dict[str, Any]: + msg["content"].setdefault("password", False) return msg - -def adapt( - msg: Dict[str, Any], - to_version: int =protocol_version_info[0] - ) -> Dict[str, Any]: +def adapt(msg: Dict[str, Any], to_version: int = protocol_version_info[0]) -> Dict[str, Any]: """Adapt a single message to a target version Parameters @@ -477,11 +389,12 @@ def adapt( A Jupyter message appropriate in the new version. """ from .session import utcnow - header = msg['header'] - if 'date' not in header: - header['date'] = utcnow() - if 'version' in header: - from_version = int(header['version'].split('.')[0]) + + header = msg["header"] + if "date" not in header: + header["date"] = utcnow() + if "version" in header: + from_version = int(header["version"].split(".")[0]) else: # assume last version before adding the key to the header from_version = 4 @@ -493,6 +406,6 @@ def adapt( # one adapter per major version from,to adapters = { - (5,4) : V5toV4(), - (4,5) : V4toV5(), + (5, 4): V5toV4(), + (4, 5): V4toV5(), } diff --git a/jupyter_client/asynchronous/__init__.py b/jupyter_client/asynchronous/__init__.py index 19a739f05..36f2c8469 100644 --- a/jupyter_client/asynchronous/__init__.py +++ b/jupyter_client/asynchronous/__init__.py @@ -1 +1 @@ -from .client import AsyncKernelClient +from .client import AsyncKernelClient # noqa diff --git a/jupyter_client/asynchronous/client.py b/jupyter_client/asynchronous/client.py index 86fb8737e..d67dcd482 100644 --- a/jupyter_client/asynchronous/client.py +++ b/jupyter_client/asynchronous/client.py @@ -1,20 +1,23 @@ """Implements an async kernel client""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from traitlets import Type # type: ignore -from traitlets import (Type, Instance) # type: ignore -from jupyter_client.channels import HBChannel, ZMQSocketChannel -from jupyter_client.client import KernelClient, reqrep +from jupyter_client.channels import HBChannel +from jupyter_client.channels import ZMQSocketChannel +from jupyter_client.client import KernelClient +from jupyter_client.client import reqrep def wrapped(meth, channel): def _(self, *args, **kwargs): - reply = kwargs.pop('reply', False) - timeout = kwargs.pop('timeout', None) + reply = kwargs.pop("reply", False) + timeout = kwargs.pop("timeout", None) msg_id = meth(self, *args, **kwargs) if not reply: return msg_id return self._async_recv_reply(msg_id, timeout=timeout, channel=channel) + return _ @@ -25,9 +28,9 @@ class AsyncKernelClient(KernelClient): raising :exc:`queue.Empty` if no message arrives within ``timeout`` seconds. """ - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # Channel proxy methods - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- get_shell_msg = KernelClient._async_get_shell_msg get_iopub_msg = KernelClient._async_get_iopub_msg @@ -43,10 +46,8 @@ class AsyncKernelClient(KernelClient): hb_channel_class = Type(HBChannel) control_channel_class = Type(ZMQSocketChannel) - _recv_reply = KernelClient._async_recv_reply - # replies come on the shell channel execute = reqrep(wrapped, KernelClient._execute) history = reqrep(wrapped, KernelClient._history) @@ -59,4 +60,4 @@ class AsyncKernelClient(KernelClient): execute_interactive = KernelClient._async_execute_interactive # replies come on the control channel - shutdown = reqrep(wrapped, KernelClient._shutdown, channel='control') + shutdown = reqrep(wrapped, KernelClient._shutdown, channel="control") diff --git a/jupyter_client/blocking/__init__.py b/jupyter_client/blocking/__init__.py index dc38f2403..74b09b9b1 100644 --- a/jupyter_client/blocking/__init__.py +++ b/jupyter_client/blocking/__init__.py @@ -1 +1 @@ -from .client import BlockingKernelClient \ No newline at end of file +from .client import BlockingKernelClient # noqa diff --git a/jupyter_client/blocking/client.py b/jupyter_client/blocking/client.py index 34dafdf43..bbc80f9f1 100644 --- a/jupyter_client/blocking/client.py +++ b/jupyter_client/blocking/client.py @@ -4,21 +4,24 @@ """ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - from traitlets import Type # type: ignore -from jupyter_client.channels import HBChannel, ZMQSocketChannel -from jupyter_client.client import KernelClient, reqrep + from ..utils import run_sync +from jupyter_client.channels import HBChannel +from jupyter_client.channels import ZMQSocketChannel +from jupyter_client.client import KernelClient +from jupyter_client.client import reqrep def wrapped(meth, channel): def _(self, *args, **kwargs): - reply = kwargs.pop('reply', False) - timeout = kwargs.pop('timeout', None) + reply = kwargs.pop("reply", False) + timeout = kwargs.pop("timeout", None) msg_id = meth(self, *args, **kwargs) if not reply: return msg_id return run_sync(self._async_recv_reply)(msg_id, timeout=timeout, channel=channel) + return _ @@ -29,9 +32,9 @@ class BlockingKernelClient(KernelClient): raising :exc:`queue.Empty` if no message arrives within ``timeout`` seconds. """ - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # Channel proxy methods - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- get_shell_msg = run_sync(KernelClient._async_get_shell_msg) get_iopub_msg = run_sync(KernelClient._async_get_iopub_msg) @@ -47,10 +50,8 @@ class BlockingKernelClient(KernelClient): hb_channel_class = Type(HBChannel) control_channel_class = Type(ZMQSocketChannel) - _recv_reply = run_sync(KernelClient._async_recv_reply) - # replies come on the shell channel execute = reqrep(wrapped, KernelClient._execute) history = reqrep(wrapped, KernelClient._history) @@ -63,4 +64,4 @@ class BlockingKernelClient(KernelClient): execute_interactive = run_sync(KernelClient._async_execute_interactive) # replies come on the control channel - shutdown = reqrep(wrapped, KernelClient._shutdown, channel='control') + shutdown = reqrep(wrapped, KernelClient._shutdown, channel="control") diff --git a/jupyter_client/channels.py b/jupyter_client/channels.py index a94e959fa..abec3eaa4 100644 --- a/jupyter_client/channels.py +++ b/jupyter_client/channels.py @@ -1,36 +1,36 @@ """Base classes to manage a Client's interaction with a running kernel""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - +import asyncio import atexit import errno -from threading import Thread, Event import time -import asyncio -from queue import Empty import typing as t +from queue import Empty +from threading import Event +from threading import Thread -import zmq import zmq.asyncio -# import ZMQError in top-level namespace, to avoid ugly attribute-error messages -# during garbage collection of threads at exit: from zmq import ZMQError -from jupyter_client import protocol_version_info - from .channelsabc import HBChannelABC from .session import Session +from jupyter_client import protocol_version_info + +# import ZMQError in top-level namespace, to avoid ugly attribute-error messages +# during garbage collection of threads at exit -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- # Constants and exceptions -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- major_protocol_version = protocol_version_info[0] + class InvalidPortNumber(Exception): pass + class HBChannel(Thread): """The heartbeat channel which monitors the kernel heartbeat. @@ -38,12 +38,13 @@ class HBChannel(Thread): this channel, the kernel manager will ensure that it is paused and un-paused as appropriate. """ + session = None socket = None address = None _exiting = False - time_to_dead: float = 1. + time_to_dead: float = 1.0 _running = None _pause = None _beating = None @@ -52,7 +53,7 @@ def __init__( self, context: zmq.asyncio.Context, session: t.Optional[Session] = None, - address: t.Union[t.Tuple[str, int], str] = '' + address: t.Union[t.Tuple[str, int], str] = "", ): """Create the heartbeat monitor thread. @@ -72,7 +73,7 @@ def __init__( self.session = session if isinstance(address, tuple): if address[1] == 0: - message = 'The port number for a channel cannot be 0.' + message = "The port number for a channel cannot be 0." raise InvalidPortNumber(message) address_str = "tcp://%s:%i" % address else: @@ -105,10 +106,7 @@ def _create_socket(self) -> None: self.poller.register(self.socket, zmq.POLLIN) - def _poll( - self, - start_time: float - ) -> t.List[t.Any]: + def _poll(self, start_time: float) -> t.List[t.Any]: """poll for heartbeat replies until we reach self.time_to_dead. Ignores interrupts, and returns the result of poll(), which @@ -162,7 +160,7 @@ async def _async_run(self) -> None: since_last_heartbeat = 0.0 # no need to catch EFSM here, because the previous event was # either a recv or connect, which cannot be followed by EFSM - await self.socket.send(b'ping') + await self.socket.send(b"ping") request_time = time.time() ready = self._poll(request_time) if ready: @@ -213,10 +211,7 @@ def close(self) -> None: pass self.socket = None - def call_handlers( - self, - since_last_heartbeat: float - ) -> None: + def call_handlers(self, since_last_heartbeat: float) -> None: """This method is called in the ioloop thread when a message arrives. Subclasses should override this method to handle incoming messages. @@ -234,10 +229,7 @@ class ZMQSocketChannel(object): """A ZMQ socket in an async API""" def __init__( - self, - socket: zmq.sugar.socket.Socket, - session: Session, - loop: t.Any = None + self, socket: zmq.sugar.socket.Socket, session: Session, loop: t.Any = None ) -> None: """Create a channel. @@ -261,10 +253,7 @@ async def _recv(self, **kwargs) -> t.Dict[str, t.Any]: ident, smsg = self.session.feed_identities(msg) return self.session.deserialize(smsg) - async def get_msg( - self, - timeout: t.Optional[float] = None - ) -> t.Dict[str, t.Any]: + async def get_msg(self, timeout: t.Optional[float] = None) -> t.Dict[str, t.Any]: """ Gets a message if there is one that is ready. """ if timeout is not None: timeout *= 1000 # seconds to ms @@ -299,17 +288,14 @@ def close(self) -> None: except Exception: pass self.socket = None + stop = close def is_alive(self) -> bool: - return (self.socket is not None) + return self.socket is not None - def send( - self, - msg: t.Dict[str, t.Any] - ) -> None: - """Pass a message to the ZMQ socket to send - """ + def send(self, msg: t.Dict[str, t.Any]) -> None: + """Pass a message to the ZMQ socket to send""" assert self.socket is not None self.session.send(self.socket, msg) diff --git a/jupyter_client/channelsabc.py b/jupyter_client/channelsabc.py index c0825ef39..aee5d9d60 100644 --- a/jupyter_client/channelsabc.py +++ b/jupyter_client/channelsabc.py @@ -1,8 +1,6 @@ """Abstract base classes for kernel client channels""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - import abc diff --git a/jupyter_client/client.py b/jupyter_client/client.py index b4ba3f004..323263bae 100644 --- a/jupyter_client/client.py +++ b/jupyter_client/client.py @@ -1,68 +1,60 @@ """Base class to manage the interaction with a running kernel""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - -import sys import asyncio +import sys import time +import typing as t from functools import partial from getpass import getpass from queue import Empty -import socket -import typing as t - -from jupyter_client.channels import major_protocol_version -import zmq import zmq.asyncio +from traitlets import Any # type: ignore +from traitlets import Instance +from traitlets import Type -from traitlets import ( # type: ignore - Any, Instance, Type, default -) - -from .channelsabc import (ChannelABC, HBChannelABC) +from .channelsabc import ChannelABC +from .channelsabc import HBChannelABC from .clientabc import KernelClientABC from .connect import ConnectionFileMixin from .session import Session from .utils import ensure_async - +from jupyter_client.channels import major_protocol_version # some utilities to validate message structure, these might get moved elsewhere # if they prove to have more generic utility -def validate_string_dict( - dct: t.Dict[str, str] -) -> None: + +def validate_string_dict(dct: t.Dict[str, str]) -> None: """Validate that the input is a dict with string keys and values. Raises ValueError if not.""" for k, v in dct.items(): if not isinstance(k, str): - raise ValueError('key %r in dict must be a string' % k) + raise ValueError("key %r in dict must be a string" % k) if not isinstance(v, str): - raise ValueError('value %r in dict must be a string' % v) + raise ValueError("value %r in dict must be a string" % v) -def reqrep( - wrapped: t.Callable, - meth: t.Callable, - channel: str = 'shell' -) -> t.Callable: +def reqrep(wrapped: t.Callable, meth: t.Callable, channel: str = "shell") -> t.Callable: wrapped = wrapped(meth, channel) if not meth.__doc__: # python -OO removes docstrings, # so don't bother building the wrapped docstring return wrapped - basedoc, _ = meth.__doc__.split('Returns\n', 1) + basedoc, _ = meth.__doc__.split("Returns\n", 1) parts = [basedoc.strip()] - if 'Parameters' not in basedoc: - parts.append(""" + if "Parameters" not in basedoc: + parts.append( + """ Parameters ---------- - """) - parts.append(""" + """ + ) + parts.append( + """ reply: bool (default: False) Whether to wait for and return reply timeout: float or None (default: None) @@ -74,8 +66,9 @@ def reqrep( The msg_id of the request sent, if reply=False (default) reply: dict The reply message for this request, if reply=True - """) - wrapped.__doc__ = '\n'.join(parts) + """ + ) + wrapped.__doc__ = "\n".join(parts) return wrapped @@ -98,6 +91,7 @@ class KernelClient(ConnectionFileMixin): # The PyZMQ Context to use for communication with the kernel. context = Instance(zmq.asyncio.Context) + def _context_default(self) -> zmq.asyncio.Context: return zmq.asyncio.Context() @@ -118,9 +112,9 @@ def _context_default(self) -> zmq.asyncio.Context: # flag for whether execute requests should be allowed to call raw_input: allow_stdin: bool = True - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # Channel proxy methods - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- async def _async_get_shell_msg(self, *args, **kwargs) -> t.Dict[str, t.Any]: """Get a message from the shell channel""" @@ -138,10 +132,7 @@ async def _async_get_control_msg(self, *args, **kwargs) -> t.Dict[str, t.Any]: """Get a message from the control channel""" return await self.control_channel.get_msg(*args, **kwargs) - async def _async_wait_for_ready( - self, - timeout: t.Optional[float] = None - ) -> None: + async def _async_wait_for_ready(self, timeout: t.Optional[float] = None) -> None: """Waits for a response when a client is blocked - Sets future time for timeout @@ -151,17 +142,20 @@ async def _async_wait_for_ready( - Flush the IOPub channel """ if timeout is None: - timeout = float('inf') + timeout = float("inf") abs_timeout = time.time() + timeout from .manager import KernelManager + if not isinstance(self.parent, KernelManager): # This Client was not created by a KernelManager, # so wait for kernel to become responsive to heartbeats # before checking for kernel_info reply while not await ensure_async(self.is_alive()): if time.time() > abs_timeout: - raise RuntimeError("Kernel didn't respond to heartbeats in %d seconds and timed out" % timeout) + raise RuntimeError( + "Kernel didn't respond to heartbeats in %d seconds and timed out" % timeout + ) await asyncio.sleep(0.2) # Wait for kernel info reply on shell channel @@ -172,7 +166,7 @@ async def _async_wait_for_ready( except Empty: pass else: - if msg['msg_type'] == 'kernel_info_reply': + if msg["msg_type"] == "kernel_info_reply": # Checking that IOPub is connected. If it is not connected, start over. try: await self.iopub_channel.get_msg(timeout=0.2) @@ -183,7 +177,7 @@ async def _async_wait_for_ready( break if not await ensure_async(self.is_alive()): - raise RuntimeError('Kernel died before replying to kernel_info') + raise RuntimeError("Kernel died before replying to kernel_info") # Check if current time is ready check time plus timeout if time.time() > abs_timeout: @@ -197,10 +191,7 @@ async def _async_wait_for_ready( break async def _async_recv_reply( - self, - msg_id: str, - timeout: t.Optional[float] = None, - channel: str = 'shell' + self, msg_id: str, timeout: t.Optional[float] = None, channel: str = "shell" ) -> t.Dict[str, t.Any]: """Receive and return the reply for a given request""" if timeout is not None: @@ -209,25 +200,21 @@ async def _async_recv_reply( if timeout is not None: timeout = max(0, deadline - time.monotonic()) try: - if channel == 'control': + if channel == "control": reply = await self._async_get_control_msg(timeout=timeout) else: reply = await self._async_get_shell_msg(timeout=timeout) except Empty as e: raise TimeoutError("Timeout waiting for reply") from e - if reply['parent_header'].get('msg_id') != msg_id: + if reply["parent_header"].get("msg_id") != msg_id: # not my reply, someone may have forgotten to retrieve theirs continue return reply - - def _stdin_hook_default( - self, - msg: t.Dict[str, t.Any] - ) -> None: + def _stdin_hook_default(self, msg: t.Dict[str, t.Any]) -> None: """Handle an input request""" - content = msg['content'] - if content.get('password', False): + content = msg["content"] + if content.get("password", False): prompt = getpass else: prompt = input # type: ignore @@ -236,9 +223,9 @@ def _stdin_hook_default( raw_data = prompt(content["prompt"]) except EOFError: # turn EOFError into EOF character - raw_data = '\x04' + raw_data = "\x04" except KeyboardInterrupt: - sys.stdout.write('\n') + sys.stdout.write("\n") return # only send stdin reply if there *was not* another request @@ -246,42 +233,38 @@ def _stdin_hook_default( if not (self.stdin_channel.msg_ready() or self.shell_channel.msg_ready()): self.input(raw_data) - def _output_hook_default( - self, - msg: t.Dict[str, t.Any] - ) -> None: + def _output_hook_default(self, msg: t.Dict[str, t.Any]) -> None: """Default hook for redisplaying plain-text output""" - msg_type = msg['header']['msg_type'] - content = msg['content'] - if msg_type == 'stream': - stream = getattr(sys, content['name']) - stream.write(content['text']) - elif msg_type in ('display_data', 'execute_result'): - sys.stdout.write(content['data'].get('text/plain', '')) - elif msg_type == 'error': - print('\n'.join(content['traceback']), file=sys.stderr) + msg_type = msg["header"]["msg_type"] + content = msg["content"] + if msg_type == "stream": + stream = getattr(sys, content["name"]) + stream.write(content["text"]) + elif msg_type in ("display_data", "execute_result"): + sys.stdout.write(content["data"].get("text/plain", "")) + elif msg_type == "error": + print("\n".join(content["traceback"]), file=sys.stderr) def _output_hook_kernel( self, session: Session, socket: zmq.sugar.socket.Socket, parent_header, - msg: t.Dict[str, t.Any] + msg: t.Dict[str, t.Any], ) -> None: """Output hook when running inside an IPython kernel adds rich output support. """ - msg_type = msg['header']['msg_type'] - if msg_type in ('display_data', 'execute_result', 'error'): - session.send(socket, msg_type, msg['content'], parent=parent_header) + msg_type = msg["header"]["msg_type"] + if msg_type in ("display_data", "execute_result", "error"): + session.send(socket, msg_type, msg["content"], parent=parent_header) else: self._output_hook_default(msg) - - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # Channel management methods - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- def start_channels( self, @@ -289,7 +272,7 @@ def start_channels( iopub: bool = True, stdin: bool = True, hb: bool = True, - control: bool = True + control: bool = True, ) -> None: """Starts the channels for this kernel. @@ -332,9 +315,13 @@ def stop_channels(self) -> None: @property def channels_running(self) -> bool: """Are any of the channels created and running?""" - return (self.shell_channel.is_alive() or self.iopub_channel.is_alive() or - self.stdin_channel.is_alive() or self.hb_channel.is_alive() or - self.control_channel.is_alive()) + return ( + self.shell_channel.is_alive() + or self.iopub_channel.is_alive() + or self.stdin_channel.is_alive() + or self.hb_channel.is_alive() + or self.control_channel.is_alive() + ) ioloop = None # Overridden in subclasses that use pyzmq event loop @@ -342,64 +329,55 @@ def channels_running(self) -> bool: def shell_channel(self) -> t.Any: """Get the shell channel object for this kernel.""" if self._shell_channel is None: - url = self._make_url('shell') + url = self._make_url("shell") self.log.debug("connecting shell channel to %s", url) socket = self.connect_shell(identity=self.session.bsession) - self._shell_channel = self.shell_channel_class( - socket, self.session, self.ioloop - ) + self._shell_channel = self.shell_channel_class(socket, self.session, self.ioloop) return self._shell_channel @property def iopub_channel(self) -> t.Any: """Get the iopub channel object for this kernel.""" if self._iopub_channel is None: - url = self._make_url('iopub') + url = self._make_url("iopub") self.log.debug("connecting iopub channel to %s", url) socket = self.connect_iopub() - self._iopub_channel = self.iopub_channel_class( - socket, self.session, self.ioloop - ) + self._iopub_channel = self.iopub_channel_class(socket, self.session, self.ioloop) return self._iopub_channel @property def stdin_channel(self) -> t.Any: """Get the stdin channel object for this kernel.""" if self._stdin_channel is None: - url = self._make_url('stdin') + url = self._make_url("stdin") self.log.debug("connecting stdin channel to %s", url) socket = self.connect_stdin(identity=self.session.bsession) - self._stdin_channel = self.stdin_channel_class( - socket, self.session, self.ioloop - ) + self._stdin_channel = self.stdin_channel_class(socket, self.session, self.ioloop) return self._stdin_channel @property def hb_channel(self) -> t.Any: """Get the hb channel object for this kernel.""" if self._hb_channel is None: - url = self._make_url('hb') + url = self._make_url("hb") self.log.debug("connecting heartbeat channel to %s", url) - self._hb_channel = self.hb_channel_class( - self.context, self.session, url - ) + self._hb_channel = self.hb_channel_class(self.context, self.session, url) return self._hb_channel @property def control_channel(self) -> t.Any: """Get the control channel object for this kernel.""" if self._control_channel is None: - url = self._make_url('control') + url = self._make_url("control") self.log.debug("connecting control channel to %s", url) socket = self.connect_control(identity=self.session.bsession) - self._control_channel = self.control_channel_class( - socket, self.session, self.ioloop - ) + self._control_channel = self.control_channel_class(socket, self.session, self.ioloop) return self._control_channel async def _async_is_alive(self) -> bool: """Is the kernel process still running?""" - from .manager import KernelManager, AsyncKernelManager + from .manager import KernelManager + if isinstance(self.parent, KernelManager): # This KernelClient was created by a KernelManager, # we can ask the parent KernelManager: @@ -413,7 +391,6 @@ async def _async_is_alive(self) -> bool: # so naively return True return True - async def _async_execute_interactive( self, code: str, @@ -424,7 +401,7 @@ async def _async_execute_interactive( stop_on_error: bool = True, timeout: t.Optional[float] = None, output_hook: t.Optional[t.Callable] = None, - stdin_hook: t.Optional[t.Callable] =None, + stdin_hook: t.Optional[t.Callable] = None, ) -> t.Dict[str, t.Any]: """Execute code in the kernel interactively @@ -486,21 +463,23 @@ async def _async_execute_interactive( allow_stdin = self.allow_stdin if allow_stdin and not self.stdin_channel.is_alive(): raise RuntimeError("stdin channel must be running to allow input") - msg_id = self._execute(code, - silent=silent, - store_history=store_history, - user_expressions=user_expressions, - allow_stdin=allow_stdin, - stop_on_error=stop_on_error, + msg_id = self._execute( + code, + silent=silent, + store_history=store_history, + user_expressions=user_expressions, + allow_stdin=allow_stdin, + stop_on_error=stop_on_error, ) if stdin_hook is None: stdin_hook = self._stdin_hook_default if output_hook is None: # detect IPython kernel - if 'IPython' in sys.modules: + if "IPython" in sys.modules: from IPython import get_ipython # type: ignore + ip = get_ipython() - in_kernel = getattr(ip, 'kernel', False) + in_kernel = getattr(ip, "kernel", False) if in_kernel: output_hook = partial( self._output_hook_kernel, @@ -544,14 +523,16 @@ async def _async_execute_interactive( msg = await self.iopub_channel.get_msg(timeout=0) - if msg['parent_header'].get('msg_id') != msg_id: + if msg["parent_header"].get("msg_id") != msg_id: # not from my request continue output_hook(msg) # stop on idle - if msg['header']['msg_type'] == 'status' and \ - msg['content']['execution_state'] == 'idle': + if ( + msg["header"]["msg_type"] == "status" + and msg["content"]["execution_state"] == "idle" + ): break # output is done, get the reply @@ -559,7 +540,6 @@ async def _async_execute_interactive( timeout = max(0, deadline - time.monotonic()) return await self._async_recv_reply(msg_id, timeout=timeout) - # Methods to send specific messages on channels def _execute( self, @@ -568,7 +548,7 @@ def _execute( store_history: bool = True, user_expressions: t.Optional[t.Dict[str, t.Any]] = None, allow_stdin: t.Optional[bool] = None, - stop_on_error: bool = True + stop_on_error: bool = True, ) -> str: """Execute code in the kernel. @@ -609,27 +589,26 @@ def _execute( if allow_stdin is None: allow_stdin = self.allow_stdin - # Don't waste network traffic if inputs are invalid if not isinstance(code, str): - raise ValueError('code %r must be a string' % code) + raise ValueError("code %r must be a string" % code) validate_string_dict(user_expressions) # Create class for content/msg creation. Related to, but possibly # not in Session. - content = dict(code=code, silent=silent, store_history=store_history, - user_expressions=user_expressions, - allow_stdin=allow_stdin, stop_on_error=stop_on_error - ) - msg = self.session.msg('execute_request', content) + content = dict( + code=code, + silent=silent, + store_history=store_history, + user_expressions=user_expressions, + allow_stdin=allow_stdin, + stop_on_error=stop_on_error, + ) + msg = self.session.msg("execute_request", content) self.shell_channel.send(msg) - return msg['header']['msg_id'] + return msg["header"]["msg_id"] - def _complete( - self, - code: str, - cursor_pos: t.Optional[int] = None - ) -> str: + def _complete(self, code: str, cursor_pos: t.Optional[int] = None) -> str: """Tab complete text in the kernel's namespace. Parameters @@ -648,16 +627,11 @@ def _complete( if cursor_pos is None: cursor_pos = len(code) content = dict(code=code, cursor_pos=cursor_pos) - msg = self.session.msg('complete_request', content) + msg = self.session.msg("complete_request", content) self.shell_channel.send(msg) - return msg['header']['msg_id'] + return msg["header"]["msg_id"] - def _inspect( - self, - code: str, - cursor_pos: t.Optional[int] = None, - detail_level: int = 0 - ) -> str: + def _inspect(self, code: str, cursor_pos: t.Optional[int] = None, detail_level: int = 0) -> str: """Get metadata information about an object in the kernel's namespace. It is up to the kernel to determine the appropriate object to inspect. @@ -679,19 +653,21 @@ def _inspect( """ if cursor_pos is None: cursor_pos = len(code) - content = dict(code=code, cursor_pos=cursor_pos, + content = dict( + code=code, + cursor_pos=cursor_pos, detail_level=detail_level, ) - msg = self.session.msg('inspect_request', content) + msg = self.session.msg("inspect_request", content) self.shell_channel.send(msg) - return msg['header']['msg_id'] + return msg["header"]["msg_id"] def _history( self, raw: bool = True, output: bool = False, - hist_access_type: str = 'range', - **kwargs + hist_access_type: str = "range", + **kwargs, ) -> str: """Get entries from the kernel's history list. @@ -724,14 +700,13 @@ def _history( ------- The ID of the message sent. """ - if hist_access_type == 'range': - kwargs.setdefault('session', 0) - kwargs.setdefault('start', 0) - content = dict(raw=raw, output=output, hist_access_type=hist_access_type, - **kwargs) - msg = self.session.msg('history_request', content) + if hist_access_type == "range": + kwargs.setdefault("session", 0) + kwargs.setdefault("start", 0) + content = dict(raw=raw, output=output, hist_access_type=hist_access_type, **kwargs) + msg = self.session.msg("history_request", content) self.shell_channel.send(msg) - return msg['header']['msg_id'] + return msg["header"]["msg_id"] def _kernel_info(self) -> str: """Request kernel info @@ -740,14 +715,11 @@ def _kernel_info(self) -> str: ------- The msg_id of the message sent """ - msg = self.session.msg('kernel_info_request') + msg = self.session.msg("kernel_info_request") self.shell_channel.send(msg) - return msg['header']['msg_id'] + return msg["header"]["msg_id"] - def _comm_info( - self, - target_name: t.Optional[str] = None - ) -> str: + def _comm_info(self, target_name: t.Optional[str] = None) -> str: """Request comm info Returns @@ -758,49 +730,37 @@ def _comm_info( content = {} else: content = dict(target_name=target_name) - msg = self.session.msg('comm_info_request', content) + msg = self.session.msg("comm_info_request", content) self.shell_channel.send(msg) - return msg['header']['msg_id'] + return msg["header"]["msg_id"] - def _handle_kernel_info_reply( - self, - msg: t.Dict[str, t.Any] - ) -> None: + def _handle_kernel_info_reply(self, msg: t.Dict[str, t.Any]) -> None: """handle kernel info reply sets protocol adaptation version. This might be run from a separate thread. """ - adapt_version = int(msg['content']['protocol_version'].split('.')[0]) + adapt_version = int(msg["content"]["protocol_version"].split(".")[0]) if adapt_version != major_protocol_version: self.session.adapt_version = adapt_version - def is_complete( - self, - code: str - ) -> str: + def is_complete(self, code: str) -> str: """Ask the kernel whether some code is complete and ready to execute.""" - msg = self.session.msg('is_complete_request', {'code': code}) + msg = self.session.msg("is_complete_request", {"code": code}) self.shell_channel.send(msg) - return msg['header']['msg_id'] + return msg["header"]["msg_id"] - def input( - self, - string: str - ) -> None: + def input(self, string: str) -> None: """Send a string of raw input to the kernel. This should only be called in response to the kernel sending an ``input_request`` message on the stdin channel. """ content = dict(value=string) - msg = self.session.msg('input_reply', content) + msg = self.session.msg("input_reply", content) self.stdin_channel.send(msg) - def _shutdown( - self, - restart: bool = False - ) -> str: + def _shutdown(self, restart: bool = False) -> str: """Request an immediate kernel shutdown on the control channel. Upon receipt of the (empty) reply, client code can safely assume that @@ -817,8 +777,9 @@ def _shutdown( """ # Send quit message to kernel. Once we implement kernel-side setattr, # this should probably be done that way, but for now this will do. - msg = self.session.msg('shutdown_request', {'restart':restart}) + msg = self.session.msg("shutdown_request", {"restart": restart}) self.control_channel.send(msg) - return msg['header']['msg_id'] + return msg["header"]["msg_id"] + KernelClientABC.register(KernelClient) diff --git a/jupyter_client/clientabc.py b/jupyter_client/clientabc.py index d4701c421..8227cb377 100644 --- a/jupyter_client/clientabc.py +++ b/jupyter_client/clientabc.py @@ -1,22 +1,19 @@ """Abstract base class for kernel clients""" - -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- # Copyright (c) The Jupyter Development Team # # Distributed under the terms of the BSD License. The full license is in # the file COPYING, distributed as part of this software. -#----------------------------------------------------------------------------- - -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- # Imports -#----------------------------------------------------------------------------- - +# ----------------------------------------------------------------------------- import abc - -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- # Main kernel client class -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- + class KernelClientABC(object, metaclass=abc.ABCMeta): """KernelManager ABC. @@ -50,9 +47,9 @@ def stdin_channel_class(self): def control_channel_class(self): pass - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # Channel management methods - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- @abc.abstractmethod def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, control=True): diff --git a/jupyter_client/connect.py b/jupyter_client/connect.py index 2f8e70352..418d73207 100644 --- a/jupyter_client/connect.py +++ b/jupyter_client/connect.py @@ -3,10 +3,8 @@ The :class:`ConnectionFileMixin` class in this module encapsulates the logic related to writing and reading connections files. """ - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - import errno import glob import json @@ -16,18 +14,28 @@ import tempfile import warnings from getpass import getpass -from contextlib import contextmanager -from typing import Union, Optional, List, Tuple, Dict, Any, cast - -import zmq - +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import zmq # type: ignore +from jupyter_core.paths import jupyter_data_dir # type: ignore +from jupyter_core.paths import jupyter_runtime_dir +from jupyter_core.paths import secure_write +from traitlets import Bool # type: ignore +from traitlets import CaselessStrEnum +from traitlets import Instance +from traitlets import Integer +from traitlets import observe +from traitlets import Type +from traitlets import Unicode from traitlets.config import LoggingConfigurable # type: ignore -from .localinterfaces import localhost -from traitlets import ( # type: ignore - Bool, Integer, Unicode, CaselessStrEnum, Instance, Type, observe -) -from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write # type: ignore +from .localinterfaces import localhost from .utils import _filefind @@ -38,11 +46,11 @@ def write_connection_file( stdin_port: int = 0, hb_port: int = 0, control_port: int = 0, - ip: str = '', - key: bytes = b'', - transport: str = 'tcp', - signature_scheme: str = 'hmac-sha256', - kernel_name: str = '' + ip: str = "", + key: bytes = b"", + transport: str = "tcp", + signature_scheme: str = "hmac-sha256", + kernel_name: str = "", ) -> Tuple[str, Dict[str, Union[int, str]]]: """Generates a JSON config file, including the selection of random ports. @@ -88,23 +96,25 @@ def write_connection_file( ip = localhost() # default to temporary connector file if not fname: - fd, fname = tempfile.mkstemp('.json') + fd, fname = tempfile.mkstemp(".json") os.close(fd) # Find open ports as necessary. ports: List[int] = [] sockets: List[socket.socket] = [] - ports_needed = int(shell_port <= 0) + \ - int(iopub_port <= 0) + \ - int(stdin_port <= 0) + \ - int(control_port <= 0) + \ - int(hb_port <= 0) - if transport == 'tcp': + ports_needed = ( + int(shell_port <= 0) + + int(iopub_port <= 0) + + int(stdin_port <= 0) + + int(control_port <= 0) + + int(hb_port <= 0) + ) + if transport == "tcp": for i in range(ports_needed): sock = socket.socket() # struct.pack('ii', (0,0)) is 8 null bytes - sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b'\0' * 8) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8) sock.bind((ip, 0)) sockets.append(sock) for sock in sockets: @@ -129,17 +139,18 @@ def write_connection_file( if hb_port <= 0: hb_port = ports.pop(0) - cfg: Dict[str, Union[int, str]] = dict( shell_port=shell_port, - iopub_port=iopub_port, - stdin_port=stdin_port, - control_port=control_port, - hb_port=hb_port, - ) - cfg['ip'] = ip - cfg['key'] = key.decode() - cfg['transport'] = transport - cfg['signature_scheme'] = signature_scheme - cfg['kernel_name'] = kernel_name + cfg: Dict[str, Union[int, str]] = dict( + shell_port=shell_port, + iopub_port=iopub_port, + stdin_port=stdin_port, + control_port=control_port, + hb_port=hb_port, + ) + cfg["ip"] = ip + cfg["key"] = key.decode() + cfg["transport"] = transport + cfg["signature_scheme"] = signature_scheme + cfg["kernel_name"] = kernel_name # Only ever write this file as user read/writeable # This would otherwise introduce a vulnerability as a file has secrets @@ -147,7 +158,7 @@ def write_connection_file( with secure_write(fname) as f: f.write(json.dumps(cfg, indent=2)) - if hasattr(stat, 'S_ISVTX'): + if hasattr(stat, "S_ISVTX"): # set the sticky bit on the file and its parent directory # to avoid periodic cleanup paths = [fname] @@ -169,7 +180,8 @@ def write_connection_file( # failed to set sticky bit, probably not a big deal warnings.warn( "Failed to set sticky bit on %r: %s" - "\nProbably not a big deal, but runtime files may be cleaned up periodically." % (path, e), + "\nProbably not a big deal, but runtime files may be cleaned up " + "periodically." % (path, e), RuntimeWarning, ) @@ -177,9 +189,9 @@ def write_connection_file( def find_connection_file( - filename: str ='kernel-*.json', + filename: str = "kernel-*.json", path: Optional[Union[str, List[str]]] = None, - profile: Optional[str] = None + profile: Optional[str] = None, ) -> str: """find a connection file, and return its absolute path. @@ -204,7 +216,7 @@ def find_connection_file( if profile is not None: warnings.warn("Jupyter has no profiles. profile=%s has been ignored." % profile) if path is None: - path = ['.', jupyter_runtime_dir()] + path = [".", jupyter_runtime_dir()] if isinstance(path, str): path = [path] @@ -216,18 +228,18 @@ def find_connection_file( # not found by full name - if '*' in filename: + if "*" in filename: # given as a glob already pat = filename else: # accept any substring match - pat = '*%s*' % filename + pat = "*%s*" % filename matches = [] for p in path: matches.extend(glob.glob(os.path.join(p, pat))) - matches = [ os.path.abspath(m) for m in matches ] + matches = [os.path.abspath(m) for m in matches] if not matches: raise IOError("Could not find %r in %r" % (filename, path)) elif len(matches) == 1: @@ -240,7 +252,7 @@ def find_connection_file( def tunnel_to_kernel( connection_info: Union[str, Dict[str, Any]], sshserver: str, - sshkey: Optional[str] = None + sshkey: Optional[str] = None, ) -> Tuple[Any, ...]: """tunnel connections to a kernel via ssh @@ -268,6 +280,7 @@ def tunnel_to_kernel( The five ports on localhost that have been forwarded to the kernel. """ from .ssh import tunnel + if isinstance(connection_info, str): # it's a path, unpack it with open(connection_info) as f: @@ -276,9 +289,15 @@ def tunnel_to_kernel( cf = cast(Dict[str, Any], connection_info) lports = tunnel.select_random_ports(5) - rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port'], cf['control_port'] + rports = ( + cf["shell_port"], + cf["iopub_port"], + cf["stdin_port"], + cf["hb_port"], + cf["control_port"], + ) - remote_ip = cf['ip'] + remote_ip = cf["ip"] if tunnel.try_passwordless_ssh(sshserver, sshkey): password: Union[bool, str] = False @@ -291,95 +310,95 @@ def tunnel_to_kernel( return tuple(lports) -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- # Mixin for classes that work with connection files -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- channel_socket_types = { - 'hb' : zmq.REQ, - 'shell' : zmq.DEALER, - 'iopub' : zmq.SUB, - 'stdin' : zmq.DEALER, - 'control': zmq.DEALER, + "hb": zmq.REQ, + "shell": zmq.DEALER, + "iopub": zmq.SUB, + "stdin": zmq.DEALER, + "control": zmq.DEALER, } -port_names = [ "%s_port" % channel for channel in ('shell', 'stdin', 'iopub', 'hb', 'control')] +port_names = ["%s_port" % channel for channel in ("shell", "stdin", "iopub", "hb", "control")] + class ConnectionFileMixin(LoggingConfigurable): """Mixin for configurable classes that work with connection files""" data_dir = Unicode() + def _data_dir_default(self): return jupyter_data_dir() # The addresses for the communication channels - connection_file = Unicode('', config=True, - help="""JSON file in which to store connection info [default: kernel-.json] + connection_file = Unicode( + "", + config=True, + help="""JSON file in which to store connection info [default: kernel-.json] This file will contain the IP, ports, and authentication key needed to connect clients to this kernel. By default, this file will be created in the security dir of the current profile, but can be specified by absolute path. - """) + """, + ) _connection_file_written = Bool(False) - transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True) + transport = CaselessStrEnum(["tcp", "ipc"], default_value="tcp", config=True) kernel_name = Unicode() - ip = Unicode(config=True, + ip = Unicode( + config=True, help="""Set the kernel\'s IP address [default localhost]. If the IP address is something other than localhost, then Consoles on other machines will be able to connect - to the Kernel, so be careful!""" + to the Kernel, so be careful!""", ) def _ip_default(self): - if self.transport == 'ipc': + if self.transport == "ipc": if self.connection_file: - return os.path.splitext(self.connection_file)[0] + '-ipc' + return os.path.splitext(self.connection_file)[0] + "-ipc" else: - return 'kernel-ipc' + return "kernel-ipc" else: return localhost() - @observe('ip') + @observe("ip") def _ip_changed(self, change): - if change['new'] == '*': - self.ip = '0.0.0.0' + if change["new"] == "*": + self.ip = "0.0.0.0" # protected traits - hb_port = Integer(0, config=True, - help="set the heartbeat port [default: random]") - shell_port = Integer(0, config=True, - help="set the shell (ROUTER) port [default: random]") - iopub_port = Integer(0, config=True, - help="set the iopub (PUB) port [default: random]") - stdin_port = Integer(0, config=True, - help="set the stdin (ROUTER) port [default: random]") - control_port = Integer(0, config=True, - help="set the control (ROUTER) port [default: random]") + hb_port = Integer(0, config=True, help="set the heartbeat port [default: random]") + shell_port = Integer(0, config=True, help="set the shell (ROUTER) port [default: random]") + iopub_port = Integer(0, config=True, help="set the iopub (PUB) port [default: random]") + stdin_port = Integer(0, config=True, help="set the stdin (ROUTER) port [default: random]") + control_port = Integer(0, config=True, help="set the control (ROUTER) port [default: random]") # names of the ports with random assignment _random_port_names: Optional[List[str]] = None @property def ports(self) -> List[int]: - return [ getattr(self, name) for name in port_names ] + return [getattr(self, name) for name in port_names] # The Session to use for communication with the kernel. - session = Instance('jupyter_client.session.Session') + session = Instance("jupyter_client.session.Session") + def _session_default(self): from jupyter_client.session import Session + return Session(parent=self) - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # Connection and ipc file management - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- - def get_connection_info( - self, - session: bool =False - ) -> Dict[str, Any]: + def get_connection_info(self, session: bool = False) -> Dict[str, Any]: """Return the connection info as a dict Parameters @@ -406,21 +425,24 @@ def get_connection_info( if session: # add *clone* of my session, # so that state such as digest_history is not shared. - info['session'] = self.session.clone() + info["session"] = self.session.clone() else: # add session info - info.update(dict( - signature_scheme=self.session.signature_scheme, - key=self.session.key, - )) + info.update( + dict( + signature_scheme=self.session.signature_scheme, + key=self.session.key, + ) + ) return info # factory for blocking clients - blocking_class = Type(klass=object, default_value='jupyter_client.BlockingKernelClient') + blocking_class = Type(klass=object, default_value="jupyter_client.BlockingKernelClient") + def blocking_client(self): """Make a blocking client connected to my kernel""" info = self.get_connection_info() - info['parent'] = self + info["parent"] = self bc = self.blocking_class(**info) bc.session.key = self.session.key return bc @@ -440,7 +462,7 @@ def cleanup_connection_file(self) -> None: def cleanup_ipc_files(self) -> None: """Cleanup ipc files if we wrote them.""" - if self.transport != 'ipc': + if self.transport != "ipc": return for port in self.ports: ipcfile = "%s-%i" % (self.ip, port) @@ -455,7 +477,7 @@ def _record_random_port_names(self) -> None: Records on first invocation, if the transport is tcp. Does nothing on later invocations.""" - if self.transport != 'tcp': + if self.transport != "tcp": return if self._random_port_names is not None: return @@ -485,13 +507,18 @@ def write_connection_file(self) -> None: if self._connection_file_written and os.path.exists(self.connection_file): return - self.connection_file, cfg = write_connection_file(self.connection_file, - transport=self.transport, ip=self.ip, key=self.session.key, - stdin_port=self.stdin_port, iopub_port=self.iopub_port, - shell_port=self.shell_port, hb_port=self.hb_port, + self.connection_file, cfg = write_connection_file( + self.connection_file, + transport=self.transport, + ip=self.ip, + key=self.session.key, + stdin_port=self.stdin_port, + iopub_port=self.iopub_port, + shell_port=self.shell_port, + hb_port=self.hb_port, control_port=self.control_port, signature_scheme=self.session.signature_scheme, - kernel_name=self.kernel_name + kernel_name=self.kernel_name, ) # write_connection_file also sets default ports: self._record_random_port_names() @@ -500,10 +527,7 @@ def write_connection_file(self) -> None: self._connection_file_written = True - def load_connection_file( - self, - connection_file: Optional[str] = None - ) -> None: + def load_connection_file(self, connection_file: Optional[str] = None) -> None: """Load connection info from JSON dict in self.connection_file. Parameters @@ -519,10 +543,7 @@ def load_connection_file( info = json.load(f) self.load_connection_info(info) - def load_connection_info( - self, - info: Dict[str, int] - ) -> None: + def load_connection_info(self, info: Dict[str, int]) -> None: """Load connection info from a dict containing connection info. Typically this data comes from a connection file @@ -534,8 +555,8 @@ def load_connection_info( Dictionary containing connection_info. See the connection_file spec for details. """ - self.transport = info.get('transport', self.transport) - self.ip = info.get('ip', self._ip_default()) + self.transport = info.get("transport", self.transport) + self.ip = info.get("ip", self._ip_default()) self._record_random_port_names() for name in port_names: @@ -550,31 +571,26 @@ def load_connection_info( assert isinstance(key, bytes) self.session.key = key - if 'signature_scheme' in info: - self.session.signature_scheme = info['signature_scheme'] + if "signature_scheme" in info: + self.session.signature_scheme = info["signature_scheme"] - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # Creating connected sockets - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- - def _make_url( - self, - channel: str - ) -> str: + def _make_url(self, channel: str) -> str: """Make a ZeroMQ URL for a given channel.""" transport = self.transport ip = self.ip - port = getattr(self, '%s_port' % channel) + port = getattr(self, "%s_port" % channel) - if transport == 'tcp': + if transport == "tcp": return "tcp://%s:%i" % (ip, port) else: return "%s://%s-%s" % (transport, ip, port) def _create_connected_socket( - self, - channel: str, - identity: Optional[bytes] = None + self, channel: str, identity: Optional[bytes] = None ) -> zmq.sugar.socket.Socket: """Create a zmq Socket and connect it to the kernel.""" url = self._make_url(channel) @@ -588,46 +604,31 @@ def _create_connected_socket( sock.connect(url) return sock - def connect_iopub( - self, - identity: Optional[bytes] = None - ) -> zmq.sugar.socket.Socket: + def connect_iopub(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the IOPub channel""" - sock = self._create_connected_socket('iopub', identity=identity) - sock.setsockopt(zmq.SUBSCRIBE, b'') + sock = self._create_connected_socket("iopub", identity=identity) + sock.setsockopt(zmq.SUBSCRIBE, b"") return sock - def connect_shell( - self, - identity: Optional[bytes] = None - ) -> zmq.sugar.socket.Socket: + def connect_shell(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the Shell channel""" - return self._create_connected_socket('shell', identity=identity) + return self._create_connected_socket("shell", identity=identity) - def connect_stdin( - self, - identity: Optional[bytes] = None - ) -> zmq.sugar.socket.Socket: + def connect_stdin(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the StdIn channel""" - return self._create_connected_socket('stdin', identity=identity) + return self._create_connected_socket("stdin", identity=identity) - def connect_hb( - self, - identity: Optional[bytes] = None - ) -> zmq.sugar.socket.Socket: + def connect_hb(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the Heartbeat channel""" - return self._create_connected_socket('hb', identity=identity) + return self._create_connected_socket("hb", identity=identity) - def connect_control( - self, - identity: Optional[bytes] = None - ) -> zmq.sugar.socket.Socket: + def connect_control(self, identity: Optional[bytes] = None) -> zmq.sugar.socket.Socket: """return zmq Socket connected to the Control channel""" - return self._create_connected_socket('control', identity=identity) + return self._create_connected_socket("control", identity=identity) __all__ = [ - 'write_connection_file', - 'find_connection_file', - 'tunnel_to_kernel', + "write_connection_file", + "find_connection_file", + "tunnel_to_kernel", ] diff --git a/jupyter_client/consoleapp.py b/jupyter_client/consoleapp.py index e491dcc24..7afa5da4a 100644 --- a/jupyter_client/consoleapp.py +++ b/jupyter_client/consoleapp.py @@ -6,7 +6,6 @@ """ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - import atexit import os import signal @@ -15,28 +14,32 @@ import warnings from typing import cast - +from jupyter_core.application import base_aliases # type: ignore +from jupyter_core.application import base_flags +from traitlets import CBool # type: ignore +from traitlets import CUnicode +from traitlets import Dict +from traitlets import List +from traitlets import Type +from traitlets import Unicode from traitlets.config.application import boolean_flag # type: ignore -from traitlets import ( # type: ignore - Dict, List, Unicode, CUnicode, CBool, Any, Type -) - -from jupyter_core.application import base_flags, base_aliases # type: ignore +from . import connect +from . import find_connection_file +from . import KernelManager +from . import tunnel_to_kernel from .blocking import BlockingKernelClient -from .restarter import KernelRestarter -from . import KernelManager, tunnel_to_kernel, find_connection_file, connect from .kernelspec import NoSuchKernel +from .localinterfaces import localhost +from .restarter import KernelRestarter from .session import Session +from .utils import _filefind ConnectionFileMixin = connect.ConnectionFileMixin -from .localinterfaces import localhost -from .utils import _filefind - -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- # Aliases and Flags -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- flags = {} flags.update(base_flags) @@ -44,22 +47,27 @@ # these must be scrubbed before being passed to the kernel, # or it will raise an error on unrecognized flags app_flags = { - 'existing' : ({'JupyterConsoleApp' : {'existing' : 'kernel*.json'}}, - "Connect to an existing kernel. If no argument specified, guess most recent"), + "existing": ( + {"JupyterConsoleApp": {"existing": "kernel*.json"}}, + "Connect to an existing kernel. If no argument specified, guess most recent", + ), } -app_flags.update(boolean_flag( - 'confirm-exit', 'JupyterConsoleApp.confirm_exit', - """Set to display confirmation dialog on exit. You can always use 'exit' or +app_flags.update( + boolean_flag( + "confirm-exit", + "JupyterConsoleApp.confirm_exit", + """Set to display confirmation dialog on exit. You can always use 'exit' or 'quit', to force a direct exit without any confirmation. This can also be set in the config file by setting `c.JupyterConsoleApp.confirm_exit`. """, - """Don't prompt the user when exiting. This will terminate the kernel + """Don't prompt the user when exiting. This will terminate the kernel if it is owned by the frontend, and leave it alive if it is external. This can also be set in the config file by setting `c.JupyterConsoleApp.confirm_exit`. - """ -)) + """, + ) +) flags.update(app_flags) aliases = {} @@ -67,30 +75,29 @@ # also scrub aliases from the frontend app_aliases = dict( - ip = 'JupyterConsoleApp.ip', - transport = 'JupyterConsoleApp.transport', - hb = 'JupyterConsoleApp.hb_port', - shell = 'JupyterConsoleApp.shell_port', - iopub = 'JupyterConsoleApp.iopub_port', - stdin = 'JupyterConsoleApp.stdin_port', - control = 'JupyterConsoleApp.control_port', - existing = 'JupyterConsoleApp.existing', - f = 'JupyterConsoleApp.connection_file', - - kernel = 'JupyterConsoleApp.kernel_name', - - ssh = 'JupyterConsoleApp.sshserver', + ip="JupyterConsoleApp.ip", + transport="JupyterConsoleApp.transport", + hb="JupyterConsoleApp.hb_port", + shell="JupyterConsoleApp.shell_port", + iopub="JupyterConsoleApp.iopub_port", + stdin="JupyterConsoleApp.stdin_port", + control="JupyterConsoleApp.control_port", + existing="JupyterConsoleApp.existing", + f="JupyterConsoleApp.connection_file", + kernel="JupyterConsoleApp.kernel_name", + ssh="JupyterConsoleApp.sshserver", ) aliases.update(app_aliases) -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- # Classes -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- classes = [KernelManager, KernelRestarter, Session] + class JupyterConsoleApp(ConnectionFileMixin): - name = 'jupyter-console-mixin' + name = "jupyter-console-mixin" description = """ The Jupyter Console Mixin. @@ -115,7 +122,7 @@ class JupyterConsoleApp(ConnectionFileMixin): kernel_manager_class = Type( default_value=KernelManager, config=True, - help='The kernel manager class to use.' + help="The kernel manager class to use.", ) kernel_client_class = BlockingKernelClient @@ -123,21 +130,25 @@ class JupyterConsoleApp(ConnectionFileMixin): # connection info: - sshserver = Unicode('', config=True, - help="""The SSH server to use to connect to the kernel.""") - sshkey = Unicode('', config=True, - help="""Path to the ssh key to use for logging in to the ssh server.""") + sshserver = Unicode("", config=True, help="""The SSH server to use to connect to the kernel.""") + sshkey = Unicode( + "", + config=True, + help="""Path to the ssh key to use for logging in to the ssh server.""", + ) def _connection_file_default(self) -> str: - return 'kernel-%i.json' % os.getpid() + return "kernel-%i.json" % os.getpid() - existing = CUnicode('', config=True, - help="""Connect to an already running kernel""") + existing = CUnicode("", config=True, help="""Connect to an already running kernel""") - kernel_name = Unicode('python', config=True, - help="""The name of the default kernel to start.""") + kernel_name = Unicode( + "python", config=True, help="""The name of the default kernel to start.""" + ) - confirm_exit = CBool(True, config=True, + confirm_exit = CBool( + True, + config=True, help=""" Set to display confirmation dialog on exit. You can always use 'exit' or 'quit', to force a direct exit without any confirmation.""", @@ -167,9 +178,11 @@ def init_connection_file(self) -> None: """ if self.existing: try: - cf = find_connection_file(self.existing, ['.', self.runtime_dir]) + cf = find_connection_file(self.existing, [".", self.runtime_dir]) except Exception: - self.log.critical("Could not find existing kernel connection file %s", self.existing) + self.log.critical( + "Could not find existing kernel connection file %s", self.existing + ) self.exit(1) self.log.debug("Connecting to existing kernel: %s" % cf) self.connection_file = cf @@ -187,9 +200,7 @@ def init_connection_file(self) -> None: cf = self.connection_file self.connection_file = cf try: - self.connection_file = _filefind( - self.connection_file, [".", self.runtime_dir] - ) + self.connection_file = _filefind(self.connection_file, [".", self.runtime_dir]) except IOError: self.log.debug("Connection File not found: %s", self.connection_file) return @@ -200,7 +211,11 @@ def init_connection_file(self) -> None: try: self.load_connection_file() except Exception: - self.log.error("Failed to load connection file: %r", self.connection_file, exc_info=True) + self.log.error( + "Failed to load connection file: %r", + self.connection_file, + exc_info=True, + ) self.exit(1) def init_ssh(self) -> None: @@ -212,7 +227,7 @@ def init_ssh(self) -> None: transport = self.transport ip = self.ip - if transport != 'tcp': + if transport != "tcp": self.log.error("Can only use ssh tunnels with TCP sockets, not %s", transport) sys.exit(-1) @@ -222,45 +237,52 @@ def init_ssh(self) -> None: ip = localhost() # build connection dict for tunnels: - info = dict(ip=ip, - shell_port=self.shell_port, - iopub_port=self.iopub_port, - stdin_port=self.stdin_port, - hb_port=self.hb_port, - control_port=self.control_port + info = dict( + ip=ip, + shell_port=self.shell_port, + iopub_port=self.iopub_port, + stdin_port=self.stdin_port, + hb_port=self.hb_port, + control_port=self.control_port, ) - self.log.info("Forwarding connections to %s via %s"%(ip, self.sshserver)) + self.log.info("Forwarding connections to %s via %s" % (ip, self.sshserver)) # tunnels return a new set of ports, which will be on localhost: self.ip = localhost() try: newports = tunnel_to_kernel(info, self.sshserver, self.sshkey) - except: + except: # noqa # even catch KeyboardInterrupt self.log.error("Could not setup tunnels", exc_info=True) self.exit(1) - self.shell_port, self.iopub_port, self.stdin_port, self.hb_port, self.control_port = newports + ( + self.shell_port, + self.iopub_port, + self.stdin_port, + self.hb_port, + self.control_port, + ) = newports cf = self.connection_file root, ext = os.path.splitext(cf) - self.connection_file = root + '-ssh' + ext - self.write_connection_file() # write the new connection file + self.connection_file = root + "-ssh" + ext + self.write_connection_file() # write the new connection file self.log.info("To connect another client via this tunnel, use:") self.log.info("--existing %s" % os.path.basename(self.connection_file)) def _new_connection_file(self) -> str: - cf = '' + cf = "" while not cf: # we don't need a 128b id to distinguish kernels, use more readable # 48b node segment (12 hex chars). Users running more than 32k simultaneous # kernels can subclass. - ident = str(uuid.uuid4()).split('-')[-1] - cf = os.path.join(self.runtime_dir, 'kernel-%s.json' % ident) + ident = str(uuid.uuid4()).split("-")[-1] + cf = os.path.join(self.runtime_dir, "kernel-%s.json" % ident) # only keep if it's actually new. Protect against unlikely collision # in 48b random search space - cf = cf if not os.path.exists(cf) else '' + cf = cf if not os.path.exists(cf) else "" return cf def init_kernel_manager(self) -> None: @@ -273,18 +295,18 @@ def init_kernel_manager(self) -> None: # Create a KernelManager and start a kernel. try: self.kernel_manager = self.kernel_manager_class( - ip=self.ip, - session=self.session, - transport=self.transport, - shell_port=self.shell_port, - iopub_port=self.iopub_port, - stdin_port=self.stdin_port, - hb_port=self.hb_port, - control_port=self.control_port, - connection_file=self.connection_file, - kernel_name=self.kernel_name, - parent=self, - data_dir=self.data_dir, + ip=self.ip, + session=self.session, + transport=self.transport, + shell_port=self.shell_port, + iopub_port=self.iopub_port, + stdin_port=self.stdin_port, + hb_port=self.hb_port, + control_port=self.control_port, + connection_file=self.connection_file, + kernel_name=self.kernel_name, + parent=self, + data_dir=self.data_dir, ) except NoSuchKernel: self.log.critical("Could not find kernel %s", self.kernel_name) @@ -293,7 +315,7 @@ def init_kernel_manager(self) -> None: self.kernel_manager = cast(KernelManager, self.kernel_manager) self.kernel_manager.client_factory = self.kernel_client_class kwargs = {} - kwargs['extra_arguments'] = self.kernel_argv + kwargs["extra_arguments"] = self.kernel_argv self.kernel_manager.start_kernel(**kwargs) atexit.register(self.kernel_manager.cleanup_ipc_files) @@ -303,11 +325,11 @@ def init_kernel_manager(self) -> None: # in case KM defaults / ssh writing changes things: km = self.kernel_manager - self.shell_port=km.shell_port - self.iopub_port=km.iopub_port - self.stdin_port=km.stdin_port - self.hb_port=km.hb_port - self.control_port=km.control_port + self.shell_port = km.shell_port + self.iopub_port = km.iopub_port + self.stdin_port = km.stdin_port + self.hb_port = km.hb_port + self.control_port = km.control_port self.connection_file = km.connection_file atexit.register(self.kernel_manager.cleanup_connection_file) @@ -317,22 +339,20 @@ def init_kernel_client(self) -> None: self.kernel_client = self.kernel_manager.client() else: self.kernel_client = self.kernel_client_class( - session=self.session, - ip=self.ip, - transport=self.transport, - shell_port=self.shell_port, - iopub_port=self.iopub_port, - stdin_port=self.stdin_port, - hb_port=self.hb_port, - control_port=self.control_port, - connection_file=self.connection_file, - parent=self, + session=self.session, + ip=self.ip, + transport=self.transport, + shell_port=self.shell_port, + iopub_port=self.iopub_port, + stdin_port=self.stdin_port, + hb_port=self.hb_port, + control_port=self.control_port, + connection_file=self.connection_file, + parent=self, ) self.kernel_client.start_channels() - - def initialize(self, argv=None) -> None: """ Classes which mix this class in should call: @@ -345,6 +365,7 @@ def initialize(self, argv=None) -> None: self.init_kernel_manager() self.init_kernel_client() + class IPythonConsoleApp(JupyterConsoleApp): def __init__(self, *args, **kwargs): warnings.warn("IPythonConsoleApp is deprecated. Use JupyterConsoleApp") diff --git a/jupyter_client/ioloop/__init__.py b/jupyter_client/ioloop/__init__.py index 4203d2f82..204d5f8aa 100644 --- a/jupyter_client/ioloop/__init__.py +++ b/jupyter_client/ioloop/__init__.py @@ -1,2 +1,4 @@ -from .manager import IOLoopKernelManager, AsyncIOLoopKernelManager -from .restarter import IOLoopKernelRestarter, AsyncIOLoopKernelRestarter +from .manager import AsyncIOLoopKernelManager # noqa +from .manager import IOLoopKernelManager # noqa +from .restarter import AsyncIOLoopKernelRestarter # noqa +from .restarter import IOLoopKernelRestarter # noqa diff --git a/jupyter_client/ioloop/manager.py b/jupyter_client/ioloop/manager.py index 85d4f9d4b..23713ac33 100644 --- a/jupyter_client/ioloop/manager.py +++ b/jupyter_client/ioloop/manager.py @@ -1,30 +1,28 @@ """A kernel manager with a tornado IOLoop""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - from tornado import ioloop +from traitlets import Instance +from traitlets import Type from zmq.eventloop.zmqstream import ZMQStream -from traitlets import ( - Instance, - Type, -) - -from jupyter_client.manager import KernelManager, AsyncKernelManager -from .restarter import IOLoopKernelRestarter, AsyncIOLoopKernelRestarter +from .restarter import AsyncIOLoopKernelRestarter +from .restarter import IOLoopKernelRestarter +from jupyter_client.manager import AsyncKernelManager +from jupyter_client.manager import KernelManager def as_zmqstream(f): def wrapped(self, *args, **kwargs): socket = f(self, *args, **kwargs) return ZMQStream(socket, self.loop) + return wrapped class IOLoopKernelManager(KernelManager): - loop = Instance('tornado.ioloop.IOLoop') + loop = Instance("tornado.ioloop.IOLoop") def _loop_default(self): return ioloop.IOLoop.current() @@ -33,20 +31,19 @@ def _loop_default(self): default_value=IOLoopKernelRestarter, klass=IOLoopKernelRestarter, help=( - 'Type of KernelRestarter to use. ' - 'Must be a subclass of IOLoopKernelRestarter.\n' - 'Override this to customize how kernel restarts are managed.' + "Type of KernelRestarter to use. " + "Must be a subclass of IOLoopKernelRestarter.\n" + "Override this to customize how kernel restarts are managed." ), config=True, ) - _restarter = Instance('jupyter_client.ioloop.IOLoopKernelRestarter', allow_none=True) + _restarter = Instance("jupyter_client.ioloop.IOLoopKernelRestarter", allow_none=True) def start_restarter(self): if self.autorestart and self.has_kernel: if self._restarter is None: self._restarter = self.restarter_class( - kernel_manager=self, loop=self.loop, - parent=self, log=self.log + kernel_manager=self, loop=self.loop, parent=self, log=self.log ) self._restarter.start() @@ -64,7 +61,7 @@ def stop_restarter(self): class AsyncIOLoopKernelManager(AsyncKernelManager): - loop = Instance('tornado.ioloop.IOLoop') + loop = Instance("tornado.ioloop.IOLoop") def _loop_default(self): return ioloop.IOLoop.current() @@ -73,20 +70,19 @@ def _loop_default(self): default_value=AsyncIOLoopKernelRestarter, klass=AsyncIOLoopKernelRestarter, help=( - 'Type of KernelRestarter to use. ' - 'Must be a subclass of AsyncIOLoopKernelManager.\n' - 'Override this to customize how kernel restarts are managed.' + "Type of KernelRestarter to use. " + "Must be a subclass of AsyncIOLoopKernelManager.\n" + "Override this to customize how kernel restarts are managed." ), config=True, ) - _restarter = Instance('jupyter_client.ioloop.AsyncIOLoopKernelRestarter', allow_none=True) + _restarter = Instance("jupyter_client.ioloop.AsyncIOLoopKernelRestarter", allow_none=True) def start_restarter(self): if self.autorestart and self.has_kernel: if self._restarter is None: self._restarter = self.restarter_class( - kernel_manager=self, loop=self.loop, - parent=self, log=self.log + kernel_manager=self, loop=self.loop, parent=self, log=self.log ) self._restarter.start() diff --git a/jupyter_client/ioloop/restarter.py b/jupyter_client/ioloop/restarter.py index d5a5628e1..4fbdc2977 100644 --- a/jupyter_client/ioloop/restarter.py +++ b/jupyter_client/ioloop/restarter.py @@ -3,28 +3,26 @@ This watches a kernel's state using KernelManager.is_alive and auto restarts the kernel if it dies. """ - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - import warnings +from traitlets import Instance from zmq.eventloop import ioloop from jupyter_client.restarter import KernelRestarter -from traitlets import ( - Instance, -) class IOLoopKernelRestarter(KernelRestarter): """Monitor and autorestart a kernel.""" - loop = Instance('tornado.ioloop.IOLoop') + loop = Instance("tornado.ioloop.IOLoop") def _loop_default(self): - warnings.warn("IOLoopKernelRestarter.loop is deprecated in jupyter-client 5.2", - DeprecationWarning, stacklevel=4, + warnings.warn( + "IOLoopKernelRestarter.loop is deprecated in jupyter-client 5.2", + DeprecationWarning, + stacklevel=4, ) return ioloop.IOLoop.current() @@ -34,7 +32,8 @@ def start(self): """Start the polling of the kernel.""" if self._pcallback is None: self._pcallback = ioloop.PeriodicCallback( - self.poll, 1000*self.time_to_dead, + self.poll, + 1000 * self.time_to_dead, ) self._pcallback.start() @@ -46,10 +45,9 @@ def stop(self): class AsyncIOLoopKernelRestarter(IOLoopKernelRestarter): - async def poll(self): if self.debug: - self.log.debug('Polling kernel...') + self.log.debug("Polling kernel...") is_alive = await self.kernel_manager.is_alive() if not is_alive: if self._restarting: @@ -59,18 +57,19 @@ async def poll(self): if self._restart_count >= self.restart_limit: self.log.warning("AsyncIOLoopKernelRestarter: restart failed") - self._fire_callbacks('dead') + self._fire_callbacks("dead") self._restarting = False self._restart_count = 0 self.stop() else: newports = self.random_ports_until_alive and self._initial_startup - self.log.info('AsyncIOLoopKernelRestarter: restarting kernel (%i/%i), %s random ports', + self.log.info( + "AsyncIOLoopKernelRestarter: restarting kernel (%i/%i), %s random ports", self._restart_count, self.restart_limit, - 'new' if newports else 'keep' + "new" if newports else "keep", ) - self._fire_callbacks('restart') + self._fire_callbacks("restart") await self.kernel_manager.restart_kernel(now=True, newports=newports) self._restarting = True else: diff --git a/jupyter_client/jsonutil.py b/jupyter_client/jsonutil.py index 667e33f1f..1095f85ef 100644 --- a/jupyter_client/jsonutil.py +++ b/jupyter_client/jsonutil.py @@ -1,33 +1,35 @@ """Utilities to manipulate JSON objects.""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - -from datetime import datetime import re import warnings -from typing import Optional, Union +from datetime import datetime +from typing import Optional +from typing import Union from dateutil.parser import parse as _dateutil_parse from dateutil.tz import tzlocal -next_attr_name = '__next__' # Not sure what downstream library uses this, but left it to be safe +next_attr_name = "__next__" # Not sure what downstream library uses this, but left it to be safe -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- # Globals and constants -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- # timestamp formats ISO8601 = "%Y-%m-%dT%H:%M:%S.%f" -ISO8601_PAT = re.compile(r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2})(\.\d{1,6})?(Z|([\+\-]\d{2}:?\d{2}))?$") +ISO8601_PAT = re.compile( + r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2})(\.\d{1,6})?(Z|([\+\-]\d{2}:?\d{2}))?$" +) # holy crap, strptime is not threadsafe. # Calling it once at import seems to help. datetime.strptime("1", "%d") -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- # Classes and functions -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- + def _ensure_tzinfo(dt: datetime) -> datetime: """Ensure a datetime object has tzinfo @@ -36,12 +38,15 @@ def _ensure_tzinfo(dt: datetime) -> datetime: """ if not dt.tzinfo: # No more naïve datetime objects! - warnings.warn("Interpreting naive datetime as local %s. Please add timezone info to timestamps." % dt, + warnings.warn( + "Interpreting naive datetime as local %s. Please add timezone info to timestamps." % dt, DeprecationWarning, - stacklevel=4) + stacklevel=4, + ) dt = dt.replace(tzinfo=tzlocal()) return dt + def parse_date(s: Optional[str]) -> Optional[Union[str, datetime]]: """parse an ISO8601 date string @@ -57,36 +62,38 @@ def parse_date(s: Optional[str]) -> Optional[Union[str, datetime]]: return _ensure_tzinfo(dt) return s + def extract_dates(obj): """extract ISO8601 dates from unpacked JSON""" if isinstance(obj, dict): - new_obj = {} # don't clobber - for k,v in obj.items(): + new_obj = {} # don't clobber + for k, v in obj.items(): new_obj[k] = extract_dates(v) obj = new_obj elif isinstance(obj, (list, tuple)): - obj = [ extract_dates(o) for o in obj ] + obj = [extract_dates(o) for o in obj] elif isinstance(obj, str): obj = parse_date(obj) return obj + def squash_dates(obj): """squash datetime objects into ISO8601 strings""" if isinstance(obj, dict): - obj = dict(obj) # don't clobber - for k,v in obj.items(): + obj = dict(obj) # don't clobber + for k, v in obj.items(): obj[k] = squash_dates(v) elif isinstance(obj, (list, tuple)): - obj = [ squash_dates(o) for o in obj ] + obj = [squash_dates(o) for o in obj] elif isinstance(obj, datetime): obj = obj.isoformat() return obj + def date_default(obj): """default function for packing datetime objects in JSON.""" if isinstance(obj, datetime): obj = _ensure_tzinfo(obj) - return obj.isoformat().replace('+00:00', 'Z') + return obj.isoformat().replace("+00:00", "Z") else: raise TypeError("%r is not JSON serializable" % obj) - diff --git a/jupyter_client/kernelapp.py b/jupyter_client/kernelapp.py index b95afb0b0..623a33fb5 100644 --- a/jupyter_client/kernelapp.py +++ b/jupyter_client/kernelapp.py @@ -2,64 +2,66 @@ import signal import uuid -from jupyter_core.application import JupyterApp, base_flags # type: ignore +from jupyter_core.application import base_flags # type: ignore +from jupyter_core.application import JupyterApp from tornado.ioloop import IOLoop from traitlets import Unicode # type: ignore from . import __version__ -from .kernelspec import KernelSpecManager, NATIVE_KERNEL_NAME +from .kernelspec import KernelSpecManager +from .kernelspec import NATIVE_KERNEL_NAME from .manager import KernelManager + class KernelApp(JupyterApp): - """Launch a kernel by name in a local subprocess. - """ + """Launch a kernel by name in a local subprocess.""" + version = __version__ description = "Run a kernel locally in a subprocess" classes = [KernelManager, KernelSpecManager] aliases = { - 'kernel': 'KernelApp.kernel_name', - 'ip': 'KernelManager.ip', + "kernel": "KernelApp.kernel_name", + "ip": "KernelManager.ip", } - flags = {'debug': base_flags['debug']} + flags = {"debug": base_flags["debug"]} - kernel_name = Unicode(NATIVE_KERNEL_NAME, - help = 'The name of a kernel type to start' - ).tag(config=True) + kernel_name = Unicode(NATIVE_KERNEL_NAME, help="The name of a kernel type to start").tag( + config=True + ) def initialize(self, argv=None): super().initialize(argv) - - cf_basename = 'kernel-%s.json' % uuid.uuid4() - self.config.setdefault('KernelManager', {}).setdefault('connection_file', os.path.join(self.runtime_dir, cf_basename)) - self.km = KernelManager(kernel_name=self.kernel_name, - config=self.config) - + + cf_basename = "kernel-%s.json" % uuid.uuid4() + self.config.setdefault("KernelManager", {}).setdefault( + "connection_file", os.path.join(self.runtime_dir, cf_basename) + ) + self.km = KernelManager(kernel_name=self.kernel_name, config=self.config) + self.loop = IOLoop.current() self.loop.add_callback(self._record_started) def setup_signals(self) -> None: """Shutdown on SIGTERM or SIGINT (Ctrl-C)""" - if os.name == 'nt': + if os.name == "nt": return def shutdown_handler(signo, frame): self.loop.add_callback_from_signal(self.shutdown, signo) + for sig in [signal.SIGTERM, signal.SIGINT]: signal.signal(sig, shutdown_handler) - def shutdown( - self, - signo: int - ) -> None: - self.log.info('Shutting down on signal %d' % signo) + def shutdown(self, signo: int) -> None: + self.log.info("Shutting down on signal %d" % signo) self.km.shutdown_kernel() self.loop.stop() def log_connection_info(self) -> None: cf = self.km.connection_file - self.log.info('Connection file: %s', cf) + self.log.info("Connection file: %s", cf) self.log.info("To connect a client: --existing %s", os.path.basename(cf)) def _record_started(self) -> None: @@ -67,13 +69,13 @@ def _record_started(self) -> None: Do not rely on this except in our own tests! """ - fn = os.environ.get('JUPYTER_CLIENT_TEST_RECORD_STARTUP_PRIVATE') + fn = os.environ.get("JUPYTER_CLIENT_TEST_RECORD_STARTUP_PRIVATE") if fn is not None: - with open(fn, 'wb'): + with open(fn, "wb"): pass def start(self) -> None: - self.log.info('Starting kernel %r', self.kernel_name) + self.log.info("Starting kernel %r", self.kernel_name) try: self.km.start_kernel() self.log_connection_info() diff --git a/jupyter_client/kernelspec.py b/jupyter_client/kernelspec.py index 772806f99..7ecc765dd 100644 --- a/jupyter_client/kernelspec.py +++ b/jupyter_client/kernelspec.py @@ -1,9 +1,6 @@ """Tools for managing kernel specs""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - -import errno import io import json import os @@ -11,17 +8,22 @@ import shutil import warnings -pjoin = os.path.join - -from traitlets import ( - HasTraits, List, Unicode, Dict, Set, Bool, Type, CaselessStrEnum -) -from traitlets.config import LoggingConfigurable - -from jupyter_core.paths import jupyter_data_dir, jupyter_path, SYSTEM_JUPYTER_PATH +from jupyter_core.paths import jupyter_data_dir # type: ignore +from jupyter_core.paths import jupyter_path +from jupyter_core.paths import SYSTEM_JUPYTER_PATH +from traitlets import Bool # type: ignore +from traitlets import CaselessStrEnum +from traitlets import Dict +from traitlets import HasTraits +from traitlets import List +from traitlets import Set +from traitlets import Type +from traitlets import Unicode +from traitlets.config import LoggingConfigurable # type: ignore +pjoin = os.path.join -NATIVE_KERNEL_NAME = 'python3' +NATIVE_KERNEL_NAME = "python3" class KernelSpec(HasTraits): @@ -30,9 +32,7 @@ class KernelSpec(HasTraits): language = Unicode() env = Dict() resource_dir = Unicode() - interrupt_mode = CaselessStrEnum( - ['message', 'signal'], default_value='signal' - ) + interrupt_mode = CaselessStrEnum(["message", "signal"], default_value="signal") metadata = Dict() @classmethod @@ -41,19 +41,20 @@ def from_resource_dir(cls, resource_dir): Pass the path to the *directory* containing kernel.json. """ - kernel_file = pjoin(resource_dir, 'kernel.json') - with io.open(kernel_file, 'r', encoding='utf-8') as f: + kernel_file = pjoin(resource_dir, "kernel.json") + with io.open(kernel_file, "r", encoding="utf-8") as f: kernel_dict = json.load(f) return cls(resource_dir=resource_dir, **kernel_dict) def to_dict(self): - d = dict(argv=self.argv, - env=self.env, - display_name=self.display_name, - language=self.language, - interrupt_mode=self.interrupt_mode, - metadata=self.metadata, - ) + d = dict( + argv=self.argv, + env=self.env, + display_name=self.display_name, + language=self.language, + interrupt_mode=self.interrupt_mode, + metadata=self.metadata, + ) return d @@ -65,7 +66,8 @@ def to_json(self): return json.dumps(self.to_dict()) -_kernel_name_pat = re.compile(r'^[a-z0-9._\-]+$', re.IGNORECASE) +_kernel_name_pat = re.compile(r"^[a-z0-9._\-]+$", re.IGNORECASE) + def _is_valid_kernel_name(name): """Check that a kernel name is valid.""" @@ -73,13 +75,15 @@ def _is_valid_kernel_name(name): return _kernel_name_pat.match(name) -_kernel_name_description = "Kernel names can only contain ASCII letters and numbers and these separators:" \ - " - . _ (hyphen, period, and underscore)." +_kernel_name_description = ( + "Kernel names can only contain ASCII letters and numbers and these separators:" + " - . _ (hyphen, period, and underscore)." +) def _is_kernel_dir(path): """Is ``path`` a kernel directory?""" - return os.path.isdir(path) and os.path.isfile(pjoin(path, 'kernel.json')) + return os.path.isdir(path) and os.path.isfile(pjoin(path, "kernel.json")) def _list_kernels_in(dir): @@ -96,8 +100,9 @@ def _list_kernels_in(dir): continue key = f.lower() if not _is_valid_kernel_name(key): - warnings.warn("Invalid kernelspec directory name (%s): %s" - % (_kernel_name_description, path), stacklevel=3, + warnings.warn( + "Invalid kernelspec directory name (%s): %s" % (_kernel_name_description, path), + stacklevel=3, ) kernels[key] = path return kernels @@ -113,49 +118,57 @@ def __str__(self): class KernelSpecManager(LoggingConfigurable): - kernel_spec_class = Type(KernelSpec, config=True, + kernel_spec_class = Type( + KernelSpec, + config=True, help="""The kernel spec class. This is configurable to allow subclassing of the KernelSpecManager for customized behavior. - """ + """, ) - ensure_native_kernel = Bool(True, config=True, + ensure_native_kernel = Bool( + True, + config=True, help="""If there is no Python kernelspec registered and the IPython kernel is available, ensure it is added to the spec list. - """ + """, ) data_dir = Unicode() + def _data_dir_default(self): return jupyter_data_dir() user_kernel_dir = Unicode() + def _user_kernel_dir_default(self): - return pjoin(self.data_dir, 'kernels') + return pjoin(self.data_dir, "kernels") - whitelist = Set(config=True, + whitelist = Set( + config=True, help="""Whitelist of allowed kernel names. By default, all installed kernels are allowed. - """ + """, ) kernel_dirs = List( help="List of kernel directories to search. Later ones take priority over earlier." ) + def _kernel_dirs_default(self): - dirs = jupyter_path('kernels') + dirs = jupyter_path("kernels") # At some point, we should stop adding .ipython/kernels to the path, # but the cost to keeping it is very small. try: - from IPython.paths import get_ipython_dir + from IPython.paths import get_ipython_dir # type: ignore except ImportError: try: - from IPython.utils.path import get_ipython_dir + from IPython.utils.path import get_ipython_dir # type: ignore except ImportError: # no IPython, no ipython dir get_ipython_dir = None if get_ipython_dir is not None: - dirs.append(os.path.join(get_ipython_dir(), 'kernels')) + dirs.append(os.path.join(get_ipython_dir(), "kernels")) return dirs def find_kernel_specs(self): @@ -170,21 +183,25 @@ def find_kernel_specs(self): if self.ensure_native_kernel and NATIVE_KERNEL_NAME not in d: try: - from ipykernel.kernelspec import RESOURCES - self.log.debug("Native kernel (%s) available from %s", - NATIVE_KERNEL_NAME, RESOURCES) + from ipykernel.kernelspec import RESOURCES # type: ignore + + self.log.debug( + "Native kernel (%s) available from %s", + NATIVE_KERNEL_NAME, + RESOURCES, + ) d[NATIVE_KERNEL_NAME] = RESOURCES except ImportError: self.log.warning("Native kernel (%s) is not available", NATIVE_KERNEL_NAME) if self.whitelist: # filter if there's a whitelist - d = {name:spec for name,spec in d.items() if name in self.whitelist} + d = {name: spec for name, spec in d.items() if name in self.whitelist} return d # TODO: Caching? def _get_kernel_spec_by_name(self, kernel_name, resource_dir): - """ Returns a :class:`KernelSpec` instance for a given kernel_name + """Returns a :class:`KernelSpec` instance for a given kernel_name and resource_dir. """ if kernel_name == NATIVE_KERNEL_NAME: @@ -222,8 +239,11 @@ def get_kernel_spec(self, kernel_name): Raises :exc:`NoSuchKernel` if the given kernel name is not found. """ if not _is_valid_kernel_name(kernel_name): - self.log.warning("Kernelspec name %r is invalid: %s", kernel_name, - _kernel_name_description) + self.log.warning( + "Kernelspec name %r is invalid: %s", + kernel_name, + _kernel_name_description, + ) resource_dir = self._find_spec_directory(kernel_name.lower()) if resource_dir is None: @@ -256,10 +276,7 @@ def get_all_specs(self): # and get_kernel_spec, but not the newer get_all_specs spec = self.get_kernel_spec(kname) - res[kname] = { - "resource_dir": resource_dir, - "spec": spec.to_dict() - } + res[kname] = {"resource_dir": resource_dir, "spec": spec.to_dict()} except Exception: self.log.warning("Error loading kernelspec %r", kname, exc_info=True) return res @@ -287,13 +304,13 @@ def _get_destination_dir(self, kernel_name, user=False, prefix=None): if user: return os.path.join(self.user_kernel_dir, kernel_name) elif prefix: - return os.path.join(os.path.abspath(prefix), 'share', 'jupyter', 'kernels', kernel_name) + return os.path.join(os.path.abspath(prefix), "share", "jupyter", "kernels", kernel_name) else: - return os.path.join(SYSTEM_JUPYTER_PATH[0], 'kernels', kernel_name) - + return os.path.join(SYSTEM_JUPYTER_PATH[0], "kernels", kernel_name) - def install_kernel_spec(self, source_dir, kernel_name=None, user=False, - replace=None, prefix=None): + def install_kernel_spec( + self, source_dir, kernel_name=None, user=False, replace=None, prefix=None + ): """Install a kernel spec by copying its directory. If ``kernel_name`` is not given, the basename of ``source_dir`` will @@ -307,45 +324,53 @@ def install_kernel_spec(self, source_dir, kernel_name=None, user=False, PREFIX/share/jupyter/kernels/KERNEL_NAME. This can be sys.prefix for installation inside virtual or conda envs. """ - source_dir = source_dir.rstrip('/\\') + source_dir = source_dir.rstrip("/\\") if not kernel_name: kernel_name = os.path.basename(source_dir) kernel_name = kernel_name.lower() if not _is_valid_kernel_name(kernel_name): - raise ValueError("Invalid kernel name %r. %s" % (kernel_name, _kernel_name_description)) + raise ValueError( + "Invalid kernel name %r. %s" % (kernel_name, _kernel_name_description) + ) if user and prefix: raise ValueError("Can't specify both user and prefix. Please choose one or the other.") if replace is not None: warnings.warn( - "replace is ignored. Installing a kernelspec always replaces an existing installation", + "replace is ignored. Installing a kernelspec always replaces an existing " + "installation", DeprecationWarning, stacklevel=2, ) destination = self._get_destination_dir(kernel_name, user=user, prefix=prefix) - self.log.debug('Installing kernelspec in %s', destination) + self.log.debug("Installing kernelspec in %s", destination) kernel_dir = os.path.dirname(destination) if kernel_dir not in self.kernel_dirs: - self.log.warning("Installing to %s, which is not in %s. The kernelspec may not be found.", - kernel_dir, self.kernel_dirs, + self.log.warning( + "Installing to %s, which is not in %s. The kernelspec may not be found.", + kernel_dir, + self.kernel_dirs, ) if os.path.isdir(destination): - self.log.info('Removing existing kernelspec in %s', destination) + self.log.info("Removing existing kernelspec in %s", destination) shutil.rmtree(destination) shutil.copytree(source_dir, destination) - self.log.info('Installed kernelspec %s in %s', kernel_name, destination) + self.log.info("Installed kernelspec %s in %s", kernel_name, destination) return destination def install_native_kernel_spec(self, user=False): """DEPRECATED: Use ipykernel.kernelspec.install""" - warnings.warn("install_native_kernel_spec is deprecated." - " Use ipykernel.kernelspec import install.", stacklevel=2) + warnings.warn( + "install_native_kernel_spec is deprecated." " Use ipykernel.kernelspec import install.", + stacklevel=2, + ) from ipykernel.kernelspec import install + install(self, user=user) @@ -353,6 +378,7 @@ def find_kernel_specs(): """Returns a dict mapping kernel names to resource directories.""" return KernelSpecManager().find_kernel_specs() + def get_kernel_spec(kernel_name): """Returns a :class:`KernelSpec` instance for the given kernel_name. @@ -360,14 +386,16 @@ def get_kernel_spec(kernel_name): """ return KernelSpecManager().get_kernel_spec(kernel_name) -def install_kernel_spec(source_dir, kernel_name=None, user=False, replace=False, - prefix=None): - return KernelSpecManager().install_kernel_spec(source_dir, kernel_name, - user, replace, prefix) + +def install_kernel_spec(source_dir, kernel_name=None, user=False, replace=False, prefix=None): + return KernelSpecManager().install_kernel_spec(source_dir, kernel_name, user, replace, prefix) + install_kernel_spec.__doc__ = KernelSpecManager.install_kernel_spec.__doc__ + def install_native_kernel_spec(user=False): return KernelSpecManager().install_native_kernel_spec(user=user) + install_native_kernel_spec.__doc__ = KernelSpecManager.install_native_kernel_spec.__doc__ diff --git a/jupyter_client/kernelspecapp.py b/jupyter_client/kernelspecapp.py index 095aad451..9d98e41e1 100644 --- a/jupyter_client/kernelspecapp.py +++ b/jupyter_client/kernelspecapp.py @@ -1,17 +1,19 @@ - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - import errno +import json import os.path import sys -import json +from jupyter_core.application import base_aliases +from jupyter_core.application import base_flags +from jupyter_core.application import JupyterApp +from traitlets import Bool +from traitlets import Dict +from traitlets import Instance +from traitlets import List +from traitlets import Unicode from traitlets.config.application import Application -from jupyter_core.application import ( - JupyterApp, base_flags, base_aliases -) -from traitlets import Instance, Dict, Unicode, Bool, List from . import __version__ from .kernelspec import KernelSpecManager @@ -21,13 +23,19 @@ class ListKernelSpecs(JupyterApp): version = __version__ description = """List installed kernel specifications.""" kernel_spec_manager = Instance(KernelSpecManager) - json_output = Bool(False, help='output spec name and location as machine-readable json.', - config=True) + json_output = Bool( + False, + help="output spec name and location as machine-readable json.", + config=True, + ) - flags = {'json': ({'ListKernelSpecs': {'json_output': True}}, - "output spec name and location as machine-readable json."), - 'debug': base_flags['debug'], - } + flags = { + "json": ( + {"ListKernelSpecs": {"json_output": True}}, + "output spec name and location as machine-readable json.", + ), + "debug": base_flags["debug"], + } def _kernel_spec_manager_default(self): return KernelSpecManager(parent=self, data_dir=self.data_dir) @@ -55,10 +63,7 @@ def path_key(item): for kernelname, path in sorted(paths.items(), key=path_key): print(" %s %s" % (kernelname.ljust(name_len), path)) else: - print(json.dumps({ - 'kernelspecs': specs - }, indent=2)) - + print(json.dumps({"kernelspecs": specs}, indent=2)) class InstallKernelSpec(JupyterApp): @@ -80,41 +85,49 @@ def _kernel_spec_manager_default(self): return KernelSpecManager(data_dir=self.data_dir) sourcedir = Unicode() - kernel_name = Unicode("", config=True, - help="Install the kernel spec with this name" - ) + kernel_name = Unicode("", config=True, help="Install the kernel spec with this name") + def _kernel_name_default(self): return os.path.basename(self.sourcedir) - user = Bool(False, config=True, + user = Bool( + False, + config=True, help=""" Try to install the kernel spec to the per-user directory instead of the system or environment directory. - """ + """, ) - prefix = Unicode('', config=True, + prefix = Unicode( + "", + config=True, help="""Specify a prefix to install to, e.g. an env. The kernelspec will be installed in PREFIX/share/jupyter/kernels/ - """ - ) - replace = Bool(False, config=True, - help="Replace any existing kernel spec with this name." + """, ) + replace = Bool(False, config=True, help="Replace any existing kernel spec with this name.") aliases = { - 'name': 'InstallKernelSpec.kernel_name', - 'prefix': 'InstallKernelSpec.prefix', + "name": "InstallKernelSpec.kernel_name", + "prefix": "InstallKernelSpec.prefix", } aliases.update(base_aliases) - flags = {'user': ({'InstallKernelSpec': {'user': True}}, - "Install to the per-user kernel registry"), - 'replace': ({'InstallKernelSpec': {'replace': True}}, - "Replace any existing kernel spec with this name."), - 'sys-prefix': ({'InstallKernelSpec': {'prefix': sys.prefix}}, - "Install to Python's sys.prefix. Useful in conda/virtual environments."), - 'debug': base_flags['debug'], - } + flags = { + "user": ( + {"InstallKernelSpec": {"user": True}}, + "Install to the per-user kernel registry", + ), + "replace": ( + {"InstallKernelSpec": {"replace": True}}, + "Replace any existing kernel spec with this name.", + ), + "sys-prefix": ( + {"InstallKernelSpec": {"prefix": sys.prefix}}, + "Install to Python's sys.prefix. Useful in conda/virtual environments.", + ), + "debug": base_flags["debug"], + } def parse_command_line(self, argv): super().parse_command_line(argv) @@ -129,39 +142,46 @@ def start(self): if self.user and self.prefix: self.exit("Can't specify both user and prefix. Please choose one or the other.") try: - self.kernel_spec_manager.install_kernel_spec(self.sourcedir, - kernel_name=self.kernel_name, - user=self.user, - prefix=self.prefix, - replace=self.replace, - ) + self.kernel_spec_manager.install_kernel_spec( + self.sourcedir, + kernel_name=self.kernel_name, + user=self.user, + prefix=self.prefix, + replace=self.replace, + ) except OSError as e: if e.errno == errno.EACCES: print(e, file=sys.stderr) if not self.user: - print("Perhaps you want to install with `sudo` or `--user`?", file=sys.stderr) + print( + "Perhaps you want to install with `sudo` or `--user`?", + file=sys.stderr, + ) self.exit(1) elif e.errno == errno.EEXIST: - print("A kernel spec is already present at %s" % e.filename, file=sys.stderr) + print( + "A kernel spec is already present at %s" % e.filename, + file=sys.stderr, + ) self.exit(1) raise + class RemoveKernelSpec(JupyterApp): version = __version__ description = """Remove one or more Jupyter kernelspecs by name.""" examples = """jupyter kernelspec remove python2 [my_kernel ...]""" - force = Bool(False, config=True, - help="""Force removal, don't prompt for confirmation.""" - ) + force = Bool(False, config=True, help="""Force removal, don't prompt for confirmation.""") spec_names = List(Unicode()) kernel_spec_manager = Instance(KernelSpecManager) + def _kernel_spec_manager_default(self): return KernelSpecManager(data_dir=self.data_dir, parent=self) flags = { - 'f': ({'RemoveKernelSpec': {'force': True}}, force.get_metadata('help')), + "f": ({"RemoveKernelSpec": {"force": True}}, force.get_metadata("help")), } flags.update(JupyterApp.flags) @@ -169,7 +189,7 @@ def parse_command_line(self, argv): super().parse_command_line(argv) # accept positional arg as profile name if self.extra_args: - self.spec_names = sorted(set(self.extra_args)) # remove duplicates + self.spec_names = sorted(set(self.extra_args)) # remove duplicates else: self.exit("No kernelspec specified.") @@ -178,14 +198,14 @@ def start(self): spec_paths = self.kernel_spec_manager.find_kernel_specs() missing = set(self.spec_names).difference(set(spec_paths)) if missing: - self.exit("Couldn't find kernel spec(s): %s" % ', '.join(missing)) + self.exit("Couldn't find kernel spec(s): %s" % ", ".join(missing)) if not self.force: print("Kernel specs to remove:") for name in self.spec_names: print(" %s\t%s" % (name.ljust(20), spec_paths[name])) answer = input("Remove %i kernel specs [y/N]: " % len(self.spec_names)) - if not answer.lower().startswith('y'): + if not answer.lower().startswith("y"): return for kernel_name in self.spec_names: @@ -209,21 +229,28 @@ class InstallNativeKernelSpec(JupyterApp): def _kernel_spec_manager_default(self): return KernelSpecManager(data_dir=self.data_dir) - user = Bool(False, config=True, + user = Bool( + False, + config=True, help=""" Try to install the kernel spec to the per-user directory instead of the system or environment directory. - """ + """, ) - flags = {'user': ({'InstallNativeKernelSpec': {'user': True}}, - "Install to the per-user kernel registry"), - 'debug': base_flags['debug'], - } + flags = { + "user": ( + {"InstallNativeKernelSpec": {"user": True}}, + "Install to the per-user kernel registry", + ), + "debug": base_flags["debug"], + } def start(self): - self.log.warning("`jupyter kernelspec install-self` is DEPRECATED as of 4.0." - " You probably want `ipython kernel install` to install the IPython kernelspec.") + self.log.warning( + "`jupyter kernelspec install-self` is DEPRECATED as of 4.0." + " You probably want `ipython kernel install` to install the IPython kernelspec." + ) try: from ipykernel import kernelspec except ImportError: @@ -235,29 +262,41 @@ def start(self): if e.errno == errno.EACCES: print(e, file=sys.stderr) if not self.user: - print("Perhaps you want to install with `sudo` or `--user`?", file=sys.stderr) + print( + "Perhaps you want to install with `sudo` or `--user`?", + file=sys.stderr, + ) self.exit(1) self.exit(e) + class KernelSpecApp(Application): version = __version__ name = "jupyter kernelspec" description = """Manage Jupyter kernel specifications.""" - subcommands = Dict({ - 'list': (ListKernelSpecs, ListKernelSpecs.description.splitlines()[0]), - 'install': (InstallKernelSpec, InstallKernelSpec.description.splitlines()[0]), - 'uninstall': (RemoveKernelSpec, "Alias for remove"), - 'remove': (RemoveKernelSpec, RemoveKernelSpec.description.splitlines()[0]), - 'install-self': (InstallNativeKernelSpec, InstallNativeKernelSpec.description.splitlines()[0]), - }) + subcommands = Dict( + { + "list": (ListKernelSpecs, ListKernelSpecs.description.splitlines()[0]), + "install": ( + InstallKernelSpec, + InstallKernelSpec.description.splitlines()[0], + ), + "uninstall": (RemoveKernelSpec, "Alias for remove"), + "remove": (RemoveKernelSpec, RemoveKernelSpec.description.splitlines()[0]), + "install-self": ( + InstallNativeKernelSpec, + InstallNativeKernelSpec.description.splitlines()[0], + ), + } + ) aliases = {} flags = {} def start(self): if self.subapp is None: - print("No subcommand specified. Must specify one of: %s"% list(self.subcommands)) + print("No subcommand specified. Must specify one of: %s" % list(self.subcommands)) print() self.print_description() self.print_subcommands() @@ -266,5 +305,5 @@ def start(self): return self.subapp.start() -if __name__ == '__main__': +if __name__ == "__main__": KernelSpecApp.launch_instance() diff --git a/jupyter_client/launcher.py b/jupyter_client/launcher.py index 930ee74b0..960413ee1 100644 --- a/jupyter_client/launcher.py +++ b/jupyter_client/launcher.py @@ -1,12 +1,13 @@ """Utilities for launching kernels""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - import os import sys -from subprocess import Popen, PIPE -from typing import List, Dict, Optional +from subprocess import PIPE +from subprocess import Popen +from typing import Dict +from typing import List +from typing import Optional from traitlets.log import get_logger # type: ignore @@ -19,9 +20,9 @@ def launch_kernel( env: Optional[Dict[str, str]] = None, independent: bool = False, cwd: Optional[str] = None, - **kw + **kw, ) -> Popen: - """ Launches a localhost kernel, binding to the specified ports. + """Launches a localhost kernel, binding to the specified ports. Parameters ---------- @@ -65,9 +66,9 @@ def launch_kernel( # If this process in running on pythonw, we know that stdin, stdout, and # stderr are all invalid. - redirect_out = sys.executable.endswith('pythonw.exe') + redirect_out = sys.executable.endswith("pythonw.exe") if redirect_out: - blackhole = open(os.devnull, 'w') + blackhole = open(os.devnull, "w") _stdout = blackhole if stdout is None else stdout _stderr = blackhole if stderr is None else stderr else: @@ -86,11 +87,12 @@ def launch_kernel( kwargs.update(main_args) # Spawn a kernel. - if sys.platform == 'win32': + if sys.platform == "win32": if cwd: - kwargs['cwd'] = cwd + kwargs["cwd"] = cwd from .win_interrupt import create_interrupt_event + # Create a Win32 event for interrupting the kernel # and store it in an environment variable. interrupt_event = create_interrupt_event() @@ -99,59 +101,66 @@ def launch_kernel( env["IPY_INTERRUPT_EVENT"] = env["JPY_INTERRUPT_EVENT"] try: - from _winapi import (DuplicateHandle, GetCurrentProcess, - DUPLICATE_SAME_ACCESS, CREATE_NEW_PROCESS_GROUP) - except: - from _subprocess import (DuplicateHandle, GetCurrentProcess, # type: ignore - DUPLICATE_SAME_ACCESS, CREATE_NEW_PROCESS_GROUP) # type: ignore + from _winapi import ( + CREATE_NEW_PROCESS_GROUP, + DUPLICATE_SAME_ACCESS, + DuplicateHandle, + GetCurrentProcess, + ) + except: # noqa + from _subprocess import GetCurrentProcess # type: ignore + from _subprocess import CREATE_NEW_PROCESS_GROUP, DUPLICATE_SAME_ACCESS, DuplicateHandle # create a handle on the parent to be inherited if independent: - kwargs['creationflags'] = CREATE_NEW_PROCESS_GROUP + kwargs["creationflags"] = CREATE_NEW_PROCESS_GROUP else: pid = GetCurrentProcess() - handle = DuplicateHandle(pid, pid, pid, 0, - True, # Inheritable by new processes. - DUPLICATE_SAME_ACCESS) - env['JPY_PARENT_PID'] = str(int(handle)) + handle = DuplicateHandle( + pid, + pid, + pid, + 0, + True, + DUPLICATE_SAME_ACCESS, # Inheritable by new processes. + ) + env["JPY_PARENT_PID"] = str(int(handle)) # Prevent creating new console window on pythonw if redirect_out: - kwargs['creationflags'] = kwargs.setdefault('creationflags', 0) | 0x08000000 # CREATE_NO_WINDOW + kwargs["creationflags"] = ( + kwargs.setdefault("creationflags", 0) | 0x08000000 + ) # CREATE_NO_WINDOW # Avoid closing the above parent and interrupt handles. # close_fds is True by default on Python >=3.7 # or when no stream is captured on Python <3.7 # (we always capture stdin, so this is already False by default on <3.7) - kwargs['close_fds'] = False + kwargs["close_fds"] = False else: # Create a new session. # This makes it easier to interrupt the kernel, # because we want to interrupt the whole process group. # We don't use setpgrp, which is known to cause problems for kernels starting # certain interactive subprocesses, such as bash -i. - kwargs['start_new_session'] = True + kwargs["start_new_session"] = True if not independent: - env['JPY_PARENT_PID'] = str(os.getpid()) + env["JPY_PARENT_PID"] = str(os.getpid()) try: # Allow to use ~/ in the command or its arguments cmd = [os.path.expanduser(s) for s in cmd] proc = Popen(cmd, **kwargs) - except Exception as exc: - msg = ( - "Failed to run command:\n{}\n" - " PATH={!r}\n" - " with kwargs:\n{!r}\n" - ) + except Exception: + msg = "Failed to run command:\n{}\n" " PATH={!r}\n" " with kwargs:\n{!r}\n" # exclude environment variables, # which may contain access tokens and the like. - without_env = {key:value for key, value in kwargs.items() if key != 'env'} - msg = msg.format(cmd, env.get('PATH', os.defpath), without_env) + without_env = {key: value for key, value in kwargs.items() if key != "env"} + msg = msg.format(cmd, env.get("PATH", os.defpath), without_env) get_logger().error(msg) raise - if sys.platform == 'win32': + if sys.platform == "win32": # Attach the interrupt event to the Popen objet so it can be used later. proc.win32_interrupt_event = interrupt_event # type: ignore @@ -163,6 +172,7 @@ def launch_kernel( return proc + __all__ = [ - 'launch_kernel', + "launch_kernel", ] diff --git a/jupyter_client/localinterfaces.py b/jupyter_client/localinterfaces.py index 20469cf68..80f96a324 100644 --- a/jupyter_client/localinterfaces.py +++ b/jupyter_client/localinterfaces.py @@ -1,21 +1,19 @@ """Utilities for identifying local IP addresses.""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - import os import re import socket import subprocess -from subprocess import Popen, PIPE - +from subprocess import PIPE +from subprocess import Popen +from typing import List from warnings import warn +LOCAL_IPS: List = [] +PUBLIC_IPS: List = [] -LOCAL_IPS = [] -PUBLIC_IPS = [] - -LOCALHOST = '' +LOCALHOST = "" def _uniq_stable(elems): @@ -27,76 +25,87 @@ def _uniq_stable(elems): seen = set() return [x for x in elems if x not in seen and not seen.add(x)] + def _get_output(cmd): """Get output of a command, raising IOError if it fails""" startupinfo = None - if os.name == 'nt': + if os.name == "nt": startupinfo = subprocess.STARTUPINFO() startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW p = Popen(cmd, stdout=PIPE, stderr=PIPE, startupinfo=startupinfo) stdout, stderr = p.communicate() if p.returncode: - raise IOError("Failed to run %s: %s" % (cmd, stderr.decode('utf8', 'replace'))) - return stdout.decode('utf8', 'replace') + raise IOError("Failed to run %s: %s" % (cmd, stderr.decode("utf8", "replace"))) + return stdout.decode("utf8", "replace") + def _only_once(f): """decorator to only run a function once""" f.called = False + def wrapped(**kwargs): if f.called: return ret = f(**kwargs) f.called = True return ret + return wrapped + def _requires_ips(f): """decorator to ensure load_ips has been run before f""" + def ips_loaded(*args, **kwargs): _load_ips() return f(*args, **kwargs) + return ips_loaded + # subprocess-parsing ip finders class NoIPAddresses(Exception): pass + def _populate_from_list(addrs): """populate local and public IPs from flat list of all IPs""" if not addrs: raise NoIPAddresses - + global LOCALHOST public_ips = [] local_ips = [] - + for ip in addrs: local_ips.append(ip) - if not ip.startswith('127.'): + if not ip.startswith("127."): public_ips.append(ip) elif not LOCALHOST: LOCALHOST = ip - - if not LOCALHOST or LOCALHOST == '127.0.0.1': - LOCALHOST = '127.0.0.1' + + if not LOCALHOST or LOCALHOST == "127.0.0.1": + LOCALHOST = "127.0.0.1" local_ips.insert(0, LOCALHOST) - - local_ips.extend(['0.0.0.0', '']) - + + local_ips.extend(["0.0.0.0", ""]) + LOCAL_IPS[:] = _uniq_stable(local_ips) PUBLIC_IPS[:] = _uniq_stable(public_ips) -_ifconfig_ipv4_pat = re.compile(r'inet\b.*?(\d+\.\d+\.\d+\.\d+)', re.IGNORECASE) + +_ifconfig_ipv4_pat = re.compile(r"inet\b.*?(\d+\.\d+\.\d+\.\d+)", re.IGNORECASE) + def _load_ips_ifconfig(): """load ip addresses from `ifconfig` output (posix)""" - + try: - out = _get_output('ifconfig') + out = _get_output("ifconfig") except (IOError, OSError): # no ifconfig, it's usually in /sbin and /sbin is not on everyone's PATH - out = _get_output('/sbin/ifconfig') - + out = _get_output("/sbin/ifconfig") + lines = out.splitlines() addrs = [] for line in lines: @@ -108,22 +117,24 @@ def _load_ips_ifconfig(): def _load_ips_ip(): """load ip addresses from `ip addr` output (Linux)""" - out = _get_output(['ip', '-f', 'inet', 'addr']) - + out = _get_output(["ip", "-f", "inet", "addr"]) + lines = out.splitlines() addrs = [] for line in lines: blocks = line.lower().split() - if (len(blocks) >= 2) and (blocks[0] == 'inet'): - addrs.append(blocks[1].split('/')[0]) + if (len(blocks) >= 2) and (blocks[0] == "inet"): + addrs.append(blocks[1].split("/")[0]) _populate_from_list(addrs) -_ipconfig_ipv4_pat = re.compile(r'ipv4.*?(\d+\.\d+\.\d+\.\d+)$', re.IGNORECASE) + +_ipconfig_ipv4_pat = re.compile(r"ipv4.*?(\d+\.\d+\.\d+\.\d+)$", re.IGNORECASE) + def _load_ips_ipconfig(): """load ip addresses from `ipconfig` output (Windows)""" - out = _get_output('ipconfig') - + out = _get_output("ipconfig") + lines = out.splitlines() addrs = [] for line in lines: @@ -135,92 +146,95 @@ def _load_ips_ipconfig(): def _load_ips_netifaces(): """load ip addresses with netifaces""" - import netifaces + import netifaces # type: ignore + global LOCALHOST local_ips = [] public_ips = [] - + # list of iface names, 'lo0', 'eth0', etc. for iface in netifaces.interfaces(): # list of ipv4 addrinfo dicts ipv4s = netifaces.ifaddresses(iface).get(netifaces.AF_INET, []) for entry in ipv4s: - addr = entry.get('addr') + addr = entry.get("addr") if not addr: continue - if not (iface.startswith('lo') or addr.startswith('127.')): + if not (iface.startswith("lo") or addr.startswith("127.")): public_ips.append(addr) elif not LOCALHOST: LOCALHOST = addr local_ips.append(addr) if not LOCALHOST: # we never found a loopback interface (can this ever happen?), assume common default - LOCALHOST = '127.0.0.1' + LOCALHOST = "127.0.0.1" local_ips.insert(0, LOCALHOST) - local_ips.extend(['0.0.0.0', '']) + local_ips.extend(["0.0.0.0", ""]) LOCAL_IPS[:] = _uniq_stable(local_ips) PUBLIC_IPS[:] = _uniq_stable(public_ips) def _load_ips_gethostbyname(): """load ip addresses with socket.gethostbyname_ex - + This can be slow. """ global LOCALHOST try: - LOCAL_IPS[:] = socket.gethostbyname_ex('localhost')[2] + LOCAL_IPS[:] = socket.gethostbyname_ex("localhost")[2] except socket.error: # assume common default - LOCAL_IPS[:] = ['127.0.0.1'] - + LOCAL_IPS[:] = ["127.0.0.1"] + try: hostname = socket.gethostname() PUBLIC_IPS[:] = socket.gethostbyname_ex(hostname)[2] # try hostname.local, in case hostname has been short-circuited to loopback - if not hostname.endswith('.local') and all(ip.startswith('127') for ip in PUBLIC_IPS): - PUBLIC_IPS[:] = socket.gethostbyname_ex(socket.gethostname() + '.local')[2] + if not hostname.endswith(".local") and all(ip.startswith("127") for ip in PUBLIC_IPS): + PUBLIC_IPS[:] = socket.gethostbyname_ex(socket.gethostname() + ".local")[2] except socket.error: pass finally: PUBLIC_IPS[:] = _uniq_stable(PUBLIC_IPS) LOCAL_IPS.extend(PUBLIC_IPS) - + # include all-interface aliases: 0.0.0.0 and '' - LOCAL_IPS.extend(['0.0.0.0', '']) + LOCAL_IPS.extend(["0.0.0.0", ""]) LOCAL_IPS[:] = _uniq_stable(LOCAL_IPS) LOCALHOST = LOCAL_IPS[0] + def _load_ips_dumb(): """Fallback in case of unexpected failure""" global LOCALHOST - LOCALHOST = '127.0.0.1' - LOCAL_IPS[:] = [LOCALHOST, '0.0.0.0', ''] + LOCALHOST = "127.0.0.1" + LOCAL_IPS[:] = [LOCALHOST, "0.0.0.0", ""] PUBLIC_IPS[:] = [] + @_only_once def _load_ips(suppress_exceptions=True): """load the IPs that point to this machine - + This function will only ever be called once. - + It will use netifaces to do it quickly if available. Then it will fallback on parsing the output of ifconfig / ip addr / ipconfig, as appropriate. Finally, it will fallback on socket.gethostbyname_ex, which can be slow. """ - + try: # first priority, use netifaces try: return _load_ips_netifaces() except ImportError: pass - + # second priority, parse subprocess output (how reliable is this?) - - if os.name == 'nt': + + if os.name == "nt": try: return _load_ips_ipconfig() except (IOError, NoIPAddresses): @@ -234,9 +248,9 @@ def _load_ips(suppress_exceptions=True): return _load_ips_ifconfig() except (IOError, OSError, NoIPAddresses): pass - + # lowest priority, use gethostbyname - + return _load_ips_gethostbyname() except Exception as e: if not suppress_exceptions: @@ -251,21 +265,25 @@ def local_ips(): """return the IP addresses that point to this machine""" return LOCAL_IPS + @_requires_ips def public_ips(): """return the IP addresses for this machine that are visible to other machines""" return PUBLIC_IPS + @_requires_ips def localhost(): """return ip for localhost (almost always 127.0.0.1)""" return LOCALHOST + @_requires_ips def is_local_ip(ip): """does `ip` point to this machine?""" return ip in LOCAL_IPS + @_requires_ips def is_public_ip(ip): """is `ip` a publicly visible address?""" diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py index a2c2ad60c..99e4ffa14 100644 --- a/jupyter_client/manager.py +++ b/jupyter_client/manager.py @@ -1,39 +1,41 @@ """Base class to manage a running kernel""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - -from contextlib import contextmanager import asyncio import os import re import signal import sys -import time -import warnings -from subprocess import Popen import typing as t - +import warnings +from contextlib import contextmanager from enum import Enum +from subprocess import Popen import zmq - -from .localinterfaces import is_local_ip, local_ips -from traitlets import ( # type: ignore - Any, Float, Instance, Unicode, List, Bool, Type, DottedObjectName, - default, observe, observe_compat -) +from traitlets import Any # type: ignore +from traitlets import Bool +from traitlets import default +from traitlets import DottedObjectName +from traitlets import Float +from traitlets import Instance +from traitlets import List +from traitlets import observe +from traitlets import observe_compat +from traitlets import Type +from traitlets import Unicode from traitlets.utils.importstring import import_item # type: ignore -from jupyter_client import ( - launch_kernel, - kernelspec, - KernelClient, -) + from .connect import ConnectionFileMixin -from .managerabc import ( - KernelManagerABC -) -from .utils import run_sync, ensure_async +from .localinterfaces import is_local_ip +from .localinterfaces import local_ips +from .managerabc import KernelManagerABC +from .utils import ensure_async +from .utils import run_sync +from jupyter_client import KernelClient +from jupyter_client import kernelspec +from jupyter_client import launch_kernel + class _ShutdownStatus(Enum): """ @@ -43,6 +45,7 @@ class _ShutdownStatus(Enum): missbehavior. """ + Unset = None ShutdownRequest = "ShutdownRequest" SigtermRequest = "SigtermRequest" @@ -64,25 +67,24 @@ def __init__(self, *args, **kwargs): # The PyZMQ Context to use for communication with the kernel. context: Instance = Instance(zmq.Context) - @default('context') + @default("context") def _context_default(self) -> zmq.Context: self._created_context = True return zmq.Context() # the class to create with our `client` method - client_class: DottedObjectName = DottedObjectName('jupyter_client.blocking.BlockingKernelClient') - client_factory: Type = Type(klass='jupyter_client.KernelClient') + client_class: DottedObjectName = DottedObjectName( + "jupyter_client.blocking.BlockingKernelClient" + ) + client_factory: Type = Type(klass="jupyter_client.KernelClient") - @default('client_factory') + @default("client_factory") def _client_factory_default(self) -> Type: return import_item(self.client_class) - @observe('client_class') - def _client_class_changed( - self, - change: t.Dict[str, DottedObjectName] - ) -> None: - self.client_factory = import_item(str(change['new'])) + @observe("client_class") + def _client_class_changed(self, change: t.Dict[str, DottedObjectName]) -> None: + self.client_factory = import_item(str(change["new"])) # The kernel process with which the KernelManager is communicating. # generally a Popen instance @@ -90,20 +92,18 @@ def _client_class_changed( kernel_spec_manager: Instance = Instance(kernelspec.KernelSpecManager) - @default('kernel_spec_manager') + @default("kernel_spec_manager") def _kernel_spec_manager_default(self) -> kernelspec.KernelSpecManager: return kernelspec.KernelSpecManager(data_dir=self.data_dir) - @observe('kernel_spec_manager') + @observe("kernel_spec_manager") @observe_compat - def _kernel_spec_manager_changed( - self, - change: t.Dict[str, Instance] - ) -> None: + def _kernel_spec_manager_changed(self, change: t.Dict[str, Instance]) -> None: self._kernel_spec = None shutdown_wait_time: Float = Float( - 5.0, config=True, + 5.0, + config=True, help="Time to wait for a kernel to terminate before killing it, " "in seconds. When a shutdown request is initiated, the kernel " "will be immediately send and interrupt (SIGINT), followed" @@ -115,24 +115,23 @@ def _kernel_spec_manager_changed( kernel_name: Unicode = Unicode(kernelspec.NATIVE_KERNEL_NAME) - @observe('kernel_name') - def _kernel_name_changed( - self, - change: t.Dict[str, Unicode] - ) -> None: + @observe("kernel_name") + def _kernel_name_changed(self, change: t.Dict[str, Unicode]) -> None: self._kernel_spec = None - if change['new'] == 'python': + if change["new"] == "python": self.kernel_name = kernelspec.NATIVE_KERNEL_NAME _kernel_spec: t.Optional[kernelspec.KernelSpec] = None @property def kernel_spec(self) -> t.Optional[kernelspec.KernelSpec]: - if self._kernel_spec is None and self.kernel_name != '': + if self._kernel_spec is None and self.kernel_name != "": self._kernel_spec = self.kernel_spec_manager.get_kernel_spec(self.kernel_name) return self._kernel_spec - kernel_cmd: List = List(Unicode(), config=True, + kernel_cmd: List = List( + Unicode(), + config=True, help="""DEPRECATED: Use kernel_name instead. The Popen Command to launch the kernel. @@ -143,22 +142,25 @@ def kernel_spec(self) -> t.Optional[kernelspec.KernelSpec]: arguments that the kernel understands. In particular, this means that the kernel does not receive the option --debug if it given on the Jupyter command line. - """ + """, ) def _kernel_cmd_changed(self, name, old, new): - warnings.warn("Setting kernel_cmd is deprecated, use kernel_spec to " - "start different kernels.") + warnings.warn( + "Setting kernel_cmd is deprecated, use kernel_spec to " "start different kernels." + ) - cache_ports: Bool = Bool(help='True if the MultiKernelManager should cache ports for this KernelManager instance') + cache_ports: Bool = Bool( + help="True if the MultiKernelManager should cache ports for this KernelManager instance" + ) - @default('cache_ports') + @default("cache_ports") def _default_cache_ports(self) -> bool: - return self.transport == 'tcp' + return self.transport == "tcp" @property def ipykernel(self) -> bool: - return self.kernel_name in {'python', 'python2', 'python3'} + return self.kernel_name in {"python", "python2", "python3"} # Protected traits _launch_args: Any = Any() @@ -166,8 +168,8 @@ def ipykernel(self) -> bool: _restarter: Any = Any() - autorestart: Bool = Bool(True, config=True, - help="""Should we autorestart the kernel if it dies.""" + autorestart: Bool = Bool( + True, config=True, help="""Should we autorestart the kernel if it dies.""" ) shutting_down: bool = False @@ -176,9 +178,9 @@ def __del__(self) -> None: self._close_control_socket() self.cleanup_connection_file() - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # Kernel restarter - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- def start_restarter(self) -> None: pass @@ -186,51 +188,42 @@ def start_restarter(self) -> None: def stop_restarter(self) -> None: pass - def add_restart_callback( - self, - callback: t.Callable, - event: str = 'restart' - ) -> None: + def add_restart_callback(self, callback: t.Callable, event: str = "restart") -> None: """register a callback to be called when a kernel is restarted""" if self._restarter is None: return self._restarter.add_callback(callback, event) - def remove_restart_callback( - self, - callback: t.Callable, - event: str ='restart' - ) -> None: + def remove_restart_callback(self, callback: t.Callable, event: str = "restart") -> None: """unregister a callback to be called when a kernel is restarted""" if self._restarter is None: return self._restarter.remove_callback(callback, event) - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # create a Client connected to our Kernel - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- def client(self, **kwargs) -> KernelClient: """Create a client configured to connect to our kernel""" kw = {} kw.update(self.get_connection_info(session=True)) - kw.update(dict( - connection_file=self.connection_file, - parent=self, - )) + kw.update( + dict( + connection_file=self.connection_file, + parent=self, + ) + ) # add kwargs last, for manual overrides kw.update(kwargs) return self.client_factory(**kw) - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # Kernel management - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- - def format_kernel_cmd( - self, - extra_arguments: t.Optional[t.List[str]] = None - ) -> t.List[str]: + def format_kernel_cmd(self, extra_arguments: t.Optional[t.List[str]] = None) -> t.List[str]: """replace templated args (e.g. {connection_file})""" extra_arguments = extra_arguments or [] if self.kernel_cmd: @@ -239,9 +232,11 @@ def format_kernel_cmd( assert self.kernel_spec is not None cmd = self.kernel_spec.argv + extra_arguments - if cmd and cmd[0] in {'python', - 'python%i' % sys.version_info[0], - 'python%i.%i' % sys.version_info[:2]}: + if cmd and cmd[0] in { + "python", + "python%i" % sys.version_info[0], + "python%i.%i" % sys.version_info[:2], + }: # executable is 'python' or 'python3', use sys.executable. # These will typically be the same, # but if the current process is in an env @@ -255,27 +250,25 @@ def format_kernel_cmd( # is not usable by non python kernels because the path is being rerouted when # inside of a store app. # See this bug here: https://bugs.python.org/issue41196 - ns = dict(connection_file=os.path.realpath(self.connection_file), - prefix=sys.prefix, - ) + ns = dict( + connection_file=os.path.realpath(self.connection_file), + prefix=sys.prefix, + ) if self.kernel_spec: ns["resource_dir"] = self.kernel_spec.resource_dir ns.update(self._launch_args) - pat = re.compile(r'\{([A-Za-z0-9_]+)\}') + pat = re.compile(r"\{([A-Za-z0-9_]+)\}") + def from_ns(match): """Get the key out of ns if it's there, otherwise no change.""" return ns.get(match.group(1), match.group()) return [pat.sub(from_ns, arg) for arg in cmd] - async def _async_launch_kernel( - self, - kernel_cmd: t.List[str], - **kw - ) -> Popen: + async def _async_launch_kernel(self, kernel_cmd: t.List[str], **kw) -> Popen: """actually launch the kernel override in a subclass to launch kernel subprocesses differently @@ -288,7 +281,7 @@ async def _async_launch_kernel( def _connect_control_socket(self) -> None: if self._control_socket is None: - self._control_socket = self._create_connected_socket('control') + self._control_socket = self._create_connected_socket("control") self._control_socket.linger = 100 def _close_control_socket(self) -> None: @@ -310,13 +303,14 @@ def pre_start_kernel(self, **kw) -> t.Tuple[t.List[str], t.Dict[str, t.Any]]: and launching the kernel (e.g. Popen kwargs). """ self.shutting_down = False - if self.transport == 'tcp' and not is_local_ip(self.ip): - raise RuntimeError("Can only launch a kernel on a local interface. " - "This one is not: %s." - "Make sure that the '*_address' attributes are " - "configured properly. " - "Currently valid addresses are: %s" % (self.ip, local_ips()) - ) + if self.transport == "tcp" and not is_local_ip(self.ip): + raise RuntimeError( + "Can only launch a kernel on a local interface. " + "This one is not: %s." + "Make sure that the '*_address' attributes are " + "configured properly. " + "Currently valid addresses are: %s" % (self.ip, local_ips()) + ) # write connection file / get default ports self.write_connection_file() @@ -324,29 +318,29 @@ def pre_start_kernel(self, **kw) -> t.Tuple[t.List[str], t.Dict[str, t.Any]]: # save kwargs for use in restart self._launch_args = kw.copy() # build the Popen cmd - extra_arguments = kw.pop('extra_arguments', []) + extra_arguments = kw.pop("extra_arguments", []) kernel_cmd = self.format_kernel_cmd(extra_arguments=extra_arguments) - env = kw.pop('env', os.environ).copy() + env = kw.pop("env", os.environ).copy() # Don't allow PYTHONEXECUTABLE to be passed to kernel process. # If set, it can bork all the things. - env.pop('PYTHONEXECUTABLE', None) + env.pop("PYTHONEXECUTABLE", None) if not self.kernel_cmd: # If kernel_cmd has been set manually, don't refer to a kernel spec. # Environment variables from kernel spec are added to os.environ. assert self.kernel_spec is not None env.update(self._get_env_substitutions(self.kernel_spec.env, env)) - kw['env'] = env + kw["env"] = env return kernel_cmd, kw def _get_env_substitutions( self, templated_env: t.Optional[t.Dict[str, str]], - substitution_values: t.Dict[str, str] + substitution_values: t.Dict[str, str], ) -> t.Optional[t.Dict[str, str]]: - """ Walks env entries in templated_env and applies possible substitutions from current env - (represented by substitution_values). - Returns the substituted list of env entries. + """Walks env entries in templated_env and applies possible substitutions from current env + (represented by substitution_values). + Returns the substituted list of env entries. """ substituted_env = {} if templated_env: @@ -384,12 +378,8 @@ async def _async_start_kernel(self, **kw): start_kernel = run_sync(_async_start_kernel) - def request_shutdown( - self, - restart: bool = False - ) -> None: - """Send a shutdown request via control channel - """ + def request_shutdown(self, restart: bool = False) -> None: + """Send a shutdown request via control channel""" content = dict(restart=restart) msg = self.session.msg("shutdown_request", content=content) # ensure control socket is connected @@ -397,9 +387,7 @@ def request_shutdown( self.session.send(self._control_socket, msg) async def _async_finish_shutdown( - self, - waittime: t.Optional[float] = None, - pollinterval: float = 0.1 + self, waittime: t.Optional[float] = None, pollinterval: float = 0.1 ) -> None: """Wait for kernel shutdown, then kill process if it doesn't shutdown. @@ -435,10 +423,7 @@ async def _async_finish_shutdown( finish_shutdown = run_sync(_async_finish_shutdown) - def cleanup_resources( - self, - restart: bool = False - ) -> None: + def cleanup_resources(self, restart: bool = False) -> None: """Clean up resources when the kernel is shut down""" if not restart: self.cleanup_connection_file() @@ -450,20 +435,16 @@ def cleanup_resources( if self._created_context and not restart: self.context.destroy(linger=100) - def cleanup( - self, - connection_file: bool = True - ) -> None: + def cleanup(self, connection_file: bool = True) -> None: """Clean up resources when the kernel is shut down""" - warnings.warn("Method cleanup(connection_file=True) is deprecated, use cleanup_resources(restart=False).", - FutureWarning) + warnings.warn( + "Method cleanup(connection_file=True) is deprecated, use cleanup_resources" + "(restart=False).", + FutureWarning, + ) self.cleanup_resources(restart=not connection_file) - async def _async_shutdown_kernel( - self, - now: bool = False, - restart: bool = False - ): + async def _async_shutdown_kernel(self, now: bool = False, restart: bool = False): """Attempts to stop the kernel process cleanly. This attempts to shutdown the kernels cleanly by: @@ -510,7 +491,9 @@ async def _async_shutdown_kernel( # path if cleanup() is overridden but cleanup_resources() is not. overrides_cleanup = type(self).cleanup is not KernelManager.cleanup - overrides_cleanup_resources = type(self).cleanup_resources is not KernelManager.cleanup_resources + overrides_cleanup_resources = ( + type(self).cleanup_resources is not KernelManager.cleanup_resources + ) if overrides_cleanup and not overrides_cleanup_resources: self.cleanup(connection_file=not restart) @@ -519,12 +502,7 @@ async def _async_shutdown_kernel( shutdown_kernel = run_sync(_async_shutdown_kernel) - async def _async_restart_kernel( - self, - now: bool = False, - newports: bool = False, - **kw - ) -> None: + async def _async_restart_kernel(self, now: bool = False, newports: bool = False, **kw) -> None: """Restarts a kernel with the arguments that were used to launch it. Parameters @@ -550,8 +528,7 @@ async def _async_restart_kernel( kernel. """ if self._launch_args is None: - raise RuntimeError("Cannot restart the kernel. " - "No previous call to 'start_kernel'.") + raise RuntimeError("Cannot restart the kernel. " "No previous call to 'start_kernel'.") else: # Stop currently running kernel. await ensure_async(self.shutdown_kernel(now=now, restart=True)) @@ -611,20 +588,21 @@ async def _async_kill_kernel(self) -> None: # Signal the kernel to terminate (sends SIGKILL on Unix and calls # TerminateProcess() on Win32). try: - if hasattr(signal, 'SIGKILL'): + if hasattr(signal, "SIGKILL"): await self._async_signal_kernel(signal.SIGKILL) # type: ignore else: self.kernel.kill() except OSError as e: # In Windows, we will get an Access Denied error if the process # has already terminated. Ignore it. - if sys.platform == 'win32': + if sys.platform == "win32": if e.winerror != 5: # type: ignore raise # On Unix, we may get an ESRCH error if the process has already # terminated. Ignore it. else: from errno import ESRCH + if e.errno != ESRCH: raise @@ -653,14 +631,15 @@ async def _async_interrupt_kernel(self) -> None: if self.has_kernel: assert self.kernel_spec is not None interrupt_mode = self.kernel_spec.interrupt_mode - if interrupt_mode == 'signal': - if sys.platform == 'win32': + if interrupt_mode == "signal": + if sys.platform == "win32": from .win_interrupt import send_interrupt + send_interrupt(self.kernel.win32_interrupt_event) else: await self._async_signal_kernel(signal.SIGINT) - elif interrupt_mode == 'message': + elif interrupt_mode == "message": msg = self.session.msg("interrupt_request", content={}) self._connect_control_socket() self.session.send(self._control_socket, msg) @@ -669,10 +648,7 @@ async def _async_interrupt_kernel(self) -> None: interrupt_kernel = run_sync(_async_interrupt_kernel) - async def _async_signal_kernel( - self, - signum: int - ) -> None: + async def _async_signal_kernel(self, signum: int) -> None: """Sends a signal to the process group of the kernel (this usually includes the kernel and any subprocesses spawned by the kernel). @@ -707,10 +683,7 @@ async def _async_is_alive(self) -> bool: is_alive = run_sync(_async_is_alive) - async def _async_wait( - self, - pollinterval: float = 0.1 - ) -> None: + async def _async_wait(self, pollinterval: float = 0.1) -> None: # Use busy loop at 100ms intervals, polling until the process is # not alive. If we find the process is no longer alive, complete # its cleanup via the blocking wait(). Callers are responsible for @@ -721,8 +694,10 @@ async def _async_wait( class AsyncKernelManager(KernelManager): # the class to create with our `client` method - client_class: DottedObjectName = DottedObjectName('jupyter_client.asynchronous.AsyncKernelClient') - client_factory: Type = Type(klass='jupyter_client.asynchronous.AsyncKernelClient') + client_class: DottedObjectName = DottedObjectName( + "jupyter_client.asynchronous.AsyncKernelClient" + ) + client_factory: Type = Type(klass="jupyter_client.asynchronous.AsyncKernelClient") _launch_kernel = KernelManager._async_launch_kernel start_kernel = KernelManager._async_start_kernel @@ -740,9 +715,7 @@ class AsyncKernelManager(KernelManager): def start_new_kernel( - startup_timeout: float =60, - kernel_name: str = 'python', - **kwargs + startup_timeout: float = 60, kernel_name: str = "python", **kwargs ) -> t.Tuple[KernelManager, KernelClient]: """Start a new kernel, and return its Manager and Client""" km = KernelManager(kernel_name=kernel_name) @@ -760,9 +733,7 @@ def start_new_kernel( async def start_new_async_kernel( - startup_timeout: float = 60, - kernel_name: str = 'python', - **kwargs + startup_timeout: float = 60, kernel_name: str = "python", **kwargs ) -> t.Tuple[AsyncKernelManager, KernelClient]: """Start a new kernel, and return its Manager and Client""" km = AsyncKernelManager(kernel_name=kernel_name) diff --git a/jupyter_client/managerabc.py b/jupyter_client/managerabc.py index 4b3f94baf..138485e2c 100644 --- a/jupyter_client/managerabc.py +++ b/jupyter_client/managerabc.py @@ -1,8 +1,6 @@ """Abstract base class for kernel managers.""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - import abc @@ -18,9 +16,9 @@ class KernelManagerABC(object, metaclass=abc.ABCMeta): def kernel(self): pass - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- # Kernel management - #-------------------------------------------------------------------------- + # -------------------------------------------------------------------------- @abc.abstractmethod def start_kernel(self, **kw): diff --git a/jupyter_client/multikernelmanager.py b/jupyter_client/multikernelmanager.py index 8d8cb5d9c..ffc0f174c 100644 --- a/jupyter_client/multikernelmanager.py +++ b/jupyter_client/multikernelmanager.py @@ -1,41 +1,39 @@ """A kernel manager for multiple kernels""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - import asyncio import os -import uuid import socket import typing as t +import uuid import zmq - +from traitlets import Any # type: ignore +from traitlets import Bool +from traitlets import default +from traitlets import Dict +from traitlets import DottedObjectName +from traitlets import Instance +from traitlets import observe +from traitlets import Unicode from traitlets.config.configurable import LoggingConfigurable # type: ignore from traitlets.utils.importstring import import_item # type: ignore -from traitlets import ( # type: ignore - Any, Bool, Dict, DottedObjectName, Instance, Unicode, default, observe -) -from .kernelspec import NATIVE_KERNEL_NAME, KernelSpecManager +from .kernelspec import KernelSpecManager +from .kernelspec import NATIVE_KERNEL_NAME from .manager import KernelManager -from .utils import run_sync, ensure_async +from .utils import ensure_async +from .utils import run_sync class DuplicateKernelError(Exception): pass -def kernel_method( - f: t.Callable -) -> t.Callable: +def kernel_method(f: t.Callable) -> t.Callable: """decorator for proxying MKM.method(kernel_id) to individual KMs by ID""" - def wrapped( - self, - kernel_id: str, - *args, - **kwargs - ) -> t.Union[t.Callable, t.Awaitable]: + + def wrapped(self, kernel_id: str, *args, **kwargs) -> t.Union[t.Callable, t.Awaitable]: # get the kernel km = self.get_kernel(kernel_id) method = getattr(km, f.__name__) @@ -46,23 +44,25 @@ def wrapped( f(self, kernel_id, *args, **kwargs) # return the method result return r + return wrapped class MultiKernelManager(LoggingConfigurable): """A class for managing multiple kernels.""" - default_kernel_name = Unicode(NATIVE_KERNEL_NAME, config=True, - help="The name of the default kernel to start" + default_kernel_name = Unicode( + NATIVE_KERNEL_NAME, config=True, help="The name of the default kernel to start" ) kernel_spec_manager = Instance(KernelSpecManager, allow_none=True) kernel_manager_class = DottedObjectName( - "jupyter_client.ioloop.IOLoopKernelManager", config=True, + "jupyter_client.ioloop.IOLoopKernelManager", + config=True, help="""The kernel manager class. This is configurable to allow subclassing of the KernelManager for customized behavior. - """ + """, ) def __init__(self, *args, **kwargs): @@ -71,13 +71,13 @@ def __init__(self, *args, **kwargs): # Cache all the currently used ports self.currently_used_ports = set() - @observe('kernel_manager_class') + @observe("kernel_manager_class") def _kernel_manager_class_changed(self, change): self.kernel_manager_factory = self._create_kernel_manager_factory() kernel_manager_factory = Any(help="this is kernel_manager_class after import") - @default('kernel_manager_factory') + @default("kernel_manager_factory") def _kernel_manager_factory_default(self): return self._create_kernel_manager_factory() @@ -103,13 +103,10 @@ def create_kernel_manager(*args, **kwargs) -> KernelManager: return create_kernel_manager - def _find_available_port( - self, - ip: str - ) -> int: + def _find_available_port(self, ip: str) -> int: while True: tmp_sock = socket.socket() - tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b'\0' * 8) + tmp_sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, b"\0" * 8) tmp_sock.bind((ip, 0)) port = tmp_sock.getsockname()[1] tmp_sock.close() @@ -129,7 +126,7 @@ def _find_available_port( _created_context = Bool(False) - context = Instance('zmq.Context') + context = Instance("zmq.Context") _starting_kernels = Dict() @@ -150,7 +147,7 @@ def __del__(self): else: super_del() - connection_dir = Unicode('') + connection_dir = Unicode("") _kernels = Dict() @@ -168,14 +165,12 @@ def __contains__(self, kernel_id) -> bool: return kernel_id in self._kernels def pre_start_kernel( - self, - kernel_name: t.Optional[str], - kwargs + self, kernel_name: t.Optional[str], kwargs ) -> t.Tuple[KernelManager, str, str]: # kwargs should be mutable, passing it as a dict argument. - kernel_id = kwargs.pop('kernel_id', self.new_kernel_id(**kwargs)) + kernel_id = kwargs.pop("kernel_id", self.new_kernel_id(**kwargs)) if kernel_id in self: - raise DuplicateKernelError('Kernel already exists: %s' % kernel_id) + raise DuplicateKernelError("Kernel already exists: %s" % kernel_id) if kernel_name is None: kernel_name = self.default_kernel_name @@ -184,28 +179,23 @@ def pre_start_kernel( # including things like its transport and ip. constructor_kwargs = {} if self.kernel_spec_manager: - constructor_kwargs['kernel_spec_manager'] = self.kernel_spec_manager - km = self.kernel_manager_factory(connection_file=os.path.join( - self.connection_dir, "kernel-%s.json" % kernel_id), - parent=self, log=self.log, kernel_name=kernel_name, - **constructor_kwargs + constructor_kwargs["kernel_spec_manager"] = self.kernel_spec_manager + km = self.kernel_manager_factory( + connection_file=os.path.join(self.connection_dir, "kernel-%s.json" % kernel_id), + parent=self, + log=self.log, + kernel_name=kernel_name, + **constructor_kwargs, ) return km, kernel_name, kernel_id async def _add_kernel_when_ready( - self, - kernel_id: str, - km: KernelManager, - kernel_awaitable: t.Awaitable + self, kernel_id: str, km: KernelManager, kernel_awaitable: t.Awaitable ) -> None: await kernel_awaitable self._kernels[kernel_id] = km - async def _async_start_kernel( - self, - kernel_name: t.Optional[str] = None, - **kwargs - ) -> str: + async def _async_start_kernel(self, kernel_name: t.Optional[str] = None, **kwargs) -> str: """Start a new kernel. The caller can pick a kernel_id by passing one in as a keyword arg, @@ -215,14 +205,13 @@ async def _async_start_kernel( """ km, kernel_name, kernel_id = self.pre_start_kernel(kernel_name, kwargs) if not isinstance(km, KernelManager): - self.log.warning("Kernel manager class ({km_class}) is not an instance of 'KernelManager'!". - format(km_class=self.kernel_manager_class.__class__)) - fut = asyncio.ensure_future( - self._add_kernel_when_ready( - kernel_id, - km, - ensure_async(km.start_kernel(**kwargs)) + self.log.warning( + "Kernel manager class ({km_class}) is not an instance of 'KernelManager'!".format( + km_class=self.kernel_manager_class.__class__ + ) ) + fut = asyncio.ensure_future( + self._add_kernel_when_ready(kernel_id, km, ensure_async(km.start_kernel(**kwargs))) ) self._starting_kernels[kernel_id] = fut await fut @@ -235,7 +224,7 @@ async def _async_shutdown_kernel( self, kernel_id: str, now: t.Optional[bool] = False, - restart: t.Optional[bool] = False + restart: t.Optional[bool] = False, ) -> None: """Shutdown a kernel by its kernel uuid. @@ -253,8 +242,11 @@ async def _async_shutdown_kernel( km = self.get_kernel(kernel_id) ports = ( - km.shell_port, km.iopub_port, km.stdin_port, - km.hb_port, km.control_port + km.shell_port, + km.iopub_port, + km.stdin_port, + km.hb_port, + km.control_port, ) await ensure_async(km.shutdown_kernel(now, restart)) @@ -267,11 +259,7 @@ async def _async_shutdown_kernel( shutdown_kernel = run_sync(_async_shutdown_kernel) @kernel_method - def request_shutdown( - self, - kernel_id: str, - restart: t.Optional[bool] = False - ) -> None: + def request_shutdown(self, kernel_id: str, restart: t.Optional[bool] = False) -> None: """Ask a kernel to shut down by its kernel uuid""" @kernel_method @@ -279,32 +267,20 @@ def finish_shutdown( self, kernel_id: str, waittime: t.Optional[float] = None, - pollinterval: t.Optional[float] = 0.1 + pollinterval: t.Optional[float] = 0.1, ) -> None: - """Wait for a kernel to finish shutting down, and kill it if it doesn't - """ + """Wait for a kernel to finish shutting down, and kill it if it doesn't""" self.log.info("Kernel shutdown: %s" % kernel_id) @kernel_method - def cleanup( - self, - kernel_id: str, - connection_file: bool = True - ) -> None: + def cleanup(self, kernel_id: str, connection_file: bool = True) -> None: """Clean up a kernel's resources""" @kernel_method - def cleanup_resources( - self, - kernel_id: str, - restart: bool = False - ) -> None: + def cleanup_resources(self, kernel_id: str, restart: bool = False) -> None: """Clean up a kernel's resources""" - def remove_kernel( - self, - kernel_id: str - ) -> KernelManager: + def remove_kernel(self, kernel_id: str) -> KernelManager: """remove a kernel from our mapping. Mainly so that a kernel can be removed if it is already dead, @@ -314,35 +290,24 @@ def remove_kernel( """ return self._kernels.pop(kernel_id) - async def _shutdown_starting_kernel( - self, - kid: str, - now: bool - ) -> None: + async def _shutdown_starting_kernel(self, kid: str, now: bool) -> None: if kid in self._starting_kernels: await self._starting_kernels[kid] await ensure_async(self.shutdown_kernel(kid, now=now)) - async def _async_shutdown_all( - self, - now: bool = False - ) -> None: + async def _async_shutdown_all(self, now: bool = False) -> None: """Shutdown all kernels.""" kids = self.list_kernel_ids() futs = [ensure_async(self.shutdown_kernel(kid, now=now)) for kid in kids] futs += [ - self._shutdown_starting_kernel(kid, now=now) - for kid in self._starting_kernels.keys() + self._shutdown_starting_kernel(kid, now=now) for kid in self._starting_kernels.keys() ] await asyncio.gather(*futs) shutdown_all = run_sync(_async_shutdown_all) @kernel_method - def interrupt_kernel( - self, - kernel_id: str - ) -> None: + def interrupt_kernel(self, kernel_id: str) -> None: """Interrupt (SIGINT) the kernel by its uuid. Parameters @@ -353,11 +318,7 @@ def interrupt_kernel( self.log.info("Kernel interrupted: %s" % kernel_id) @kernel_method - def signal_kernel( - self, - kernel_id: str, - signum: int - ) -> None: + def signal_kernel(self, kernel_id: str, signum: int) -> None: """Sends a signal to the kernel by its uuid. Note that since only SIGTERM is supported on Windows, this function @@ -371,11 +332,7 @@ def signal_kernel( self.log.info("Signaled Kernel %s with %s" % (kernel_id, signum)) @kernel_method - def restart_kernel( - self, - kernel_id: str, - now: bool = False - ) -> None: + def restart_kernel(self, kernel_id: str, now: bool = False) -> None: """Restart a kernel by its uuid, keeping the same ports. Parameters @@ -386,10 +343,7 @@ def restart_kernel( self.log.info("Kernel restarted: %s" % kernel_id) @kernel_method - def is_alive( - self, - kernel_id: str - ) -> bool: + def is_alive(self, kernel_id: str) -> bool: """Is the kernel alive. This calls KernelManager.is_alive() which calls Popen.poll on the @@ -401,18 +355,12 @@ def is_alive( The id of the kernel. """ - def _check_kernel_id( - self, - kernel_id: str - ) -> None: + def _check_kernel_id(self, kernel_id: str) -> None: """check that a kernel id is valid""" if kernel_id not in self: raise KeyError("Kernel with id not found: %s" % kernel_id) - def get_kernel( - self, - kernel_id: str - ) -> KernelManager: + def get_kernel(self, kernel_id: str) -> KernelManager: """Get the single KernelManager object for a kernel by its uuid. Parameters @@ -425,27 +373,18 @@ def get_kernel( @kernel_method def add_restart_callback( - self, - kernel_id: str, - callback: t.Callable, - event: str = 'restart' + self, kernel_id: str, callback: t.Callable, event: str = "restart" ) -> None: """add a callback for the KernelRestarter""" @kernel_method def remove_restart_callback( - self, - kernel_id: str, - callback: t.Callable, - event: str = 'restart' + self, kernel_id: str, callback: t.Callable, event: str = "restart" ) -> None: """remove a callback for the KernelRestarter""" @kernel_method - def get_connection_info( - self, - kernel_id: str - ) -> t.Dict[str, t.Any]: + def get_connection_info(self, kernel_id: str) -> t.Dict[str, t.Any]: """Return a dictionary of connection data for a kernel. Parameters @@ -463,11 +402,7 @@ def get_connection_info( """ @kernel_method - def connect_iopub( - self, - kernel_id: str, - identity: t.Optional[bytes] = None - ) -> socket.socket: + def connect_iopub(self, kernel_id: str, identity: t.Optional[bytes] = None) -> socket.socket: """Return a zmq Socket connected to the iopub channel. Parameters @@ -483,11 +418,7 @@ def connect_iopub( """ @kernel_method - def connect_shell( - self, - kernel_id: str, - identity: t.Optional[bytes] = None - ) -> socket.socket: + def connect_shell(self, kernel_id: str, identity: t.Optional[bytes] = None) -> socket.socket: """Return a zmq Socket connected to the shell channel. Parameters @@ -503,11 +434,7 @@ def connect_shell( """ @kernel_method - def connect_control( - self, - kernel_id: str, - identity: t.Optional[bytes] = None - ) -> socket.socket: + def connect_control(self, kernel_id: str, identity: t.Optional[bytes] = None) -> socket.socket: """Return a zmq Socket connected to the control channel. Parameters @@ -523,11 +450,7 @@ def connect_control( """ @kernel_method - def connect_stdin( - self, - kernel_id: str, - identity: t.Optional[bytes] = None - ) -> socket.socket: + def connect_stdin(self, kernel_id: str, identity: t.Optional[bytes] = None) -> socket.socket: """Return a zmq Socket connected to the stdin channel. Parameters @@ -543,11 +466,7 @@ def connect_stdin( """ @kernel_method - def connect_hb( - self, - kernel_id: str, - identity: t.Optional[bytes] = None - ) -> socket.socket: + def connect_hb(self, kernel_id: str, identity: t.Optional[bytes] = None) -> socket.socket: """Return a zmq Socket connected to the hb channel. Parameters @@ -575,10 +494,11 @@ def new_kernel_id(self, **kwargs) -> str: class AsyncMultiKernelManager(MultiKernelManager): kernel_manager_class = DottedObjectName( - "jupyter_client.ioloop.AsyncIOLoopKernelManager", config=True, + "jupyter_client.ioloop.AsyncIOLoopKernelManager", + config=True, help="""The kernel manager class. This is configurable to allow subclassing of the AsyncKernelManager for customized behavior. - """ + """, ) start_kernel = MultiKernelManager._async_start_kernel diff --git a/jupyter_client/restarter.py b/jupyter_client/restarter.py index 8070680eb..a87253a62 100644 --- a/jupyter_client/restarter.py +++ b/jupyter_client/restarter.py @@ -5,44 +5,49 @@ It is an incomplete base class, and must be subclassed. """ - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - -from traitlets.config.configurable import LoggingConfigurable -from traitlets import ( - Instance, Float, Dict, Bool, Integer, -) +from traitlets import Bool # type: ignore +from traitlets import Dict +from traitlets import Float +from traitlets import Instance +from traitlets import Integer +from traitlets.config.configurable import LoggingConfigurable # type: ignore class KernelRestarter(LoggingConfigurable): """Monitor and autorestart a kernel.""" - kernel_manager = Instance('jupyter_client.KernelManager') + kernel_manager = Instance("jupyter_client.KernelManager") - debug = Bool(False, config=True, + debug = Bool( + False, + config=True, help="""Whether to include every poll event in debugging output. Has to be set explicitly, because there will be *a lot* of output. - """ + """, ) - time_to_dead = Float(3.0, config=True, - help="""Kernel heartbeat interval in seconds.""" - ) + time_to_dead = Float(3.0, config=True, help="""Kernel heartbeat interval in seconds.""") - restart_limit = Integer(5, config=True, - help="""The number of consecutive autorestarts before the kernel is presumed dead.""" + restart_limit = Integer( + 5, + config=True, + help="""The number of consecutive autorestarts before the kernel is presumed dead.""", ) - random_ports_until_alive = Bool(True, config=True, - help="""Whether to choose new random ports when restarting before the kernel is alive.""" + random_ports_until_alive = Bool( + True, + config=True, + help="""Whether to choose new random ports when restarting before the kernel is alive.""", ) _restarting = Bool(False) _restart_count = Integer(0) _initial_startup = Bool(True) callbacks = Dict() + def _callbacks_default(self): return dict(restart=[], dead=[]) @@ -54,7 +59,7 @@ def stop(self): """Stop the kernel polling.""" raise NotImplementedError("Must be implemented in a subclass") - def add_callback(self, f, event='restart'): + def add_callback(self, f, event="restart"): """register a callback to fire on a particular event Possible values for event: @@ -65,7 +70,7 @@ def add_callback(self, f, event='restart'): """ self.callbacks[event].append(f) - def remove_callback(self, f, event='restart'): + def remove_callback(self, f, event="restart"): """unregister a callback to fire on a particular event Possible values for event: @@ -84,14 +89,19 @@ def _fire_callbacks(self, event): for callback in self.callbacks[event]: try: callback() - except Exception as e: - self.log.error("KernelRestarter: %s callback %r failed", event, callback, exc_info=True) + except Exception: + self.log.error( + "KernelRestarter: %s callback %r failed", + event, + callback, + exc_info=True, + ) def poll(self): if self.debug: - self.log.debug('Polling kernel...') + self.log.debug("Polling kernel...") if self.kernel_manager.shutting_down: - self.log.debug('Kernel shutdown in progress...') + self.log.debug("Kernel shutdown in progress...") return if not self.kernel_manager.is_alive(): if self._restarting: @@ -101,18 +111,19 @@ def poll(self): if self._restart_count >= self.restart_limit: self.log.warning("KernelRestarter: restart failed") - self._fire_callbacks('dead') + self._fire_callbacks("dead") self._restarting = False self._restart_count = 0 self.stop() else: newports = self.random_ports_until_alive and self._initial_startup - self.log.info('KernelRestarter: restarting kernel (%i/%i), %s random ports', + self.log.info( + "KernelRestarter: restarting kernel (%i/%i), %s random ports", self._restart_count, self.restart_limit, - 'new' if newports else 'keep' + "new" if newports else "keep", ) - self._fire_callbacks('restart') + self._fire_callbacks("restart") self.kernel_manager.restart_kernel(now=True, newports=newports) self._restarting = True else: diff --git a/jupyter_client/runapp.py b/jupyter_client/runapp.py index 96e3cdce6..9b2746aa2 100644 --- a/jupyter_client/runapp.py +++ b/jupyter_client/runapp.py @@ -1,23 +1,22 @@ - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - -import logging -import signal import queue -import time +import signal import sys +import time +from jupyter_core.application import base_aliases +from jupyter_core.application import base_flags +from jupyter_core.application import JupyterApp +from traitlets import Any +from traitlets import Dict +from traitlets import Float from traitlets.config import catch_config_error -from traitlets import ( - Instance, Dict, Unicode, Bool, List, CUnicode, Any, Float -) -from jupyter_core.application import ( - JupyterApp, base_flags, base_aliases -) from . import __version__ -from .consoleapp import JupyterConsoleApp, app_aliases, app_flags +from .consoleapp import app_aliases +from .consoleapp import app_flags +from .consoleapp import JupyterConsoleApp OUTPUT_TIMEOUT = 10 @@ -40,6 +39,7 @@ frontend_aliases = set(frontend_aliases.keys()) frontend_flags = set(frontend_flags.keys()) + class RunApp(JupyterApp, JupyterConsoleApp): version = __version__ name = "jupyter run" @@ -48,14 +48,16 @@ class RunApp(JupyterApp, JupyterConsoleApp): aliases = Dict(aliases) frontend_aliases = Any(frontend_aliases) frontend_flags = Any(frontend_flags) - kernel_timeout = Float(60, config=True, + kernel_timeout = Float( + 60, + config=True, help="""Timeout for giving up on a kernel (in seconds). On first connect and restart, the console tests whether the kernel is running and responsive by sending kernel_info_requests. This sets the timeout in seconds for how long the kernel can take before being presumed dead. - """ + """, ) def parse_command_line(self, argv=None): @@ -90,8 +92,8 @@ def init_kernel_info(self): if (time.time() - tic) > timeout: raise RuntimeError("Kernel didn't respond to kernel_info_request") from e else: - if reply['parent_header'].get('msg_id') == msg_id: - self.kernel_info = reply['content'] + if reply["parent_header"].get("msg_id") == msg_id: + self.kernel_info = reply["content"] return def start(self): @@ -103,17 +105,18 @@ def start(self): with open(filename) as fp: code = fp.read() reply = self.kernel_client.execute_interactive(code, timeout=OUTPUT_TIMEOUT) - return_code = 0 if reply['content']['status'] == 'ok' else 1 + return_code = 0 if reply["content"]["status"] == "ok" else 1 if return_code: raise Exception("jupyter-run error running '%s'" % filename) else: code = sys.stdin.read() reply = self.kernel_client.execute_interactive(code, timeout=OUTPUT_TIMEOUT) - return_code = 0 if reply['content']['status'] == 'ok' else 1 + return_code = 0 if reply["content"]["status"] == "ok" else 1 if return_code: raise Exception("jupyter-run error running 'stdin'") + main = launch_new_instance = RunApp.launch_instance -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/jupyter_client/session.py b/jupyter_client/session.py index fde87d165..8066c5ce2 100644 --- a/jupyter_client/session.py +++ b/jupyter_client/session.py @@ -8,11 +8,8 @@ Sessions. * A Message object for convenience that allows attribute-access to the msg dict. """ - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - -from binascii import b2a_hex import hashlib import hmac import logging @@ -20,58 +17,70 @@ import pickle import pprint import random -import warnings import typing as t - +import warnings +from binascii import b2a_hex from datetime import datetime from datetime import timezone - -PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL - -# We are using compare_digest to limit the surface of timing attacks -from hmac import compare_digest - -utc = timezone.utc +from hmac import ( + compare_digest, +) # We are using compare_digest to limit the surface of timing attacks import zmq -from zmq.utils import jsonapi +from traitlets import Any # type: ignore +from traitlets import Bool +from traitlets import CBytes +from traitlets import CUnicode +from traitlets import Dict # type: ignore +from traitlets import DottedObjectName +from traitlets import Instance +from traitlets import Integer +from traitlets import observe +from traitlets import Set +from traitlets import TraitError +from traitlets import Unicode +from traitlets.config.configurable import Configurable # type: ignore +from traitlets.config.configurable import LoggingConfigurable +from traitlets.log import get_logger # type: ignore +from traitlets.utils.importstring import import_item # type: ignore from zmq.eventloop.ioloop import IOLoop from zmq.eventloop.zmqstream import ZMQStream +from zmq.utils import jsonapi - -from jupyter_client.jsonutil import extract_dates, squash_dates, date_default from jupyter_client import protocol_version from jupyter_client.adapter import adapt +from jupyter_client.jsonutil import date_default +from jupyter_client.jsonutil import extract_dates +from jupyter_client.jsonutil import squash_dates -from traitlets import ( # type: ignore - CBytes, Unicode, Bool, Any, Instance, Set, DottedObjectName, CUnicode, - Dict, Integer, TraitError, observe -) -from traitlets.log import get_logger # type: ignore -from traitlets.utils.importstring import import_item # type: ignore -from traitlets.config.configurable import Configurable, LoggingConfigurable # type: ignore -#----------------------------------------------------------------------------- +PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL + +utc = timezone.utc + +# ----------------------------------------------------------------------------- # utility functions -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- + def squash_unicode(obj): """coerce unicode back to bytestrings.""" - if isinstance(obj,dict): + if isinstance(obj, dict): for key in obj.keys(): obj[key] = squash_unicode(obj[key]) if isinstance(key, str): obj[squash_unicode(key)] = obj.pop(key) elif isinstance(obj, list): - for i,v in enumerate(obj): + for i, v in enumerate(obj): obj[i] = squash_unicode(v) elif isinstance(obj, str): - obj = obj.encode('utf8') + obj = obj.encode("utf8") return obj -#----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- # globals and defaults -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- # default values for the thresholds: MAX_ITEMS = 64 @@ -80,12 +89,25 @@ def squash_unicode(obj): # ISO8601-ify datetime objects # allow unicode # disallow nan, because it's not actually valid JSON -json_packer = lambda obj: jsonapi.dumps(obj, default=date_default, - ensure_ascii=False, allow_nan=False, -) -json_unpacker = lambda s: jsonapi.loads(s) -pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL) + +def json_packer(obj): + return jsonapi.dumps( + obj, + default=date_default, + ensure_ascii=False, + allow_nan=False, + ) + + +def json_unpacker(s): + return jsonapi.loads(s) + + +def pickle_packer(o): + return pickle.dumps(squash_dates(o), PICKLE_PROTOCOL) + + pickle_unpacker = pickle.loads default_packer = json_packer @@ -95,9 +117,10 @@ def squash_unicode(obj): # singleton dummy tracker, which will always report as done DONE = zmq.MessageTracker() -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- # Mixin tools for apps that use Sessions -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- + def new_id() -> str: """Generate a new random id. @@ -110,30 +133,34 @@ def new_id() -> str: id string (16 random bytes as hex-encoded text, chunks separated by '-') """ buf = os.urandom(16) - return '-'.join(b2a_hex(x).decode('ascii') for x in ( - buf[:4], buf[4:] - )) + return "-".join(b2a_hex(x).decode("ascii") for x in (buf[:4], buf[4:])) + def new_id_bytes() -> bytes: """Return new_id as ascii bytes""" - return new_id().encode('ascii') + return new_id().encode("ascii") + session_aliases = dict( - ident = 'Session.session', - user = 'Session.username', - keyfile = 'Session.keyfile', + ident="Session.session", + user="Session.username", + keyfile="Session.keyfile", ) session_flags = { - 'secure' : ({'Session' : { 'key' : new_id_bytes(), - 'keyfile' : '' }}, + "secure": ( + {"Session": {"key": new_id_bytes(), "keyfile": ""}}, """Use HMAC digests for authentication of messages. Setting this flag will generate a new UUID to use as the HMAC key. - """), - 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }}, - """Don't authenticate messages."""), + """, + ), + "no-secure": ( + {"Session": {"key": b"", "keyfile": ""}}, + """Don't authenticate messages.""", + ), } + def default_secure(cfg) -> None: """Set the default behavior for a config environment to be secure. @@ -141,40 +168,44 @@ def default_secure(cfg) -> None: a new random UUID. """ warnings.warn("default_secure is deprecated", DeprecationWarning) - if 'Session' in cfg: - if 'key' in cfg.Session or 'keyfile' in cfg.Session: + if "Session" in cfg: + if "key" in cfg.Session or "keyfile" in cfg.Session: return # key/keyfile not specified, generate new UUID: cfg.Session.key = new_id_bytes() + def utcnow() -> datetime: """Return timezone-aware UTC timestamp""" return datetime.utcnow().replace(tzinfo=utc) -#----------------------------------------------------------------------------- + +# ----------------------------------------------------------------------------- # Classes -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- + class SessionFactory(LoggingConfigurable): """The Base class for configurables that have a Session, Context, logger, and IOLoop. """ - logname = Unicode('') + logname = Unicode("") - @observe('logname') + @observe("logname") def _logname_changed(self, change) -> None: - self.log = logging.getLogger(change['new']) + self.log = logging.getLogger(change["new"]) # not configurable: - context = Instance('zmq.Context') + context = Instance("zmq.Context") + def _context_default(self) -> zmq.Context: return zmq.Context() - session = Instance('jupyter_client.session.Session', - allow_none=True) + session = Instance("jupyter_client.session.Session", allow_none=True) + + loop = Instance("tornado.ioloop.IOLoop") - loop = Instance('tornado.ioloop.IOLoop') def _loop_default(self): return IOLoop.current() @@ -192,10 +223,7 @@ class Message(object): A Message can be created from a dict and a dict from a Message instance simply by calling dict(msg_obj).""" - def __init__( - self, - msg_dict: t.Dict[str, t.Any] - ) -> None: + def __init__(self, msg_dict: t.Dict[str, t.Any]) -> None: dct = self.__dict__ for k, v in dict(msg_dict).items(): if isinstance(v, dict): @@ -219,30 +247,24 @@ def __getitem__(self, k) -> t.Any: return self.__dict__[k] -def msg_header( - msg_id: str, - msg_type: str, - username: str, - session: 'Session' -) -> t.Dict[str, t.Any]: +def msg_header(msg_id: str, msg_type: str, username: str, session: "Session") -> t.Dict[str, t.Any]: """Create a new message header""" date = utcnow() version = protocol_version return locals() -def extract_header( - msg_or_header: t.Dict[str, t.Any] -) -> t.Dict[str, t.Any]: + +def extract_header(msg_or_header: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: """Given a message or header, return the header.""" if not msg_or_header: return {} try: # See if msg_or_header is the entire message. - h = msg_or_header['header'] + h = msg_or_header["header"] except KeyError: try: # See if msg_or_header is just the header - h = msg_or_header['msg_id'] + h = msg_or_header["msg_id"] except KeyError: raise else: @@ -251,6 +273,7 @@ def extract_header( h = dict(h) return h + class Session(Configurable): """Object for handling serialization and sending of messages. @@ -294,103 +317,120 @@ class Session(Configurable): debug = Bool(False, config=True, help="""Debug output in the Session""") - check_pid = Bool(True, config=True, + check_pid = Bool( + True, + config=True, help="""Whether to check PID to protect against calls after fork. This check can be disabled if fork-safety is handled elsewhere. - """) + """, + ) - packer = DottedObjectName('json',config=True, - help="""The name of the packer for serializing messages. + packer = DottedObjectName( + "json", + config=True, + help="""The name of the packer for serializing messages. Should be one of 'json', 'pickle', or an import name - for a custom callable serializer.""") + for a custom callable serializer.""", + ) - @observe('packer') + @observe("packer") def _packer_changed(self, change): - new = change['new'] - if new.lower() == 'json': + new = change["new"] + if new.lower() == "json": self.pack = json_packer self.unpack = json_unpacker self.unpacker = new - elif new.lower() == 'pickle': + elif new.lower() == "pickle": self.pack = pickle_packer self.unpack = pickle_unpacker self.unpacker = new else: self.pack = import_item(str(new)) - unpacker = DottedObjectName('json', config=True, + unpacker = DottedObjectName( + "json", + config=True, help="""The name of the unpacker for unserializing messages. - Only used with custom functions for `packer`.""") + Only used with custom functions for `packer`.""", + ) - @observe('unpacker') + @observe("unpacker") def _unpacker_changed(self, change): - new = change['new'] - if new.lower() == 'json': + new = change["new"] + if new.lower() == "json": self.pack = json_packer self.unpack = json_unpacker self.packer = new - elif new.lower() == 'pickle': + elif new.lower() == "pickle": self.pack = pickle_packer self.unpack = pickle_unpacker self.packer = new else: self.unpack = import_item(str(new)) - session = CUnicode('', config=True, - help="""The UUID identifying this session.""") + session = CUnicode("", config=True, help="""The UUID identifying this session.""") + def _session_default(self) -> str: u = new_id() - self.bsession = u.encode('ascii') + self.bsession = u.encode("ascii") return u - @observe('session') + @observe("session") def _session_changed(self, change): - self.bsession = self.session.encode('ascii') + self.bsession = self.session.encode("ascii") # bsession is the session as bytes - bsession = CBytes(b'') + bsession = CBytes(b"") username = Unicode( os.environ.get("USER", "username"), help="""Username for the Session. Default is your system username.""", - config=True) + config=True, + ) - metadata = Dict({}, config=True, - help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""") + metadata = Dict( + {}, + config=True, + help="Metadata dictionary, which serves as the default top-level metadata dict for each " + "message.", + ) # if 0, no adapting to do. adapt_version = Integer(0) # message signature related traits: - key = CBytes(config=True, - help="""execution key, for signing messages.""") + key = CBytes(config=True, help="""execution key, for signing messages.""") + def _key_default(self) -> bytes: return new_id_bytes() - @observe('key') + @observe("key") def _key_changed(self, change): self._new_auth() - signature_scheme = Unicode('hmac-sha256', config=True, + signature_scheme = Unicode( + "hmac-sha256", + config=True, help="""The digest scheme used to construct the message signatures. - Must have the form 'hmac-HASH'.""") + Must have the form 'hmac-HASH'.""", + ) - @observe('signature_scheme') + @observe("signature_scheme") def _signature_scheme_changed(self, change): - new = change['new'] - if not new.startswith('hmac-'): + new = change["new"] + if not new.startswith("hmac-"): raise TraitError("signature_scheme must start with 'hmac-', got %r" % new) - hash_name = new.split('-', 1)[1] + hash_name = new.split("-", 1)[1] try: self.digest_mod = getattr(hashlib, hash_name) except AttributeError as e: - raise TraitError("hashlib has no such attribute: %s" % - hash_name) from e + raise TraitError("hashlib has no such attribute: %s" % hash_name) from e self._new_auth() digest_mod = Any() + def _digest_mod_default(self) -> t.Callable: return hashlib.sha256 @@ -403,19 +443,20 @@ def _new_auth(self) -> None: self.auth = None digest_history = Set() - digest_history_size = Integer(2**16, config=True, + digest_history_size = Integer( + 2 ** 16, + config=True, help="""The maximum number of digests to remember. The digest history will be culled when it exceeds this value. - """ + """, ) - keyfile = Unicode('', config=True, - help="""path to file containing execution key.""") + keyfile = Unicode("", config=True, help="""path to file containing execution key.""") - @observe('keyfile') + @observe("keyfile") def _keyfile_changed(self, change): - with open(change['new'], 'rb') as f: + with open(change["new"], "rb") as f: self.key = f.read().strip() # for protecting against sends from forks @@ -423,35 +464,43 @@ def _keyfile_changed(self, change): # serialization traits: - pack = Any(default_packer) # the actual packer function + pack = Any(default_packer) # the actual packer function - @observe('pack') + @observe("pack") def _pack_changed(self, change): - new = change['new'] + new = change["new"] if not callable(new): - raise TypeError("packer must be callable, not %s"%type(new)) + raise TypeError("packer must be callable, not %s" % type(new)) - unpack = Any(default_unpacker) # the actual packer function + unpack = Any(default_unpacker) # the actual packer function - @observe('unpack') + @observe("unpack") def _unpack_changed(self, change): # unpacker is not checked - it is assumed to be - new = change['new'] + new = change["new"] if not callable(new): - raise TypeError("unpacker must be callable, not %s"%type(new)) + raise TypeError("unpacker must be callable, not %s" % type(new)) # thresholds: - copy_threshold = Integer(2**16, config=True, - help="Threshold (in bytes) beyond which a buffer should be sent without copying.") - buffer_threshold = Integer(MAX_BYTES, config=True, - help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.") - item_threshold = Integer(MAX_ITEMS, config=True, + copy_threshold = Integer( + 2 ** 16, + config=True, + help="Threshold (in bytes) beyond which a buffer should be sent without copying.", + ) + buffer_threshold = Integer( + MAX_BYTES, + config=True, + help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid " + "pickling.", + ) + item_threshold = Integer( + MAX_ITEMS, + config=True, help="""The maximum number of items for a container to be introspected for custom serialization. Containers larger than this are pickled outright. - """ + """, ) - def __init__(self, **kwargs): """create a Session object @@ -500,9 +549,11 @@ def __init__(self, **kwargs): self.pid = os.getpid() self._new_auth() if not self.key: - get_logger().warning("Message signing is disabled. This is insecure and not recommended!") + get_logger().warning( + "Message signing is disabled. This is insecure and not recommended!" + ) - def clone(self) -> 'Session': + def clone(self) -> "Session": """Create a copy of this Session Useful when connecting multiple times to a given kernel. @@ -521,11 +572,12 @@ def clone(self) -> 'Session': return new_session message_count = 0 + @property def msg_id(self) -> str: message_number = self.message_count self.message_count += 1 - return '{}_{}'.format(self.session, message_number) + return "{}_{}".format(self.session, message_number) def _check_packers(self) -> None: """check packers for datetime support.""" @@ -533,30 +585,30 @@ def _check_packers(self) -> None: unpack = self.unpack # check simple serialization - msg_list = dict(a=[1,'hi']) + msg_list = dict(a=[1, "hi"]) try: packed = pack(msg_list) except Exception as e: error_msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}" - if self.packer == 'json': + if self.packer == "json": jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod else: jsonmsg = "" - raise ValueError( - error_msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg) - ) from e + raise ValueError(error_msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)) from e # ensure packed message is bytes if not isinstance(packed, bytes): - raise ValueError("message packed to %r, but bytes are required"%type(packed)) + raise ValueError("message packed to %r, but bytes are required" % type(packed)) # check that unpack is pack's inverse try: unpacked = unpack(packed) assert unpacked == msg_list except Exception as e: - error_msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}" - if self.packer == 'json': + error_msg = ( + "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}" + ) + if self.packer == "json": jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod else: jsonmsg = "" @@ -568,16 +620,13 @@ def _check_packers(self) -> None: msg_datetime = dict(t=utcnow()) try: unpacked = unpack(pack(msg_datetime)) - if isinstance(unpacked['t'], datetime): + if isinstance(unpacked["t"], datetime): raise ValueError("Shouldn't deserialize to datetime") except Exception: self.pack = lambda o: pack(squash_dates(o)) self.unpack = lambda s: unpack(s) - def msg_header( - self, - msg_type: str - ) -> t.Dict[str, t.Any]: + def msg_header(self, msg_type: str) -> t.Dict[str, t.Any]: return msg_header(self.msg_id, msg_type, self.username, self.session) def msg( @@ -586,7 +635,7 @@ def msg( content: t.Optional[t.Dict] = None, parent: t.Optional[t.Dict[str, t.Any]] = None, header: t.Optional[t.Dict[str, t.Any]] = None, - metadata: t.Optional[t.Dict[str, t.Any]] = None + metadata: t.Optional[t.Dict[str, t.Any]] = None, ) -> t.Dict[str, t.Any]: """Return the nested message dict. @@ -596,20 +645,17 @@ def msg( """ msg = {} header = self.msg_header(msg_type) if header is None else header - msg['header'] = header - msg['msg_id'] = header['msg_id'] - msg['msg_type'] = header['msg_type'] - msg['parent_header'] = {} if parent is None else extract_header(parent) - msg['content'] = {} if content is None else content - msg['metadata'] = self.metadata.copy() + msg["header"] = header + msg["msg_id"] = header["msg_id"] + msg["msg_type"] = header["msg_type"] + msg["parent_header"] = {} if parent is None else extract_header(parent) + msg["content"] = {} if content is None else content + msg["metadata"] = self.metadata.copy() if metadata is not None: - msg['metadata'].update(metadata) + msg["metadata"].update(metadata) return msg - def sign( - self, - msg_list: t.List - ) -> bytes: + def sign(self, msg_list: t.List) -> bytes: """Sign a message with HMAC digest. If no auth, return b''. Parameters @@ -618,7 +664,7 @@ def sign( The [p_header,p_parent,p_content] part of the message list. """ if self.auth is None: - return b'' + return b"" h = self.auth.copy() for m in msg_list: h.update(m) @@ -627,7 +673,7 @@ def sign( def serialize( self, msg: t.Dict[str, t.Any], - ident: t.Optional[t.Union[t.List[bytes], bytes]] = None + ident: t.Optional[t.Union[t.List[bytes], bytes]] = None, ) -> t.List[bytes]: """Serialize the message components to bytes. @@ -651,7 +697,7 @@ def serialize( In this list, the ``p_*`` entities are the packed or serialized versions, so if JSON is used, these are utf8 encoded JSON strings. """ - content = msg.get('content', {}) + content = msg.get("content", {}) if content is None: content = self.none elif isinstance(content, dict): @@ -661,14 +707,15 @@ def serialize( pass elif isinstance(content, str): # should be bytes, but JSON often spits out unicode - content = content.encode('utf8') + content = content.encode("utf8") else: - raise TypeError("Content incorrect type: %s"%type(content)) + raise TypeError("Content incorrect type: %s" % type(content)) - real_message = [self.pack(msg['header']), - self.pack(msg['parent_header']), - self.pack(msg['metadata']), - content, + real_message = [ + self.pack(msg["header"]), + self.pack(msg["parent_header"]), + self.pack(msg["metadata"]), + content, ] to_send = [] @@ -697,7 +744,7 @@ def send( buffers: t.Optional[t.List[bytes]] = None, track: bool = False, header: t.Optional[t.Dict[str, t.Any]] = None, - metadata: t.Optional[t.Dict[str, t.Any]] = None + metadata: t.Optional[t.Dict[str, t.Any]] = None, ) -> t.Optional[t.Dict[str, t.Any]]: """Build and send a message via stream or socket. @@ -750,14 +797,17 @@ def send( # We got a Message or message dict, not a msg_type so don't # build a new Message. msg = msg_or_type - buffers = buffers or msg.get('buffers', []) + buffers = buffers or msg.get("buffers", []) else: - msg = self.msg(msg_or_type, content=content, parent=parent, - header=header, metadata=metadata) - if self.check_pid and not os.getpid() == self.pid: - get_logger().warning("WARNING: attempted to send message from fork\n%s", - msg + msg = self.msg( + msg_or_type, + content=content, + parent=parent, + header=header, + metadata=metadata, ) + if self.check_pid and not os.getpid() == self.pid: + get_logger().warning("WARNING: attempted to send message from fork\n%s", msg) return None buffers = [] if buffers is None else buffers for idx, buf in enumerate(buffers): @@ -771,7 +821,7 @@ def send( raise TypeError("Buffer objects must support the buffer protocol.") from e # memoryview.contiguous is new in 3.3, # just skip the check on Python 2 - if hasattr(view, 'contiguous') and not view.contiguous: + if hasattr(view, "contiguous") and not view.contiguous: # zmq requires memoryviews to be contiguous raise ValueError("Buffer %i (%r) is not contiguous" % (idx, buf)) @@ -779,8 +829,8 @@ def send( msg = adapt(msg, self.adapt_version) to_send = self.serialize(msg, ident) to_send.extend(buffers) - longest = max([ len(s) for s in to_send ]) - copy = (longest < self.copy_threshold) + longest = max([len(s) for s in to_send]) + copy = longest < self.copy_threshold if buffers and track and not copy: # only really track when we are doing zero-copy buffers @@ -795,7 +845,7 @@ def send( pprint.pprint(to_send) pprint.pprint(buffers) - msg['tracker'] = tracker + msg["tracker"] = tracker return msg @@ -837,9 +887,9 @@ def send_raw( def recv( self, socket: zmq.sugar.socket.Socket, - mode: int =zmq.NOBLOCK, - content: bool =True, - copy: bool = True + mode: int = zmq.NOBLOCK, + content: bool = True, + copy: bool = True, ) -> t.Tuple[t.Optional[t.List[bytes]], t.Optional[t.Dict[str, t.Any]]]: """Receive and unpack a message. @@ -875,9 +925,7 @@ def recv( raise e def feed_identities( - self, - msg_list: t.Union[t.List[bytes], t.List[zmq.Message]], - copy: bool =True + self, msg_list: t.Union[t.List[bytes], t.List[zmq.Message]], copy: bool = True ) -> t.Tuple[t.List[bytes], t.Union[t.List[bytes], t.List[zmq.Message]]]: """Split the identities from the rest of the message. @@ -904,7 +952,7 @@ def feed_identities( if copy: msg_list = t.cast(t.List[bytes], msg_list) idx = msg_list.index(DELIM) - return msg_list[:idx], msg_list[idx+1:] + return msg_list[:idx], msg_list[idx + 1 :] # noqa else: msg_list = t.cast(t.List[zmq.Message], msg_list) failed = True @@ -914,13 +962,10 @@ def feed_identities( break if failed: raise ValueError("DELIM not in msg_list") - idents, msg_list = msg_list[:idx], msg_list[idx+1:] + idents, msg_list = msg_list[:idx], msg_list[idx + 1 :] # noqa return [bytes(m.bytes) for m in idents], msg_list - def _add_digest( - self, - signature: bytes - ) -> None: + def _add_digest(self, signature: bytes) -> None: """add a digest to history to protect against replay attacks""" if self.digest_history_size == 0: # no history, never add digests @@ -947,8 +992,8 @@ def _cull_digest_history(self) -> None: def deserialize( self, msg_list: t.Union[t.List[bytes], t.List[zmq.Message]], - content: bool =True, - copy: bool =True + content: bool = True, + copy: bool = True, ) -> t.Dict[str, t.Any]: """Unserialize a msg_list to a nested message dict. @@ -996,23 +1041,23 @@ def deserialize( if not compare_digest(signature, check): raise ValueError("Invalid Signature: %r" % signature) if not len(msg_list) >= minlen: - raise TypeError("malformed message, must have at least %i elements"%minlen) + raise TypeError("malformed message, must have at least %i elements" % minlen) header = self.unpack(msg_list[1]) - message['header'] = extract_dates(header) - message['msg_id'] = header['msg_id'] - message['msg_type'] = header['msg_type'] - message['parent_header'] = extract_dates(self.unpack(msg_list[2])) - message['metadata'] = self.unpack(msg_list[3]) + message["header"] = extract_dates(header) + message["msg_id"] = header["msg_id"] + message["msg_type"] = header["msg_type"] + message["parent_header"] = extract_dates(self.unpack(msg_list[2])) + message["metadata"] = self.unpack(msg_list[3]) if content: - message['content'] = self.unpack(msg_list[4]) + message["content"] = self.unpack(msg_list[4]) else: - message['content'] = msg_list[4] + message["content"] = msg_list[4] buffers = [memoryview(b) for b in msg_list[5:]] if buffers and buffers[0].shape is None: # force copy to workaround pyzmq #646 msg_list = t.cast(t.List[zmq.Message], msg_list) buffers = [memoryview(bytes(b.bytes)) for b in msg_list[5:]] - message['buffers'] = buffers + message["buffers"] = buffers if self.debug: pprint.pprint(message) # adapt to the current version @@ -1029,15 +1074,15 @@ def unserialize(self, *args, **kwargs) -> t.Dict[str, t.Any]: def test_msg2obj(): am = dict(x=1) ao = Message(am) - assert ao.x == am['x'] + assert ao.x == am["x"] - am['y'] = dict(z=1) + am["y"] = dict(z=1) ao = Message(am) - assert ao.y.z == am['y']['z'] + assert ao.y.z == am["y"]["z"] - k1, k2 = 'y', 'z' + k1, k2 = "y", "z" assert ao[k1][k2] == am[k1][k2] am2 = dict(ao) - assert am['x'] == am2['x'] - assert am['y']['z'] == am2['y']['z'] + assert am["x"] == am2["x"] + assert am["y"]["z"] == am2["y"]["z"] diff --git a/jupyter_client/ssh/__init__.py b/jupyter_client/ssh/__init__.py index d7bc9d566..bc0db11d4 100644 --- a/jupyter_client/ssh/__init__.py +++ b/jupyter_client/ssh/__init__.py @@ -1 +1 @@ -from jupyter_client.ssh.tunnel import * +from jupyter_client.ssh.tunnel import * # noqa diff --git a/jupyter_client/ssh/forward.py b/jupyter_client/ssh/forward.py index 16e4cc681..618f82bf4 100644 --- a/jupyter_client/ssh/forward.py +++ b/jupyter_client/ssh/forward.py @@ -16,7 +16,6 @@ # You should have received a copy of the GNU Lesser General Public License # along with Paramiko; if not, write to the Free Software Foundation, Inc., # 51 Franklin Street, Fifth Floor, Boston, MA 02111-1301 USA. - """ Sample script showing how to do local port forwarding over paramiko. @@ -24,38 +23,46 @@ forwarding (the openssh -L option) from a local port through a tunneled connection to a destination reachable from the SSH server machine. """ - import logging import select import socketserver -logger = logging.getLogger('ssh') +logger = logging.getLogger("ssh") -class ForwardServer (socketserver.ThreadingTCPServer): +class ForwardServer(socketserver.ThreadingTCPServer): daemon_threads = True allow_reuse_address = True -class Handler (socketserver.BaseRequestHandler): - +class Handler(socketserver.BaseRequestHandler): def handle(self): try: - chan = self.ssh_transport.open_channel('direct-tcpip', - (self.chain_host, self.chain_port), - self.request.getpeername()) + chan = self.ssh_transport.open_channel( + "direct-tcpip", + (self.chain_host, self.chain_port), + self.request.getpeername(), + ) except Exception as e: - logger.debug('Incoming request to %s:%d failed: %s' % (self.chain_host, - self.chain_port, - repr(e))) + logger.debug( + "Incoming request to %s:%d failed: %s" % (self.chain_host, self.chain_port, repr(e)) + ) return if chan is None: - logger.debug('Incoming request to %s:%d was rejected by the SSH server.' % - (self.chain_host, self.chain_port)) + logger.debug( + "Incoming request to %s:%d was rejected by the SSH server." + % (self.chain_host, self.chain_port) + ) return - logger.debug('Connected! Tunnel open %r -> %r -> %r' % (self.request.getpeername(), - chan.getpeername(), (self.chain_host, self.chain_port))) + logger.debug( + "Connected! Tunnel open %r -> %r -> %r" + % ( + self.request.getpeername(), + chan.getpeername(), + (self.chain_host, self.chain_port), + ) + ) while True: r, w, x = select.select([self.request, chan], [], []) if self.request in r: @@ -70,18 +77,19 @@ def handle(self): self.request.send(data) chan.close() self.request.close() - logger.debug('Tunnel closed ') + logger.debug("Tunnel closed ") def forward_tunnel(local_port, remote_host, remote_port, transport): # this is a little convoluted, but lets me configure things for the Handler # object. (SocketServer doesn't give Handlers any way to access the outer # server normally.) - class SubHander (Handler): + class SubHander(Handler): chain_host = remote_host chain_port = remote_port ssh_transport = transport - ForwardServer(('127.0.0.1', local_port), SubHander).serve_forever() + + ForwardServer(("127.0.0.1", local_port), SubHander).serve_forever() -__all__ = ['forward_tunnel'] +__all__ = ["forward_tunnel"] diff --git a/jupyter_client/ssh/tunnel.py b/jupyter_client/ssh/tunnel.py index 3da2b5138..9f95da12c 100644 --- a/jupyter_client/ssh/tunnel.py +++ b/jupyter_client/ssh/tunnel.py @@ -1,12 +1,10 @@ """Basic ssh tunnel utilities, and convenience functions for tunneling zeromq connections. """ - # Copyright (C) 2010-2011 IPython Development Team # Copyright (C) 2011- PyZMQ Developers # # Redistributed from IPython under the terms of the BSD License. - import atexit import os import re @@ -14,18 +12,23 @@ import socket import sys import warnings -from getpass import getpass, getuser +from getpass import getpass +from getpass import getuser from multiprocessing import Process try: with warnings.catch_warnings(): - warnings.simplefilter('ignore', DeprecationWarning) + warnings.simplefilter("ignore", DeprecationWarning) import paramiko + SSHException = paramiko.ssh_exception.SSHException except ImportError: paramiko = None + class SSHException(Exception): pass + + else: from .forward import forward_tunnel @@ -43,7 +46,7 @@ def select_random_ports(n): sockets = [] for i in range(n): sock = socket.socket() - sock.bind(('', 0)) + sock.bind(("", 0)) ports.append(sock.getsockname()[1]) sockets.append(sock) for sock in sockets: @@ -51,10 +54,10 @@ def select_random_ports(n): return ports -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- # Check for passwordless login -#----------------------------------------------------------------------------- -_password_pat = re.compile(b(r'pass(word|phrase):'), re.IGNORECASE) +# ----------------------------------------------------------------------------- +_password_pat = re.compile(b(r"pass(word|phrase):"), re.IGNORECASE) def try_passwordless_ssh(server, keyfile, paramiko=None): @@ -65,7 +68,7 @@ def try_passwordless_ssh(server, keyfile, paramiko=None): If paramiko is None, the default for the platform is chosen. """ if paramiko is None: - paramiko = sys.platform == 'win32' + paramiko = sys.platform == "win32" if not paramiko: f = _try_passwordless_openssh else: @@ -77,22 +80,22 @@ def _try_passwordless_openssh(server, keyfile): """Try passwordless login with shell ssh command.""" if pexpect is None: raise ImportError("pexpect unavailable, use paramiko") - cmd = 'ssh -f ' + server + cmd = "ssh -f " + server if keyfile: - cmd += ' -i ' + keyfile - cmd += ' exit' + cmd += " -i " + keyfile + cmd += " exit" # pop SSH_ASKPASS from env env = os.environ.copy() - env.pop('SSH_ASKPASS', None) + env.pop("SSH_ASKPASS", None) - ssh_newkey = 'Are you sure you want to continue connecting' + ssh_newkey = "Are you sure you want to continue connecting" p = pexpect.spawn(cmd, env=env) while True: try: - i = p.expect([ssh_newkey, _password_pat], timeout=.1) + i = p.expect([ssh_newkey, _password_pat], timeout=0.1) if i == 0: - raise SSHException('The authenticity of the host can\'t be established.') + raise SSHException("The authenticity of the host can't be established.") except pexpect.TIMEOUT: continue except pexpect.EOF: @@ -105,7 +108,7 @@ def _try_passwordless_paramiko(server, keyfile): """Try passwordless login with paramiko.""" if paramiko is None: msg = "Paramiko unavailable, " - if sys.platform == 'win32': + if sys.platform == "win32": msg += "Paramiko is required for ssh tunneled connections on Windows." else: msg += "use OpenSSH." @@ -115,8 +118,7 @@ def _try_passwordless_paramiko(server, keyfile): client.load_system_host_keys() client.set_missing_host_key_policy(paramiko.WarningPolicy()) try: - client.connect(server, port, username=username, key_filename=keyfile, - look_for_keys=True) + client.connect(server, port, username=username, key_filename=keyfile, look_for_keys=True) except paramiko.AuthenticationException: return False else: @@ -133,7 +135,14 @@ def tunnel_connection(socket, addr, server, keyfile=None, password=None, paramik selected local port of the tunnel. """ - new_url, tunnel = open_tunnel(addr, server, keyfile=keyfile, password=password, paramiko=paramiko, timeout=timeout) + new_url, tunnel = open_tunnel( + addr, + server, + keyfile=keyfile, + password=password, + paramiko=paramiko, + timeout=timeout, + ) socket.connect(new_url) return tunnel @@ -151,21 +160,31 @@ def open_tunnel(addr, server, keyfile=None, password=None, paramiko=None, timeou """ lport = select_random_ports(1)[0] - transport, addr = addr.split('://') - ip, rport = addr.split(':') + transport, addr = addr.split("://") + ip, rport = addr.split(":") rport = int(rport) if paramiko is None: - paramiko = sys.platform == 'win32' + paramiko = sys.platform == "win32" if paramiko: tunnelf = paramiko_tunnel else: tunnelf = openssh_tunnel - tunnel = tunnelf(lport, rport, server, remoteip=ip, keyfile=keyfile, password=password, timeout=timeout) - return 'tcp://127.0.0.1:%i' % lport, tunnel - - -def openssh_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=60): + tunnel = tunnelf( + lport, + rport, + server, + remoteip=ip, + keyfile=keyfile, + password=password, + timeout=timeout, + ) + return "tcp://127.0.0.1:%i" % lport, tunnel + + +def openssh_tunnel( + lport, rport, server, remoteip="127.0.0.1", keyfile=None, password=None, timeout=60 +): """Create an ssh tunnel using command-line ssh that connects port lport on this machine to localhost:rport on server. The tunnel will automatically close when not in use, remaining open @@ -207,35 +226,46 @@ def openssh_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, pas if keyfile: ssh += "-i " + keyfile - if ':' in server: - server, port = server.split(':') + if ":" in server: + server, port = server.split(":") ssh += " -p %s" % port cmd = "%s -O check %s" % (ssh, server) (output, exitstatus) = pexpect.run(cmd, withexitstatus=True) if not exitstatus: - pid = int(output[output.find(b"(pid=")+5:output.find(b")")]) + pid = int(output[output.find(b"(pid=") + 5 : output.find(b")")]) # noqa cmd = "%s -O forward -L 127.0.0.1:%i:%s:%i %s" % ( - ssh, lport, remoteip, rport, server) + ssh, + lport, + remoteip, + rport, + server, + ) (output, exitstatus) = pexpect.run(cmd, withexitstatus=True) if not exitstatus: atexit.register(_stop_tunnel, cmd.replace("-O forward", "-O cancel", 1)) return pid cmd = "%s -f -S none -L 127.0.0.1:%i:%s:%i %s sleep %i" % ( - ssh, lport, remoteip, rport, server, timeout) + ssh, + lport, + remoteip, + rport, + server, + timeout, + ) # pop SSH_ASKPASS from env env = os.environ.copy() - env.pop('SSH_ASKPASS', None) + env.pop("SSH_ASKPASS", None) - ssh_newkey = 'Are you sure you want to continue connecting' + ssh_newkey = "Are you sure you want to continue connecting" tunnel = pexpect.spawn(cmd, env=env) failed = False while True: try: - i = tunnel.expect([ssh_newkey, _password_pat], timeout=.1) + i = tunnel.expect([ssh_newkey, _password_pat], timeout=0.1) if i == 0: - raise SSHException('The authenticity of the host can\'t be established.') + raise SSHException("The authenticity of the host can't be established.") except pexpect.TIMEOUT: continue except pexpect.EOF as e: @@ -261,19 +291,21 @@ def _stop_tunnel(cmd): def _split_server(server): - if '@' in server: - username, server = server.split('@', 1) + if "@" in server: + username, server = server.split("@", 1) else: username = getuser() - if ':' in server: - server, port = server.split(':') + if ":" in server: + server, port = server.split(":") port = int(port) else: port = 22 return username, server, port -def paramiko_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, password=None, timeout=60): +def paramiko_tunnel( + lport, rport, server, remoteip="127.0.0.1", keyfile=None, password=None, timeout=60 +): """launch a tunner with paramiko in a subprocess. This should only be used when shell ssh is unavailable (e.g. Windows). @@ -320,9 +352,11 @@ def paramiko_tunnel(lport, rport, server, remoteip='127.0.0.1', keyfile=None, pa if not _try_passwordless_paramiko(server, keyfile): password = getpass("%s's password: " % (server)) - p = Process(target=_paramiko_tunnel, - args=(lport, rport, server, remoteip), - kwargs=dict(keyfile=keyfile, password=password)) + p = Process( + target=_paramiko_tunnel, + args=(lport, rport, server, remoteip), + kwargs=dict(keyfile=keyfile, password=password), + ) p.daemon = True p.start() return p @@ -338,16 +372,22 @@ def _paramiko_tunnel(lport, rport, server, remoteip, keyfile=None, password=None client.set_missing_host_key_policy(paramiko.WarningPolicy()) try: - client.connect(server, port, username=username, key_filename=keyfile, - look_for_keys=True, password=password) -# except paramiko.AuthenticationException: -# if password is None: -# password = getpass("%s@%s's password: "%(username, server)) -# client.connect(server, port, username=username, password=password) -# else: -# raise + client.connect( + server, + port, + username=username, + key_filename=keyfile, + look_for_keys=True, + password=password, + ) + # except paramiko.AuthenticationException: + # if password is None: + # password = getpass("%s@%s's password: "%(username, server)) + # client.connect(server, port, username=username, password=password) + # else: + # raise except Exception as e: - print('*** Failed to connect to %s:%d: %r' % (server, port, e)) + print("*** Failed to connect to %s:%d: %r" % (server, port, e)) sys.exit(1) # Don't let SIGINT kill the tunnel subprocess @@ -356,17 +396,23 @@ def _paramiko_tunnel(lport, rport, server, remoteip, keyfile=None, password=None try: forward_tunnel(lport, remoteip, rport, client.get_transport()) except KeyboardInterrupt: - print('SIGINT: Port forwarding stopped cleanly') + print("SIGINT: Port forwarding stopped cleanly") sys.exit(0) except Exception as e: print("Port forwarding stopped uncleanly: %s" % e) sys.exit(255) -if sys.platform == 'win32': +if sys.platform == "win32": ssh_tunnel = paramiko_tunnel else: ssh_tunnel = openssh_tunnel -__all__ = ['tunnel_connection', 'ssh_tunnel', 'openssh_tunnel', 'paramiko_tunnel', 'try_passwordless_ssh'] +__all__ = [ + "tunnel_connection", + "ssh_tunnel", + "openssh_tunnel", + "paramiko_tunnel", + "try_passwordless_ssh", +] diff --git a/jupyter_client/tests/conftest.py b/jupyter_client/tests/conftest.py index bd6e97c1e..19f481dbd 100644 --- a/jupyter_client/tests/conftest.py +++ b/jupyter_client/tests/conftest.py @@ -1,20 +1,19 @@ - import asyncio import os import sys import pytest - -if os.name == 'nt' and sys.version_info >= (3, 7): +if os.name == "nt" and sys.version_info >= (3, 7): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + @pytest.fixture def event_loop(): # Make sure we test against a selector event loop # since pyzmq doesn't like the proactor loop. # This fixture is picked up by pytest-asyncio - if os.name == 'nt' and sys.version_info >= (3, 7): + if os.name == "nt" and sys.version_info >= (3, 7): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) loop = asyncio.SelectorEventLoop() try: diff --git a/jupyter_client/tests/signalkernel.py b/jupyter_client/tests/signalkernel.py index 93e290b7b..e26731ff8 100644 --- a/jupyter_client/tests/signalkernel.py +++ b/jupyter_client/tests/signalkernel.py @@ -1,31 +1,27 @@ """Test kernel for signalling subprocesses""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - import os - -from subprocess import Popen, PIPE -import sys +import signal import time +from subprocess import PIPE +from subprocess import Popen from ipykernel.displayhook import ZMQDisplayHook -from ipykernel.kernelbase import Kernel from ipykernel.kernelapp import IPKernelApp - +from ipykernel.kernelbase import Kernel from tornado.web import gen -import signal - class SignalTestKernel(Kernel): """Kernel for testing subprocess signaling""" - implementation = 'signaltest' - implementation_version = '0.0' - banner = '' + + implementation = "signaltest" + implementation_version = "0.0" + banner = "" def __init__(self, **kwargs): - kwargs.pop('user_ns', None) + kwargs.pop("user_ns", None) super().__init__(**kwargs) self.children = [] @@ -37,33 +33,34 @@ def shutdown_request(self, stream, ident, parent): if os.environ.get("NO_SHUTDOWN_REPLY") != "1": yield gen.maybe_future(super().shutdown_request(stream, ident, parent)) - def do_execute(self, code, silent, store_history=True, user_expressions=None, - allow_stdin=False): + def do_execute( + self, code, silent, store_history=True, user_expressions=None, allow_stdin=False + ): code = code.strip() reply = { - 'status': 'ok', - 'user_expressions': {}, + "status": "ok", + "user_expressions": {}, } - if code == 'start': - child = Popen(['bash', '-i', '-c', 'sleep 30'], stderr=PIPE) + if code == "start": + child = Popen(["bash", "-i", "-c", "sleep 30"], stderr=PIPE) self.children.append(child) - reply['user_expressions']['pid'] = self.children[-1].pid - elif code == 'check': - reply['user_expressions']['poll'] = [ child.poll() for child in self.children ] - elif code == 'env': - reply['user_expressions']['env'] = os.getenv("TEST_VARS", "") - elif code == 'sleep': + reply["user_expressions"]["pid"] = self.children[-1].pid + elif code == "check": + reply["user_expressions"]["poll"] = [child.poll() for child in self.children] + elif code == "env": + reply["user_expressions"]["env"] = os.getenv("TEST_VARS", "") + elif code == "sleep": try: time.sleep(10) except KeyboardInterrupt: - reply['user_expressions']['interrupted'] = True + reply["user_expressions"]["interrupted"] = True else: - reply['user_expressions']['interrupted'] = False + reply["user_expressions"]["interrupted"] = False else: - reply['status'] = 'error' - reply['ename'] = 'Error' - reply['evalue'] = code - reply['traceback'] = ['no such command: %s' % code] + reply["status"] = "error" + reply["ename"] = "Error" + reply["evalue"] = code + reply["traceback"] = ["no such command: %s" % code] return reply @@ -75,7 +72,7 @@ def init_io(self): self.displayhook = ZMQDisplayHook(self.session, self.iopub_socket) -if __name__ == '__main__': +if __name__ == "__main__": # make startup artificially slow, # so that we exercise client logic for slow-starting kernels time.sleep(2) diff --git a/jupyter_client/tests/test_adapter.py b/jupyter_client/tests/test_adapter.py index dae020788..c54217f10 100644 --- a/jupyter_client/tests/test_adapter.py +++ b/jupyter_client/tests/test_adapter.py @@ -1,31 +1,32 @@ """Tests for adapting Jupyter msg spec versions""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - import copy import json from unittest import TestCase -from jupyter_client.adapter import adapt, V4toV5, V5toV4, code_to_line +from jupyter_client.adapter import adapt +from jupyter_client.adapter import code_to_line +from jupyter_client.adapter import V4toV5 from jupyter_client.session import Session def test_default_version(): s = Session() msg = s.msg("msg_type") - msg['header'].pop('version') + msg["header"].pop("version") original = copy.deepcopy(msg) adapted = adapt(original) - assert adapted['header']['version'] == V4toV5.version + assert adapted["header"]["version"] == V4toV5.version + def test_code_to_line_no_code(): line, pos = code_to_line("", 0) assert line == "" assert pos == 0 -class AdapterTest(TestCase): +class AdapterTest(TestCase): def setUp(self): self.session = Session() @@ -45,190 +46,220 @@ class V4toV5TestCase(AdapterTest): def msg(self, msg_type, content): """Create a v4 msg (same as v5, minus version header)""" msg = self.session.msg(msg_type, content) - msg['header'].pop('version') + msg["header"].pop("version") return msg def test_same_version(self): - msg = self.msg("execute_result", - content={'status' : 'ok'} - ) + msg = self.msg("execute_result", content={"status": "ok"}) original, adapted = self.adapt(msg, self.from_version) self.assertEqual(original, adapted) def test_no_adapt(self): - msg = self.msg("input_reply", {'value' : 'some text'}) + msg = self.msg("input_reply", {"value": "some text"}) v4, v5 = self.adapt(msg) - self.assertEqual(v5['header']['version'], V4toV5.version) - v5['header'].pop('version') + self.assertEqual(v5["header"]["version"], V4toV5.version) + v5["header"].pop("version") self.assertEqual(v4, v5) def test_rename_type(self): for v5_type, v4_type in [ - ('execute_result', 'pyout'), - ('execute_input', 'pyin'), - ('error', 'pyerr'), - ]: - msg = self.msg(v4_type, {'key' : 'value'}) + ("execute_result", "pyout"), + ("execute_input", "pyin"), + ("error", "pyerr"), + ]: + msg = self.msg(v4_type, {"key": "value"}) v4, v5 = self.adapt(msg) - self.assertEqual(v5['header']['version'], V4toV5.version) - self.assertEqual(v5['header']['msg_type'], v5_type) - self.assertEqual(v4['content'], v5['content']) + self.assertEqual(v5["header"]["version"], V4toV5.version) + self.assertEqual(v5["header"]["msg_type"], v5_type) + self.assertEqual(v4["content"], v5["content"]) def test_execute_request(self): - msg = self.msg("execute_request", { - 'code' : 'a=5', - 'silent' : False, - 'user_expressions' : {'a' : 'apple'}, - 'user_variables' : ['b'], - }) + msg = self.msg( + "execute_request", + { + "code": "a=5", + "silent": False, + "user_expressions": {"a": "apple"}, + "user_variables": ["b"], + }, + ) v4, v5 = self.adapt(msg) - self.assertEqual(v4['header']['msg_type'], v5['header']['msg_type']) - v4c = v4['content'] - v5c = v5['content'] - self.assertEqual(v5c['user_expressions'], {'a' : 'apple', 'b': 'b'}) - self.assertNotIn('user_variables', v5c) - self.assertEqual(v5c['code'], v4c['code']) + self.assertEqual(v4["header"]["msg_type"], v5["header"]["msg_type"]) + v4c = v4["content"] + v5c = v5["content"] + self.assertEqual(v5c["user_expressions"], {"a": "apple", "b": "b"}) + self.assertNotIn("user_variables", v5c) + self.assertEqual(v5c["code"], v4c["code"]) def test_execute_reply(self): - msg = self.msg("execute_reply", { - 'status': 'ok', - 'execution_count': 7, - 'user_variables': {'a': 1}, - 'user_expressions': {'a+a': 2}, - 'payload': [{'source':'page', 'text':'blah'}] - }) + msg = self.msg( + "execute_reply", + { + "status": "ok", + "execution_count": 7, + "user_variables": {"a": 1}, + "user_expressions": {"a+a": 2}, + "payload": [{"source": "page", "text": "blah"}], + }, + ) v4, v5 = self.adapt(msg) - v5c = v5['content'] - self.assertNotIn('user_variables', v5c) - self.assertEqual(v5c['user_expressions'], {'a': 1, 'a+a': 2}) - self.assertEqual(v5c['payload'], [{'source': 'page', - 'data': {'text/plain': 'blah'}} - ]) + v5c = v5["content"] + self.assertNotIn("user_variables", v5c) + self.assertEqual(v5c["user_expressions"], {"a": 1, "a+a": 2}) + self.assertEqual(v5c["payload"], [{"source": "page", "data": {"text/plain": "blah"}}]) def test_complete_request(self): - msg = self.msg("complete_request", { - 'text' : 'a.is', - 'line' : 'foo = a.is', - 'block' : None, - 'cursor_pos' : 10, - }) + msg = self.msg( + "complete_request", + { + "text": "a.is", + "line": "foo = a.is", + "block": None, + "cursor_pos": 10, + }, + ) v4, v5 = self.adapt(msg) - v4c = v4['content'] - v5c = v5['content'] - for key in ('text', 'line', 'block'): + v4c = v4["content"] + v5c = v5["content"] + for key in ("text", "line", "block"): self.assertNotIn(key, v5c) - self.assertEqual(v5c['cursor_pos'], v4c['cursor_pos']) - self.assertEqual(v5c['code'], v4c['line']) + self.assertEqual(v5c["cursor_pos"], v4c["cursor_pos"]) + self.assertEqual(v5c["code"], v4c["line"]) def test_complete_reply(self): - msg = self.msg("complete_reply", { - 'matched_text' : 'a.is', - 'matches' : ['a.isalnum', - 'a.isalpha', - 'a.isdigit', - 'a.islower', - ], - }) + msg = self.msg( + "complete_reply", + { + "matched_text": "a.is", + "matches": [ + "a.isalnum", + "a.isalpha", + "a.isdigit", + "a.islower", + ], + }, + ) v4, v5 = self.adapt(msg) - v4c = v4['content'] - v5c = v5['content'] + v4c = v4["content"] + v5c = v5["content"] - self.assertEqual(v5c['matches'], v4c['matches']) - self.assertEqual(v5c['metadata'], {}) - self.assertEqual(v5c['cursor_start'], -4) - self.assertEqual(v5c['cursor_end'], None) + self.assertEqual(v5c["matches"], v4c["matches"]) + self.assertEqual(v5c["metadata"], {}) + self.assertEqual(v5c["cursor_start"], -4) + self.assertEqual(v5c["cursor_end"], None) def test_object_info_request(self): - msg = self.msg("object_info_request", { - 'oname' : 'foo', - 'detail_level' : 1, - }) + msg = self.msg( + "object_info_request", + { + "oname": "foo", + "detail_level": 1, + }, + ) v4, v5 = self.adapt(msg) - self.assertEqual(v5['header']['msg_type'], 'inspect_request') - v4c = v4['content'] - v5c = v5['content'] - self.assertEqual(v5c['code'], v4c['oname']) - self.assertEqual(v5c['cursor_pos'], len(v4c['oname'])) - self.assertEqual(v5c['detail_level'], v4c['detail_level']) + self.assertEqual(v5["header"]["msg_type"], "inspect_request") + v4c = v4["content"] + v5c = v5["content"] + self.assertEqual(v5c["code"], v4c["oname"]) + self.assertEqual(v5c["cursor_pos"], len(v4c["oname"])) + self.assertEqual(v5c["detail_level"], v4c["detail_level"]) def test_object_info_reply(self): - msg = self.msg("object_info_reply", { - 'name' : 'foo', - 'found' : True, - 'status' : 'ok', - 'definition' : 'foo(a=5)', - 'docstring' : "the docstring", - }) + msg = self.msg( + "object_info_reply", + { + "name": "foo", + "found": True, + "status": "ok", + "definition": "foo(a=5)", + "docstring": "the docstring", + }, + ) v4, v5 = self.adapt(msg) - self.assertEqual(v5['header']['msg_type'], 'inspect_reply') - v4c = v4['content'] - v5c = v5['content'] - self.assertEqual(sorted(v5c), [ 'data', 'found', 'metadata', 'status']) - text = v5c['data']['text/plain'] - self.assertEqual(text, '\n'.join([v4c['definition'], v4c['docstring']])) + self.assertEqual(v5["header"]["msg_type"], "inspect_reply") + v4c = v4["content"] + v5c = v5["content"] + self.assertEqual(sorted(v5c), ["data", "found", "metadata", "status"]) + text = v5c["data"]["text/plain"] + self.assertEqual(text, "\n".join([v4c["definition"], v4c["docstring"]])) def test_object_info_reply_not_found(self): - msg = self.msg("object_info_reply", { - 'name' : 'foo', - 'found' : False, - }) + msg = self.msg( + "object_info_reply", + { + "name": "foo", + "found": False, + }, + ) v4, v5 = self.adapt(msg) - self.assertEqual(v5['header']['msg_type'], 'inspect_reply') - v4c = v4['content'] - v5c = v5['content'] - self.assertEqual(v5c, { - 'status': 'ok', - 'found': False, - 'data': {}, - 'metadata': {}, - }) + self.assertEqual(v5["header"]["msg_type"], "inspect_reply") + v4["content"] + v5c = v5["content"] + self.assertEqual( + v5c, + { + "status": "ok", + "found": False, + "data": {}, + "metadata": {}, + }, + ) def test_kernel_info_reply(self): - msg = self.msg("kernel_info_reply", { - 'language': 'python', - 'language_version': [2,8,0], - 'ipython_version': [1,2,3], - }) + msg = self.msg( + "kernel_info_reply", + { + "language": "python", + "language_version": [2, 8, 0], + "ipython_version": [1, 2, 3], + }, + ) v4, v5 = self.adapt(msg) - v4c = v4['content'] - v5c = v5['content'] - self.assertEqual(v5c, { - 'protocol_version': '4.1', - 'implementation': 'ipython', - 'implementation_version': '1.2.3', - 'language_info': { - 'name': 'python', - 'version': '2.8.0', + v4["content"] + v5c = v5["content"] + self.assertEqual( + v5c, + { + "protocol_version": "4.1", + "implementation": "ipython", + "implementation_version": "1.2.3", + "language_info": { + "name": "python", + "version": "2.8.0", + }, + "banner": "", }, - 'banner' : '', - }) + ) # iopub channel def test_display_data(self): jsondata = dict(a=5) - msg = self.msg("display_data", { - 'data' : { - 'text/plain' : 'some text', - 'application/json' : json.dumps(jsondata) + msg = self.msg( + "display_data", + { + "data": { + "text/plain": "some text", + "application/json": json.dumps(jsondata), + }, + "metadata": {"text/plain": {"key": "value"}}, }, - 'metadata' : {'text/plain' : { 'key' : 'value' }}, - }) + ) v4, v5 = self.adapt(msg) - v4c = v4['content'] - v5c = v5['content'] - self.assertEqual(v5c['metadata'], v4c['metadata']) - self.assertEqual(v5c['data']['text/plain'], v4c['data']['text/plain']) - self.assertEqual(v5c['data']['application/json'], jsondata) + v4c = v4["content"] + v5c = v5["content"] + self.assertEqual(v5c["metadata"], v4c["metadata"]) + self.assertEqual(v5c["data"]["text/plain"], v4c["data"]["text/plain"]) + self.assertEqual(v5c["data"]["application/json"], jsondata) # stdin channel def test_input_request(self): - msg = self.msg('input_request', {'prompt': "$>"}) + msg = self.msg("input_request", {"prompt": "$>"}) v4, v5 = self.adapt(msg) - self.assertEqual(v5['content']['prompt'], v4['content']['prompt']) - self.assertFalse(v5['content']['password']) + self.assertEqual(v5["content"]["prompt"], v4["content"]["prompt"]) + self.assertFalse(v5["content"]["password"]) class V5toV4TestCase(AdapterTest): @@ -239,166 +270,188 @@ def msg(self, msg_type, content): return self.session.msg(msg_type, content) def test_same_version(self): - msg = self.msg("execute_result", - content={'status' : 'ok'} - ) + msg = self.msg("execute_result", content={"status": "ok"}) original, adapted = self.adapt(msg, self.from_version) self.assertEqual(original, adapted) def test_no_adapt(self): - msg = self.msg("input_reply", {'value' : 'some text'}) + msg = self.msg("input_reply", {"value": "some text"}) v5, v4 = self.adapt(msg) - self.assertNotIn('version', v4['header']) - v5['header'].pop('version') + self.assertNotIn("version", v4["header"]) + v5["header"].pop("version") self.assertEqual(v4, v5) def test_rename_type(self): for v5_type, v4_type in [ - ('execute_result', 'pyout'), - ('execute_input', 'pyin'), - ('error', 'pyerr'), - ]: - msg = self.msg(v5_type, {'key' : 'value'}) + ("execute_result", "pyout"), + ("execute_input", "pyin"), + ("error", "pyerr"), + ]: + msg = self.msg(v5_type, {"key": "value"}) v5, v4 = self.adapt(msg) - self.assertEqual(v4['header']['msg_type'], v4_type) - assert 'version' not in v4['header'] - self.assertEqual(v4['content'], v5['content']) + self.assertEqual(v4["header"]["msg_type"], v4_type) + assert "version" not in v4["header"] + self.assertEqual(v4["content"], v5["content"]) def test_execute_request(self): - msg = self.msg("execute_request", { - 'code' : 'a=5', - 'silent' : False, - 'user_expressions' : {'a' : 'apple'}, - }) + msg = self.msg( + "execute_request", + { + "code": "a=5", + "silent": False, + "user_expressions": {"a": "apple"}, + }, + ) v5, v4 = self.adapt(msg) - self.assertEqual(v4['header']['msg_type'], v5['header']['msg_type']) - v4c = v4['content'] - v5c = v5['content'] - self.assertEqual(v4c['user_variables'], []) - self.assertEqual(v5c['code'], v4c['code']) + self.assertEqual(v4["header"]["msg_type"], v5["header"]["msg_type"]) + v4c = v4["content"] + v5c = v5["content"] + self.assertEqual(v4c["user_variables"], []) + self.assertEqual(v5c["code"], v4c["code"]) def test_complete_request(self): - msg = self.msg("complete_request", { - 'code' : 'def foo():\n' - ' a.is\n' - 'foo()', - 'cursor_pos': 19, - }) + msg = self.msg( + "complete_request", + { + "code": "def foo():\n" " a.is\n" "foo()", + "cursor_pos": 19, + }, + ) v5, v4 = self.adapt(msg) - v4c = v4['content'] - v5c = v5['content'] - self.assertNotIn('code', v4c) - self.assertEqual(v4c['line'], v5c['code'].splitlines(True)[1]) - self.assertEqual(v4c['cursor_pos'], 8) - self.assertEqual(v4c['text'], '') - self.assertEqual(v4c['block'], None) + v4c = v4["content"] + v5c = v5["content"] + self.assertNotIn("code", v4c) + self.assertEqual(v4c["line"], v5c["code"].splitlines(True)[1]) + self.assertEqual(v4c["cursor_pos"], 8) + self.assertEqual(v4c["text"], "") + self.assertEqual(v4c["block"], None) def test_complete_reply(self): - msg = self.msg("complete_reply", { - 'cursor_start' : 10, - 'cursor_end' : 14, - 'matches' : ['a.isalnum', - 'a.isalpha', - 'a.isdigit', - 'a.islower', - ], - 'metadata' : {}, - }) + msg = self.msg( + "complete_reply", + { + "cursor_start": 10, + "cursor_end": 14, + "matches": [ + "a.isalnum", + "a.isalpha", + "a.isdigit", + "a.islower", + ], + "metadata": {}, + }, + ) v5, v4 = self.adapt(msg) - v4c = v4['content'] - v5c = v5['content'] - self.assertEqual(v4c['matched_text'], 'a.is') - self.assertEqual(v4c['matches'], v5c['matches']) + v4c = v4["content"] + v5c = v5["content"] + self.assertEqual(v4c["matched_text"], "a.is") + self.assertEqual(v4c["matches"], v5c["matches"]) def test_inspect_request(self): - msg = self.msg("inspect_request", { - 'code' : 'def foo():\n' - ' apple\n' - 'bar()', - 'cursor_pos': 18, - 'detail_level' : 1, - }) + msg = self.msg( + "inspect_request", + { + "code": "def foo():\n" " apple\n" "bar()", + "cursor_pos": 18, + "detail_level": 1, + }, + ) v5, v4 = self.adapt(msg) - self.assertEqual(v4['header']['msg_type'], 'object_info_request') - v4c = v4['content'] - v5c = v5['content'] - self.assertEqual(v4c['oname'], 'apple') - self.assertEqual(v5c['detail_level'], v4c['detail_level']) + self.assertEqual(v4["header"]["msg_type"], "object_info_request") + v4c = v4["content"] + v5c = v5["content"] + self.assertEqual(v4c["oname"], "apple") + self.assertEqual(v5c["detail_level"], v4c["detail_level"]) def test_inspect_request_token(self): - line = 'something(range(10), kwarg=smth) ; xxx.xxx.xxx( firstarg, rand(234,23), kwarg1=2,' - msg = self.msg("inspect_request", { - 'code' : line, - 'cursor_pos': len(line)-1, - 'detail_level' : 1, - }) + line = "something(range(10), kwarg=smth) ; xxx.xxx.xxx( firstarg, rand(234,23), kwarg1=2," + msg = self.msg( + "inspect_request", + { + "code": line, + "cursor_pos": len(line) - 1, + "detail_level": 1, + }, + ) v5, v4 = self.adapt(msg) - self.assertEqual(v4['header']['msg_type'], 'object_info_request') - v4c = v4['content'] - v5c = v5['content'] - self.assertEqual(v4c['oname'], 'xxx.xxx.xxx') - self.assertEqual(v5c['detail_level'], v4c['detail_level']) + self.assertEqual(v4["header"]["msg_type"], "object_info_request") + v4c = v4["content"] + v5c = v5["content"] + self.assertEqual(v4c["oname"], "xxx.xxx.xxx") + self.assertEqual(v5c["detail_level"], v4c["detail_level"]) def test_inspect_reply(self): - msg = self.msg("inspect_reply", { - 'name' : 'foo', - 'found' : True, - 'data' : {'text/plain' : 'some text'}, - 'metadata' : {}, - }) + msg = self.msg( + "inspect_reply", + { + "name": "foo", + "found": True, + "data": {"text/plain": "some text"}, + "metadata": {}, + }, + ) v5, v4 = self.adapt(msg) - self.assertEqual(v4['header']['msg_type'], 'object_info_reply') - v4c = v4['content'] - v5c = v5['content'] - self.assertEqual(sorted(v4c), ['found', 'oname']) - self.assertEqual(v4c['found'], False) + self.assertEqual(v4["header"]["msg_type"], "object_info_reply") + v4c = v4["content"] + v5["content"] + self.assertEqual(sorted(v4c), ["found", "oname"]) + self.assertEqual(v4c["found"], False) def test_kernel_info_reply(self): - msg = self.msg("kernel_info_reply", { - 'protocol_version': '5.0', - 'implementation': 'ipython', - 'implementation_version': '1.2.3', - 'language_info': { - 'name': 'python', - 'version': '2.8.0', - 'mimetype': 'text/x-python', + msg = self.msg( + "kernel_info_reply", + { + "protocol_version": "5.0", + "implementation": "ipython", + "implementation_version": "1.2.3", + "language_info": { + "name": "python", + "version": "2.8.0", + "mimetype": "text/x-python", + }, + "banner": "the banner", }, - 'banner' : 'the banner', - }) + ) v5, v4 = self.adapt(msg) - v4c = v4['content'] - v5c = v5['content'] - info = v5c['language_info'] - self.assertEqual(v4c, { - 'protocol_version': [5,0], - 'language': 'python', - 'language_version': [2,8,0], - 'ipython_version': [1,2,3], - }) + v4c = v4["content"] + v5c = v5["content"] + v5c["language_info"] + self.assertEqual( + v4c, + { + "protocol_version": [5, 0], + "language": "python", + "language_version": [2, 8, 0], + "ipython_version": [1, 2, 3], + }, + ) # iopub channel def test_display_data(self): jsondata = dict(a=5) - msg = self.msg("display_data", { - 'data' : { - 'text/plain' : 'some text', - 'application/json' : jsondata, + msg = self.msg( + "display_data", + { + "data": { + "text/plain": "some text", + "application/json": jsondata, + }, + "metadata": {"text/plain": {"key": "value"}}, }, - 'metadata' : {'text/plain' : { 'key' : 'value' }}, - }) + ) v5, v4 = self.adapt(msg) - v4c = v4['content'] - v5c = v5['content'] - self.assertEqual(v5c['metadata'], v4c['metadata']) - self.assertEqual(v5c['data']['text/plain'], v4c['data']['text/plain']) - self.assertEqual(v4c['data']['application/json'], json.dumps(jsondata)) + v4c = v4["content"] + v5c = v5["content"] + self.assertEqual(v5c["metadata"], v4c["metadata"]) + self.assertEqual(v5c["data"]["text/plain"], v4c["data"]["text/plain"]) + self.assertEqual(v4c["data"]["application/json"], json.dumps(jsondata)) # stdin channel def test_input_request(self): - msg = self.msg('input_request', {'prompt': "$>", 'password' : True}) + msg = self.msg("input_request", {"prompt": "$>", "password": True}) v5, v4 = self.adapt(msg) - self.assertEqual(v5['content']['prompt'], v4['content']['prompt']) - self.assertNotIn('password', v4['content']) + self.assertEqual(v5["content"]["prompt"], v4["content"]["prompt"]) + self.assertNotIn("password", v4["content"]) diff --git a/jupyter_client/tests/test_client.py b/jupyter_client/tests/test_client.py index 35188236d..3c0f625c9 100644 --- a/jupyter_client/tests/test_client.py +++ b/jupyter_client/tests/test_client.py @@ -1,22 +1,22 @@ """Tests for the KernelClient""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - - import os -pjoin = os.path.join from unittest import TestCase -from jupyter_client.kernelspec import KernelSpecManager, NoSuchKernel, NATIVE_KERNEL_NAME +import pytest +from IPython.utils.capture import capture_output + from ..manager import start_new_kernel from .utils import test_env +from jupyter_client.kernelspec import KernelSpecManager +from jupyter_client.kernelspec import NATIVE_KERNEL_NAME +from jupyter_client.kernelspec import NoSuchKernel -import pytest +TIMEOUT = 30 -from IPython.utils.capture import capture_output +pjoin = os.path.join -TIMEOUT = 30 class TestKernelClient(TestCase): def setUp(self): @@ -36,52 +36,52 @@ def test_execute_interactive(self): with capture_output() as io: reply = kc.execute_interactive("print('hello')", timeout=TIMEOUT) - assert 'hello' in io.stdout - assert reply['content']['status'] == 'ok' + assert "hello" in io.stdout + assert reply["content"]["status"] == "ok" def _check_reply(self, reply_type, reply): self.assertIsInstance(reply, dict) - self.assertEqual(reply['header']['msg_type'], reply_type + '_reply') - self.assertEqual(reply['parent_header']['msg_type'], reply_type + '_request') + self.assertEqual(reply["header"]["msg_type"], reply_type + "_reply") + self.assertEqual(reply["parent_header"]["msg_type"], reply_type + "_request") def test_history(self): kc = self.kc msg_id = kc.history(session=0) self.assertIsInstance(msg_id, str) reply = kc.history(session=0, reply=True, timeout=TIMEOUT) - self._check_reply('history', reply) + self._check_reply("history", reply) def test_inspect(self): kc = self.kc - msg_id = kc.inspect('who cares') + msg_id = kc.inspect("who cares") self.assertIsInstance(msg_id, str) - reply = kc.inspect('code', reply=True, timeout=TIMEOUT) - self._check_reply('inspect', reply) + reply = kc.inspect("code", reply=True, timeout=TIMEOUT) + self._check_reply("inspect", reply) def test_complete(self): kc = self.kc - msg_id = kc.complete('who cares') + msg_id = kc.complete("who cares") self.assertIsInstance(msg_id, str) - reply = kc.complete('code', reply=True, timeout=TIMEOUT) - self._check_reply('complete', reply) + reply = kc.complete("code", reply=True, timeout=TIMEOUT) + self._check_reply("complete", reply) def test_kernel_info(self): kc = self.kc msg_id = kc.kernel_info() self.assertIsInstance(msg_id, str) reply = kc.kernel_info(reply=True, timeout=TIMEOUT) - self._check_reply('kernel_info', reply) + self._check_reply("kernel_info", reply) def test_comm_info(self): kc = self.kc msg_id = kc.comm_info() self.assertIsInstance(msg_id, str) reply = kc.comm_info(reply=True, timeout=TIMEOUT) - self._check_reply('comm_info', reply) + self._check_reply("comm_info", reply) def test_shutdown(self): kc = self.kc msg_id = kc.shutdown() self.assertIsInstance(msg_id, str) reply = kc.shutdown(reply=True, timeout=TIMEOUT) - self._check_reply('shutdown', reply) + self._check_reply("shutdown", reply) diff --git a/jupyter_client/tests/test_connect.py b/jupyter_client/tests/test_connect.py index fe0f3535d..7ae6618c2 100644 --- a/jupyter_client/tests/test_connect.py +++ b/jupyter_client/tests/test_connect.py @@ -1,23 +1,17 @@ """Tests for kernel connection utilities""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - import json import os -import re -import stat -import tempfile -import shutil +from tempfile import TemporaryDirectory -from traitlets.config import Config from jupyter_core.application import JupyterApp from jupyter_core.paths import jupyter_runtime_dir -from tempfile import TemporaryDirectory -from jupyter_client import connect, KernelClient + +from jupyter_client import connect +from jupyter_client import KernelClient from jupyter_client.consoleapp import JupyterConsoleApp from jupyter_client.session import Session -from jupyter_client.connect import secure_write class TemporaryWorkingDirectory(TemporaryDirectory): @@ -40,33 +34,50 @@ def __exit__(self, exc, value, tb): return super().__exit__(exc, value, tb) - class DummyConsoleApp(JupyterApp, JupyterConsoleApp): def initialize(self, argv=[]): JupyterApp.initialize(self, argv=argv) self.init_connection_file() + class DummyConfigurable(connect.ConnectionFileMixin): def initialize(self): pass -sample_info = dict(ip='1.2.3.4', transport='ipc', - shell_port=1, hb_port=2, iopub_port=3, stdin_port=4, control_port=5, - key=b'abc123', signature_scheme='hmac-md5', kernel_name='python' - ) -sample_info_kn = dict(ip='1.2.3.4', transport='ipc', - shell_port=1, hb_port=2, iopub_port=3, stdin_port=4, control_port=5, - key=b'abc123', signature_scheme='hmac-md5', kernel_name='test' - ) +sample_info = dict( + ip="1.2.3.4", + transport="ipc", + shell_port=1, + hb_port=2, + iopub_port=3, + stdin_port=4, + control_port=5, + key=b"abc123", + signature_scheme="hmac-md5", + kernel_name="python", +) + +sample_info_kn = dict( + ip="1.2.3.4", + transport="ipc", + shell_port=1, + hb_port=2, + iopub_port=3, + stdin_port=4, + control_port=5, + key=b"abc123", + signature_scheme="hmac-md5", + kernel_name="test", +) def test_write_connection_file(): with TemporaryDirectory() as d: - cf = os.path.join(d, 'kernel.json') + cf = os.path.join(d, "kernel.json") connect.write_connection_file(cf, **sample_info) assert os.path.exists(cf) - with open(cf, 'r') as f: + with open(cf, "r") as f: info = json.load(f) info["key"] = info["key"].encode() assert info == sample_info @@ -80,13 +91,13 @@ def test_load_connection_file_session(): session = app.session with TemporaryDirectory() as d: - cf = os.path.join(d, 'kernel.json') + cf = os.path.join(d, "kernel.json") connect.write_connection_file(cf, **sample_info) app.connection_file = cf app.load_connection_file() - assert session.key == sample_info['key'] - assert session.signature_scheme == sample_info['signature_scheme'] + assert session.key == sample_info["key"] + assert session.signature_scheme == sample_info["signature_scheme"] def test_load_connection_file_session_with_kn(): @@ -97,25 +108,25 @@ def test_load_connection_file_session_with_kn(): session = app.session with TemporaryDirectory() as d: - cf = os.path.join(d, 'kernel.json') + cf = os.path.join(d, "kernel.json") connect.write_connection_file(cf, **sample_info_kn) app.connection_file = cf app.load_connection_file() - assert session.key == sample_info_kn['key'] - assert session.signature_scheme == sample_info_kn['signature_scheme'] + assert session.key == sample_info_kn["key"] + assert session.signature_scheme == sample_info_kn["signature_scheme"] def test_app_load_connection_file(): """test `ipython console --existing` loads a connection file""" with TemporaryDirectory() as d: - cf = os.path.join(d, 'kernel.json') + cf = os.path.join(d, "kernel.json") connect.write_connection_file(cf, **sample_info) app = DummyConsoleApp(connection_file=cf) app.initialize(argv=[]) for attr, expected in sample_info.items(): - if attr in ('key', 'signature_scheme'): + if attr in ("key", "signature_scheme"): continue value = getattr(app, attr) assert value == expected, "app.%s = %s != %s" % (attr, value, expected) @@ -124,89 +135,89 @@ def test_app_load_connection_file(): def test_load_connection_info(): client = KernelClient() info = { - 'control_port': 53702, - 'hb_port': 53705, - 'iopub_port': 53703, - 'ip': '0.0.0.0', - 'key': 'secret', - 'shell_port': 53700, - 'signature_scheme': 'hmac-sha256', - 'stdin_port': 53701, - 'transport': 'tcp', + "control_port": 53702, + "hb_port": 53705, + "iopub_port": 53703, + "ip": "0.0.0.0", + "key": "secret", + "shell_port": 53700, + "signature_scheme": "hmac-sha256", + "stdin_port": 53701, + "transport": "tcp", } client.load_connection_info(info) - assert client.control_port == info['control_port'] - assert client.session.key.decode('ascii') == info['key'] - assert client.ip == info['ip'] + assert client.control_port == info["control_port"] + assert client.session.key.decode("ascii") == info["key"] + assert client.ip == info["ip"] def test_find_connection_file(): with TemporaryDirectory() as d: - cf = 'kernel.json' + cf = "kernel.json" app = DummyConsoleApp(runtime_dir=d, connection_file=cf) app.initialize() security_dir = app.runtime_dir profile_cf = os.path.join(security_dir, cf) - with open(profile_cf, 'w') as f: + with open(profile_cf, "w") as f: f.write("{}") for query in ( - 'kernel.json', - 'kern*', - '*ernel*', - 'k*', - ): + "kernel.json", + "kern*", + "*ernel*", + "k*", + ): assert connect.find_connection_file(query, path=security_dir) == profile_cf def test_find_connection_file_local(): - with TemporaryWorkingDirectory() as d: - cf = 'test.json' + with TemporaryWorkingDirectory(): + cf = "test.json" abs_cf = os.path.abspath(cf) - with open(cf, 'w') as f: - f.write('{}') + with open(cf, "w") as f: + f.write("{}") for query in ( - 'test.json', - 'test', + "test.json", + "test", abs_cf, - os.path.join('.', 'test.json'), + os.path.join(".", "test.json"), ): - assert connect.find_connection_file(query, path=['.', jupyter_runtime_dir()]) == abs_cf + assert connect.find_connection_file(query, path=[".", jupyter_runtime_dir()]) == abs_cf def test_find_connection_file_relative(): - with TemporaryWorkingDirectory() as d: - cf = 'test.json' - os.mkdir('subdir') - cf = os.path.join('subdir', 'test.json') + with TemporaryWorkingDirectory(): + cf = "test.json" + os.mkdir("subdir") + cf = os.path.join("subdir", "test.json") abs_cf = os.path.abspath(cf) - with open(cf, 'w') as f: - f.write('{}') + with open(cf, "w") as f: + f.write("{}") for query in ( - os.path.join('.', 'subdir', 'test.json'), - os.path.join('subdir', 'test.json'), + os.path.join(".", "subdir", "test.json"), + os.path.join("subdir", "test.json"), abs_cf, ): - assert connect.find_connection_file(query, path=['.', jupyter_runtime_dir()]) == abs_cf + assert connect.find_connection_file(query, path=[".", jupyter_runtime_dir()]) == abs_cf def test_find_connection_file_abspath(): - with TemporaryDirectory() as d: - cf = 'absolute.json' + with TemporaryDirectory(): + cf = "absolute.json" abs_cf = os.path.abspath(cf) - with open(cf, 'w') as f: - f.write('{}') + with open(cf, "w") as f: + f.write("{}") assert connect.find_connection_file(abs_cf, path=jupyter_runtime_dir()) == abs_cf os.remove(abs_cf) def test_mixin_record_random_ports(): with TemporaryDirectory() as d: - dc = DummyConfigurable(data_dir=d, kernel_name='via-tcp', transport='tcp') + dc = DummyConfigurable(data_dir=d, kernel_name="via-tcp", transport="tcp") dc.write_connection_file() assert dc._connection_file_written @@ -216,7 +227,7 @@ def test_mixin_record_random_ports(): def test_mixin_cleanup_random_ports(): with TemporaryDirectory() as d: - dc = DummyConfigurable(data_dir=d, kernel_name='via-tcp', transport='tcp') + dc = DummyConfigurable(data_dir=d, kernel_name="via-tcp", transport="tcp") dc.write_connection_file() filename = dc.connection_file dc.cleanup_random_ports() diff --git a/jupyter_client/tests/test_jsonutil.py b/jupyter_client/tests/test_jsonutil.py index d7383c1f7..2f17e4fe3 100644 --- a/jupyter_client/tests/test_jsonutil.py +++ b/jupyter_client/tests/test_jsonutil.py @@ -1,29 +1,26 @@ """Test suite for our JSON utilities.""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - -import json -import pytest import datetime - +import json from datetime import timedelta from unittest import mock -from dateutil.tz import tzlocal, tzoffset + +import pytest +from dateutil.tz import tzlocal +from dateutil.tz import tzoffset + from jupyter_client import jsonutil from jupyter_client.session import utcnow - -REFERENCE_DATETIME = datetime.datetime( - 2013, 7, 3, 16, 34, 52, 249482, tzlocal() -) +REFERENCE_DATETIME = datetime.datetime(2013, 7, 3, 16, 34, 52, 249482, tzlocal()) def test_extract_date_from_naive(): ref = REFERENCE_DATETIME - timestamp = '2013-07-03T16:34:52.249482' + timestamp = "2013-07-03T16:34:52.249482" - with pytest.deprecated_call(match='Interpreting naive datetime as local'): + with pytest.deprecated_call(match="Interpreting naive datetime as local"): extracted = jsonutil.extract_dates(timestamp) assert isinstance(extracted, datetime.datetime) @@ -35,11 +32,11 @@ def test_extract_date_from_naive(): def test_extract_dates(): ref = REFERENCE_DATETIME timestamps = [ - '2013-07-03T16:34:52.249482Z', - '2013-07-03T16:34:52.249482-0800', - '2013-07-03T16:34:52.249482+0800', - '2013-07-03T16:34:52.249482-08:00', - '2013-07-03T16:34:52.249482+08:00', + "2013-07-03T16:34:52.249482Z", + "2013-07-03T16:34:52.249482-0800", + "2013-07-03T16:34:52.249482+0800", + "2013-07-03T16:34:52.249482-08:00", + "2013-07-03T16:34:52.249482+08:00", ] extracted = jsonutil.extract_dates(timestamps) for dt in extracted: @@ -54,14 +51,14 @@ def test_extract_dates(): def test_parse_ms_precision(): - base = '2013-07-03T16:34:52' - digits = '1234567890' + base = "2013-07-03T16:34:52" + digits = "1234567890" - parsed = jsonutil.parse_date(base+'Z') + parsed = jsonutil.parse_date(base + "Z") assert isinstance(parsed, datetime.datetime) for i in range(len(digits)): - ts = base + '.' + digits[:i] - parsed = jsonutil.parse_date(ts+'Z') + ts = base + "." + digits[:i] + parsed = jsonutil.parse_date(ts + "Z") if i >= 1 and i <= 6: assert isinstance(parsed, datetime.datetime) else: @@ -70,16 +67,15 @@ def test_parse_ms_precision(): def test_date_default(): naive = datetime.datetime.now() - local = tzoffset('Local', -8 * 3600) - other = tzoffset('Other', 2 * 3600) + local = tzoffset("Local", -8 * 3600) + other = tzoffset("Other", 2 * 3600) data = dict(naive=naive, utc=utcnow(), withtz=naive.replace(tzinfo=other)) - with mock.patch.object(jsonutil, 'tzlocal', lambda : local): - with pytest.deprecated_call(match='Please add timezone info'): + with mock.patch.object(jsonutil, "tzlocal", lambda: local): + with pytest.deprecated_call(match="Please add timezone info"): jsondata = json.dumps(data, default=jsonutil.date_default) assert "Z" in jsondata assert jsondata.count("Z") == 1 extracted = jsonutil.extract_dates(json.loads(jsondata)) for dt in extracted.values(): assert isinstance(dt, datetime.datetime) - assert dt.tzinfo != None - + assert dt.tzinfo is not None diff --git a/jupyter_client/tests/test_kernelapp.py b/jupyter_client/tests/test_kernelapp.py index af28f814a..702e2b375 100644 --- a/jupyter_client/tests/test_kernelapp.py +++ b/jupyter_client/tests/test_kernelapp.py @@ -1,38 +1,45 @@ import os -import sys import shutil +import sys import time - -from subprocess import Popen, PIPE +from subprocess import PIPE +from subprocess import Popen from tempfile import mkdtemp + def _launch(extra_env): env = os.environ.copy() env.update(extra_env) - return Popen([sys.executable, '-c', - 'from jupyter_client.kernelapp import main; main()'], - env=env, stderr=PIPE) + return Popen( + [sys.executable, "-c", "from jupyter_client.kernelapp import main; main()"], + env=env, + stderr=PIPE, + ) + WAIT_TIME = 10 POLL_FREQ = 10 + def test_kernelapp_lifecycle(): # Check that 'jupyter kernel' starts and terminates OK. runtime_dir = mkdtemp() startup_dir = mkdtemp() - started = os.path.join(startup_dir, 'started') + started = os.path.join(startup_dir, "started") try: - p = _launch({'JUPYTER_RUNTIME_DIR': runtime_dir, - 'JUPYTER_CLIENT_TEST_RECORD_STARTUP_PRIVATE': started, - }) + p = _launch( + { + "JUPYTER_RUNTIME_DIR": runtime_dir, + "JUPYTER_CLIENT_TEST_RECORD_STARTUP_PRIVATE": started, + } + ) # Wait for start for _ in range(WAIT_TIME * POLL_FREQ): if os.path.isfile(started): break time.sleep(1 / POLL_FREQ) else: - raise AssertionError("No started file created in {} seconds" - .format(WAIT_TIME)) + raise AssertionError("No started file created in {} seconds".format(WAIT_TIME)) # Connection file should be there by now for _ in range(WAIT_TIME * POLL_FREQ): @@ -41,18 +48,17 @@ def test_kernelapp_lifecycle(): break time.sleep(1 / POLL_FREQ) else: - raise AssertionError("No connection file created in {} seconds" - .format(WAIT_TIME)) + raise AssertionError("No connection file created in {} seconds".format(WAIT_TIME)) assert len(files) == 1 cf = files[0] - assert cf.startswith('kernel') - assert cf.endswith('.json') + assert cf.startswith("kernel") + assert cf.endswith(".json") # Send SIGTERM to shut down time.sleep(1) p.terminate() _, stderr = p.communicate(timeout=WAIT_TIME) - assert cf in stderr.decode('utf-8', 'replace') + assert cf in stderr.decode("utf-8", "replace") finally: shutil.rmtree(runtime_dir) shutil.rmtree(startup_dir) diff --git a/jupyter_client/tests/test_kernelmanager.py b/jupyter_client/tests/test_kernelmanager.py index 5380bdca0..11a946418 100644 --- a/jupyter_client/tests/test_kernelmanager.py +++ b/jupyter_client/tests/test_kernelmanager.py @@ -1,33 +1,36 @@ """Tests for the KernelManager""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - - import asyncio +import concurrent.futures import json import os import signal import sys import time -import concurrent.futures -import pytest +from subprocess import PIPE -import nest_asyncio -from async_generator import async_generator, yield_ -from traitlets.config.loader import Config +import pytest +from async_generator import async_generator +from async_generator import yield_ from jupyter_core import paths -from jupyter_client import KernelManager, AsyncKernelManager -from subprocess import PIPE +from traitlets.config.loader import Config -from ..manager import start_new_kernel, start_new_async_kernel from ..manager import _ShutdownStatus -from .utils import test_env, SyncKMSubclass, AsyncKMSubclass, AsyncKernelManagerWithCleanup +from ..manager import start_new_async_kernel +from ..manager import start_new_kernel +from .utils import AsyncKernelManagerWithCleanup +from .utils import AsyncKMSubclass +from .utils import SyncKMSubclass +from .utils import test_env +from jupyter_client import AsyncKernelManager +from jupyter_client import KernelManager pjoin = os.path.join TIMEOUT = 30 + @pytest.fixture(autouse=True) def env(): env_patch = test_env() @@ -36,9 +39,9 @@ def env(): env_patch.stop() -@pytest.fixture(params=['tcp', 'ipc']) +@pytest.fixture(params=["tcp", "ipc"]) def transport(request): - if sys.platform == 'win32' and request.param == 'ipc': # + if sys.platform == "win32" and request.param == "ipc": # pytest.skip("Transport 'ipc' not supported on Windows.") return request.param @@ -47,8 +50,8 @@ def transport(request): def config(transport): c = Config() c.KernelManager.transport = transport - if transport == 'ipc': - c.KernelManager.ip = 'test' + if transport == "ipc": + c.KernelManager.ip = "test" return c @@ -92,7 +95,7 @@ def install_kernel_dont_terminate(): @pytest.fixture def start_kernel(): - km, kc = start_new_kernel(kernel_name='signaltest') + km, kc = start_new_kernel(kernel_name="signaltest") yield km, kc kc.stop_channels() km.shutdown_kernel() @@ -104,6 +107,7 @@ def km(config): km = KernelManager(config=config) return km + @pytest.fixture def km_subclass(config): km = SyncKMSubclass(config=config) @@ -113,6 +117,7 @@ def km_subclass(config): @pytest.fixture def zmq_context(): import zmq + ctx = zmq.Context() yield ctx ctx.term() @@ -131,9 +136,10 @@ def async_km_subclass(config): @pytest.fixture -@async_generator # This is only necessary while Python 3.5 is support afterwhich both it and yield_() can be removed +@async_generator # This is only necessary while Python 3.5 is support afterwhich both it and +# yield_() can be removed async def start_async_kernel(): - km, kc = await start_new_async_kernel(kernel_name='signaltest') + km, kc = await start_new_async_kernel(kernel_name="signaltest") await yield_((km, kc)) kc.stop_channels() await km.shutdown_kernel() @@ -158,9 +164,7 @@ class TestKernelManagerShutDownGracefully: ], ) - @pytest.mark.skipif( - sys.platform == "win32", reason="Windows doesn't support signals" - ) + @pytest.mark.skipif(sys.platform == "win32", reason="Windows doesn't support signals") @pytest.mark.parametrize(*parameters) def test_signal_kernel_subprocesses(self, name, install, expected): install() @@ -174,9 +178,7 @@ def test_signal_kernel_subprocesses(self, name, install, expected): assert km._shutdown_status == expected @pytest.mark.asyncio - @pytest.mark.skipif( - sys.platform == "win32", reason="Windows doesn't support signals" - ) + @pytest.mark.skipif(sys.platform == "win32", reason="Windows doesn't support signals") @pytest.mark.parametrize(*parameters) async def test_async_signal_kernel_subprocesses(self, name, install, expected): install() @@ -191,7 +193,6 @@ async def test_async_signal_kernel_subprocesses(self, name, install, expected): class TestKernelManager: - def test_lifecycle(self, km): km.start_kernel(stdout=PIPE, stderr=PIPE) assert km.is_alive() @@ -205,14 +206,22 @@ def test_lifecycle(self, km): def test_get_connect_info(self, km): cinfo = km.get_connection_info() keys = sorted(cinfo.keys()) - expected = sorted([ - 'ip', 'transport', - 'hb_port', 'shell_port', 'stdin_port', 'iopub_port', 'control_port', - 'key', 'signature_scheme', - ]) + expected = sorted( + [ + "ip", + "transport", + "hb_port", + "shell_port", + "stdin_port", + "iopub_port", + "control_port", + "key", + "signature_scheme", + ] + ) assert keys == expected - @pytest.mark.skipif(sys.platform == 'win32', reason="Windows doesn't support signals") + @pytest.mark.skipif(sys.platform == "win32", reason="Windows doesn't support signals") def test_signal_kernel_subprocesses(self, install_kernel, start_kernel): km, kc = start_kernel @@ -221,36 +230,36 @@ def execute(cmd): request_id = kc.execute(cmd) while True: reply = kc.get_shell_msg(TIMEOUT) - if reply['parent_header']['msg_id'] == request_id: + if reply["parent_header"]["msg_id"] == request_id: break - content = reply['content'] - assert content['status'] == 'ok' + content = reply["content"] + assert content["status"] == "ok" return content N = 5 for i in range(N): execute("start") time.sleep(1) # make sure subprocs stay up - reply = execute('check') - assert reply['user_expressions']['poll'] == [None] * N - + reply = execute("check") + assert reply["user_expressions"]["poll"] == [None] * N + # start a job on the kernel to be interrupted - kc.execute('sleep') + kc.execute("sleep") time.sleep(1) # ensure sleep message has been handled before we interrupt km.interrupt_kernel() reply = kc.get_shell_msg(TIMEOUT) - content = reply['content'] - assert content['status'] == 'ok' - assert content['user_expressions']['interrupted'] + content = reply["content"] + assert content["status"] == "ok" + assert content["user_expressions"]["interrupted"] # wait up to 10s for subprocesses to handle signal for i in range(100): - reply = execute('check') - if reply['user_expressions']['poll'] != [-signal.SIGINT] * N: + reply = execute("check") + if reply["user_expressions"]["poll"] != [-signal.SIGINT] * N: time.sleep(0.1) else: break # verify that subprocesses were interrupted - assert reply['user_expressions']['poll'] == [-signal.SIGINT] * N + assert reply["user_expressions"]["poll"] == [-signal.SIGINT] * N def test_start_new_kernel(self, install_kernel, start_kernel): km, kc = start_kernel @@ -263,15 +272,15 @@ def execute(cmd): request_id = kc.execute(cmd) while True: reply = kc.get_shell_msg(TIMEOUT) - if reply['parent_header']['msg_id'] == request_id: + if reply["parent_header"]["msg_id"] == request_id: break - content = reply['content'] - assert content['status'] == 'ok' + content = reply["content"] + assert content["status"] == "ok" return content - reply = execute('env') + reply = execute("env") assert reply is not None - assert reply['user_expressions']['env'] == 'test_var_1:test_var_2' + assert reply["user_expressions"]["env"] == "test_var_1:test_var_2" def test_templated_kspec_env(self, install_kernel, start_kernel): km, kc = start_kernel @@ -298,37 +307,37 @@ def test_no_cleanup_shared_context(self, zmq_context): def test_subclass_callables(self, km_subclass): km_subclass.reset_counts() km_subclass.start_kernel(stdout=PIPE, stderr=PIPE) - assert km_subclass.call_count('start_kernel') == 1 - assert km_subclass.call_count('_launch_kernel') == 1 + assert km_subclass.call_count("start_kernel") == 1 + assert km_subclass.call_count("_launch_kernel") == 1 is_alive = km_subclass.is_alive() assert is_alive km_subclass.reset_counts() km_subclass.restart_kernel(now=True) - assert km_subclass.call_count('restart_kernel') == 1 - assert km_subclass.call_count('shutdown_kernel') == 1 - assert km_subclass.call_count('interrupt_kernel') == 1 - assert km_subclass.call_count('_kill_kernel') == 1 - assert km_subclass.call_count('cleanup_resources') == 1 - assert km_subclass.call_count('start_kernel') == 1 - assert km_subclass.call_count('_launch_kernel') == 1 + assert km_subclass.call_count("restart_kernel") == 1 + assert km_subclass.call_count("shutdown_kernel") == 1 + assert km_subclass.call_count("interrupt_kernel") == 1 + assert km_subclass.call_count("_kill_kernel") == 1 + assert km_subclass.call_count("cleanup_resources") == 1 + assert km_subclass.call_count("start_kernel") == 1 + assert km_subclass.call_count("_launch_kernel") == 1 is_alive = km_subclass.is_alive() assert is_alive km_subclass.reset_counts() km_subclass.interrupt_kernel() - assert km_subclass.call_count('interrupt_kernel') == 1 + assert km_subclass.call_count("interrupt_kernel") == 1 assert isinstance(km_subclass, KernelManager) km_subclass.reset_counts() km_subclass.shutdown_kernel(now=False) - assert km_subclass.call_count('shutdown_kernel') == 1 - assert km_subclass.call_count('interrupt_kernel') == 1 - assert km_subclass.call_count('request_shutdown') == 1 - assert km_subclass.call_count('finish_shutdown') == 1 - assert km_subclass.call_count('cleanup_resources') == 1 + assert km_subclass.call_count("shutdown_kernel") == 1 + assert km_subclass.call_count("interrupt_kernel") == 1 + assert km_subclass.call_count("request_shutdown") == 1 + assert km_subclass.call_count("finish_shutdown") == 1 + assert km_subclass.call_count("cleanup_resources") == 1 is_alive = km_subclass.is_alive() assert is_alive is False @@ -336,7 +345,6 @@ def test_subclass_callables(self, km_subclass): class TestParallel: - @pytest.mark.timeout(TIMEOUT) def test_start_sequence_kernels(self, config, install_kernel): """Ensure that a sequence of kernel startups doesn't break anything.""" @@ -346,7 +354,7 @@ def test_start_sequence_kernels(self, config, install_kernel): @pytest.mark.timeout(TIMEOUT) def test_start_parallel_thread_kernels(self, config, install_kernel): - if config.KernelManager.transport == 'ipc': # FIXME + if config.KernelManager.transport == "ipc": # FIXME pytest.skip("IPC transport is currently not working for this test!") self._run_signaltest_lifecycle(config) @@ -357,9 +365,12 @@ def test_start_parallel_thread_kernels(self, config, install_kernel): future2.result() @pytest.mark.timeout(TIMEOUT) - @pytest.mark.skipif((sys.platform == 'darwin') and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 8)), reason='"Bad file descriptor" error') + @pytest.mark.skipif( + (sys.platform == "darwin") and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 8)), + reason='"Bad file descriptor" error', + ) def test_start_parallel_process_kernels(self, config, install_kernel): - if config.KernelManager.transport == 'ipc': # FIXME + if config.KernelManager.transport == "ipc": # FIXME pytest.skip("IPC transport is currently not working for this test!") self._run_signaltest_lifecycle(config) with concurrent.futures.ThreadPoolExecutor(max_workers=1) as thread_executor: @@ -370,9 +381,12 @@ def test_start_parallel_process_kernels(self, config, install_kernel): future1.result() @pytest.mark.timeout(TIMEOUT) - @pytest.mark.skipif((sys.platform == 'darwin') and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 8)), reason='"Bad file descriptor" error') + @pytest.mark.skipif( + (sys.platform == "darwin") and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 8)), + reason='"Bad file descriptor" error', + ) def test_start_sequence_process_kernels(self, config, install_kernel): - if config.KernelManager.transport == 'ipc': # FIXME + if config.KernelManager.transport == "ipc": # FIXME pytest.skip("IPC transport is currently not working for this test!") self._run_signaltest_lifecycle(config) with concurrent.futures.ProcessPoolExecutor(max_workers=1) as pool_executor: @@ -393,27 +407,27 @@ def _prepare_kernel(self, km, startup_timeout=TIMEOUT, **kwargs): return kc def _run_signaltest_lifecycle(self, config=None): - km = KernelManager(config=config, kernel_name='signaltest') + km = KernelManager(config=config, kernel_name="signaltest") kc = self._prepare_kernel(km, stdout=PIPE, stderr=PIPE) def execute(cmd): request_id = kc.execute(cmd) while True: reply = kc.get_shell_msg(TIMEOUT) - if reply['parent_header']['msg_id'] == request_id: + if reply["parent_header"]["msg_id"] == request_id: break - content = reply['content'] - assert content['status'] == 'ok' + content = reply["content"] + assert content["status"] == "ok" return content execute("start") assert km.is_alive() - execute('check') + execute("check") assert km.is_alive() km.restart_kernel(now=True) assert km.is_alive() - execute('check') + execute("check") km.shutdown_kernel() assert km.context.closed @@ -421,7 +435,6 @@ def execute(cmd): @pytest.mark.asyncio class TestAsyncKernelManager: - async def test_lifecycle(self, async_km): await async_km.start_kernel(stdout=PIPE, stderr=PIPE) is_alive = await async_km.is_alive() @@ -439,11 +452,19 @@ async def test_lifecycle(self, async_km): async def test_get_connect_info(self, async_km): cinfo = async_km.get_connection_info() keys = sorted(cinfo.keys()) - expected = sorted([ - 'ip', 'transport', - 'hb_port', 'shell_port', 'stdin_port', 'iopub_port', 'control_port', - 'key', 'signature_scheme', - ]) + expected = sorted( + [ + "ip", + "transport", + "hb_port", + "shell_port", + "stdin_port", + "iopub_port", + "control_port", + "key", + "signature_scheme", + ] + ) assert keys == expected async def test_subclass_deprecations(self, async_km): @@ -464,7 +485,7 @@ async def test_subclass_deprecations(self, async_km): assert hasattr(async_km, "which_cleanup") is False @pytest.mark.timeout(10) - @pytest.mark.skipif(sys.platform == 'win32', reason="Windows doesn't support signals") + @pytest.mark.skipif(sys.platform == "win32", reason="Windows doesn't support signals") async def test_signal_kernel_subprocesses(self, install_kernel, start_async_kernel): km, kc = start_async_kernel @@ -473,11 +494,12 @@ async def execute(cmd): request_id = kc.execute(cmd) while True: reply = await kc.get_shell_msg(TIMEOUT) - if reply['parent_header']['msg_id'] == request_id: + if reply["parent_header"]["msg_id"] == request_id: break - content = reply['content'] - assert content['status'] == 'ok' + content = reply["content"] + assert content["status"] == "ok" return content + # Ensure that shutdown_kernel and stop_channels are called at the end of the test. # Note: we cannot use addCleanup() for these since it doesn't prpperly handle # coroutines - which km.shutdown_kernel now is. @@ -485,29 +507,29 @@ async def execute(cmd): for i in range(N): await execute("start") await asyncio.sleep(1) # make sure subprocs stay up - reply = await execute('check') - assert reply['user_expressions']['poll'] == [None] * N + reply = await execute("check") + assert reply["user_expressions"]["poll"] == [None] * N # start a job on the kernel to be interrupted - request_id = kc.execute('sleep') + request_id = kc.execute("sleep") await asyncio.sleep(1) # ensure sleep message has been handled before we interrupt await km.interrupt_kernel() while True: reply = await kc.get_shell_msg(TIMEOUT) - if reply['parent_header']['msg_id'] == request_id: + if reply["parent_header"]["msg_id"] == request_id: break - content = reply['content'] - assert content['status'] == 'ok' - assert content['user_expressions']['interrupted'] is True + content = reply["content"] + assert content["status"] == "ok" + assert content["user_expressions"]["interrupted"] is True # wait up to 5s for subprocesses to handle signal for i in range(50): - reply = await execute('check') - if reply['user_expressions']['poll'] != [-signal.SIGINT] * N: + reply = await execute("check") + if reply["user_expressions"]["poll"] != [-signal.SIGINT] * N: await asyncio.sleep(0.1) else: break # verify that subprocesses were interrupted - assert reply['user_expressions']['poll'] == [-signal.SIGINT] * N + assert reply["user_expressions"]["poll"] == [-signal.SIGINT] * N @pytest.mark.timeout(10) async def test_start_new_async_kernel(self, install_kernel, start_async_kernel): @@ -520,37 +542,37 @@ async def test_start_new_async_kernel(self, install_kernel, start_async_kernel): async def test_subclass_callables(self, async_km_subclass): async_km_subclass.reset_counts() await async_km_subclass.start_kernel(stdout=PIPE, stderr=PIPE) - assert async_km_subclass.call_count('start_kernel') == 1 - assert async_km_subclass.call_count('_launch_kernel') == 1 + assert async_km_subclass.call_count("start_kernel") == 1 + assert async_km_subclass.call_count("_launch_kernel") == 1 is_alive = await async_km_subclass.is_alive() assert is_alive async_km_subclass.reset_counts() await async_km_subclass.restart_kernel(now=True) - assert async_km_subclass.call_count('restart_kernel') == 1 - assert async_km_subclass.call_count('shutdown_kernel') == 1 - assert async_km_subclass.call_count('interrupt_kernel') == 1 - assert async_km_subclass.call_count('_kill_kernel') == 1 - assert async_km_subclass.call_count('cleanup_resources') == 1 - assert async_km_subclass.call_count('start_kernel') == 1 - assert async_km_subclass.call_count('_launch_kernel') == 1 + assert async_km_subclass.call_count("restart_kernel") == 1 + assert async_km_subclass.call_count("shutdown_kernel") == 1 + assert async_km_subclass.call_count("interrupt_kernel") == 1 + assert async_km_subclass.call_count("_kill_kernel") == 1 + assert async_km_subclass.call_count("cleanup_resources") == 1 + assert async_km_subclass.call_count("start_kernel") == 1 + assert async_km_subclass.call_count("_launch_kernel") == 1 is_alive = await async_km_subclass.is_alive() assert is_alive async_km_subclass.reset_counts() await async_km_subclass.interrupt_kernel() - assert async_km_subclass.call_count('interrupt_kernel') == 1 + assert async_km_subclass.call_count("interrupt_kernel") == 1 assert isinstance(async_km_subclass, AsyncKernelManager) async_km_subclass.reset_counts() await async_km_subclass.shutdown_kernel(now=False) - assert async_km_subclass.call_count('shutdown_kernel') == 1 - assert async_km_subclass.call_count('interrupt_kernel') == 1 - assert async_km_subclass.call_count('request_shutdown') == 1 - assert async_km_subclass.call_count('finish_shutdown') == 1 - assert async_km_subclass.call_count('cleanup_resources') == 1 + assert async_km_subclass.call_count("shutdown_kernel") == 1 + assert async_km_subclass.call_count("interrupt_kernel") == 1 + assert async_km_subclass.call_count("request_shutdown") == 1 + assert async_km_subclass.call_count("finish_shutdown") == 1 + assert async_km_subclass.call_count("cleanup_resources") == 1 is_alive = await async_km_subclass.is_alive() assert is_alive is False diff --git a/jupyter_client/tests/test_kernelspec.py b/jupyter_client/tests/test_kernelspec.py index 5161cdb39..19436e311 100644 --- a/jupyter_client/tests/test_kernelspec.py +++ b/jupyter_client/tests/test_kernelspec.py @@ -1,38 +1,39 @@ """Tests for the KernelSpecManager""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - -import pytest import copy -import io import json import os import sys import tempfile import unittest - from io import StringIO -from os.path import join as pjoin -from subprocess import Popen, PIPE, STDOUT from logging import StreamHandler +from os.path import join as pjoin +from subprocess import PIPE +from subprocess import Popen +from subprocess import STDOUT from tempfile import TemporaryDirectory -from jupyter_client import kernelspec + +import pytest from jupyter_core import paths + from .utils import test_env +from jupyter_client import kernelspec -sample_kernel_json = {'argv':['cat', '{connection_file}'], - 'display_name':'Test kernel', - } +sample_kernel_json = { + "argv": ["cat", "{connection_file}"], + "display_name": "Test kernel", +} -class KernelSpecTests(unittest.TestCase): +class KernelSpecTests(unittest.TestCase): def _install_sample_kernel(self, kernels_dir): """install a sample kernel in a kernels directory""" - sample_kernel_dir = pjoin(kernels_dir, 'sample') + sample_kernel_dir = pjoin(kernels_dir, "sample") os.makedirs(sample_kernel_dir) - json_file = pjoin(sample_kernel_dir, 'kernel.json') - with open(json_file, 'w') as f: + json_file = pjoin(sample_kernel_dir, "kernel.json") + with open(json_file, "w") as f: json.dump(sample_kernel_json, f) return sample_kernel_dir @@ -40,14 +41,15 @@ def setUp(self): self.env_patch = test_env() self.env_patch.start() self.sample_kernel_dir = self._install_sample_kernel( - pjoin(paths.jupyter_data_dir(), 'kernels')) + pjoin(paths.jupyter_data_dir(), "kernels") + ) self.ksm = kernelspec.KernelSpecManager() td2 = TemporaryDirectory() self.addCleanup(td2.cleanup) self.installable_kernel = td2.name - with open(pjoin(self.installable_kernel, 'kernel.json'), 'w') as f: + with open(pjoin(self.installable_kernel, "kernel.json"), "w") as f: json.dump(sample_kernel_json, f) def tearDown(self): @@ -55,20 +57,20 @@ def tearDown(self): def test_find_kernel_specs(self): kernels = self.ksm.find_kernel_specs() - self.assertEqual(kernels['sample'], self.sample_kernel_dir) + self.assertEqual(kernels["sample"], self.sample_kernel_dir) def test_get_kernel_spec(self): - ks = self.ksm.get_kernel_spec('SAMPLE') # Case insensitive + ks = self.ksm.get_kernel_spec("SAMPLE") # Case insensitive self.assertEqual(ks.resource_dir, self.sample_kernel_dir) - self.assertEqual(ks.argv, sample_kernel_json['argv']) - self.assertEqual(ks.display_name, sample_kernel_json['display_name']) + self.assertEqual(ks.argv, sample_kernel_json["argv"]) + self.assertEqual(ks.display_name, sample_kernel_json["display_name"]) self.assertEqual(ks.env, {}) self.assertEqual(ks.metadata, {}) def test_find_all_specs(self): kernels = self.ksm.get_all_specs() - self.assertEqual(kernels['sample']['resource_dir'], self.sample_kernel_dir) - self.assertIsNotNone(kernels['sample']['spec']) + self.assertEqual(kernels["sample"]["resource_dir"], self.sample_kernel_dir) + self.assertIsNotNone(kernels["sample"]["spec"]) def test_kernel_spec_priority(self): td = TemporaryDirectory() @@ -76,21 +78,17 @@ def test_kernel_spec_priority(self): sample_kernel = self._install_sample_kernel(td.name) self.ksm.kernel_dirs.append(td.name) kernels = self.ksm.find_kernel_specs() - self.assertEqual(kernels['sample'], self.sample_kernel_dir) + self.assertEqual(kernels["sample"], self.sample_kernel_dir) self.ksm.kernel_dirs.insert(0, td.name) kernels = self.ksm.find_kernel_specs() - self.assertEqual(kernels['sample'], sample_kernel) + self.assertEqual(kernels["sample"], sample_kernel) def test_install_kernel_spec(self): - self.ksm.install_kernel_spec(self.installable_kernel, - kernel_name='tstinstalled', - user=True) - self.assertIn('tstinstalled', self.ksm.find_kernel_specs()) + self.ksm.install_kernel_spec(self.installable_kernel, kernel_name="tstinstalled", user=True) + self.assertIn("tstinstalled", self.ksm.find_kernel_specs()) # install again works - self.ksm.install_kernel_spec(self.installable_kernel, - kernel_name='tstinstalled', - user=True) + self.ksm.install_kernel_spec(self.installable_kernel, kernel_name="tstinstalled", user=True) def test_install_kernel_spec_prefix(self): td = TemporaryDirectory() @@ -98,66 +96,75 @@ def test_install_kernel_spec_prefix(self): capture = StringIO() handler = StreamHandler(capture) self.ksm.log.addHandler(handler) - self.ksm.install_kernel_spec(self.installable_kernel, - kernel_name='tstinstalled', - prefix=td.name) + self.ksm.install_kernel_spec( + self.installable_kernel, kernel_name="tstinstalled", prefix=td.name + ) captured = capture.getvalue() self.ksm.log.removeHandler(handler) self.assertIn("may not be found", captured) - self.assertNotIn('tstinstalled', self.ksm.find_kernel_specs()) + self.assertNotIn("tstinstalled", self.ksm.find_kernel_specs()) # add prefix to path, so we find the spec - self.ksm.kernel_dirs.append(pjoin(td.name, 'share', 'jupyter', 'kernels')) - self.assertIn('tstinstalled', self.ksm.find_kernel_specs()) + self.ksm.kernel_dirs.append(pjoin(td.name, "share", "jupyter", "kernels")) + self.assertIn("tstinstalled", self.ksm.find_kernel_specs()) # Run it again, no warning this time because we've added it to the path capture = StringIO() handler = StreamHandler(capture) self.ksm.log.addHandler(handler) - self.ksm.install_kernel_spec(self.installable_kernel, - kernel_name='tstinstalled', - prefix=td.name) + self.ksm.install_kernel_spec( + self.installable_kernel, kernel_name="tstinstalled", prefix=td.name + ) captured = capture.getvalue() self.ksm.log.removeHandler(handler) self.assertNotIn("may not be found", captured) @pytest.mark.skipif( - not (os.name != 'nt' and not os.access('/usr/local/share', os.W_OK)), - reason="needs Unix system without root privileges") + not (os.name != "nt" and not os.access("/usr/local/share", os.W_OK)), + reason="needs Unix system without root privileges", + ) def test_cant_install_kernel_spec(self): with self.assertRaises(OSError): - self.ksm.install_kernel_spec(self.installable_kernel, - kernel_name='tstinstalled', - user=False) + self.ksm.install_kernel_spec( + self.installable_kernel, kernel_name="tstinstalled", user=False + ) def test_remove_kernel_spec(self): - path = self.ksm.remove_kernel_spec('sample') + path = self.ksm.remove_kernel_spec("sample") self.assertEqual(path, self.sample_kernel_dir) def test_remove_kernel_spec_app(self): p = Popen( - [sys.executable, '-m', 'jupyter_client.kernelspecapp', 'remove', 'sample', '-f'], - stdout=PIPE, stderr=STDOUT, + [ + sys.executable, + "-m", + "jupyter_client.kernelspecapp", + "remove", + "sample", + "-f", + ], + stdout=PIPE, + stderr=STDOUT, env=os.environ, ) out, _ = p.communicate() - self.assertEqual(p.returncode, 0, out.decode('utf8', 'replace')) + self.assertEqual(p.returncode, 0, out.decode("utf8", "replace")) def test_validate_kernel_name(self): for good in [ - 'julia-0.4', - 'ipython', - 'R', - 'python_3', - 'Haskell-1-2-3', + "julia-0.4", + "ipython", + "R", + "python_3", + "Haskell-1-2-3", ]: assert kernelspec._is_valid_kernel_name(good) for bad in [ - 'has space', - 'ünicode', - '%percent', - 'question?', + "has space", + "ünicode", + "%percent", + "question?", ]: assert not kernelspec._is_valid_kernel_name(bad) @@ -171,7 +178,7 @@ def test_subclass(self): class MyKSM(kernelspec.KernelSpecManager): def get_kernel_spec(self, name): spec = copy.copy(native_kernel) - if name == 'fake': + if name == "fake": spec.name = name spec.resource_dir = resource_dir elif name == native_name: @@ -182,7 +189,7 @@ def get_kernel_spec(self, name): def find_kernel_specs(self): return { - 'fake': resource_dir, + "fake": resource_dir, native_name: native_kernel.resource_dir, } @@ -190,4 +197,4 @@ def find_kernel_specs(self): # find_kernel_specs and get_kernel_spec are defined myksm = MyKSM() specs = myksm.get_all_specs() - assert sorted(specs) == ['fake', native_name] + assert sorted(specs) == ["fake", native_name] diff --git a/jupyter_client/tests/test_localinterfaces.py b/jupyter_client/tests/test_localinterfaces.py index ed4f68e01..7b6c4a9ad 100644 --- a/jupyter_client/tests/test_localinterfaces.py +++ b/jupyter_client/tests/test_localinterfaces.py @@ -1,15 +1,15 @@ -#----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- # Copyright (c) The Jupyter Development Team # # Distributed under the terms of the BSD License. The full license is in # the file COPYING, distributed as part of this software. -#----------------------------------------------------------------------------- - +# ----------------------------------------------------------------------------- from .. import localinterfaces + def test_load_ips(): # Override the machinery that skips it if it was called before localinterfaces._load_ips.called = False # Just check this doesn't error - localinterfaces._load_ips(suppress_exceptions=False) \ No newline at end of file + localinterfaces._load_ips(suppress_exceptions=False) diff --git a/jupyter_client/tests/test_manager.py b/jupyter_client/tests/test_manager.py index 8a59ab532..4f62ca17b 100644 --- a/jupyter_client/tests/test_manager.py +++ b/jupyter_client/tests/test_manager.py @@ -1,38 +1,34 @@ """Tests for KernelManager""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +import os +import tempfile +from unittest import mock from jupyter_client.kernelspec import KernelSpec -from unittest import mock from jupyter_client.manager import KernelManager -import os -import tempfile def test_connection_file_real_path(): """ Verify realpath is used when formatting connection file """ - with mock.patch('os.path.realpath') as patched_realpath: - patched_realpath.return_value = 'foobar' - km = KernelManager(connection_file=os.path.join( - tempfile.gettempdir(), "kernel-test.json"), - kernel_name='test_kernel') + with mock.patch("os.path.realpath") as patched_realpath: + patched_realpath.return_value = "foobar" + km = KernelManager( + connection_file=os.path.join(tempfile.gettempdir(), "kernel-test.json"), + kernel_name="test_kernel", + ) # KernelSpec and launch args have to be mocked as we don't have an actual kernel on disk - km._kernel_spec = KernelSpec(resource_dir='test', + km._kernel_spec = KernelSpec( + resource_dir="test", **{ - "argv": [ - "python.exe", - "-m", - "test_kernel", - "-f", - "{connection_file}" - ], + "argv": ["python.exe", "-m", "test_kernel", "-f", "{connection_file}"], "env": {}, "display_name": "test_kernel", "language": "python", - "metadata": {} - }) + "metadata": {}, + }, + ) km._launch_args = {} cmds = km.format_kernel_cmd() - assert cmds[4] == 'foobar' + assert cmds[4] == "foobar" diff --git a/jupyter_client/tests/test_multikernelmanager.py b/jupyter_client/tests/test_multikernelmanager.py index fba0eff04..7f45ff11f 100644 --- a/jupyter_client/tests/test_multikernelmanager.py +++ b/jupyter_client/tests/test_multikernelmanager.py @@ -1,19 +1,26 @@ """Tests for the notebook kernel and session manager.""" - import asyncio import concurrent.futures -import uuid import sys - -import pytest +import uuid from subprocess import PIPE from unittest import TestCase -from tornado.testing import AsyncTestCase, gen_test + +import pytest +from tornado.testing import AsyncTestCase +from tornado.testing import gen_test from traitlets.config.loader import Config -from jupyter_client import KernelManager, AsyncKernelManager -from jupyter_client.multikernelmanager import MultiKernelManager, AsyncMultiKernelManager -from .utils import skip_win32, SyncMKMSubclass, AsyncMKMSubclass, SyncKMSubclass, AsyncKMSubclass + from ..localinterfaces import localhost +from .utils import AsyncKMSubclass +from .utils import AsyncMKMSubclass +from .utils import skip_win32 +from .utils import SyncKMSubclass +from .utils import SyncMKMSubclass +from jupyter_client import AsyncKernelManager +from jupyter_client import KernelManager +from jupyter_client.multikernelmanager import AsyncMultiKernelManager +from jupyter_client.multikernelmanager import MultiKernelManager TIMEOUT = 30 @@ -37,8 +44,8 @@ def _get_tcp_km_sub(): @staticmethod def _get_ipc_km(): c = Config() - c.KernelManager.transport = 'ipc' - c.KernelManager.ip = 'test' + c.KernelManager.transport = "ipc" + c.KernelManager.ip = "test" km = MultiKernelManager(config=c) return km @@ -53,7 +60,7 @@ def _run_lifecycle(km, test_kid=None): assert km.is_alive(kid) assert kid in km assert kid in km.list_kernel_ids() - assert len(km) == 1, f'{len(km)} != {1}' + assert len(km) == 1, f"{len(km)} != {1}" km.restart_kernel(kid, now=True) assert km.is_alive(kid) assert kid in km.list_kernel_ids() @@ -61,22 +68,22 @@ def _run_lifecycle(km, test_kid=None): k = km.get_kernel(kid) assert isinstance(k, KernelManager) km.shutdown_kernel(kid, now=True) - assert kid not in km, f'{kid} not in {km}' + assert kid not in km, f"{kid} not in {km}" def _run_cinfo(self, km, transport, ip): kid = km.start_kernel(stdout=PIPE, stderr=PIPE) - k = km.get_kernel(kid) + km.get_kernel(kid) cinfo = km.get_connection_info(kid) - self.assertEqual(transport, cinfo['transport']) - self.assertEqual(ip, cinfo['ip']) - self.assertTrue('stdin_port' in cinfo) - self.assertTrue('iopub_port' in cinfo) + self.assertEqual(transport, cinfo["transport"]) + self.assertEqual(ip, cinfo["ip"]) + self.assertTrue("stdin_port" in cinfo) + self.assertTrue("iopub_port" in cinfo) stream = km.connect_iopub(kid) stream.close() - self.assertTrue('shell_port' in cinfo) + self.assertTrue("shell_port" in cinfo) stream = km.connect_shell(kid) stream.close() - self.assertTrue('hb_port' in cinfo) + self.assertTrue("hb_port" in cinfo) stream = km.connect_hb(kid) stream.close() km.shutdown_kernel(kid, now=True) @@ -102,7 +109,7 @@ def test_shutdown_all(self): def test_tcp_cinfo(self): km = self._get_tcp_km() - self._run_cinfo(km, 'tcp', localhost()) + self._run_cinfo(km, "tcp", localhost()) @skip_win32 def test_ipc_lifecycle(self): @@ -112,7 +119,7 @@ def test_ipc_lifecycle(self): @skip_win32 def test_ipc_cinfo(self): km = self._get_ipc_km() - self._run_cinfo(km, 'ipc', 'test') + self._run_cinfo(km, "ipc", "test") def test_start_sequence_tcp_kernels(self): """Ensure that a sequence of kernel startups doesn't break anything.""" @@ -141,7 +148,10 @@ def test_start_parallel_thread_kernels(self): future1.result() future2.result() - @pytest.mark.skipif((sys.platform == 'darwin') and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 8)), reason='"Bad file descriptor" error') + @pytest.mark.skipif( + (sys.platform == "darwin") and (sys.version_info >= (3, 6)) and (sys.version_info < (3, 8)), + reason='"Bad file descriptor" error', + ) def test_start_parallel_process_kernels(self): self.test_tcp_lifecycle() @@ -158,28 +168,28 @@ def test_subclass_callables(self): km.reset_counts() kid = km.start_kernel(stdout=PIPE, stderr=PIPE) - assert km.call_count('start_kernel') == 1 + assert km.call_count("start_kernel") == 1 assert isinstance(km.get_kernel(kid), SyncKMSubclass) - assert km.get_kernel(kid).call_count('start_kernel') == 1 - assert km.get_kernel(kid).call_count('_launch_kernel') == 1 + assert km.get_kernel(kid).call_count("start_kernel") == 1 + assert km.get_kernel(kid).call_count("_launch_kernel") == 1 assert km.is_alive(kid) assert kid in km assert kid in km.list_kernel_ids() - assert len(km) == 1, f'{len(km)} != {1}' + assert len(km) == 1, f"{len(km)} != {1}" km.get_kernel(kid).reset_counts() km.reset_counts() km.restart_kernel(kid, now=True) - assert km.call_count('restart_kernel') == 1 - assert km.call_count('get_kernel') == 1 - assert km.get_kernel(kid).call_count('restart_kernel') == 1 - assert km.get_kernel(kid).call_count('shutdown_kernel') == 1 - assert km.get_kernel(kid).call_count('interrupt_kernel') == 1 - assert km.get_kernel(kid).call_count('_kill_kernel') == 1 - assert km.get_kernel(kid).call_count('cleanup_resources') == 1 - assert km.get_kernel(kid).call_count('start_kernel') == 1 - assert km.get_kernel(kid).call_count('_launch_kernel') == 1 + assert km.call_count("restart_kernel") == 1 + assert km.call_count("get_kernel") == 1 + assert km.get_kernel(kid).call_count("restart_kernel") == 1 + assert km.get_kernel(kid).call_count("shutdown_kernel") == 1 + assert km.get_kernel(kid).call_count("interrupt_kernel") == 1 + assert km.get_kernel(kid).call_count("_kill_kernel") == 1 + assert km.get_kernel(kid).call_count("cleanup_resources") == 1 + assert km.get_kernel(kid).call_count("start_kernel") == 1 + assert km.get_kernel(kid).call_count("_launch_kernel") == 1 assert km.is_alive(kid) assert kid in km.list_kernel_ids() @@ -187,26 +197,26 @@ def test_subclass_callables(self): km.get_kernel(kid).reset_counts() km.reset_counts() km.interrupt_kernel(kid) - assert km.call_count('interrupt_kernel') == 1 - assert km.call_count('get_kernel') == 1 - assert km.get_kernel(kid).call_count('interrupt_kernel') == 1 + assert km.call_count("interrupt_kernel") == 1 + assert km.call_count("get_kernel") == 1 + assert km.get_kernel(kid).call_count("interrupt_kernel") == 1 km.get_kernel(kid).reset_counts() km.reset_counts() k = km.get_kernel(kid) assert isinstance(k, SyncKMSubclass) - assert km.call_count('get_kernel') == 1 + assert km.call_count("get_kernel") == 1 km.get_kernel(kid).reset_counts() km.reset_counts() km.shutdown_all(now=True) - assert km.call_count('shutdown_kernel') == 1 - assert km.call_count('remove_kernel') == 1 - assert km.call_count('request_shutdown') == 0 - assert km.call_count('finish_shutdown') == 0 - assert km.call_count('cleanup_resources') == 0 + assert km.call_count("shutdown_kernel") == 1 + assert km.call_count("remove_kernel") == 1 + assert km.call_count("request_shutdown") == 0 + assert km.call_count("finish_shutdown") == 0 + assert km.call_count("cleanup_resources") == 0 - assert kid not in km, f'{kid} not in {km}' + assert kid not in km, f"{kid} not in {km}" class TestAsyncKernelManager(AsyncTestCase): @@ -228,8 +238,8 @@ def _get_tcp_km_sub(): @staticmethod def _get_ipc_km(): c = Config() - c.KernelManager.transport = 'ipc' - c.KernelManager.ip = 'test' + c.KernelManager.transport = "ipc" + c.KernelManager.ip = "test" km = AsyncMultiKernelManager(config=c) return km @@ -244,7 +254,7 @@ async def _run_lifecycle(km, test_kid=None): assert await km.is_alive(kid) assert kid in km assert kid in km.list_kernel_ids() - assert len(km) == 1, f'{len(km)} != {1}' + assert len(km) == 1, f"{len(km)} != {1}" await km.restart_kernel(kid, now=True) assert await km.is_alive(kid) assert kid in km.list_kernel_ids() @@ -252,22 +262,22 @@ async def _run_lifecycle(km, test_kid=None): k = km.get_kernel(kid) assert isinstance(k, AsyncKernelManager) await km.shutdown_kernel(kid, now=True) - assert kid not in km, f'{kid} not in {km}' + assert kid not in km, f"{kid} not in {km}" async def _run_cinfo(self, km, transport, ip): kid = await km.start_kernel(stdout=PIPE, stderr=PIPE) - k = km.get_kernel(kid) + km.get_kernel(kid) cinfo = km.get_connection_info(kid) - self.assertEqual(transport, cinfo['transport']) - self.assertEqual(ip, cinfo['ip']) - self.assertTrue('stdin_port' in cinfo) - self.assertTrue('iopub_port' in cinfo) + self.assertEqual(transport, cinfo["transport"]) + self.assertEqual(ip, cinfo["ip"]) + self.assertTrue("stdin_port" in cinfo) + self.assertTrue("iopub_port" in cinfo) stream = km.connect_iopub(kid) stream.close() - self.assertTrue('shell_port' in cinfo) + self.assertTrue("shell_port" in cinfo) stream = km.connect_shell(kid) stream.close() - self.assertTrue('hb_port' in cinfo) + self.assertTrue("hb_port" in cinfo) stream = km.connect_hb(kid) stream.close() await km.shutdown_kernel(kid, now=True) @@ -327,7 +337,7 @@ async def test_shutdown_all_while_starting(self): @gen_test async def test_tcp_cinfo(self): km = self._get_tcp_km() - await self._run_cinfo(km, 'tcp', localhost()) + await self._run_cinfo(km, "tcp", localhost()) @skip_win32 @gen_test @@ -339,7 +349,7 @@ async def test_ipc_lifecycle(self): @gen_test async def test_ipc_cinfo(self): km = self._get_ipc_km() - await self._run_cinfo(km, 'ipc', 'test') + await self._run_cinfo(km, "ipc", "test") @gen_test async def test_start_sequence_tcp_kernels(self): @@ -407,28 +417,28 @@ async def test_subclass_callables(self): mkm.reset_counts() kid = await mkm.start_kernel(stdout=PIPE, stderr=PIPE) - assert mkm.call_count('start_kernel') == 1 + assert mkm.call_count("start_kernel") == 1 assert isinstance(mkm.get_kernel(kid), AsyncKMSubclass) - assert mkm.get_kernel(kid).call_count('start_kernel') == 1 - assert mkm.get_kernel(kid).call_count('_launch_kernel') == 1 + assert mkm.get_kernel(kid).call_count("start_kernel") == 1 + assert mkm.get_kernel(kid).call_count("_launch_kernel") == 1 assert await mkm.is_alive(kid) assert kid in mkm assert kid in mkm.list_kernel_ids() - assert len(mkm) == 1, f'{len(mkm)} != {1}' + assert len(mkm) == 1, f"{len(mkm)} != {1}" mkm.get_kernel(kid).reset_counts() mkm.reset_counts() await mkm.restart_kernel(kid, now=True) - assert mkm.call_count('restart_kernel') == 1 - assert mkm.call_count('get_kernel') == 1 - assert mkm.get_kernel(kid).call_count('restart_kernel') == 1 - assert mkm.get_kernel(kid).call_count('shutdown_kernel') == 1 - assert mkm.get_kernel(kid).call_count('interrupt_kernel') == 1 - assert mkm.get_kernel(kid).call_count('_kill_kernel') == 1 - assert mkm.get_kernel(kid).call_count('cleanup_resources') == 1 - assert mkm.get_kernel(kid).call_count('start_kernel') == 1 - assert mkm.get_kernel(kid).call_count('_launch_kernel') == 1 + assert mkm.call_count("restart_kernel") == 1 + assert mkm.call_count("get_kernel") == 1 + assert mkm.get_kernel(kid).call_count("restart_kernel") == 1 + assert mkm.get_kernel(kid).call_count("shutdown_kernel") == 1 + assert mkm.get_kernel(kid).call_count("interrupt_kernel") == 1 + assert mkm.get_kernel(kid).call_count("_kill_kernel") == 1 + assert mkm.get_kernel(kid).call_count("cleanup_resources") == 1 + assert mkm.get_kernel(kid).call_count("start_kernel") == 1 + assert mkm.get_kernel(kid).call_count("_launch_kernel") == 1 assert await mkm.is_alive(kid) assert kid in mkm.list_kernel_ids() @@ -436,23 +446,23 @@ async def test_subclass_callables(self): mkm.get_kernel(kid).reset_counts() mkm.reset_counts() await mkm.interrupt_kernel(kid) - assert mkm.call_count('interrupt_kernel') == 1 - assert mkm.call_count('get_kernel') == 1 - assert mkm.get_kernel(kid).call_count('interrupt_kernel') == 1 + assert mkm.call_count("interrupt_kernel") == 1 + assert mkm.call_count("get_kernel") == 1 + assert mkm.get_kernel(kid).call_count("interrupt_kernel") == 1 mkm.get_kernel(kid).reset_counts() mkm.reset_counts() k = mkm.get_kernel(kid) assert isinstance(k, AsyncKMSubclass) - assert mkm.call_count('get_kernel') == 1 + assert mkm.call_count("get_kernel") == 1 mkm.get_kernel(kid).reset_counts() mkm.reset_counts() await mkm.shutdown_all(now=True) - assert mkm.call_count('shutdown_kernel') == 1 - assert mkm.call_count('remove_kernel') == 1 - assert mkm.call_count('request_shutdown') == 0 - assert mkm.call_count('finish_shutdown') == 0 - assert mkm.call_count('cleanup_resources') == 0 + assert mkm.call_count("shutdown_kernel") == 1 + assert mkm.call_count("remove_kernel") == 1 + assert mkm.call_count("request_shutdown") == 0 + assert mkm.call_count("finish_shutdown") == 0 + assert mkm.call_count("cleanup_resources") == 0 - assert kid not in mkm, f'{kid} not in {mkm}' + assert kid not in mkm, f"{kid} not in {mkm}" diff --git a/jupyter_client/tests/test_public_api.py b/jupyter_client/tests/test_public_api.py index 5ebf2f3d3..6b3fdc671 100644 --- a/jupyter_client/tests/test_public_api.py +++ b/jupyter_client/tests/test_public_api.py @@ -1,11 +1,10 @@ """Test the jupyter_client public API """ - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - -from jupyter_client import launcher, connect import jupyter_client +from jupyter_client import connect +from jupyter_client import launcher def test_kms(): @@ -13,15 +12,18 @@ def test_kms(): KM = base + "KernelManager" assert KM in dir(jupyter_client) + def test_kcs(): for base in ("", "Blocking", "Async"): KM = base + "KernelClient" assert KM in dir(jupyter_client) + def test_launcher(): for name in launcher.__all__: assert name in dir(jupyter_client) + def test_connect(): for name in connect.__all__: assert name in dir(jupyter_client) diff --git a/jupyter_client/tests/test_session.py b/jupyter_client/tests/test_session.py index 45be9a96a..82ad9e666 100644 --- a/jupyter_client/tests/test_session.py +++ b/jupyter_client/tests/test_session.py @@ -1,8 +1,6 @@ """test building messages with Session""" - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - import hmac import os import uuid @@ -10,23 +8,23 @@ from unittest import mock import pytest - import zmq - -from zmq.tests import BaseZMQTestCase from zmq.eventloop.zmqstream import ZMQStream +from zmq.tests import BaseZMQTestCase -from jupyter_client import session as ss from jupyter_client import jsonutil +from jupyter_client import session as ss + def _bad_packer(obj): raise TypeError("I don't work") + def _bad_unpacker(bytes): raise TypeError("I don't work either") -class SessionTestCase(BaseZMQTestCase): +class SessionTestCase(BaseZMQTestCase): def setUp(self): BaseZMQTestCase.setUp(self) self.session = ss.Session() @@ -35,42 +33,41 @@ def setUp(self): @pytest.fixture def no_copy_threshold(): """Disable zero-copy optimizations in pyzmq >= 17""" - with mock.patch.object(zmq, 'COPY_THRESHOLD', 1, create=True): + with mock.patch.object(zmq, "COPY_THRESHOLD", 1, create=True): yield -@pytest.mark.usefixtures('no_copy_threshold') +@pytest.mark.usefixtures("no_copy_threshold") class TestSession(SessionTestCase): - def test_msg(self): """message format""" - msg = self.session.msg('execute') - thekeys = set('header parent_header metadata content msg_type msg_id'.split()) + msg = self.session.msg("execute") + thekeys = set("header parent_header metadata content msg_type msg_id".split()) s = set(msg.keys()) self.assertEqual(s, thekeys) - self.assertTrue(isinstance(msg['content'],dict)) - self.assertTrue(isinstance(msg['metadata'],dict)) - self.assertTrue(isinstance(msg['header'],dict)) - self.assertTrue(isinstance(msg['parent_header'],dict)) - self.assertTrue(isinstance(msg['msg_id'], str)) - self.assertTrue(isinstance(msg['msg_type'], str)) - self.assertEqual(msg['header']['msg_type'], 'execute') - self.assertEqual(msg['msg_type'], 'execute') + self.assertTrue(isinstance(msg["content"], dict)) + self.assertTrue(isinstance(msg["metadata"], dict)) + self.assertTrue(isinstance(msg["header"], dict)) + self.assertTrue(isinstance(msg["parent_header"], dict)) + self.assertTrue(isinstance(msg["msg_id"], str)) + self.assertTrue(isinstance(msg["msg_type"], str)) + self.assertEqual(msg["header"]["msg_type"], "execute") + self.assertEqual(msg["msg_type"], "execute") def test_serialize(self): - msg = self.session.msg('execute', content=dict(a=10, b=1.1)) - msg_list = self.session.serialize(msg, ident=b'foo') + msg = self.session.msg("execute", content=dict(a=10, b=1.1)) + msg_list = self.session.serialize(msg, ident=b"foo") ident, msg_list = self.session.feed_identities(msg_list) new_msg = self.session.deserialize(msg_list) - self.assertEqual(ident[0], b'foo') - self.assertEqual(new_msg['msg_id'],msg['msg_id']) - self.assertEqual(new_msg['msg_type'],msg['msg_type']) - self.assertEqual(new_msg['header'],msg['header']) - self.assertEqual(new_msg['content'],msg['content']) - self.assertEqual(new_msg['parent_header'],msg['parent_header']) - self.assertEqual(new_msg['metadata'],msg['metadata']) + self.assertEqual(ident[0], b"foo") + self.assertEqual(new_msg["msg_id"], msg["msg_id"]) + self.assertEqual(new_msg["msg_type"], msg["msg_type"]) + self.assertEqual(new_msg["header"], msg["header"]) + self.assertEqual(new_msg["content"], msg["content"]) + self.assertEqual(new_msg["parent_header"], msg["parent_header"]) + self.assertEqual(new_msg["metadata"], msg["metadata"]) # ensure floats don't come out as Decimal: - self.assertEqual(type(new_msg['content']['b']),type(new_msg['content']['b'])) + self.assertEqual(type(new_msg["content"]["b"]), type(new_msg["content"]["b"])) def test_default_secure(self): self.assertIsInstance(self.session.key, bytes) @@ -83,60 +80,68 @@ def test_send(self): A.bind("inproc://test") B.connect("inproc://test") - msg = self.session.msg('execute', content=dict(a=10)) - self.session.send(A, msg, ident=b'foo', buffers=[b'bar']) + msg = self.session.msg("execute", content=dict(a=10)) + self.session.send(A, msg, ident=b"foo", buffers=[b"bar"]) ident, msg_list = self.session.feed_identities(B.recv_multipart()) new_msg = self.session.deserialize(msg_list) - self.assertEqual(ident[0], b'foo') - self.assertEqual(new_msg['msg_id'],msg['msg_id']) - self.assertEqual(new_msg['msg_type'],msg['msg_type']) - self.assertEqual(new_msg['header'],msg['header']) - self.assertEqual(new_msg['content'],msg['content']) - self.assertEqual(new_msg['parent_header'],msg['parent_header']) - self.assertEqual(new_msg['metadata'],msg['metadata']) - self.assertEqual(new_msg['buffers'],[b'bar']) - - content = msg['content'] - header = msg['header'] - header['msg_id'] = self.session.msg_id - parent = msg['parent_header'] - metadata = msg['metadata'] - msg_type = header['msg_type'] - self.session.send(A, None, content=content, parent=parent, - header=header, metadata=metadata, ident=b'foo', buffers=[b'bar']) + self.assertEqual(ident[0], b"foo") + self.assertEqual(new_msg["msg_id"], msg["msg_id"]) + self.assertEqual(new_msg["msg_type"], msg["msg_type"]) + self.assertEqual(new_msg["header"], msg["header"]) + self.assertEqual(new_msg["content"], msg["content"]) + self.assertEqual(new_msg["parent_header"], msg["parent_header"]) + self.assertEqual(new_msg["metadata"], msg["metadata"]) + self.assertEqual(new_msg["buffers"], [b"bar"]) + + content = msg["content"] + header = msg["header"] + header["msg_id"] = self.session.msg_id + parent = msg["parent_header"] + metadata = msg["metadata"] + header["msg_type"] + self.session.send( + A, + None, + content=content, + parent=parent, + header=header, + metadata=metadata, + ident=b"foo", + buffers=[b"bar"], + ) ident, msg_list = self.session.feed_identities(B.recv_multipart()) new_msg = self.session.deserialize(msg_list) - self.assertEqual(ident[0], b'foo') - self.assertEqual(new_msg['msg_id'],header['msg_id']) - self.assertEqual(new_msg['msg_type'],msg['msg_type']) - self.assertEqual(new_msg['header'],msg['header']) - self.assertEqual(new_msg['content'],msg['content']) - self.assertEqual(new_msg['metadata'],msg['metadata']) - self.assertEqual(new_msg['parent_header'],msg['parent_header']) - self.assertEqual(new_msg['buffers'],[b'bar']) - - header['msg_id'] = self.session.msg_id - - self.session.send(A, msg, ident=b'foo', buffers=[b'bar']) + self.assertEqual(ident[0], b"foo") + self.assertEqual(new_msg["msg_id"], header["msg_id"]) + self.assertEqual(new_msg["msg_type"], msg["msg_type"]) + self.assertEqual(new_msg["header"], msg["header"]) + self.assertEqual(new_msg["content"], msg["content"]) + self.assertEqual(new_msg["metadata"], msg["metadata"]) + self.assertEqual(new_msg["parent_header"], msg["parent_header"]) + self.assertEqual(new_msg["buffers"], [b"bar"]) + + header["msg_id"] = self.session.msg_id + + self.session.send(A, msg, ident=b"foo", buffers=[b"bar"]) ident, new_msg = self.session.recv(B) - self.assertEqual(ident[0], b'foo') - self.assertEqual(new_msg['msg_id'],header['msg_id']) - self.assertEqual(new_msg['msg_type'],msg['msg_type']) - self.assertEqual(new_msg['header'],msg['header']) - self.assertEqual(new_msg['content'],msg['content']) - self.assertEqual(new_msg['metadata'],msg['metadata']) - self.assertEqual(new_msg['parent_header'],msg['parent_header']) - self.assertEqual(new_msg['buffers'],[b'bar']) + self.assertEqual(ident[0], b"foo") + self.assertEqual(new_msg["msg_id"], header["msg_id"]) + self.assertEqual(new_msg["msg_type"], msg["msg_type"]) + self.assertEqual(new_msg["header"], msg["header"]) + self.assertEqual(new_msg["content"], msg["content"]) + self.assertEqual(new_msg["metadata"], msg["metadata"]) + self.assertEqual(new_msg["parent_header"], msg["parent_header"]) + self.assertEqual(new_msg["buffers"], [b"bar"]) # buffers must support the buffer protocol with self.assertRaises(TypeError): - self.session.send(A, msg, ident=b'foo', buffers=[1]) + self.session.send(A, msg, ident=b"foo", buffers=[1]) # buffers must be contiguous buf = memoryview(os.urandom(16)) with self.assertRaises(ValueError): - self.session.send(A, msg, ident=b'foo', buffers=[buf[::2]]) + self.session.send(A, msg, ident=b"foo", buffers=[buf[::2]]) A.close() B.close() @@ -147,73 +152,70 @@ def test_args(self): s = self.session self.assertTrue(s.pack is ss.default_packer) self.assertTrue(s.unpack is ss.default_unpacker) - self.assertEqual(s.username, os.environ.get('USER', 'username')) + self.assertEqual(s.username, os.environ.get("USER", "username")) s = ss.Session() - self.assertEqual(s.username, os.environ.get('USER', 'username')) + self.assertEqual(s.username, os.environ.get("USER", "username")) - self.assertRaises(TypeError, ss.Session, pack='hi') - self.assertRaises(TypeError, ss.Session, unpack='hi') + self.assertRaises(TypeError, ss.Session, pack="hi") + self.assertRaises(TypeError, ss.Session, unpack="hi") u = str(uuid.uuid4()) - s = ss.Session(username='carrot', session=u) + s = ss.Session(username="carrot", session=u) self.assertEqual(s.session, u) - self.assertEqual(s.username, 'carrot') + self.assertEqual(s.username, "carrot") def test_tracking(self): """test tracking messages""" - a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) + a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR) s = self.session s.copy_threshold = 1 - stream = ZMQStream(a) - msg = s.send(a, 'hello', track=False) - self.assertTrue(msg['tracker'] is ss.DONE) - msg = s.send(a, 'hello', track=True) - self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker)) - M = zmq.Message(b'hi there', track=True) - msg = s.send(a, 'hello', buffers=[M], track=True) - t = msg['tracker'] + ZMQStream(a) + msg = s.send(a, "hello", track=False) + self.assertTrue(msg["tracker"] is ss.DONE) + msg = s.send(a, "hello", track=True) + self.assertTrue(isinstance(msg["tracker"], zmq.MessageTracker)) + M = zmq.Message(b"hi there", track=True) + msg = s.send(a, "hello", buffers=[M], track=True) + t = msg["tracker"] self.assertTrue(isinstance(t, zmq.MessageTracker)) - self.assertRaises(zmq.NotDone, t.wait, .1) + self.assertRaises(zmq.NotDone, t.wait, 0.1) del M - t.wait(1) # this will raise - + t.wait(1) # this will raise def test_unique_msg_ids(self): """test that messages receive unique ids""" ids = set() - for i in range(2**12): - h = self.session.msg_header('test') - msg_id = h['msg_id'] + for i in range(2 ** 12): + h = self.session.msg_header("test") + msg_id = h["msg_id"] self.assertTrue(msg_id not in ids) ids.add(msg_id) def test_feed_identities(self): """scrub the front for zmq IDENTITIES""" - theids = "engine client other".split() - content = dict(code='whoda',stuff=object()) - themsg = self.session.msg('execute',content=content) - pmsg = theids + content = dict(code="whoda", stuff=object()) + self.session.msg("execute", content=content) def test_session_id(self): session = ss.Session() # get bs before us bs = session.bsession us = session.session - self.assertEqual(us.encode('ascii'), bs) + self.assertEqual(us.encode("ascii"), bs) session = ss.Session() # get us before bs us = session.session bs = session.bsession - self.assertEqual(us.encode('ascii'), bs) + self.assertEqual(us.encode("ascii"), bs) # change propagates: - session.session = 'something else' + session.session = "something else" bs = session.bsession us = session.session - self.assertEqual(us.encode('ascii'), bs) - session = ss.Session(session='stuff') + self.assertEqual(us.encode("ascii"), bs) + session = ss.Session(session="stuff") # get us before bs - self.assertEqual(session.bsession, session.session.encode('ascii')) - self.assertEqual(b'stuff', session.bsession) + self.assertEqual(session.bsession, session.session.encode("ascii")) + self.assertEqual(b"stuff", session.bsession) def test_zero_digest_history(self): session = ss.Session(digest_history_size=0) @@ -236,7 +238,7 @@ def test_cull_digest_history(self): def test_bad_pack(self): try: - session = ss.Session(pack=_bad_packer) + ss.Session(pack=_bad_packer) except ValueError as e: self.assertIn("could not serialize", str(e)) self.assertIn("don't work", str(e)) @@ -245,7 +247,7 @@ def test_bad_pack(self): def test_bad_unpack(self): try: - session = ss.Session(unpack=_bad_unpacker) + ss.Session(unpack=_bad_unpacker) except ValueError as e: self.assertIn("could not handle output", str(e)) self.assertIn("don't work either", str(e)) @@ -254,7 +256,7 @@ def test_bad_unpack(self): def test_bad_packer(self): try: - session = ss.Session(packer=__name__ + '._bad_packer') + ss.Session(packer=__name__ + "._bad_packer") except ValueError as e: self.assertIn("could not serialize", str(e)) self.assertIn("don't work", str(e)) @@ -263,7 +265,7 @@ def test_bad_packer(self): def test_bad_unpacker(self): try: - session = ss.Session(unpacker=__name__ + '._bad_unpacker') + ss.Session(unpacker=__name__ + "._bad_unpacker") except ValueError as e: self.assertIn("could not handle output", str(e)) self.assertIn("don't work either", str(e)) @@ -272,35 +274,35 @@ def test_bad_unpacker(self): def test_bad_roundtrip(self): with self.assertRaises(ValueError): - session = ss.Session(unpack=lambda b: 5) + ss.Session(unpack=lambda b: 5) def _datetime_test(self, session): content = dict(t=ss.utcnow()) metadata = dict(t=ss.utcnow()) - p = session.msg('msg') - msg = session.msg('msg', content=content, metadata=metadata, parent=p['header']) + p = session.msg("msg") + msg = session.msg("msg", content=content, metadata=metadata, parent=p["header"]) smsg = session.serialize(msg) msg2 = session.deserialize(session.feed_identities(smsg)[1]) - assert isinstance(msg2['header']['date'], datetime) - self.assertEqual(msg['header'], msg2['header']) - self.assertEqual(msg['parent_header'], msg2['parent_header']) - self.assertEqual(msg['parent_header'], msg2['parent_header']) - assert isinstance(msg['content']['t'], datetime) - assert isinstance(msg['metadata']['t'], datetime) - assert isinstance(msg2['content']['t'], str) - assert isinstance(msg2['metadata']['t'], str) - self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content'])) - self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content'])) + assert isinstance(msg2["header"]["date"], datetime) + self.assertEqual(msg["header"], msg2["header"]) + self.assertEqual(msg["parent_header"], msg2["parent_header"]) + self.assertEqual(msg["parent_header"], msg2["parent_header"]) + assert isinstance(msg["content"]["t"], datetime) + assert isinstance(msg["metadata"]["t"], datetime) + assert isinstance(msg2["content"]["t"], str) + assert isinstance(msg2["metadata"]["t"], str) + self.assertEqual(msg["content"], jsonutil.extract_dates(msg2["content"])) + self.assertEqual(msg["content"], jsonutil.extract_dates(msg2["content"])) def test_datetimes(self): self._datetime_test(self.session) def test_datetimes_pickle(self): - session = ss.Session(packer='pickle') + session = ss.Session(packer="pickle") self._datetime_test(session) def test_datetimes_msgpack(self): - msgpack = pytest.importorskip('msgpack') + msgpack = pytest.importorskip("msgpack") session = ss.Session( pack=msgpack.packb, @@ -315,19 +317,21 @@ def test_send_raw(self): A.bind("inproc://test") B.connect("inproc://test") - msg = self.session.msg('execute', content=dict(a=10)) - msg_list = [self.session.pack(msg[part]) for part in - ['header', 'parent_header', 'metadata', 'content']] - self.session.send_raw(A, msg_list, ident=b'foo') + msg = self.session.msg("execute", content=dict(a=10)) + msg_list = [ + self.session.pack(msg[part]) + for part in ["header", "parent_header", "metadata", "content"] + ] + self.session.send_raw(A, msg_list, ident=b"foo") ident, new_msg_list = self.session.feed_identities(B.recv_multipart()) new_msg = self.session.deserialize(new_msg_list) - self.assertEqual(ident[0], b'foo') - self.assertEqual(new_msg['msg_type'],msg['msg_type']) - self.assertEqual(new_msg['header'],msg['header']) - self.assertEqual(new_msg['parent_header'],msg['parent_header']) - self.assertEqual(new_msg['content'],msg['content']) - self.assertEqual(new_msg['metadata'],msg['metadata']) + self.assertEqual(ident[0], b"foo") + self.assertEqual(new_msg["msg_type"], msg["msg_type"]) + self.assertEqual(new_msg["header"], msg["header"]) + self.assertEqual(new_msg["parent_header"], msg["parent_header"]) + self.assertEqual(new_msg["content"], msg["content"]) + self.assertEqual(new_msg["metadata"], msg["metadata"]) A.close() B.close() @@ -335,12 +339,12 @@ def test_send_raw(self): def test_clone(self): s = self.session - s._add_digest('initial') + s._add_digest("initial") s2 = s.clone() assert s2.session == s.session assert s2.digest_history == s.digest_history assert s2.digest_history is not s.digest_history - digest = 'abcdef' + digest = "abcdef" s._add_digest(digest) assert digest in s.digest_history assert digest not in s2.digest_history diff --git a/jupyter_client/tests/test_ssh.py b/jupyter_client/tests/test_ssh.py index e1673f9f4..90f96c443 100644 --- a/jupyter_client/tests/test_ssh.py +++ b/jupyter_client/tests/test_ssh.py @@ -1,5 +1,6 @@ from jupyter_client.ssh.tunnel import select_random_ports + def test_random_ports(): for i in range(4096): ports = select_random_ports(10) diff --git a/jupyter_client/tests/utils.py b/jupyter_client/tests/utils.py index 2778f927d..9165330ca 100644 --- a/jupyter_client/tests/utils.py +++ b/jupyter_client/tests/utils.py @@ -2,65 +2,74 @@ """ import os -pjoin = os.path.join import sys -from unittest.mock import patch from tempfile import TemporaryDirectory from typing import Dict +from unittest.mock import patch import pytest -from jupyter_client import AsyncKernelManager, KernelManager, AsyncMultiKernelManager, MultiKernelManager +from jupyter_client import AsyncKernelManager +from jupyter_client import AsyncMultiKernelManager +from jupyter_client import KernelManager +from jupyter_client import MultiKernelManager + +pjoin = os.path.join -skip_win32 = pytest.mark.skipif(sys.platform.startswith('win'), reason="Windows") +skip_win32 = pytest.mark.skipif(sys.platform.startswith("win"), reason="Windows") class test_env(object): """Set Jupyter path variables to a temporary directory - + Useful as a context manager or with explicit start/stop """ + def start(self): self.test_dir = td = TemporaryDirectory() - self.env_patch = patch.dict(os.environ, { - 'JUPYTER_CONFIG_DIR': pjoin(td.name, 'jupyter'), - 'JUPYTER_DATA_DIR': pjoin(td.name, 'jupyter_data'), - 'JUPYTER_RUNTIME_DIR': pjoin(td.name, 'jupyter_runtime'), - 'IPYTHONDIR': pjoin(td.name, 'ipython'), - 'TEST_VARS': 'test_var_1', - }) + self.env_patch = patch.dict( + os.environ, + { + "JUPYTER_CONFIG_DIR": pjoin(td.name, "jupyter"), + "JUPYTER_DATA_DIR": pjoin(td.name, "jupyter_data"), + "JUPYTER_RUNTIME_DIR": pjoin(td.name, "jupyter_runtime"), + "IPYTHONDIR": pjoin(td.name, "ipython"), + "TEST_VARS": "test_var_1", + }, + ) self.env_patch.start() - + def stop(self): self.env_patch.stop() self.test_dir.cleanup() - + def __enter__(self): self.start() return self.test_dir.name - + def __exit__(self, *exc_info): self.stop() -def execute(code='', kc=None, **kwargs): +def execute(code="", kc=None, **kwargs): """wrapper for doing common steps for validating an execution request""" from .test_message_spec import validate_message + if kc is None: - kc = KC + kc = KC # noqa msg_id = kc.execute(code=code, **kwargs) - reply = kc.get_shell_msg(timeout=TIMEOUT) - validate_message(reply, 'execute_reply', msg_id) - busy = kc.get_iopub_msg(timeout=TIMEOUT) - validate_message(busy, 'status', msg_id) - assert busy['content']['execution_state'] == 'busy' + reply = kc.get_shell_msg(timeout=TIMEOUT) # noqa + validate_message(reply, "execute_reply", msg_id) + busy = kc.get_iopub_msg(timeout=TIMEOUT) # noqa + validate_message(busy, "status", msg_id) + assert busy["content"]["execution_state"] == "busy" - if not kwargs.get('silent'): - execute_input = kc.get_iopub_msg(timeout=TIMEOUT) - validate_message(execute_input, 'execute_input', msg_id) - assert execute_input['content']['code'] == code + if not kwargs.get("silent"): + execute_input = kc.get_iopub_msg(timeout=TIMEOUT) # noqa + validate_message(execute_input, "execute_input", msg_id) + assert execute_input["content"]["code"] == code - return msg_id, reply['content'] + return msg_id, reply["content"] class RecordCallMixin: @@ -95,11 +104,11 @@ def wrapped(self, *args, **kwargs): # call anything defined in the actual class method f(self, *args, **kwargs) return r + return wrapped class KMSubclass(RecordCallMixin): - @subclass_recorder def start_kernel(self, **kw): """ Record call and defer to superclass """ @@ -144,39 +153,40 @@ class SyncKMSubclass(KMSubclass, KernelManager): class AsyncKMSubclass(KMSubclass, AsyncKernelManager): """Used to test subclass hierarchies to ensure methods are called when expected. - This class is also used to test deprecation "routes" that are determined by superclass' - detection of methods. + This class is also used to test deprecation "routes" that are determined by superclass' + detection of methods. - This class represents a current subclass that overrides "interesting" methods of AsyncKernelManager. + This class represents a current subclass that overrides "interesting" methods of + AsyncKernelManager. """ + _superclass = AsyncKernelManager which_cleanup = "" # cleanup deprecation testing @subclass_recorder def cleanup(self, connection_file=True): - self.which_cleanup = 'cleanup' + self.which_cleanup = "cleanup" @subclass_recorder def cleanup_resources(self, restart=False): - self.which_cleanup = 'cleanup_resources' + self.which_cleanup = "cleanup_resources" class AsyncKernelManagerWithCleanup(AsyncKernelManager): """Used to test deprecation "routes" that are determined by superclass' detection of methods. - This class represents the older subclass that overrides cleanup(). We should find that - cleanup() is called on these instances via TestAsyncKernelManagerWithCleanup. + This class represents the older subclass that overrides cleanup(). We should find that + cleanup() is called on these instances via TestAsyncKernelManagerWithCleanup. """ def cleanup(self, connection_file=True): super().cleanup(connection_file=connection_file) - self.which_cleanup = 'cleanup' + self.which_cleanup = "cleanup" class MKMSubclass(RecordCallMixin): - def _kernel_manager_class_default(self): - return 'jupyter_client.tests.utils.SyncKMSubclass' + return "jupyter_client.tests.utils.SyncKMSubclass" @subclass_recorder def get_kernel(self, kernel_id): @@ -224,7 +234,7 @@ class SyncMKMSubclass(MKMSubclass, MultiKernelManager): _superclass = MultiKernelManager def _kernel_manager_class_default(self): - return 'jupyter_client.tests.utils.SyncKMSubclass' + return "jupyter_client.tests.utils.SyncKMSubclass" class AsyncMKMSubclass(MKMSubclass, AsyncMultiKernelManager): @@ -232,4 +242,4 @@ class AsyncMKMSubclass(MKMSubclass, AsyncMultiKernelManager): _superclass = AsyncMultiKernelManager def _kernel_manager_class_default(self): - return 'jupyter_client.tests.utils.AsyncKMSubclass' + return "jupyter_client.tests.utils.AsyncKMSubclass" diff --git a/jupyter_client/threaded.py b/jupyter_client/threaded.py index 46a166f57..25dfc6184 100644 --- a/jupyter_client/threaded.py +++ b/jupyter_client/threaded.py @@ -1,23 +1,30 @@ -""" Defines a KernelClient that provides thread-safe sockets with async callbacks on message replies. +""" Defines a KernelClient that provides thread-safe sockets with async callbacks on message +replies. """ import atexit import errno import sys -from threading import Thread, Event import time +from threading import Event +from threading import Thread -# import ZMQError in top-level namespace, to avoid ugly attribute-error messages -# during garbage collection of threads at exit: +from traitlets import Instance +from traitlets import Type from zmq import ZMQError -from zmq.eventloop import ioloop, zmqstream +from zmq.eventloop import ioloop +from zmq.eventloop import zmqstream -# Local imports -from traitlets import Type, Instance -from jupyter_client.channels import HBChannel from jupyter_client import KernelClient +from jupyter_client.channels import HBChannel + +# Local imports +# import ZMQError in top-level namespace, to avoid ugly attribute-error messages +# during garbage collection of threads at exit + class ThreadedZMQSocketChannel(object): """A ZMQ socket invoking a callback in the ioloop""" + session = None socket = None ioloop = None @@ -52,6 +59,7 @@ def setup_stream(): evt.wait() _is_alive = False + def is_alive(self): return self._is_alive @@ -79,8 +87,10 @@ def send(self, msg): This is threadsafe, as it uses IOLoop.add_callback to give the loop's thread control of the action. """ + def thread_send(): self.session.send(self.stream, msg) + self.ioloop.add_callback(thread_send) def _handle_recv(self, msg): @@ -88,7 +98,7 @@ def _handle_recv(self, msg): Unpacks message, and calls handlers with it. """ - ident,smsg = self.session.feed_identities(msg) + ident, smsg = self.session.feed_identities(msg) msg = self.session.deserialize(smsg) # let client inspect messages if self._inspect: @@ -111,7 +121,6 @@ def process_events(self): """ pass - def flush(self, timeout=1.0): """Immediately processes all pending messages on this channel. @@ -145,8 +154,8 @@ def _flush(self): class IOLoopThread(Thread): - """Run a pyzmq ioloop in a thread to send and receive messages - """ + """Run a pyzmq ioloop in a thread to send and receive messages""" + _exiting = False ioloop = None @@ -174,10 +183,11 @@ def start(self): def run(self): """Run my loop, ignoring EINTR events in the poller""" - if 'asyncio' in sys.modules: + if "asyncio" in sys.modules: # tornado may be using asyncio, # ensure an eventloop exists for this thread import asyncio + asyncio.set_event_loop(asyncio.new_event_loop()) self.ioloop = ioloop.IOLoop() # signal that self.ioloop is defined @@ -220,8 +230,7 @@ def close(self): class ThreadedKernelClient(KernelClient): - """ A KernelClient that provides thread-safe sockets with async callbacks on message replies. - """ + """A KernelClient that provides thread-safe sockets with async callbacks on message replies.""" @property def ioloop(self): @@ -239,9 +248,8 @@ def start_channels(self, shell=True, iopub=True, stdin=True, hb=True, control=Tr super().start_channels(shell, iopub, stdin, hb, control) def _check_kernel_info_reply(self, msg): - """This is run in the ioloop thread when the kernel info reply is received - """ - if msg['msg_type'] == 'kernel_info_reply': + """This is run in the ioloop thread when the kernel info reply is received""" + if msg["msg_type"] == "kernel_info_reply": self._handle_kernel_info_reply(msg) self.shell_channel._inspect = None diff --git a/jupyter_client/utils.py b/jupyter_client/utils.py index 942932b12..7a3876146 100644 --- a/jupyter_client/utils.py +++ b/jupyter_client/utils.py @@ -3,15 +3,14 @@ - provides utility wrapeprs to run asynchronous functions in a blocking environment. - vendor functions from ipython_genutils that should be retired at some point. """ - -import os -import sys import asyncio import inspect -import nest_asyncio +import os +import sys +import nest_asyncio -if os.name == 'nt' and sys.version_info >= (3, 7): +if os.name == "nt" and sys.version_info >= (3, 7): asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) @@ -24,6 +23,7 @@ def wrapped(*args, **kwargs): asyncio.set_event_loop(loop) nest_asyncio.apply(loop) return loop.run_until_complete(coro(*args, **kwargs)) + wrapped.__doc__ = coro.__doc__ return wrapped diff --git a/jupyter_client/win_interrupt.py b/jupyter_client/win_interrupt.py index fd21e8ef2..05d9be941 100644 --- a/jupyter_client/win_interrupt.py +++ b/jupyter_client/win_interrupt.py @@ -3,9 +3,9 @@ The child needs to explicitly listen for this - see ipykernel.parentpoller.ParentPollerWindows for a Python implementation. """ - import ctypes + def create_interrupt_event(): """Create an interrupt event handle. @@ -18,9 +18,12 @@ def create_interrupt_event(): # handle by new processes. # FIXME: We can clean up this mess by requiring pywin32 for IPython. class SECURITY_ATTRIBUTES(ctypes.Structure): - _fields_ = [ ("nLength", ctypes.c_int), - ("lpSecurityDescriptor", ctypes.c_void_p), - ("bInheritHandle", ctypes.c_int) ] + _fields_ = [ + ("nLength", ctypes.c_int), + ("lpSecurityDescriptor", ctypes.c_void_p), + ("bInheritHandle", ctypes.c_int), + ] + sa = SECURITY_ATTRIBUTES() sa_p = ctypes.pointer(sa) sa.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) @@ -28,12 +31,10 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): sa.bInheritHandle = 1 return ctypes.windll.kernel32.CreateEventA( - sa_p, # lpEventAttributes - False, # bManualReset - False, # bInitialState - '') # lpName + sa_p, False, False, "" # lpEventAttributes # bManualReset # bInitialState + ) # lpName + def send_interrupt(interrupt_handle): - """ Sends an interrupt event using the specified handle. - """ + """Sends an interrupt event using the specified handle.""" ctypes.windll.kernel32.SetEvent(interrupt_handle) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..662812744 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,8 @@ +[tool.black] +line-length = 100 +skip-string-normalization = true +target_version = [ + "py36", + "py37", + "py38", +] diff --git a/setup.cfg b/setup.cfg index dfc8cf3b0..69386707b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,6 @@ +[flake8] +max-line-length = 100 + [bdist_wheel] universal=0 diff --git a/setup.py b/setup.py index 8100b3f09..47280225d 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,11 @@ #!/usr/bin/env python - # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. - import os import sys + from setuptools import setup +from setuptools.command.bdist_egg import bdist_egg # the name of the project name = 'jupyter_client' @@ -17,13 +17,12 @@ packages = [] for d, _, _ in os.walk(pjoin(here, name)): if os.path.exists(pjoin(d, '__init__.py')): - packages.append(d[len(here)+1:].replace(os.path.sep, '.')) + packages.append(d[len(here) + 1 :].replace(os.path.sep, '.')) # noqa version_ns = {} with open(pjoin(here, name, '_version.py')) as f: exec(f.read(), {}, version_ns) -from setuptools.command.bdist_egg import bdist_egg class bdist_egg_disabled(bdist_egg): """Disabled version of bdist_egg @@ -31,29 +30,30 @@ class bdist_egg_disabled(bdist_egg): Prevents setup.py install from performing setuptools' default easy_install, which it should never ever do. """ + def run(self): sys.exit("Aborting implicit building of eggs. Use `pip install .` to install from source.") setup_args = dict( - name = name, - version = version_ns['__version__'], - packages = packages, - description = 'Jupyter protocol implementation and client libraries', + name=name, + version=version_ns['__version__'], + packages=packages, + description='Jupyter protocol implementation and client libraries', long_description=open('README.md').read(), long_description_content_type='text/markdown', - author = 'Jupyter Development Team', - author_email = 'jupyter@googlegroups.com', - url = 'https://jupyter.org', - license = 'BSD', - platforms = "Linux, Mac OS X, Windows", - keywords = ['Interactive', 'Interpreter', 'Shell', 'Web'], - project_urls = { + author='Jupyter Development Team', + author_email='jupyter@googlegroups.com', + url='https://jupyter.org', + license='BSD', + platforms="Linux, Mac OS X, Windows", + keywords=['Interactive', 'Interpreter', 'Shell', 'Web'], + project_urls={ 'Documentation': 'https://jupyter-client.readthedocs.io', 'Source': 'https://github.com/jupyter/jupyter_client/', 'Tracker': 'https://github.com/jupyter/jupyter_client/issues', }, - classifiers = [ + classifiers=[ 'Framework :: Jupyter', 'Intended Audience :: Developers', 'Intended Audience :: Education', @@ -68,7 +68,7 @@ def run(self): 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', ], - install_requires = [ + install_requires=[ 'traitlets', 'jupyter_core>=4.6.0', 'pyzmq>=13', @@ -76,8 +76,8 @@ def run(self): 'tornado>=4.1', 'nest-asyncio>=1.5', ], - python_requires = '>=3.5', - extras_require = { + python_requires='>=3.5', + extras_require={ 'test': [ 'async_generator', 'ipykernel', @@ -88,13 +88,14 @@ def run(self): 'pytest-timeout', 'pytest', 'mypy', + 'pre-commit', ], 'doc': open('docs/requirements.txt').read().splitlines(), }, - cmdclass = { + cmdclass={ 'bdist_egg': bdist_egg if 'bdist_egg' in sys.argv else bdist_egg_disabled, }, - entry_points = { + entry_points={ 'console_scripts': [ 'jupyter-kernelspec = jupyter_client.kernelspecapp:KernelSpecApp.launch_instance', 'jupyter-run = jupyter_client.runapp:RunApp.launch_instance',