Skip to content

Commit

Permalink
Fixed failing meta tests after update to container and gradient funct…
Browse files Browse the repository at this point in the history
…ions (#6037)
  • Loading branch information
vedpatwardhan authored Oct 22, 2022
1 parent fc8da59 commit 5d63d86
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 40 deletions.
4 changes: 2 additions & 2 deletions ivy/container/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ def multi_map(
map_nests,
assert_identical,
)
if ret:
if ret is not None:
return_dict[key] = ret
elif any(isinstance(x, (list, tuple)) for x in values) and map_nests:
ret = ivy.nested_multi_map(
Expand Down Expand Up @@ -1749,7 +1749,7 @@ def unstack_conts(self, axis, keepdims=False, dim_size=None):
"""
if dim_size is None:
dim_size = self.shape[axis]
dim_size = self.shared_shape[axis]
if keepdims:
# noinspection PyTypeChecker
return [
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/ivy/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _remove_zeros_and_nones(grads, x, idx=[]):
idx.pop()

keys = [k for k in x]
if len(keys) == 0:
if len(keys) == 0 and len(idx) and _check_if_empty(idx):
ivy.prune_nest_at_index(grads, idx)
return grads

Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/ivy/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def fomaml_step(
"""
if num_tasks is None:
num_tasks = batch.shape[0]
num_tasks = batch.shared_shape[0]
rets = _train_tasks(
batch,
inner_batch_fn,
Expand Down Expand Up @@ -570,7 +570,7 @@ def reptile_step(
"""
if num_tasks is None:
num_tasks = batch.shape[0]
num_tasks = batch.shared_shape[0]
# noinspection PyTypeChecker
rets = _train_tasks(
batch,
Expand Down Expand Up @@ -694,7 +694,7 @@ def maml_step(
"""
if num_tasks is None:
num_tasks = batch.shape[0]
num_tasks = batch.shared_shape[0]
unique_outer = outer_v is not None
func_ret, grads = ivy.execute_with_gradients(
lambda v: _train_tasks(
Expand Down
61 changes: 27 additions & 34 deletions ivy_tests/test_ivy/test_functional/test_core/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_fomaml_step_unique_vars(
# inner cost function
def inner_cost_fn(batch_in, v):
cost = 0
batch_size = batch_in.shape[0]
batch_size = batch_in.shared_shape[0]
for sub_batch_in, sub_v in zip(
batch_in.unstack_conts(0, keepdims=True), v.unstack_conts(0, keepdims=True)
):
Expand All @@ -83,7 +83,7 @@ def inner_cost_fn(batch_in, v):
# outer cost function
def outer_cost_fn(batch_in, v):
cost = 0
batch_size = batch_in.shape[0]
batch_size = batch_in.shared_shape[0]
for sub_batch_in, sub_v in zip(
batch_in.unstack_conts(0, keepdims=True), v.unstack_conts(0, keepdims=True)
):
Expand Down Expand Up @@ -154,15 +154,14 @@ def outer_cost_fn(batch_in, v):
assert ivy.equal(ivy.is_variable(calc_cost, exclusive=True), False)
assert np.allclose(ivy.to_scalar(calc_cost), true_cost)
outer_grads = rets[1]
assert ivy.equal(ivy.is_variable(outer_grads), False)
assert np.allclose(ivy.to_numpy(outer_grads.weight[0]), np.array(true_weight_grad))
if return_inner_v:
inner_v_rets = rets[2]
assert isinstance(inner_v_rets, ivy.Container)
if return_inner_v == "all":
assert list(inner_v_rets.shape) == [num_tasks, 1]
assert list(inner_v_rets.shared_shape) == [num_tasks, 1]
elif return_inner_v == "first":
assert list(inner_v_rets.shape) == [1, 1]
assert list(inner_v_rets.shared_shape) == [1, 1]


# fomaml step shared vars
Expand Down Expand Up @@ -214,7 +213,7 @@ def test_fomaml_step_shared_vars(
# inner cost function
def inner_cost_fn(batch_in, v):
cost = 0
batch_size = batch_in.shape[0]
batch_size = batch_in.shared_shape[0]
for sub_batch_in, sub_v in zip(
batch_in.unstack_conts(0, keepdims=True), v.unstack_conts(0, keepdims=True)
):
Expand All @@ -224,7 +223,7 @@ def inner_cost_fn(batch_in, v):
# outer cost function
def outer_cost_fn(batch_in, v):
cost = 0
batch_size = batch_in.shape[0]
batch_size = batch_in.shared_shape[0]
for sub_batch_in, sub_v in zip(
batch_in.unstack_conts(0, keepdims=True), v.unstack_conts(0, keepdims=True)
):
Expand Down Expand Up @@ -320,15 +319,14 @@ def loss_grad_fn(sub_batch_in, w_in, outer=False):
assert ivy.equal(ivy.is_variable(calc_cost, exclusive=True), False)
assert np.allclose(ivy.to_scalar(calc_cost), true_cost)
outer_grads = rets[1]
assert ivy.equal(ivy.is_variable(outer_grads), False)
assert np.allclose(ivy.to_numpy(outer_grads.latent[0]), np.array(true_outer_grad))
if return_inner_v:
inner_v_rets = rets[2]
assert isinstance(inner_v_rets, ivy.Container)
if return_inner_v == "all":
assert list(inner_v_rets.shape) == [num_tasks, 1]
assert list(inner_v_rets.shared_shape) == [num_tasks, 1]
elif return_inner_v == "first":
assert list(inner_v_rets.shape) == [1, 1]
assert list(inner_v_rets.shared_shape) == [1, 1]


# fomaml step overlapping vars
Expand Down Expand Up @@ -386,7 +384,7 @@ def test_fomaml_step_overlapping_vars(
# inner cost function
def inner_cost_fn(batch_in, v):
cost = 0
batch_size = batch_in.shape[0]
batch_size = batch_in.shared_shape[0]
for sub_batch_in, sub_v in zip(
batch_in.unstack_conts(0, keepdims=True), v.unstack_conts(0, keepdims=True)
):
Expand All @@ -396,7 +394,7 @@ def inner_cost_fn(batch_in, v):
# outer cost function
def outer_cost_fn(batch_in, v):
cost = 0
batch_size = batch_in.shape[0]
batch_size = batch_in.shared_shape[0]
for sub_batch_in, sub_v in zip(
batch_in.unstack_conts(0, keepdims=True), v.unstack_conts(0, keepdims=True)
):
Expand Down Expand Up @@ -471,16 +469,15 @@ def outer_cost_fn(batch_in, v):
assert ivy.equal(ivy.is_variable(calc_cost, exclusive=True), False)
assert np.allclose(ivy.to_scalar(calc_cost), true_cost)
outer_grads = rets[1]
assert ivy.equal(ivy.is_variable(outer_grads), False)
assert np.allclose(ivy.to_numpy(outer_grads.weight[0]), np.array(true_weight_grad))
assert np.allclose(ivy.to_numpy(outer_grads.latent[0]), np.array(true_latent_grad))
if return_inner_v:
inner_v_rets = rets[2]
assert isinstance(inner_v_rets, ivy.Container)
if return_inner_v == "all":
assert list(inner_v_rets.shape) == [num_tasks, 1]
assert list(inner_v_rets.shared_shape) == [num_tasks, 1]
elif return_inner_v == "first":
assert list(inner_v_rets.shape) == [1, 1]
assert list(inner_v_rets.shared_shape) == [1, 1]


# reptile step
Expand Down Expand Up @@ -520,7 +517,7 @@ def test_reptile_step(
# inner cost function
def inner_cost_fn(batch_in, v):
cost = 0
batch_size = batch_in.shape[0]
batch_size = batch_in.shared_shape[0]
for sub_batch_in, sub_v in zip(
batch_in.unstack_conts(0, keepdims=True), v.unstack_conts(0, keepdims=True)
):
Expand Down Expand Up @@ -579,15 +576,14 @@ def loss_grad_fn(sub_batch_in, w_in):
assert ivy.equal(ivy.is_variable(calc_cost, exclusive=True), False)
assert np.allclose(ivy.to_scalar(calc_cost), true_cost)
outer_grads = rets[1]
assert ivy.equal(ivy.is_variable(outer_grads), False)
assert np.allclose(ivy.to_numpy(outer_grads.latent[0]), np.array(true_outer_grad))
if return_inner_v:
inner_v_rets = rets[2]
assert isinstance(inner_v_rets, ivy.Container)
if return_inner_v == "all":
assert list(inner_v_rets.shape) == [num_tasks, 1]
assert list(inner_v_rets.shared_shape) == [num_tasks, 1]
elif return_inner_v == "first":
assert list(inner_v_rets.shape) == [1, 1]
assert list(inner_v_rets.shared_shape) == [1, 1]


# Second Order #
Expand Down Expand Up @@ -649,7 +645,7 @@ def test_maml_step_unique_vars(
# inner cost function
def inner_cost_fn(batch_in, v):
cost = 0
batch_size = batch_in.shape[0]
batch_size = batch_in.shared_shape[0]
for sub_batch_in, sub_v in zip(
batch_in.unstack_conts(0, keepdims=True), v.unstack_conts(0, keepdims=True)
):
Expand All @@ -659,7 +655,7 @@ def inner_cost_fn(batch_in, v):
# outer cost function
def outer_cost_fn(batch_in, v):
cost = 0
batch_size = batch_in.shape[0]
batch_size = batch_in.shared_shape[0]
for sub_batch_in, sub_v in zip(
batch_in.unstack_conts(0, keepdims=True), v.unstack_conts(0, keepdims=True)
):
Expand Down Expand Up @@ -728,15 +724,14 @@ def outer_cost_fn(batch_in, v):
assert ivy.equal(ivy.is_variable(calc_cost, exclusive=True), False)
assert np.allclose(ivy.to_scalar(calc_cost), true_cost)
outer_grads = rets[1]
assert ivy.equal(ivy.is_variable(outer_grads), False)
assert np.allclose(ivy.to_numpy(outer_grads.weight), np.array(true_outer_grad))
if return_inner_v:
inner_v_rets = rets[2]
assert isinstance(inner_v_rets, ivy.Container)
if return_inner_v == "all":
assert list(inner_v_rets.shape) == [num_tasks, 1]
assert list(inner_v_rets.shared_shape) == [num_tasks, 1]
elif return_inner_v == "first":
assert list(inner_v_rets.shape) == [1, 1]
assert list(inner_v_rets.shared_shape) == [1, 1]


# maml step shared vars
Expand Down Expand Up @@ -789,7 +784,7 @@ def test_maml_step_shared_vars(
# inner cost function
def inner_cost_fn(batch_in, v):
cost = 0
batch_size = batch_in.shape[0]
batch_size = batch_in.shared_shape[0]
for sub_batch_in, sub_v in zip(
batch_in.unstack_conts(0, keepdims=True), v.unstack_conts(0, keepdims=True)
):
Expand All @@ -799,7 +794,7 @@ def inner_cost_fn(batch_in, v):
# outer cost function
def outer_cost_fn(batch_in, v):
cost = 0
batch_size = batch_in.shape[0]
batch_size = batch_in.shared_shape[0]
for sub_batch_in, sub_v in zip(
batch_in.unstack_conts(0, keepdims=True), v.unstack_conts(0, keepdims=True)
):
Expand Down Expand Up @@ -939,17 +934,16 @@ def update_grad_fn(w_init, sub_batch_in, num_steps, average=False):
assert ivy.equal(ivy.is_variable(calc_cost, exclusive=True), False)
assert np.allclose(ivy.to_scalar(calc_cost), true_cost)
outer_grads = rets[1]
assert ivy.equal(ivy.is_variable(outer_grads), False)
assert np.allclose(
ivy.to_numpy(outer_grads.latent), ivy.to_numpy(true_outer_grad[0])
)
if return_inner_v:
inner_v_rets = rets[2]
assert isinstance(inner_v_rets, ivy.Container)
if return_inner_v == "all":
assert list(inner_v_rets.shape) == [num_tasks, 1]
assert list(inner_v_rets.shared_shape) == [num_tasks, 1]
elif return_inner_v == "first":
assert list(inner_v_rets.shape) == [1, 1]
assert list(inner_v_rets.shared_shape) == [1, 1]


# maml step overlapping vars
Expand Down Expand Up @@ -1007,7 +1001,7 @@ def test_maml_step_overlapping_vars(
# inner cost function
def inner_cost_fn(batch_in, v):
cost = 0
batch_size = batch_in.shape[0]
batch_size = batch_in.shared_shape[0]
for sub_batch_in, sub_v in zip(
batch_in.unstack_conts(0, keepdims=True), v.unstack_conts(0, keepdims=True)
):
Expand All @@ -1017,7 +1011,7 @@ def inner_cost_fn(batch_in, v):
# outer cost function
def outer_cost_fn(batch_in, v):
cost = 0
batch_size = batch_in.shape[0]
batch_size = batch_in.shared_shape[0]
for sub_batch_in, sub_v in zip(
batch_in.unstack_conts(0, keepdims=True), v.unstack_conts(0, keepdims=True)
):
Expand Down Expand Up @@ -1092,16 +1086,15 @@ def outer_cost_fn(batch_in, v):
assert ivy.equal(ivy.is_variable(calc_cost, exclusive=True), False)
assert np.allclose(ivy.to_scalar(calc_cost), true_cost)
outer_grads = rets[1]
assert ivy.equal(ivy.is_variable(outer_grads), False)
assert np.allclose(ivy.to_numpy(outer_grads.weight), np.array(true_weight_grad))
assert np.allclose(ivy.to_numpy(outer_grads.latent), np.array(true_latent_grad))
if return_inner_v:
inner_v_rets = rets[2]
assert isinstance(inner_v_rets, ivy.Container)
if return_inner_v == "all":
assert list(inner_v_rets.shape) == [num_tasks, 1]
assert list(inner_v_rets.shared_shape) == [num_tasks, 1]
elif return_inner_v == "first":
assert list(inner_v_rets.shape) == [1, 1]
assert list(inner_v_rets.shared_shape) == [1, 1]


# Still to Add #
Expand Down

0 comments on commit 5d63d86

Please sign in to comment.