diff --git a/Documentation/CHANGELOG.md b/Documentation/CHANGELOG.md index 54029c2d1..30210ce88 100644 --- a/Documentation/CHANGELOG.md +++ b/Documentation/CHANGELOG.md @@ -15,7 +15,8 @@ Release Date: TBD ### Major Changes * Updates the DCEGM tools to address the flaws identified in [issue #1062](https://github.com/econ-ark/HARK/issues/1062). PR: [1100](https://github.com/econ-ark/HARK/pull/1100). -, +* Updates `IndexDstn`, introducing the option to use an existing RNG instead of creating a new one, and creating and storing all the conditional distributions at initialization. [1104](https://github.com/econ-ark/HARK/pull/1104) + ### Minor Changes ### 0.12.0 diff --git a/HARK/distribution.py b/HARK/distribution.py index b9f7b1512..0615171da 100644 --- a/HARK/distribution.py +++ b/HARK/distribution.py @@ -62,26 +62,49 @@ class (such as Bernoulli, LogNormal, etc.) with information conditional = None engine = None - def __init__(self, engine, conditional, seed=0): - # Set up the RNG - super().__init__(seed) + def __init__(self, engine, conditional, RNG = None, seed=0): + + if RNG is None: + # Set up the RNG + super().__init__(seed) + else: + # If an RNG is received, use it in whatever state it is in. + self.RNG = RNG + # The seed will still be set, even if it is not used for the RNG, + # for whenever self.reset() is called. + # Note that self.reset() will stop using the RNG that was passed + # and create a new one. + self.seed = seed self.conditional = conditional self.engine = engine - - def __getitem__(self, y): - # test one item to determine case handling + + + self.dstns = [] + + # Test one item to determine case handling item0 = list(self.conditional.values())[0] - + if type(item0) is list: - cond = {key: val[y] for (key, val) in self.conditional.items()} - return self.engine(seed=self.RNG.randint(0, 2 ** 31 - 1), **cond) + # Create and store all the conditional distributions + for y in range(len(item0)): + cond = {key: val[y] for (key, val) in self.conditional.items()} + self.dstns.append(self.engine(seed=self.RNG.randint(0, 2 ** 31 - 1), **cond)) + + elif type(item0) is float: + + self.dstns = [self.engine(seed=self.RNG.randint(0, 2 ** 31 - 1), **conditional)] + else: raise ( Exception( f"IndexDistribution: Unhandled case for __getitem__ access. y: {y}; conditional: {self.conditional}" ) ) + + def __getitem__(self, y): + + return self.dstns[y] def approx(self, N, **kwds): """ @@ -114,9 +137,7 @@ def approx(self, N, **kwds): if type(item0) is float: # degenerate case. Treat the parameterization as constant. - return self.engine( - seed=self.RNG.randint(0, 2 ** 31 - 1), **self.conditional - ).approx(N, **kwds) + return self.dstns[0].approx(N, **kwds) if type(item0) is list: return TimeVaryingDiscreteDistribution(