diff --git a/python/tvm/topi/random/kernel.py b/python/tvm/topi/random/kernel.py index 576fd9254a79..b21db3778744 100644 --- a/python/tvm/topi/random/kernel.py +++ b/python/tvm/topi/random/kernel.py @@ -17,6 +17,7 @@ """Pseudorandom number kernels.""" import tvm import tvm.topi +import numpy as np from ... import tir from ...tir import ir_builder @@ -135,7 +136,7 @@ def _threefry( assert key_buf.dtype == counter_buf.dtype, "threefry key and counter must be the same dtype" def mix(a, b, rotation): - x = a + b # TODO should be wrapping + x = a + b # wrapping y = x ^ ((b << rotation) | (b >> (iwidth - rotation))) return [x, y] @@ -167,7 +168,7 @@ def key_schedule(s, i): with irb.for_range(0, out_shape, name="l") as l: # pylint: disable=invalid-name for i in range(nrounds // 4): for j in range(nwords): - out_buf[out_offset + l * nwords + j] += key_schedule(i, j) # TODO wrapping + out_buf[out_offset + l * nwords + j] += key_schedule(i, j) # wrapping for k in range(4): for j in range(nwords // 2): ( @@ -201,6 +202,13 @@ def threefry_generate(gen, out_shape): then a new generator is created by applying Threefry to the current key, path, and counter. This new generator will have a reset counter. + Warning + ------- + Threeyfry requires that unsigned integer arithmetic wraps on overflow. Currently TVM has no + guarantee of this, so threefry contains an internal assert to check wrapping behavior. This + assert may or may not run depending on your platform, so it is recommended you run + :py:func:`threefry_test_wrapping` to verify wrapping behavior. + Parameters ---------- gen : Tensor[10, uint64] @@ -234,6 +242,18 @@ def gen_ir(gen_ptr, out_gen_ptr, out_array_ptr): out_gen = irb.buffer_ptr(out_gen_ptr) out_array = irb.buffer_ptr(out_array_ptr) + # Check that unsigned arithmetic wraps, as it is required to implement threefry correctly. + irb.emit( + tvm.tir.AssertStmt( + tvm.tir.const(0xFFFFFFFFFFFFFFFF, "uint64") + tvm.tir.const(1, "uint64") + == tvm.tir.const(0, "uint64"), + tvm.tir.StringImm( + "Unsigned integer arithmetic is not wrapping, but threefry requires wrapping." + ), + tvm.tir.Evaluate(0), + ) + ) + # Create a temporary array to hold the generator state we will use to create the random # numbers. We cannot use gen because we may need to update the key + path if there is not # enough room in the counter. @@ -408,3 +428,41 @@ def gen_ir(gen_ptr, out_left_ptr, out_right_ptr): name="threefry_split", tag="threefry_split", ) + + +def threefry_test_wrapping(target, ctx): + """Test that unsigned arithmetic wraps on overflow. + + Parameters + ---------- + target : tvm.target.Target + Target to run against + ctx : tvm.runtime.TVMContext + Context to run the test on + + Returns + ------- + is_wrapping : bool + Whether or not unsigned integer arithmetic is wrapping for this target, context pair. True + indicates that threefry will work on this platform. + """ + if isinstance(target, str): + target = tvm.target.Target(target) + + def gen_ir(out_ptr): + irb = ir_builder.create() + out = irb.buffer_ptr(out_ptr) + if "gpu" in target.keys: + thread_x = tvm.te.thread_axis("threadIdx.x") + irb.scope_attr(thread_x, "thread_extent", 1) + out[0] = tvm.tir.const(0xFFFFFFFFFFFFFFFF, "uint64") + tvm.tir.const(1, "uint64") + return irb.get() + + out = tvm.tir.decl_buffer((1,), dtype="uint64") + f = tvm.te.extern( + [out.shape], [], lambda ins, outs: gen_ir(outs[0]), dtype="uint64", out_buffers=[out] + ) + s = tvm.te.create_schedule([f.op]) + out_ary = tvm.nd.array(np.ones((1,), "uint64"), ctx) + tvm.build(s, [f], target=target)(out_ary) + return out_ary.asnumpy()[0] == 0 diff --git a/tests/python/topi/python/test_topi_prng.py b/tests/python/topi/python/test_topi_prng.py index 43b0494ee6f5..649e5410c147 100644 --- a/tests/python/topi/python/test_topi_prng.py +++ b/tests/python/topi/python/test_topi_prng.py @@ -111,6 +111,14 @@ def test_threefry_generate(target, ctx): ).any(), "Overflowing counter with no space left in path should change state" +@tvm.testing.parametrize_targets +def test_threefry_wrapping(target, ctx): + assert tvm.topi.random.threefry_test_wrapping( + target, ctx + ), f"{target} does not suppport wrapping unsigned integer arithmetic" + + if __name__ == "__main__": test_threefry_split(tvm.target.Target("llvm"), tvm.context("cpu")) test_threefry_generate(tvm.target.Target("llvm"), tvm.context("cpu")) + test_threefry_wrapping(tvm.target.Target("llvm"), tvm.context("cpu"))