Skip to content

Commit

Permalink
fix type annotations in pgm (#219)
Browse files Browse the repository at this point in the history
* fix type annotations in pgm

* fix return type of x_step
  • Loading branch information
tbalke authored Feb 10, 2022
1 parent 116df6c commit 912b460
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
4 changes: 2 additions & 2 deletions docs/source/style.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ We follow the `Google string conventions <https://google.github.io/styleguide/py
.. code:: Python
state = "active"
print("The state is %s") # Not preferred
print(f"The state is {state}") # Preferred
print("The state is %s" % state) # Not preferred
print(f"The state is {state}") # Preferred
Imports
Expand Down
2 changes: 1 addition & 1 deletion scico/optimize/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class SubproblemSolver:
solver is attached.
"""

def internal_init(self, admm: "ADMM"):
def internal_init(self, admm: ADMM):
"""Second stage initializer to be called by :meth:`.ADMM.__init__`.
Args:
Expand Down
8 changes: 5 additions & 3 deletions scico/optimize/pgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def __init__(
self.timer: Timer = Timer()
self.fixed_point_residual = snp.inf

def x_step(v, L):
def x_step(v: Union[JaxArray, BlockArray], L: float) -> Union[JaxArray, BlockArray]:
return self.g.prox(v - 1.0 / L * self.f.grad(v), 1.0 / L)

self.x_step = jax.jit(x_step)
Expand Down Expand Up @@ -473,13 +473,15 @@ def x_step(v, L):

self.x: Union[JaxArray, BlockArray] = ensure_on_device(x0) # current estimate of solution

def objective(self, x=None) -> float:
def objective(self, x: Optional[Union[JaxArray, BlockArray]] = None) -> float:
r"""Evaluate the objective function :math:`f(\mb{x}) + g(\mb{x})`."""
if x is None:
x = self.x
return self.f(x) + self.g(x)

def f_quad_approx(self, x, y, L) -> float:
def f_quad_approx(
self, x: Union[JaxArray, BlockArray], y: Union[JaxArray, BlockArray], L: float
) -> float:
r"""Evaluate the quadratic approximation to function :math:`f`.
Evaluate the quadratic approximation to function :math:`f`,
Expand Down

0 comments on commit 912b460

Please sign in to comment.