Skip to content

Commit

Permalink
Use numpy and pmf function to speed up gen
Browse files Browse the repository at this point in the history
Numpy has a built-in way to sum probability mass functions (pmf).
This shaves of 60% of the generation time :D
  • Loading branch information
spinerak committed Jun 13, 2024
1 parent 1789824 commit 9290191
Showing 1 changed file with 56 additions and 33 deletions.
89 changes: 56 additions & 33 deletions worlds/yachtdice/Rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from worlds.generic.Rules import set_rule

from .YachtWeights import yacht_weights
import numpy as np


# This module adds logic to the apworld.
Expand Down Expand Up @@ -69,8 +70,6 @@ def mean_score(self, num_dice, num_rolls):
return mean_score * self.quantity




class ListState:
def __init__(self, state: List[str]):
self.state = state
Expand Down Expand Up @@ -99,19 +98,25 @@ def extract_progression(state, player, options):
)
number_of_fixed_mults = state.count("Fixed Score Multiplier", player)
number_of_step_mults = state.count("Step Score Multiplier", player)

categories = [
Category(category_value, state.count(category_name, player))
for category_name, category_value in category_mappings.items()
if state.count(category_name, player) # want all categories that have count >= 1
]
]

extra_points_in_logic = state.count("1 Point", player)
extra_points_in_logic += state.count("10 Points", player) * 10
extra_points_in_logic += state.count("100 Points", player) * 100

return categories, number_of_dice, number_of_rerolls, number_of_fixed_mults * 0.1, number_of_step_mults * 0.01, extra_points_in_logic,

return (
categories,
number_of_dice,
number_of_rerolls,
number_of_fixed_mults * 0.1,
number_of_step_mults * 0.01,
extra_points_in_logic,
)


# We will store the results of this function as it is called often for the same parameters.
Expand Down Expand Up @@ -140,18 +145,44 @@ def dice_simulation_strings(categories, num_dice, num_rolls, fixed_mult, step_mu
# sort categories because for the step multiplier, you will want low-scoring categories first
categories.sort(key=lambda category: category.mean_score(num_dice, num_rolls))

# function to add two discrete distribution.
# defaultdict is a dict where you don't need to check if an id is present, you can just use += (lot faster)
def add_distributions(dist1, dist2):
combined_dist = defaultdict(float)
for val1, prob1 in dist1.items():
for val2, prob2 in dist2.items():
combined_dist[val1 + val2] += prob1 * prob2
return dict(combined_dist)

# function to take the maximum of "times" i.i.d. dist1.
# (I have tried using defaultdict here too but this made it slower.)
# we have two ways to store a distribution (example, 0 with probability 0.4, 10 with probability 0.6):
# dict: {0: 0.4, 10: 0.6}
# pmf (probability mass function): [0.4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.6] (numpy array)
# adding two distributions works fast with pmf's and numpy's convolve method.
# maximizing two distributions (with multipliers) seems to work fastest with dictionaries.

def dict_to_pmf(dist):
"""
Convert dict-distribution to pmf-distribution
"""
max_value = max(dist) + 1
return np.array([dist.get(i, 0) for i in range(max_value)])

def pmf_to_dict(pmf):
"""
Convert pmf-distribution to dict-distribution
"""
sum_values = np.arange(0, len(pmf) + 1)
sum_dist = {v: p for v, p in zip(sum_values, pmf)}
return sum_dist

def add_distributions(pmf1, dic2):
"""
function to add two discrete distributions. The first in pmf form, the second in dict form, returns pmf.
"""
pmf2 = dict_to_pmf(dic2)

# Sum the two distributions using convolution
sum_pmf = np.convolve(pmf1, pmf2)

return sum_pmf

def max_dist(dist1, mults):
"""
function to take the maximum of "times" i.i.d. dist1.
dist1 is a dict-distribution
(I have tried using defaultdict here too but this made it slower.)
"""
new_dist = {0: 1}
for mult in mults:
c = new_dist.copy()
Expand All @@ -171,25 +202,19 @@ def max_dist(dist1, mults):

# Returns percentile value of a distribution.
def percentile_distribution(dist, percentile):
sorted_values = sorted(dist.keys())
cumulative_prob = 0

for val in sorted_values:
cumulative_prob += dist[val]
if cumulative_prob >= percentile:
return val
cumdist = np.cumsum(dist)

# Return the last value if percentile is higher than all probabilities
return sorted_values[-1]
return np.argmax(cumdist > percentile)

# parameters for logic.
# perc_return is, per difficulty, the percentages of total score it returns (it averages out the values)
# diff_divide determines how many shots the logic gets per category. Lower = more shots.
perc_return = [[0], [0.1, 0.5], [0.3, 0.7], [0.55, 0.85], [0.85, 0.95]][diff]
diff_divide = [0, 9, 7, 3, 2][diff]

# calculate total distribution
total_dist = {0: 1}
# calculate total distribution, start in pmf-form
total_dist = [1]
for j, category in enumerate(categories):
if num_dice == 0 or num_rolls == 0:
dist = {0: 100000}
Expand All @@ -208,12 +233,12 @@ def percentile_distribution(dist, percentile):

total_dist = add_distributions(total_dist, dist)

# save result into the cache, then return it
# note, total_dist is in pmf-form
outcome = sum([percentile_distribution(total_dist, perc) for perc in perc_return]) / len(perc_return)
# save result into the cache, then return it
yachtdice_cache[tup] = max(5, math.floor(outcome)) # at least 5.

return yachtdice_cache[tup]

return yachtdice_cache[tup]


def dice_simulation(state, player, options):
Expand Down Expand Up @@ -243,7 +268,6 @@ def dice_simulation(state, player, options):
return state.prog_items[player]["maximum_achievable_score"]



def set_yacht_rules(world: MultiWorld, player: int, options):
"""
Sets rules on entrances and advancements that are always applied
Expand All @@ -262,4 +286,3 @@ def set_yacht_completion_rules(world: MultiWorld, player: int):
Sets rules on completion condition
"""
world.completion_condition[player] = lambda state: state.has("Victory", player)

0 comments on commit 9290191

Please sign in to comment.