Skip to content

Commit

Permalink
add rand_hyper and rand_multi_hyper methods
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Aug 14, 2024
1 parent 7d25779 commit 3c997bb
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 0 deletions.
4 changes: 4 additions & 0 deletions hail/python/hail/expr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@
rand_gamma,
rand_int32,
rand_int64,
rand_hyper,
rand_multi_hyper,
rand_norm,
rand_norm2d,
rand_pois,
Expand Down Expand Up @@ -395,6 +397,8 @@
'rand_gamma',
'rand_cat',
'rand_dirichlet',
'rand_hyper',
'rand_multi_hyper',
'sqrt',
'corr',
'str',
Expand Down
10 changes: 10 additions & 0 deletions hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3300,6 +3300,16 @@ def rand_dirichlet(a, seed=None) -> ArrayExpression:
return hl.bind(lambda x: x / hl.sum(x), a.map(lambda p: hl.if_else(p == 0.0, 0.0, hl.rand_gamma(p, 1, seed=seed))))


@typecheck(ngood=expr_int32, nbad=expr_int32, nsample=expr_int32, seed=nullable(int))
def rand_hyper(ngood, nbad, nsample, seed=None) -> Int32Expression:
return _seeded_func("rand_hyper", tint32, seed, ngood, nbad, nsample)


@typecheck(colors=expr_array(expr_int32), nsample=expr_int32, seed=nullable(int))
def rand_multi_hyper(colors, nsample, seed=None) -> Int32Expression:
return _seeded_func("rand_multi_hyper", tarray(tint32), seed, colors, nsample)


@typecheck(x=oneof(expr_float64, expr_ndarray(expr_float64)))
@ndarray_broadcasting
def sqrt(x) -> Float64Expression:
Expand Down
2 changes: 2 additions & 0 deletions hail/python/test/hail/expr/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def test_random_function(rand_f):
test_random_function(lambda: hl.rand_gamma(1, 1))
test_random_function(lambda: hl.rand_cat(hl.array([1, 1, 1, 1])))
test_random_function(lambda: hl.rand_dirichlet(hl.array([1, 1, 1, 1])))
test_random_function(lambda: hl.rand_hyper(5, 10, 4))
test_random_function(lambda: hl.rand_multi_hyper([5, 2, 8], 4))

def test_range(self):
def same_as_python(*args):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,74 @@ object RandomSeededFunctions extends RegistryFunctions {
primitive(cb.memoize(rng.invoke[Double, Double, Double]("rgamma", a.value, scale.value)))
}

registerSCode4(
"rand_hyper",
TRNGState,
TInt32,
TInt32,
TInt32,
TInt32,
{
case (_: Type, _: SType, _: SType, _: SType, _: SType) => SInt32
},
) {
case (
_,
cb,
_,
rngState: SRNGStateValue,
nGood: SInt32Value,
nBad: SInt32Value,
nSample: SInt32Value,
_,
) =>
val rng = cb.emb.getThreefryRNG()
rngState.copyIntoEngine(cb, rng)
primitive(cb.memoize(rng.invoke[Double, Double, Double, Double](
"rhyper",
nGood.value.toD,
nBad.value.toD,
nSample.value.toD,
).toI))
}

registerSCode3(
"rand_multi_hyper",
TRNGState,
TArray(TInt32),
TInt32,
TArray(TInt32),
{
case (_: Type, _: SType, _: SType, _: SType) =>
SIndexablePointer(PCanonicalArray(PInt32(required = true)))
},
) {
case (r, cb, _, rngState: SRNGStateValue, colors: SIndexableValue, nSample: SInt32Value, _) =>
val rng = cb.emb.getThreefryRNG()
rngState.copyIntoEngine(cb, rng)
val (push, finish) = PCanonicalArray(PInt32(required = true))
.constructFromFunctions(cb, r.region, colors.loadLength(), deepCopy = false)
cb.if_(
colors.hasMissingValues(cb),
cb._fatal("rand_multi_hyper: colors may not contain missing values"),
)
val remaining = cb.newLocal[Int]("rand_multi_hyper_N", 0)
val toSample = cb.newLocal[Int]("rand_multi_hyper_toSample", nSample.value)
colors.forEachDefined(cb)((cb, _, n) => cb.assign(remaining, remaining + n.asInt.value))
colors.forEachDefined(cb) { (cb, _, n) =>
cb.assign(remaining, remaining - n.asInt.value)
val drawn = cb.memoize(rng.invoke[Double, Double, Double, Double](
"rhyper",
n.asInt.value.toD,
remaining.toD,
toSample.toD,
).toI)
cb.assign(toSample, toSample - drawn)
push(cb, IEmitCode.present(cb, primitive(drawn)))
}
finish(cb)
}

registerSCode2(
"rand_cat",
TRNGState,
Expand Down

0 comments on commit 3c997bb

Please sign in to comment.