Skip to content

Commit

Permalink
Fix einsum testcase for duplicate modes
Browse files Browse the repository at this point in the history
  • Loading branch information
manopapad authored and ipdemes committed Jan 13, 2022
1 parent 78a2e06 commit 52d95dc
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions tests/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import List, Optional, Set, Tuple

import numpy as np
from test_tools.generators import broadcasts_to, mk_0to1_array, permutes_to
from test_tools.generators import mk_0to1_array, permutes_to

import cunumeric as cn

Expand Down Expand Up @@ -161,8 +161,22 @@ def mk_inputs_that_permute_to(lib, shape):


@lru_cache(maxsize=None)
def mk_inputs_that_broadcast_to(lib, shape):
return [x for x in broadcasts_to(lib, shape)]
def mk_inputs_that_broadcast_to(lib, tgt_shape):
# If an operand contains the same mode multiple times, then we can't set
# just one of them to 1. Consider the operation 'aab->ab': (10,10,11),
# (10,10,1), (1,1,11), (1,1,1) are all acceptable input shapes, but
# (1,10,11) is not.
tgt_sizes = list(sorted(set(tgt_shape)))
res = []
for mask in product([True, False], repeat=len(tgt_sizes)):
if all(mask):
continue
tgt2src_size = {
d: (d if keep else 1) for (keep, d) in zip(mask, tgt_sizes)
}
src_shape = tuple(tgt2src_size[d] for d in tgt_shape)
res.append(mk_0to1_array(lib, src_shape))
return res


@lru_cache(maxsize=None)
Expand Down

0 comments on commit 52d95dc

Please sign in to comment.