Skip to content

Commit

Permalink
Add broadcasting test to coord_transform
Browse files Browse the repository at this point in the history
Signed-off-by: Kamil Tokarski <[email protected]>
  • Loading branch information
stiepan committed Jun 6, 2022
1 parent d118b86 commit e26dd87
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions dali/test/python/test_operator_coord_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_sequences():
num_iters = 4

def points():
return np_rng.uniform(-100, 250, (num_points, 2))
return np.float32(np_rng.uniform(-100, 250, (num_points, 2)))

def rand_range(limit):
return range(rng.randint(1, limit) + 1)
Expand All @@ -209,9 +209,11 @@ def t(sample_desc):
def mt(sample_desc):
return np.append(m(sample_desc), t(sample_desc).reshape(-1, 1), axis=1)

input_seq_data = [[np.array([points() for _ in rand_range(max_num_frames)], dtype=np.float32)
for _ in rand_range(max_batch_size)]
for _ in range(num_iters)]
input_seq_data = [[
np.array([points() for _ in rand_range(max_num_frames)], dtype=np.float32)
for _ in rand_range(max_batch_size)]
for _ in range(num_iters)]

input_cases = [
(fn.coord_transform, {}, [ArgCb("M", m, True)]),
(fn.coord_transform, {}, [ArgCb("T", t, True)]),
Expand All @@ -222,3 +224,15 @@ def mt(sample_desc):
]

yield from sequence_suite_helper(rng, "F", [("F**", input_seq_data)], input_cases, num_iters)

input_mt_data = [[
np.array([mt(None) for _ in rand_range(max_num_frames)], dtype=np.float32)
for _ in rand_range(max_batch_size)]
for _ in range(num_iters)]

input_broadcast_cases = [
(fn.coord_transform, {}, [ArgCb(0, lambda _: points(), False, "cpu")], ["cpu"], "MT"),
(fn.coord_transform, {}, [ArgCb(0, lambda _: points(), False, "gpu")], ["cpu"], "MT"),
]

yield from sequence_suite_helper(rng, "F", [("F**", input_mt_data)], input_broadcast_cases, num_iters)

0 comments on commit e26dd87

Please sign in to comment.