From 0c43e2e039f98d192a926fe042aa7b7e0e260973 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Tue, 13 Aug 2024 15:54:29 +0100 Subject: [PATCH] Move to pyright and fix type errors --- pyproject.toml | 38 +++++++++++++++++++++----------------- src/scanspec/cli.py | 2 +- src/scanspec/core.py | 12 ++++++------ src/scanspec/plot.py | 18 ++++++++++++------ src/scanspec/sphinxext.py | 3 ++- tests/test_specs.py | 2 +- 6 files changed, 43 insertions(+), 32 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e8e6bf18..41ef42cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,11 +11,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", ] description = "Specify step and flyscan paths in a serializable, efficient and Pythonic way" -dependencies = [ - "numpy>=2", - "click>=8.1", - "pydantic>=2.0", -] +dependencies = ["numpy>=2", "click>=8.1", "pydantic>=2.0"] dynamic = ["version"] license.file = "LICENSE" readme = "README.md" @@ -33,11 +29,11 @@ dev = [ "scanspec[plotting]", "scanspec[service]", "copier", - "mypy", "myst-parser", "pipdeptree", "pre-commit", "pydata-sphinx-theme>=0.12", + "pyright", "pytest", "pytest-cov", "ruff", @@ -65,8 +61,9 @@ name = "Tom Cobb" [tool.setuptools_scm] write_to = "src/scanspec/_version.py" -[tool.mypy] -ignore_missing_imports = true # Ignore missing stubs in imported modules +[tool.pyright] +# strict = ["src", "tests"] +reportMissingImports = false # Ignore missing stubs in imported modules [tool.pytest.ini_options] # Run pytest with all our checkers, and don't spam us with massive tracebacks on error @@ -99,12 +96,12 @@ passenv = * allowlist_externals = pytest pre-commit - mypy + pyright sphinx-build sphinx-autobuild commands = pre-commit: pre-commit run --all-files {posargs} - type-checking: mypy src tests {posargs} + type-checking: pyright src tests {posargs} tests: pytest --cov=scanspec --cov-report term --cov-report xml:cov.xml {posargs} docs: sphinx-{posargs:build -E --keep-going} -T docs build/html """ @@ -115,14 +112,21 @@ line-length = 88 [tool.ruff.lint] extend-select = [ - "B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b - "C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4 - "E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e - "F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f - "W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w - "I", # isort - https://docs.astral.sh/ruff/rules/#isort-i - "UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up + "B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b + "C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4 + "E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e + "F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f + "W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w + "I", # isort - https://docs.astral.sh/ruff/rules/#isort-i + "UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up + "SLF", # self - https://docs.astral.sh/ruff/settings/#lintflake8-self ] ignore = [ "B008", # We use function calls in service arguments ] + +[tool.ruff.lint.per-file-ignores] +# By default, private member access is allowed in tests +# See https://github.com/DiamondLightSource/python-copier-template/issues/154 +# Remove this line to forbid private member access in tests +"tests/**/*" = ["SLF001"] diff --git a/src/scanspec/cli.py b/src/scanspec/cli.py index 7c0e4de3..bcffd072 100644 --- a/src/scanspec/cli.py +++ b/src/scanspec/cli.py @@ -25,7 +25,7 @@ def cli(ctx, log_level: str): # if no command is supplied, print the help message if ctx.invoked_subcommand is None: - click.echo(cli.get_help(ctx)) + click.echo(cli.get_help(ctx)) # type: ignore @cli.command() diff --git a/src/scanspec/core.py b/src/scanspec/core.py index 74e0dffd..ac2f2034 100644 --- a/src/scanspec/core.py +++ b/src/scanspec/core.py @@ -35,11 +35,14 @@ StrictConfig: ConfigDict = {"extra": "forbid"} +C = TypeVar("C") +T = TypeVar("T", type, Callable) + def discriminated_union_of_subclasses( - super_cls: type, + super_cls: type[C], discriminator: str = "type", -) -> type: +) -> type[C]: """Add all subclasses of super_cls to a discriminated union. For all subclasses of super_cls, add a discriminator field to identify @@ -137,9 +140,6 @@ def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler): return super_cls -T = TypeVar("T", type, Callable) - - def uses_tagged_union(cls_or_func: T) -> T: """ T = TypeVar("T", type, Callable) @@ -616,7 +616,7 @@ def consume(self, num: int | None = None) -> Frames[Axis]: def __len__(self) -> int: """Number of frames left in a scan, reduces when `consume` is called.""" - return self.end_index - self.index + return int(self.end_index - self.index) class Midpoints(Generic[Axis]): diff --git a/src/scanspec/plot.py b/src/scanspec/plot.py index 43311663..e4fc1e8c 100644 --- a/src/scanspec/plot.py +++ b/src/scanspec/plot.py @@ -33,7 +33,7 @@ def __init__(self, xs, ys, zs, *args, **kwargs): # Added here because of https://github.com/matplotlib/matplotlib/issues/21688 def do_3d_projection(self, renderer=None): xs3d, ys3d, zs3d = self._verts3d - xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M) + xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M) # type: ignore self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) return np.min(zs) @@ -109,11 +109,17 @@ def plot_spec(spec: Spec[Any], title: str | None = None): # Setup axes if ndims > 2: plt.figure(figsize=(6, 6)) - plt_axes: Axes3D = plt.axes(projection="3d") + plt_axes = plt.axes(projection="3d") plt_axes.grid(False) - plt_axes.set_zlabel(axes[-3]) - plt_axes.set_ylabel(axes[-2]) - plt_axes.view_init(elev=15) + if isinstance(plt_axes, Axes3D): + plt_axes.set_zlabel(axes[-3]) + plt_axes.set_ylabel(axes[-2]) + plt_axes.view_init(elev=15) + else: + raise TypeError( + "Expected matplotlib to create an Axes3D object, " + f"instead got: {plt_axes}" + ) elif ndims == 2: plt.figure(figsize=(6, 6)) plt_axes = plt.axes() @@ -208,7 +214,7 @@ def plot_spec(spec: Spec[Any], title: str | None = None): _plot_arrow(plt_axes, arrow_arr) elif splines: # Plot the starting arrow in the direction of the first point - arrow_arr = [(2 * a[0] - a[1], a[0]) for a in splines[0]] + arrow_arr = [np.array([2 * a[0] - a[1], a[0]]) for a in splines[0]] _plot_arrow(plt_axes, arrow_arr) else: # First point isn't moving, put a right caret marker diff --git a/src/scanspec/sphinxext.py b/src/scanspec/sphinxext.py index ecde40d9..6a1e2630 100644 --- a/src/scanspec/sphinxext.py +++ b/src/scanspec/sphinxext.py @@ -1,5 +1,6 @@ from contextlib import contextmanager +from docutils.statemachine import StringList from matplotlib.sphinxext import plot_directive from . import __version__ @@ -25,7 +26,7 @@ class ExampleSpecDirective(plot_directive.PlotDirective): """Runs `plot_spec` on the ``spec`` definied in the content.""" def run(self): - self.content = ( + self.content = StringList( ["# Example Spec", "", "from scanspec.plot import plot_spec"] + [str(x) for x in self.content] + ["plot_spec(spec)"] diff --git a/tests/test_specs.py b/tests/test_specs.py index dad0ffb6..70e01fcc 100644 --- a/tests/test_specs.py +++ b/tests/test_specs.py @@ -492,7 +492,7 @@ def test_gap_repeat() -> None: def test_gap_repeat_non_snake() -> None: # Check that no gap doesn't propogate to dim.gap for non-snaked axis - spec: Spec[Any] = Repeat(3, gap=False) * Line.bounded(x, 11, 19, 1) + spec: Spec[str] = Repeat(3, gap=False) * Line.bounded(x, 11, 19, 1) dim = spec.frames() assert len(dim) == 3 assert dim.lower == {x: pytest.approx([11, 11, 11])}