Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stardew Valley: Improve generation performance by around 11% by moving calculating from rule evaluation to collect #4231

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
40 changes: 34 additions & 6 deletions worlds/stardew_valley/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,15 +441,43 @@ def fill_slot_data(self) -> Dict[str, Any]:

def collect(self, state: CollectionState, item: StardewItem) -> bool:
change = super().collect(state, item)
if change:
state.prog_items[self.player][Event.received_walnuts] += self.get_walnut_amount(item.name)
return change
if not change:
return False

player_state = state.prog_items[self.player]

received_progression_count = player_state[Event.received_progression_item]
received_progression_count += 1
Jouramie marked this conversation as resolved.
Show resolved Hide resolved
if self.total_progression_items:
# We can't update the percentage if we don't know the total progression items, can't divide by 0.
player_state[Event.received_progression_percent] = received_progression_count * 100 // self.total_progression_items
player_state[Event.received_progression_item] = received_progression_count

walnut_amount = self.get_walnut_amount(item.name)
if walnut_amount:
player_state[Event.received_walnuts] += walnut_amount

return True

def remove(self, state: CollectionState, item: StardewItem) -> bool:
change = super().remove(state, item)
if change:
state.prog_items[self.player][Event.received_walnuts] -= self.get_walnut_amount(item.name)
return change
if not change:
return False

player_state = state.prog_items[self.player]

received_progression_count = player_state[Event.received_progression_item]
received_progression_count -= 1
if self.total_progression_items:
# We can't update the percentage if we don't know the total progression items, can't divide by 0.
player_state[Event.received_progression_percent] = received_progression_count * 100 // self.total_progression_items
player_state[Event.received_progression_item] = received_progression_count

walnut_amount = self.get_walnut_amount(item.name)
if walnut_amount:
player_state[Event.received_walnuts] -= walnut_amount

return True

@staticmethod
def get_walnut_amount(item_name: str) -> int:
Expand Down
42 changes: 7 additions & 35 deletions worlds/stardew_valley/stardew_rule/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from BaseClasses import CollectionState
from .base import BaseStardewRule, CombinableStardewRule
from .protocol import StardewRule
from ..strings.ap_names.event_names import Event


class TotalReceived(BaseStardewRule):
Expand Down Expand Up @@ -84,42 +85,13 @@ def __repr__(self):
return f"Reach {self.resolution_hint} {self.spot}"


@dataclass(frozen=True)
class HasProgressionPercent(CombinableStardewRule):
player: int
percent: int
class HasProgressionPercent(Received):
def __init__(self, player: int, percent: int):
super().__init__(Event.received_progression_percent, player, percent, event=True)

def __post_init__(self):
assert self.percent > 0, "HasProgressionPercent rule must be above 0%"
assert self.percent <= 100, "HasProgressionPercent rule can't require more than 100% of items"

@property
def combination_key(self) -> Hashable:
return HasProgressionPercent.__name__

@property
def value(self):
return self.percent

def __call__(self, state: CollectionState) -> bool:
stardew_world = state.multiworld.worlds[self.player]
total_count = stardew_world.total_progression_items
needed_count = (total_count * self.percent) // 100
player_state = state.prog_items[self.player]

if needed_count <= len(player_state):
return True

total_count = 0
for item, item_count in player_state.items():
total_count += item_count
if total_count >= needed_count:
return True

return False

def evaluate_while_simplifying(self, state: CollectionState) -> Tuple[StardewRule, bool]:
return self, self(state)
assert self.count > 0, "HasProgressionPercent rule must be above 0%"
assert self.count <= 100, "HasProgressionPercent rule can't require more than 100% of items"

def __repr__(self):
return f"Received {self.percent}% progression items"
return f"Received {self.count}% progression items"
2 changes: 2 additions & 0 deletions worlds/stardew_valley/strings/ap_names/event_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ class Event:
winter_farming = event("Winter Farming")

received_walnuts = event("Received Walnuts")
received_progression_item = event("Received Progression Item")
received_progression_percent = event("Received Progression Percent")
11 changes: 7 additions & 4 deletions worlds/stardew_valley/test/rules/TestShipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,17 @@ class TestShipsanityEverything(SVTestBase):
def test_all_shipsanity_locations_require_shipping_bin(self):
bin_name = "Shipping Bin"
self.collect_all_except(bin_name)
shipsanity_locations = [location for location in self.get_real_locations() if
LocationTags.SHIPSANITY in location_table[location.name].tags]
shipsanity_locations = [location
for location in self.get_real_locations()
if LocationTags.SHIPSANITY in location_table[location.name].tags]
bin_item = self.create_item(bin_name)

for location in shipsanity_locations:
with self.subTest(location.name):
self.remove(bin_item)
self.assertFalse(self.world.logic.region.can_reach_location(location.name)(self.multiworld.state))
self.multiworld.state.collect(bin_item, prevent_sweep=False)

self.collect(bin_item)
shipsanity_rule = self.world.logic.region.can_reach_location(location.name)
self.assert_rule_true(shipsanity_rule, self.multiworld.state)

self.remove(bin_item)
Loading