Skip to content

Commit

Permalink
add test for chack_params inside scan can be replaced by ninf
Browse files Browse the repository at this point in the history
  • Loading branch information
aerubanov committed Nov 7, 2023
1 parent 097eae3 commit 8e5d012
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def test_check_parameters_can_be_replaced_by_ninf(self):

def test_check_parameters_removed_from_scan(self):
def scan_step(x_0):
cond = pt.ge(x_0, 0)
cond = pt.ge(x_0, 1)
x = check_parameters(x_0, cond)
x_update = collect_default_updates([x])
return x, x_update
Expand Down Expand Up @@ -375,6 +375,23 @@ def scan_step(x_0):
fn = compile_pymc([], xs)
assert np.all(fn() == -np.inf)

def test_check_parameters_can_be_replaced_by_ninf_from_scan(self):
def scan_step(x_0):
cond = pt.ge(x_0, 0)
x = check_parameters(x_0, cond, can_be_replaced_by_ninf=True)
x_update = collect_default_updates([x])
return x, x_update

xs, _ = scan(
fn=scan_step,
sequences=[
pt.as_tensor_variable([-1, 0, 1]),
],
name="xs",
)
fn = compile_pymc([], xs)
np.testing.assert_array_equal(fn(), [-np.inf, 0, 1])

def test_compile_pymc_sets_rng_updates(self):
rng = pytensor.shared(np.random.default_rng(0))
x = pm.Normal.dist(rng=rng)
Expand Down

0 comments on commit 8e5d012

Please sign in to comment.