Skip to content

Commit

Permalink
fix typos, gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Feb 13, 2024
1 parent bb9b02b commit 6802239
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
17 changes: 11 additions & 6 deletions pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1815,17 +1815,22 @@ def random_projection(shape, dtype):
# This sum() is defined above, it's not the builtin sum.
if sum_outputs:
t_rs = [
shared(random_projection(o.shape, o.dtype), borrow=True) for o in o_fn_out
shared(
value=random_projection(o.shape, o.dtype),
borrow=True,
name=f"random_projection_{i}",
)
for i, o in enumerate(o_fn_out)
]
for i, x in enumerate(t_rs):
x.name = "ranom_projection_{i}"
cost = pytensor.tensor.sum(
[pytensor.tensor.sum(x * y) for x, y in zip(t_rs, o_output)]
)
else:
t_r = shared(random_projection(o_fn_out.shape, o_fn_out.dtype), borrow=True)
t_r.name = "random_projection"

t_r = shared(
value=random_projection(o_fn_out.shape, o_fn_out.dtype),
borrow=True,
name="random_projection",
)
cost = pytensor.tensor.sum(t_r * o_output)

if no_debug_ref:
Expand Down
6 changes: 4 additions & 2 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore

import pytensor.printing
from pytensor import scalar as ps
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply
Expand Down Expand Up @@ -552,7 +553,7 @@ def __init__(self, full_matrices: bool = True, compute_uv: bool = True):
if self.full_matrices:
self.gufunc_signature = "(m,n)->(m,m),(k),(n,n)"
else:
self.gufunc_signature = "(m,n)->(m,k),(k),(k,n)"
self.gufunc_signature = "(m,n)->(o,k),(k),(k,p)"
else:
self.gufunc_signature = "(m,n)->(k)"

Expand Down Expand Up @@ -653,9 +654,10 @@ def h(t):
sign_t = ptb.where(ptm.eq(t, 0), 1, ptm.sign(t))
return ptm.maximum(ptm.abs(t), eps) * sign_t

numer = ptb.ones_like(A) - eye
numer = ptb.ones((k, k)) - eye
denom = h(s[None] - s[:, None]) * h(s[None] + s[:, None])
E = numer / denom
E = pytensor.printing.Print("E")(E)

utgu = U.T @ dU
vtgv = VT @ dV
Expand Down

0 comments on commit 6802239

Please sign in to comment.