-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Support Bernoulli op on ONNX front-end #13802
Changes from 13 commits
317d750
4fd035a
0b099b0
d326ce6
09b79cc
f8e7b0d
8962ab7
dd99e6f
5a4724f
fa105a6
e99d6c9
863cecb
53d66ba
6314fa3
163e101
c19fc49
00295c9
e91364b
3c85a2e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6707,6 +6707,117 @@ def verify_qlinearsigmoid(a_shape): | |
verify_qlinearsigmoid([]) | ||
|
||
|
||
@tvm.testing.parametrize_targets("llvm") | ||
def test_random_bernoulli(target, dev): | ||
"""test_random_bernoulli""" | ||
|
||
def verify_bernoulli( | ||
inputs=None, | ||
shape=[], | ||
in_dtype="float32", | ||
out_dtype="int32", | ||
seed=None, | ||
target=target, | ||
dev=dev, | ||
use_vm=False, | ||
freeze_params=False, | ||
rtol=0.1, | ||
atol=0.1, | ||
in_out_equal=False, | ||
): | ||
def get_bernoulli_model(shape, in_dtype="float32", out_dtype="int32", seed=None): | ||
onnx_itype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(in_dtype)] | ||
onnx_otype = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(out_dtype)] | ||
node = helper.make_node( | ||
"Bernoulli", | ||
["input"], | ||
["output"], | ||
) | ||
dtype_attr = helper.make_attribute("dtype", onnx_otype) | ||
node.attribute.append(dtype_attr) | ||
if seed is not None: | ||
seed_attr = helper.make_attribute("seed", float(seed)) | ||
node.attribute.append(seed_attr) | ||
|
||
graph = helper.make_graph( | ||
[node], | ||
"random_bernoulli_test", | ||
inputs=[helper.make_tensor_value_info("input", onnx_itype, list(shape))], | ||
outputs=[helper.make_tensor_value_info("output", onnx_otype, list(shape))], | ||
) | ||
return helper.make_model(graph, producer_name="random_bernoulli_test") | ||
|
||
if inputs is None: | ||
assert len(shape) != 0 | ||
inputs = np.random.uniform(size=shape).astype(in_dtype) | ||
else: | ||
shape = inputs.shape | ||
in_dtype = inputs.dtype | ||
model = get_bernoulli_model(shape, in_dtype, out_dtype, seed) | ||
|
||
if use_vm: | ||
tvm_out = get_tvm_output_with_vm( | ||
model, | ||
inputs, | ||
target, | ||
dev, | ||
freeze_params=freeze_params, | ||
) | ||
else: | ||
tvm_out = get_tvm_output( | ||
model, | ||
inputs, | ||
target, | ||
dev, | ||
) | ||
|
||
if isinstance(tvm_out, list): | ||
tvm_out = tvm_out[0] | ||
ideal_mean = np.mean(inputs) | ||
# check that values are 0 or 1 | ||
tvm_flat = tvm_out.flatten() | ||
for i in range(len(tvm_flat)): | ||
vvchernov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert tvm_flat[i] == 0 or tvm_flat[i] == 1 | ||
if in_out_equal: | ||
tvm.testing.assert_allclose(inputs, tvm_out) | ||
else: | ||
# check that mean value is close to the theoretical one by binomial test | ||
bnm_test_res = scipy.stats.binomtest( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Strictly speaking, this test is only appropriate when the input probabilities are all identical. I think it's guaranteed to be over-conservative in cases where the input probabilities are not identical, like when you call I would recommend just documenting this with a comment in the code. Something like:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, god. I didn't read that reference link closely enough. It contains a really blatantly misogynistic comment. Apologies to anyone I exposed to that. Please don't include that reference link in the code. (I think the math is right, though.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hello @octoJon! I've modified the test due to it was already "over-conservative" with p-value threshold = 1e-6. I've increased threshold to 0.05 as more classical approach. If test condition failed there are two cases: something wrong in the operation or we have gotten "bad" output sequence on the tail of distribution. Due to the last is rare case and should be rechecked I repeat the test again (and third time if need) with new seed for internal distribution (input is the same). |
||
k=np.sum(tvm_flat, dtype="int32"), n=len(tvm_flat), p=ideal_mean | ||
) | ||
assert bnm_test_res.pvalue >= 1e-6 | ||
|
||
# Test input sequence of 0 and 1 | ||
inputs = np.random.randint(2, size=[10000]).astype("float32") | ||
verify_bernoulli(inputs, in_out_equal=True) | ||
|
||
# Binomial test input with 0.5 values | ||
val_num = 10000 | ||
arr = [0.5] * val_num | ||
vvchernov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
inputs = np.array(arr).astype("float32") | ||
verify_bernoulli(inputs) | ||
|
||
# Binomial test input with 0.1 values | ||
arr = [0.1] * val_num | ||
inputs = np.array(arr).astype("float32") | ||
verify_bernoulli(inputs) | ||
|
||
# Simple test | ||
verify_bernoulli(shape=[1000]) | ||
|
||
# Floating output type | ||
verify_bernoulli(shape=[1000], out_dtype="float32") | ||
|
||
# Double input type | ||
verify_bernoulli(shape=[1000], in_dtype="float64") | ||
|
||
# Test N-D tensor generation | ||
verify_bernoulli(shape=[2, 4, 100, 100]) | ||
|
||
# Test with seed | ||
verify_bernoulli(shape=[1000], seed=np.random.randint(1e6)) | ||
vvchernov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@tvm.testing.parametrize_targets("llvm") | ||
def test_random_uniform(target, dev): | ||
"""test_random_uniform""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ONNX has support for
float16
. However, this type is not supported here. Maybe it's worth pointing out the reason (TODO
) why this data type is not currently supported at the ONNX front-end level?