Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typing errors and set up mypy workflow action #176

Merged
merged 72 commits into from
May 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
0101d98
Address mypy errors
bwohlberg Jan 13, 2022
1f3af74
Change action name
bwohlberg Jan 13, 2022
296c1f6
Add mypy github action
bwohlberg Jan 13, 2022
68f3406
Fix typing errors
bwohlberg Jan 14, 2022
7c67fec
Fix typing errors
bwohlberg Jan 14, 2022
837214c
Fix typing errors
bwohlberg Jan 14, 2022
9de6723
Correct job name
bwohlberg Jan 14, 2022
ecef089
Merge branch 'main' into brendt/typing
bwohlberg Jan 15, 2022
7cbea48
Fix typing errors
bwohlberg Jan 15, 2022
9f8667d
Fix typing errors
bwohlberg Jan 19, 2022
efcd355
Merge remote-tracking branch 'origin/main' into brendt/typing
Michael-T-McCann Jan 19, 2022
ea327ec
Merge branch 'brendt/typing' of github.com:lanl/scico into brendt/typing
bwohlberg Jan 19, 2022
f81b40b
Fix typing error
bwohlberg Jan 19, 2022
5ce49eb
Merge branch 'main' into brendt/typing
bwohlberg Jan 27, 2022
0e47246
Merge branch 'main' into brendt/typing
bwohlberg Jan 27, 2022
de30803
Merge branch 'main' into brendt/typing
bwohlberg Jan 27, 2022
55ce354
Fix merge error
bwohlberg Jan 27, 2022
126a422
Merge branch 'main' into brendt/typing
bwohlberg Jan 28, 2022
3681a34
Merge branch 'main' into brendt/typing
bwohlberg Feb 1, 2022
ec3fd93
Address test failure
bwohlberg Feb 1, 2022
7c6c0a9
Merge branch 'main' into brendt/typing
bwohlberg Feb 4, 2022
1ff915b
Fix typing errors
bwohlberg Feb 5, 2022
0ca1c88
Fix typing errors and clean up itstat default mechanism
bwohlberg Feb 5, 2022
06b6613
Fix a bug and some typing errors
bwohlberg Feb 5, 2022
b537927
Exclude modules with dynamically generated functions
bwohlberg Feb 5, 2022
36cfab2
Make docstring phrasing imperative
bwohlberg Feb 5, 2022
acda6d8
Suppress typing errors
bwohlberg Feb 5, 2022
6910348
Fix typing error and clean up
bwohlberg Feb 5, 2022
665fdc5
Fix typing errors
bwohlberg Feb 5, 2022
cf43c14
Fix or suppress typing errors
bwohlberg Feb 5, 2022
8d2b0cc
Typo fix
bwohlberg Feb 5, 2022
99cedad
Merge branch 'main' into brendt/typing
bwohlberg Feb 8, 2022
67e6b24
Revert erroneous attempt to resolve typing error
bwohlberg Feb 8, 2022
58ad1dd
Typing annotation fix and suppress some spurious typing errors
bwohlberg Feb 8, 2022
3457d75
Merge branch 'main' into brendt/typing
bwohlberg Feb 8, 2022
08f078a
Address typing error and rephrase error messages
bwohlberg Feb 9, 2022
5d91e85
Merge branch 'main' into brendt/typing
bwohlberg Feb 10, 2022
abb30e6
Merge branch 'main' into brendt/typing
bwohlberg Feb 10, 2022
4abb593
Merge branch 'main' into brendt/typing
bwohlberg Feb 11, 2022
9bfeb3c
Fix some typing errors
bwohlberg Feb 12, 2022
ebc4392
Supress some typing errors
bwohlberg Feb 12, 2022
f9b6cda
Address typing error
bwohlberg Feb 12, 2022
3ef307b
Fix typing errors and docstring style issues
bwohlberg Feb 12, 2022
88100e5
Address test failure
bwohlberg Feb 12, 2022
a08ddf8
Merge branch 'main' into brendt/typing
bwohlberg Feb 13, 2022
0c58b90
Merge branch 'main' into brendt/typing
bwohlberg Feb 15, 2022
60cc81a
Suppress/address some typing errors
bwohlberg Feb 18, 2022
04bbe11
Merge branch 'main' into brendt/typing
bwohlberg Feb 23, 2022
ca7b609
Merge branch 'main' into brendt/typing
bwohlberg Mar 3, 2022
9205279
Fix merge error
bwohlberg Mar 3, 2022
e5c7f2c
Merge branch 'main' into brendt/typing
bwohlberg Mar 21, 2022
a87af57
Merge branch 'main' into brendt/typing
bwohlberg Apr 8, 2022
6e62734
Merge branch 'main' into brendt/typing
bwohlberg Apr 12, 2022
22c8882
Merge branch 'main' into brendt/typing
bwohlberg Apr 22, 2022
f1aec60
Fix merge error
bwohlberg Apr 22, 2022
6f02f7b
Improve code style
bwohlberg Apr 23, 2022
6ab90d0
Resolve mypy errors
bwohlberg Apr 23, 2022
feb1830
Resolve mypy errors
bwohlberg Apr 23, 2022
07537c5
Resolve mypy errors
bwohlberg Apr 23, 2022
d0edf91
Modify mypy configuration in workflow
bwohlberg Apr 23, 2022
8e48de4
Trivial edit
bwohlberg Apr 24, 2022
72d2cd4
Bug fix
bwohlberg Apr 24, 2022
9fbf68d
Consistency improvement
bwohlberg Apr 24, 2022
68d2850
Address CodeFactor complex function
Michael-T-McCann Apr 25, 2022
7c61e25
Fix type error
Michael-T-McCann Apr 26, 2022
2f5b14c
Switch back to ignore, can't solve this problem without code bloat
Michael-T-McCann Apr 26, 2022
1fa46bb
Fixed extra 's' in `LinearOperator`s string.
FernandoDavis Apr 28, 2022
e3d8974
Trivial edit
bwohlberg Apr 28, 2022
52eba3b
Use type guard rather than type ignore
bwohlberg Apr 28, 2022
14f1450
Merge branch 'brendt/typing' of github.com:lanl/scico into brendt/typing
bwohlberg Apr 28, 2022
9bd0214
Fix EllipsisType import
bwohlberg Apr 29, 2022
61779ff
Minor edits for docs style
bwohlberg Apr 30, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Install and run mypy

name: mypy

# Controls when the workflow will run
on:
# Triggers the workflow on push or pull request events but only for the main branch
push:
branches: [ main ]
pull_request:
branches: [ main ]

# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

jobs:
mypy:
# The type of runner that the job will run on
runs-on: ubuntu-latest

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
with:
submodules: recursive
- name: Install Python 3
uses: actions/setup-python@v1
with:
python-version: 3.8
- name: Install dependencies
run: |
pip install mypy
- name: Run mypy
run: |
mypy --follow-imports=skip --ignore-missing-imports --exclude "(numpy|test)" scico/
2 changes: 1 addition & 1 deletion scico/_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,5 +203,5 @@ def __call__(self, x: JaxArray) -> JaxArray:
x = x.reshape((1,) + x.shape + (1,))
elif x.ndim == 3:
x = x.reshape((1,) + x.shape)
y = self.model.apply(self.variables, x, train=False, mutable=False)
y = self.model.apply(self.variables, x, train=False, mutable=False) # type: ignore
return y.reshape(x_shape)
18 changes: 12 additions & 6 deletions scico/_generic_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def __init__(
#: Dtype of input
self.input_dtype: DType

#: Dtype of operator
self.dtype: DType

if isinstance(input_shape, int):
self.input_shape = (input_shape,)
else:
Expand All @@ -140,7 +143,7 @@ def __init__(
if output_shape is None or output_dtype is None:
tmp = self(snp.zeros(self.input_shape, dtype=input_dtype))
if output_shape is None:
self.output_shape = tmp.shape
self.output_shape = tmp.shape # type: ignore
else:
self.output_shape = (output_shape,) if isinstance(output_shape, int) else output_shape

Expand Down Expand Up @@ -312,10 +315,11 @@ def freeze(self, argnum: int, val: Union[JaxArray, BlockArray]) -> Operator:
f"{self.input_shape[argnum]}, got {val.shape}"
)

input_shape = tuple(s for i, s in enumerate(self.input_shape) if i != argnum)
input_shape: Union[Shape, BlockShape]
input_shape = tuple(s for i, s in enumerate(self.input_shape) if i != argnum) # type: ignore

if len(input_shape) == 1:
input_shape = input_shape[0]
input_shape = input_shape[0] # type: ignore

def concat_args(args):
# Creates a blockarray with args and the frozen value in the correct place
Expand Down Expand Up @@ -456,9 +460,9 @@ def __init__(
)

if not hasattr(self, "_adj"):
self._adj = None
self._adj: Optional[Callable] = None
if not hasattr(self, "_gram"):
self._gram = None
self._gram: Optional[Callable] = None
if callable(adj_fn):
self._adj = adj_fn
self._gram = lambda x: self.adj(self(x))
Expand Down Expand Up @@ -584,7 +588,7 @@ def adj(
input `y`.

Args:
y: Point at which to compute adjoint. If `y` is
y: Point at which to compute adjoint. If `y` is
:class:`DeviceArray` or :class:`.BlockArray`, must have
`shape == self.output_shape`. If `y` is a
:class:`.LinearOperator`, must have
Expand All @@ -605,6 +609,7 @@ def adj(
f"""Shapes do not conform: input array with shape {y.shape} does not match
LinearOperator output_shape {self.output_shape}"""
)
assert self._adj is not None
return self._adj(y)

@property
Expand Down Expand Up @@ -715,6 +720,7 @@ def gram(
"""
if self._gram is None:
self._set_adjoint()
assert self._gram is not None
return self._gram(x)


Expand Down
4 changes: 2 additions & 2 deletions scico/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def variable_assign_value(path: str, var: str) -> Any:
with open(path) as f:
try:
# See http://stackoverflow.com/questions/2058802
value = parse(next(filter(lambda line: line.startswith(var), f))).body[0].value.s
value = parse(next(filter(lambda line: line.startswith(var), f))).body[0].value.s # type: ignore
except StopIteration:
raise RuntimeError(f"Could not find initialization of variable {var}")
return value
Expand Down Expand Up @@ -70,7 +70,7 @@ def current_git_hash() -> Optional[str]: # nosec pragma: no cover
Short git hash of current commit, or ``None`` if no git repo found.
"""
process = Popen(["git", "rev-parse", "--short", "HEAD"], shell=False, stdout=PIPE, stderr=PIPE)
git_hash = process.communicate()[0].strip().decode("utf-8")
git_hash: Optional[str] = process.communicate()[0].strip().decode("utf-8")
if git_hash == "":
git_hash = None
return git_hash
Expand Down
20 changes: 10 additions & 10 deletions scico/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,23 @@ def __init__(
if not isinstance(fields, dict):
raise TypeError("Parameter fields must be an instance of dict")
# Subsampling rate of results that are to be displayed
self.period = period
self.period: int = period
# Flag indicating whether to display and overwrite, or not display at all
self.overwrite = overwrite
self.overwrite: bool = overwrite
# Number of spaces seperating fields in displayed tables
self.colsep = colsep
self.colsep: int = colsep
# Main list of inserted values
self.iterations = []
self.iterations: List = []
# Total length of header string in displayed tables
self.headlength = 0
self.headlength: int = 0
# List of field names
self.fieldname = []
self.fieldname: List[str] = []
# List of field format strings
self.fieldformat = []
self.fieldformat: List[str] = []
# List of lengths of each field in displayed tables
self.fieldlength = []
self.fieldlength: List[int] = []
# Names of fields in namedtuple used to record iteration values
self.tuplefields = []
self.tuplefields: List[str] = []
# Compile regex for decomposing format strings
fmre = re.compile(r"%(\+?-?)((?:\d+)?)(\.?)((?:\d+)?)([a-z])")
# Iterate over field names
Expand Down Expand Up @@ -131,7 +131,7 @@ def __init__(
self.headlength -= colsep

# Construct namedtuple used to record values
self.IterTuple = namedtuple("IterationStatsTuple", self.tuplefields)
self.IterTuple = namedtuple("IterationStatsTuple", self.tuplefields) # type: ignore

# Set up table header string display if requested
self.display = display
Expand Down
26 changes: 13 additions & 13 deletions scico/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import os
import tempfile
import zipfile
from typing import List, Optional
from typing import List, Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -201,7 +201,7 @@ def tile_volume_slices(x: Array, sep_width: int = 10) -> Array:
"""

if x.ndim == 3:
fshape = (x.shape[0], sep_width)
fshape: Tuple[int, ...] = (x.shape[0], sep_width)
else:
fshape = (x.shape[0], sep_width, 3)
out = snp.concatenate(
Expand All @@ -214,9 +214,9 @@ def tile_volume_slices(x: Array, sep_width: int = 10) -> Array:
)

if x.ndim == 3:
fshape0 = (sep_width, out.shape[1])
fshape1 = (x.shape[2], x.shape[2] + sep_width)
trans = (1, 0)
fshape0: Tuple[int, ...] = (sep_width, out.shape[1])
fshape1: Tuple[int, ...] = (x.shape[2], x.shape[2] + sep_width)
trans: Tuple[int, ...] = (1, 0)

else:
fshape0 = (sep_width, out.shape[1], 3)
Expand Down Expand Up @@ -300,7 +300,7 @@ def create_3D_foam_phantom(
r_std: float = 0.001,
pad: float = 0.01,
is_random: bool = False,
):
) -> JaxArray:
"""Construct a 3D phantom with random radii and centers.

Args:
Expand All @@ -316,7 +316,7 @@ def create_3D_foam_phantom(
process deterministic. Default ``False``.

Returns:
3D phantom of shape im_shape
3D phantom of shape `im_shape`.
"""
c_lo = 0.0
c_hi = 1.0
Expand All @@ -331,10 +331,10 @@ def create_3D_foam_phantom(
radii = r_std * np.random.randn(N_sphere) + r_mean

im = snp.zeros(im_shape) + c_lo
for c, r in zip(centers, radii):
for c, r in zip(centers, radii): # type: ignore
dist = snp.sum((x - c) ** 2, axis=-1)
if snp.mean(im[dist < r**2] - c_lo) < 0.01 * c_hi:
# In numpy: im[dist < r**2] = c_hi
# equivalent to im[dist < r**2] = c_hi in numpy
im = im.at[dist < r**2].set(c_hi)

return im
Expand All @@ -354,13 +354,13 @@ def spnoise(img: Array, nfrac: float, nmin: float = 0.0, nmax: float = 1.0) -> A
"""

if isinstance(img, np.ndarray):
spm = np.random.uniform(-1.0, 1.0, img.shape)
spm = np.random.uniform(-1.0, 1.0, img.shape) # type: ignore
imgn = img.copy()
imgn[spm < nfrac - 1.0] = nmin
imgn[spm > 1.0 - nfrac] = nmax
else:
spm, key = random.uniform(shape=img.shape, minval=-1.0, maxval=1.0, seed=0)
spm, key = random.uniform(shape=img.shape, minval=-1.0, maxval=1.0, seed=0) # type: ignore
imgn = img
imgn = imgn.at[spm < nfrac - 1.0].set(nmin)
imgn = imgn.at[spm > 1.0 - nfrac].set(nmax)
imgn = imgn.at[spm < nfrac - 1.0].set(nmin) # type: ignore
imgn = imgn.at[spm > 1.0 - nfrac].set(nmax) # type: ignore
return imgn
6 changes: 3 additions & 3 deletions scico/functional/_denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, is_rgb: bool = False):
self.is_rgb = is_rgb
super().__init__()

def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray:
def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: # type: ignore
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type ignore used to suppress error Signature of "prox" incompatible with supertype "Functional". Is there any way of restructuring the code to avoid this violation of the Liskov substitution principle?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A proposal: Union[JaxArray, BlockArray] is used whenever a vector is expected. If the operator/functional can't deal with inputs of a certain shape (including not dealing with BlockArrays of any shape), that's a runtime error.

By analogy, a function may take a float and then throw a runtime error if it is negative. We don't try to change the type annotation to positive_float.

Copy link
Contributor

@Michael-T-McCann Michael-T-McCann Apr 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, we do something at the subclass level, maybe making prox and eval abstract (or not defined at all?). That helps with Liskov substitution because a generic Functional would have no useful behavior at all.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, our code has quite a few Liskov violations, so that would be worth considering. On "A proposal", that's a reasonable analogy, but "mypy" doesn't see it that way: the type narrowing was introduced to address other typing complaints.

r"""Apply BM3D denoiser.

Args:
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(self):
r"""Initialize a :class:`BM4D` object."""
super().__init__()

def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray:
def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: # type: ignore
r"""Apply BM4D denoiser.

Args:
Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(self, variant: str = "6M"):
"""
self.dncnn = denoiser.DnCNN(variant)

def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray:
def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: # type: ignore
r"""Apply DnCNN denoiser.

*Warning*: The `lam` parameter is ignored, and has no effect on
Expand Down
12 changes: 4 additions & 8 deletions scico/functional/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,8 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
x: Point at which to evaluate this functional.

"""
if not self.has_eval:
raise NotImplementedError(
f"Functional {type(self)} cannot be evaluated; has_eval={self.has_eval}"
)
# Functionals that can be evaluated should override this method.
raise NotImplementedError(f"Functional {type(self)} cannot be evaluated.")
bwohlberg marked this conversation as resolved.
Show resolved Hide resolved

def prox(
self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs
Expand All @@ -91,10 +89,8 @@ def prox(
classes. These include `x0`, an initial guess for the
minimizer in the definition of :math:`\mathrm{prox}`.
"""
if not self.has_prox:
raise NotImplementedError(
f"Functional {type(self)} does not have a prox; has_prox={self.has_prox}"
)
# Functionals that have a prox should override this method.
raise NotImplementedError(f"Functional {type(self)} does not have a prox.")

def conj_prox(
self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs
Expand Down
14 changes: 7 additions & 7 deletions scico/linop/_circconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import math
import operator
from functools import partial
from typing import Optional
from typing import Optional, Tuple

import numpy as np

Expand All @@ -19,7 +19,7 @@
import scico.numpy as snp
from scico._generic_operators import Operator
from scico.numpy.util import is_nested
from scico.typing import DType, JaxArray, Shape
from scico.typing import Array, DType, JaxArray, Shape

from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar

Expand Down Expand Up @@ -123,9 +123,9 @@ def __init__(
self.h_dft = snp.fft.fftn(h, s=fft_shape, axes=fft_axes)
output_dtype = result_type(h.dtype, input_dtype)

if h_center is not None:
if self.h_center is not None:
offset = -self.h_center
shifts = np.ix_(
shifts: Tuple[Array, ...] = np.ix_(
*tuple(
np.exp(-1j * k * 2 * np.pi * np.fft.fftfreq(s))
for k, s in zip(offset, input_shape[-self.ndims :])
Expand Down Expand Up @@ -173,7 +173,7 @@ def _eval(self, x: JaxArray) -> JaxArray:
hx = hx.real
return hx

def _adj(self, x: JaxArray) -> JaxArray:
def _adj(self, x: JaxArray) -> JaxArray: # type: ignore
x_dft = snp.fft.fftn(x, axes=self.ifft_axes)
H_adj_x = snp.fft.ifftn(
snp.conj(self.h_dft) * x_dft,
Expand Down Expand Up @@ -266,7 +266,7 @@ def from_operator(
ndims = ndims

if center is None:
center = tuple(d // 2 for d in H.input_shape[-ndims:])
center = tuple(d // 2 for d in H.input_shape[-ndims:]) # type: ignore

# compute impulse response
d = snp.zeros(H.input_shape, H.input_dtype)
Expand All @@ -276,7 +276,7 @@ def from_operator(
# build CircularConvolve
return CircularConvolve(
Hd,
H.input_shape,
H.input_shape, # type: ignore
ndims=ndims,
input_dtype=H.input_dtype,
h_center=snp.array(center),
Expand Down
9 changes: 4 additions & 5 deletions scico/linop/_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,18 @@ def __init__(
functions of the LinearOperator.
"""

self.axes = parse_axes(axes, input_shape)

if axes is None:
axes_list = range(len(input_shape))
axes_list = tuple(range(len(input_shape)))
elif isinstance(axes, (list, tuple)):
axes_list = axes
else:
axes_list = (axes,)
single_kwargs = dict(append=append, circular=circular, jit=False, input_dtype=input_dtype)
self.axes = parse_axes(axes_list, input_shape)
single_kwargs = dict(input_dtype=input_dtype, append=append, circular=circular, jit=False)
ops = [FiniteDifferenceSingleAxis(axis, input_shape, **single_kwargs) for axis in axes_list]

super().__init__(
ops,
ops, # type: ignore
jit=jit,
**kwargs,
)
Expand Down
Loading