Skip to content

Commit

Permalink
Fix typing errors and set up mypy workflow action (#176)
Browse files Browse the repository at this point in the history
* Address mypy errors

* Change action name

* Add mypy github action

* Fix typing errors

* Correct job name

* Fix typing error

* Fix merge error

* Address test failure

* Fix typing errors

* Fix typing errors and clean up itstat default mechanism

* Fix a bug and some typing errors

* Exclude modules with dynamically generated functions

* Make docstring phrasing imperative

* Suppress typing errors

* Fix typing error and clean up

* Fix typing errors

* Fix or suppress typing errors

* Typo fix

* Revert erroneous attempt to resolve typing error

* Typing annotation fix and suppress some spurious typing errors

* Address typing error and rephrase error messages

* Fix some typing errors

* Supress some typing errors

* Address typing error

* Fix typing errors and docstring style issues

* Address test failure

* Suppress/address some typing errors

* Fix merge error

* Fix merge error

* Improve code style

* Resolve mypy errors

* Resolve mypy errors

* Resolve mypy errors

* Modify mypy configuration in workflow

* Trivial edit

* Bug fix

* Consistency improvement

* Address CodeFactor complex function

* Fix type error

* Switch back to ignore, can't solve this problem without code bloat

* Fixed extra 's' in `LinearOperator`s string.

* Trivial edit

* Use type guard rather than type ignore

* Fix EllipsisType import

* Minor edits for docs style

Co-authored-by: Michael-T-McCann <[email protected]>
Co-authored-by: Fernando Davis <[email protected]>
  • Loading branch information
3 people authored May 1, 2022
1 parent 4381be5 commit 157f2ba
Show file tree
Hide file tree
Showing 29 changed files with 272 additions and 196 deletions.
36 changes: 36 additions & 0 deletions .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Install and run mypy

name: mypy

# Controls when the workflow will run
on:
# Triggers the workflow on push or pull request events but only for the main branch
push:
branches: [ main ]
pull_request:
branches: [ main ]

# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

jobs:
mypy:
# The type of runner that the job will run on
runs-on: ubuntu-latest

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
# Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
- uses: actions/checkout@v2
with:
submodules: recursive
- name: Install Python 3
uses: actions/setup-python@v1
with:
python-version: 3.8
- name: Install dependencies
run: |
pip install mypy
- name: Run mypy
run: |
mypy --follow-imports=skip --ignore-missing-imports --exclude "(numpy|test)" scico/
2 changes: 1 addition & 1 deletion scico/_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,5 +203,5 @@ def __call__(self, x: JaxArray) -> JaxArray:
x = x.reshape((1,) + x.shape + (1,))
elif x.ndim == 3:
x = x.reshape((1,) + x.shape)
y = self.model.apply(self.variables, x, train=False, mutable=False)
y = self.model.apply(self.variables, x, train=False, mutable=False) # type: ignore
return y.reshape(x_shape)
18 changes: 12 additions & 6 deletions scico/_generic_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def __init__(
#: Dtype of input
self.input_dtype: DType

#: Dtype of operator
self.dtype: DType

if isinstance(input_shape, int):
self.input_shape = (input_shape,)
else:
Expand All @@ -140,7 +143,7 @@ def __init__(
if output_shape is None or output_dtype is None:
tmp = self(snp.zeros(self.input_shape, dtype=input_dtype))
if output_shape is None:
self.output_shape = tmp.shape
self.output_shape = tmp.shape # type: ignore
else:
self.output_shape = (output_shape,) if isinstance(output_shape, int) else output_shape

Expand Down Expand Up @@ -312,10 +315,11 @@ def freeze(self, argnum: int, val: Union[JaxArray, BlockArray]) -> Operator:
f"{self.input_shape[argnum]}, got {val.shape}"
)

input_shape = tuple(s for i, s in enumerate(self.input_shape) if i != argnum)
input_shape: Union[Shape, BlockShape]
input_shape = tuple(s for i, s in enumerate(self.input_shape) if i != argnum) # type: ignore

if len(input_shape) == 1:
input_shape = input_shape[0]
input_shape = input_shape[0] # type: ignore

def concat_args(args):
# Creates a blockarray with args and the frozen value in the correct place
Expand Down Expand Up @@ -456,9 +460,9 @@ def __init__(
)

if not hasattr(self, "_adj"):
self._adj = None
self._adj: Optional[Callable] = None
if not hasattr(self, "_gram"):
self._gram = None
self._gram: Optional[Callable] = None
if callable(adj_fn):
self._adj = adj_fn
self._gram = lambda x: self.adj(self(x))
Expand Down Expand Up @@ -584,7 +588,7 @@ def adj(
input `y`.
Args:
y: Point at which to compute adjoint. If `y` is
y: Point at which to compute adjoint. If `y` is
:class:`DeviceArray` or :class:`.BlockArray`, must have
`shape == self.output_shape`. If `y` is a
:class:`.LinearOperator`, must have
Expand All @@ -605,6 +609,7 @@ def adj(
f"""Shapes do not conform: input array with shape {y.shape} does not match
LinearOperator output_shape {self.output_shape}"""
)
assert self._adj is not None
return self._adj(y)

@property
Expand Down Expand Up @@ -715,6 +720,7 @@ def gram(
"""
if self._gram is None:
self._set_adjoint()
assert self._gram is not None
return self._gram(x)


Expand Down
4 changes: 2 additions & 2 deletions scico/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def variable_assign_value(path: str, var: str) -> Any:
with open(path) as f:
try:
# See http://stackoverflow.com/questions/2058802
value = parse(next(filter(lambda line: line.startswith(var), f))).body[0].value.s
value = parse(next(filter(lambda line: line.startswith(var), f))).body[0].value.s # type: ignore
except StopIteration:
raise RuntimeError(f"Could not find initialization of variable {var}")
return value
Expand Down Expand Up @@ -70,7 +70,7 @@ def current_git_hash() -> Optional[str]: # nosec pragma: no cover
Short git hash of current commit, or ``None`` if no git repo found.
"""
process = Popen(["git", "rev-parse", "--short", "HEAD"], shell=False, stdout=PIPE, stderr=PIPE)
git_hash = process.communicate()[0].strip().decode("utf-8")
git_hash: Optional[str] = process.communicate()[0].strip().decode("utf-8")
if git_hash == "":
git_hash = None
return git_hash
Expand Down
20 changes: 10 additions & 10 deletions scico/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,23 @@ def __init__(
if not isinstance(fields, dict):
raise TypeError("Parameter fields must be an instance of dict")
# Subsampling rate of results that are to be displayed
self.period = period
self.period: int = period
# Flag indicating whether to display and overwrite, or not display at all
self.overwrite = overwrite
self.overwrite: bool = overwrite
# Number of spaces seperating fields in displayed tables
self.colsep = colsep
self.colsep: int = colsep
# Main list of inserted values
self.iterations = []
self.iterations: List = []
# Total length of header string in displayed tables
self.headlength = 0
self.headlength: int = 0
# List of field names
self.fieldname = []
self.fieldname: List[str] = []
# List of field format strings
self.fieldformat = []
self.fieldformat: List[str] = []
# List of lengths of each field in displayed tables
self.fieldlength = []
self.fieldlength: List[int] = []
# Names of fields in namedtuple used to record iteration values
self.tuplefields = []
self.tuplefields: List[str] = []
# Compile regex for decomposing format strings
fmre = re.compile(r"%(\+?-?)((?:\d+)?)(\.?)((?:\d+)?)([a-z])")
# Iterate over field names
Expand Down Expand Up @@ -131,7 +131,7 @@ def __init__(
self.headlength -= colsep

# Construct namedtuple used to record values
self.IterTuple = namedtuple("IterationStatsTuple", self.tuplefields)
self.IterTuple = namedtuple("IterationStatsTuple", self.tuplefields) # type: ignore

# Set up table header string display if requested
self.display = display
Expand Down
26 changes: 13 additions & 13 deletions scico/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import os
import tempfile
import zipfile
from typing import List, Optional
from typing import List, Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -201,7 +201,7 @@ def tile_volume_slices(x: Array, sep_width: int = 10) -> Array:
"""

if x.ndim == 3:
fshape = (x.shape[0], sep_width)
fshape: Tuple[int, ...] = (x.shape[0], sep_width)
else:
fshape = (x.shape[0], sep_width, 3)
out = snp.concatenate(
Expand All @@ -214,9 +214,9 @@ def tile_volume_slices(x: Array, sep_width: int = 10) -> Array:
)

if x.ndim == 3:
fshape0 = (sep_width, out.shape[1])
fshape1 = (x.shape[2], x.shape[2] + sep_width)
trans = (1, 0)
fshape0: Tuple[int, ...] = (sep_width, out.shape[1])
fshape1: Tuple[int, ...] = (x.shape[2], x.shape[2] + sep_width)
trans: Tuple[int, ...] = (1, 0)

else:
fshape0 = (sep_width, out.shape[1], 3)
Expand Down Expand Up @@ -300,7 +300,7 @@ def create_3D_foam_phantom(
r_std: float = 0.001,
pad: float = 0.01,
is_random: bool = False,
):
) -> JaxArray:
"""Construct a 3D phantom with random radii and centers.
Args:
Expand All @@ -316,7 +316,7 @@ def create_3D_foam_phantom(
process deterministic. Default ``False``.
Returns:
3D phantom of shape im_shape
3D phantom of shape `im_shape`.
"""
c_lo = 0.0
c_hi = 1.0
Expand All @@ -331,10 +331,10 @@ def create_3D_foam_phantom(
radii = r_std * np.random.randn(N_sphere) + r_mean

im = snp.zeros(im_shape) + c_lo
for c, r in zip(centers, radii):
for c, r in zip(centers, radii): # type: ignore
dist = snp.sum((x - c) ** 2, axis=-1)
if snp.mean(im[dist < r**2] - c_lo) < 0.01 * c_hi:
# In numpy: im[dist < r**2] = c_hi
# equivalent to im[dist < r**2] = c_hi in numpy
im = im.at[dist < r**2].set(c_hi)

return im
Expand All @@ -354,13 +354,13 @@ def spnoise(img: Array, nfrac: float, nmin: float = 0.0, nmax: float = 1.0) -> A
"""

if isinstance(img, np.ndarray):
spm = np.random.uniform(-1.0, 1.0, img.shape)
spm = np.random.uniform(-1.0, 1.0, img.shape) # type: ignore
imgn = img.copy()
imgn[spm < nfrac - 1.0] = nmin
imgn[spm > 1.0 - nfrac] = nmax
else:
spm, key = random.uniform(shape=img.shape, minval=-1.0, maxval=1.0, seed=0)
spm, key = random.uniform(shape=img.shape, minval=-1.0, maxval=1.0, seed=0) # type: ignore
imgn = img
imgn = imgn.at[spm < nfrac - 1.0].set(nmin)
imgn = imgn.at[spm > 1.0 - nfrac].set(nmax)
imgn = imgn.at[spm < nfrac - 1.0].set(nmin) # type: ignore
imgn = imgn.at[spm > 1.0 - nfrac].set(nmax) # type: ignore
return imgn
6 changes: 3 additions & 3 deletions scico/functional/_denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, is_rgb: bool = False):
self.is_rgb = is_rgb
super().__init__()

def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray:
def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: # type: ignore
r"""Apply BM3D denoiser.
Args:
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(self):
r"""Initialize a :class:`BM4D` object."""
super().__init__()

def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray:
def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: # type: ignore
r"""Apply BM4D denoiser.
Args:
Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(self, variant: str = "6M"):
"""
self.dncnn = denoiser.DnCNN(variant)

def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray:
def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: # type: ignore
r"""Apply DnCNN denoiser.
*Warning*: The `lam` parameter is ignored, and has no effect on
Expand Down
12 changes: 4 additions & 8 deletions scico/functional/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,8 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float:
x: Point at which to evaluate this functional.
"""
if not self.has_eval:
raise NotImplementedError(
f"Functional {type(self)} cannot be evaluated; has_eval={self.has_eval}"
)
# Functionals that can be evaluated should override this method.
raise NotImplementedError(f"Functional {type(self)} cannot be evaluated.")

def prox(
self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs
Expand All @@ -91,10 +89,8 @@ def prox(
classes. These include `x0`, an initial guess for the
minimizer in the definition of :math:`\mathrm{prox}`.
"""
if not self.has_prox:
raise NotImplementedError(
f"Functional {type(self)} does not have a prox; has_prox={self.has_prox}"
)
# Functionals that have a prox should override this method.
raise NotImplementedError(f"Functional {type(self)} does not have a prox.")

def conj_prox(
self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs
Expand Down
14 changes: 7 additions & 7 deletions scico/linop/_circconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import math
import operator
from functools import partial
from typing import Optional
from typing import Optional, Tuple

import numpy as np

Expand All @@ -19,7 +19,7 @@
import scico.numpy as snp
from scico._generic_operators import Operator
from scico.numpy.util import is_nested
from scico.typing import DType, JaxArray, Shape
from scico.typing import Array, DType, JaxArray, Shape

from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar

Expand Down Expand Up @@ -123,9 +123,9 @@ def __init__(
self.h_dft = snp.fft.fftn(h, s=fft_shape, axes=fft_axes)
output_dtype = result_type(h.dtype, input_dtype)

if h_center is not None:
if self.h_center is not None:
offset = -self.h_center
shifts = np.ix_(
shifts: Tuple[Array, ...] = np.ix_(
*tuple(
np.exp(-1j * k * 2 * np.pi * np.fft.fftfreq(s))
for k, s in zip(offset, input_shape[-self.ndims :])
Expand Down Expand Up @@ -173,7 +173,7 @@ def _eval(self, x: JaxArray) -> JaxArray:
hx = hx.real
return hx

def _adj(self, x: JaxArray) -> JaxArray:
def _adj(self, x: JaxArray) -> JaxArray: # type: ignore
x_dft = snp.fft.fftn(x, axes=self.ifft_axes)
H_adj_x = snp.fft.ifftn(
snp.conj(self.h_dft) * x_dft,
Expand Down Expand Up @@ -266,7 +266,7 @@ def from_operator(
ndims = ndims

if center is None:
center = tuple(d // 2 for d in H.input_shape[-ndims:])
center = tuple(d // 2 for d in H.input_shape[-ndims:]) # type: ignore

# compute impulse response
d = snp.zeros(H.input_shape, H.input_dtype)
Expand All @@ -276,7 +276,7 @@ def from_operator(
# build CircularConvolve
return CircularConvolve(
Hd,
H.input_shape,
H.input_shape, # type: ignore
ndims=ndims,
input_dtype=H.input_dtype,
h_center=snp.array(center),
Expand Down
9 changes: 4 additions & 5 deletions scico/linop/_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,18 @@ def __init__(
functions of the LinearOperator.
"""

self.axes = parse_axes(axes, input_shape)

if axes is None:
axes_list = range(len(input_shape))
axes_list = tuple(range(len(input_shape)))
elif isinstance(axes, (list, tuple)):
axes_list = axes
else:
axes_list = (axes,)
single_kwargs = dict(append=append, circular=circular, jit=False, input_dtype=input_dtype)
self.axes = parse_axes(axes_list, input_shape)
single_kwargs = dict(input_dtype=input_dtype, append=append, circular=circular, jit=False)
ops = [FiniteDifferenceSingleAxis(axis, input_shape, **single_kwargs) for axis in axes_list]

super().__init__(
ops,
ops, # type: ignore
jit=jit,
**kwargs,
)
Expand Down
Loading

0 comments on commit 157f2ba

Please sign in to comment.