Skip to content

Commit

Permalink
use linear weighted selection instead of softmax for choosing fad des…
Browse files Browse the repository at this point in the history
…tinations

this includes a bit of refactoring of the various `oneOf` methods
in MasonUtils, which are now using precondition
  • Loading branch information
nicolaspayette committed Oct 3, 2019
1 parent b0021be commit 42768e9
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ Bag getFadsHere() {
return fadMap.fadsAt(fisher.getLocation());
}

Optional<Fad> oneOfFadsHere() { return oneOf(getFadsHere(), fisher.grabRandomizer()); }
Optional<Fad> oneOfFadsHere() {
final Object o = oneOf(getFadsHere(), fisher.grabRandomizer());
return o instanceof Fad ? Optional.of((Fad) o) : Optional.empty();
}

public int getNumFadsInStock() { return numFadsInStock; }

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package uk.ac.ox.oxfish.fisher.equipment.fads;

import org.apache.commons.collections15.set.ListOrderedSet;
import uk.ac.ox.oxfish.fisher.Fisher;
import uk.ac.ox.oxfish.fisher.equipment.gear.PurseSeineGear;
import uk.ac.ox.oxfish.geography.SeaTile;
Expand All @@ -25,7 +26,10 @@ static Optional<Fad> oneOfFadsHere(Fisher fisher) {
}

static Optional<Fad> oneOfDeployedFads(Fisher fisher) {
return oneOf(getFadManager(fisher).getDeployedFads(), fisher.grabRandomizer());
final ListOrderedSet<Fad> deployedFads = getFadManager(fisher).getDeployedFads();
return deployedFads.isEmpty() ?
Optional.empty() :
Optional.of(oneOf(deployedFads, fisher.grabRandomizer()));
}

static Stream<Fad> fadsHere(Fisher fisher) { return bagToStream(getFadManager(fisher).getFadsHere()); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;
import static uk.ac.ox.oxfish.utility.bandit.SoftmaxBanditAlgorithm.drawFromSoftmax;
import static uk.ac.ox.oxfish.utility.MasonUtils.weightedOneOf;

abstract class IntermediateDestinationsStrategy {

Expand All @@ -26,6 +26,8 @@ abstract class IntermediateDestinationsStrategy {
@NotNull
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
private Optional<Deque<SeaTile>> currentRoute = Optional.empty();
// TODO: this should be a parameter somewhere
private double holdFillProportionConsideredFull = 0.99;

IntermediateDestinationsStrategy(NauticalMap map) {
this.map = map;
Expand Down Expand Up @@ -53,8 +55,6 @@ private void goToPort(Fisher fisher) {
currentRoute = getRoute(fisher, fisher.getHomePort().getLocation());
}

// TODO: this should be a parameter somewhere
private double holdFillProportionConsideredFull = 0.99;
private boolean holdFull(Fisher fisher) {
return fisher.getHold().getPercentageFilled() >= holdFillProportionConsideredFull;
}
Expand Down Expand Up @@ -92,14 +92,9 @@ private void chooseNewRoute(Fisher fisher, MersenneTwisterFast random) {
if (possibleRoutes.isEmpty())
currentRoute = Optional.empty();
else {
Function<Integer, Double> destinationValue = i ->
possibleRoutes.get(i)
.stream()
.mapToDouble(seaTileValues::get)
.sum();
currentRoute = Optional.of(
possibleRoutes.get(drawFromSoftmax(random, possibleRoutes.size(), destinationValue))
);
Function<Deque<SeaTile>, Double> destinationValue =
route -> route.stream().mapToDouble(seaTileValues::get).sum();
currentRoute = Optional.of(weightedOneOf(possibleRoutes, destinationValue, random));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import java.util.List;
import java.util.function.Function;
import java.util.stream.Stream;

import com.google.common.base.Preconditions;
import uk.ac.ox.oxfish.fisher.Fisher;
import uk.ac.ox.oxfish.fisher.actions.fads.DeployFad;
import uk.ac.ox.oxfish.fisher.actions.fads.FadAction;
Expand Down Expand Up @@ -49,11 +51,9 @@ public RandomPlanFadDestinationStrategy(NauticalMap map, int numberOfStepsToPlan

@Override
void makeNewPlan(Fisher fisher) {

Preconditions.checkState(!possibleActions.isEmpty(), "No possible action!");
actionQueue.addAll(Stream
.generate(() -> oneOf(possibleActions, fisher.grabRandomizer())
.orElseThrow(() -> new RuntimeException("No possible action!"))
)
.generate(() -> oneOf(possibleActions, fisher.grabRandomizer()))
.filter(pair -> pair.getSecond().apply(fisher))
.map(pair -> pair.getFirst().apply(fisher))
.limit(numberOfStepsToPlan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
import uk.ac.ox.oxfish.model.FishState;
import uk.ac.ox.oxfish.model.regs.Regulation;

import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Stream;

import static com.google.common.collect.Streams.stream;
import static java.util.stream.Collectors.toList;
import static uk.ac.ox.oxfish.fisher.equipment.fads.FadManagerUtils.oneOfFadsHere;
import static uk.ac.ox.oxfish.utility.MasonUtils.oneOf;

Expand Down Expand Up @@ -50,9 +52,10 @@ public boolean shouldFish(
public ActionResult act(
FishState model, Fisher fisher, Regulation regulation, double hoursLeft
) {
return oneOf(possibleActions(model, fisher), model.random)
.map(action -> new ActionResult(action, hoursLeft))
.orElse(new ActionResult(new Arriving(), 0));
final List<FadAction> possibleActions = possibleActions(model, fisher).collect(toList());
return possibleActions.isEmpty() ?
new ActionResult(new Arriving(), 0) :
new ActionResult(oneOf(possibleActions, model.random), hoursLeft);
}

@Override
Expand Down
20 changes: 14 additions & 6 deletions src/main/java/uk/ac/ox/oxfish/model/scenario/TunaScenario.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package uk.ac.ox.oxfish.model.scenario;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.vividsolutions.jts.geom.Coordinate;
import org.apache.commons.lang3.tuple.Triple;
Expand All @@ -8,7 +9,11 @@
import uk.ac.ox.oxfish.biology.initializer.BiologyInitializer;
import uk.ac.ox.oxfish.biology.initializer.MultipleIndependentSpeciesBiomassInitializer;
import uk.ac.ox.oxfish.biology.initializer.SingleSpeciesBiomassInitializer;
import uk.ac.ox.oxfish.biology.initializer.allocator.*;
import uk.ac.ox.oxfish.biology.initializer.allocator.ConstantAllocatorFactory;
import uk.ac.ox.oxfish.biology.initializer.allocator.CoordinateFileAllocatorFactory;
import uk.ac.ox.oxfish.biology.initializer.allocator.FileBiomassAllocatorFactory;
import uk.ac.ox.oxfish.biology.initializer.allocator.PolygonAllocatorFactory;
import uk.ac.ox.oxfish.biology.initializer.allocator.SmootherFileAllocatorFactory;
import uk.ac.ox.oxfish.biology.initializer.factory.SingleSpeciesBiomassNormalizedFactory;
import uk.ac.ox.oxfish.biology.weather.initializer.WeatherInitializer;
import uk.ac.ox.oxfish.biology.weather.initializer.factory.ConstantWeatherFactory;
Expand Down Expand Up @@ -56,11 +61,15 @@
import java.util.stream.Stream;

import static java.util.function.Function.identity;
import static java.util.stream.Collectors.*;
import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;
import static si.uom.NonSI.KNOT;
import static si.uom.NonSI.TONNE;
import static tech.units.indriya.quantity.Quantities.getQuantity;
import static tech.units.indriya.unit.Units.*;
import static tech.units.indriya.unit.Units.CUBIC_METRE;
import static tech.units.indriya.unit.Units.KILOGRAM;
import static tech.units.indriya.unit.Units.KILOMETRE_PER_HOUR;
import static uk.ac.ox.oxfish.utility.MasonUtils.oneOf;
import static uk.ac.ox.oxfish.utility.Measures.asDouble;
import static uk.ac.ox.oxfish.utility.csv.CsvParserUtil.parseAllRecords;
Expand Down Expand Up @@ -199,6 +208,7 @@ public ScenarioPopulation populateModel(FishState model) {
model.registerStartable(fadMap);

final LinkedList<Port> ports = model.getMap().getPorts();
Preconditions.checkState(!ports.isEmpty());

FisherFactory fisherFactory = fisherDefinition.getFisherFactory(model, ports, 0);
fisherFactory.getAdditionalSetups().add(fisher ->
Expand Down Expand Up @@ -245,9 +255,7 @@ record -> {
// TODO: we don't have boat entry in the tuna model for now, but when we do, this shouldn't be entirely random
fisherFactory.setBoatSupplier(fisherDefinition.makeBoatSupplier(model.random));
fisherFactory.setHoldSupplier(fisherDefinition.makeHoldSupplier(model.random, model.getBiology()));
fisherFactory.setPortSupplier(() ->
oneOf(ports, model.random).orElseThrow(() -> new RuntimeException("No ports!"))
);
fisherFactory.setPortSupplier(() -> oneOf(ports, model.random));

final Map<String, FisherFactory> fisherFactories =
ImmutableMap.of(FishState.DEFAULT_POPULATION_NAME, fisherFactory);
Expand Down
143 changes: 143 additions & 0 deletions src/main/java/uk/ac/ox/oxfish/utility/AliasMethod.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package uk.ac.ox.oxfish.utility;

/******************************************************************************
* File: AliasMethod.java
* Author: Keith Schwarz ([email protected])
*
* An implementation of the alias method implemented using Vose's algorithm.
* The alias method allows for efficient sampling of random values from a
* discrete probability distribution (i.e. rolling a loaded die) in O(1) time
* each after O(n) preprocessing time.
*
* For a complete writeup on the alias method, including the intuition and
* important proofs, please see the article "Darts, Dice, and Coins: Smpling
* from a Discrete Distribution" at
*
* http://www.keithschwarz.com/darts-dice-coins/
/******************************************************************************
* Modified to use ec.util.MersenneTwisterFast instead of java.util.Random
* -- NP 2019-09-18.
*/

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import ec.util.MersenneTwisterFast;

public final class AliasMethod {
/* The random number generator used to sample from the distribution. */
private final MersenneTwisterFast random;

/* The probability and alias tables. */
private final int[] alias;
private final double[] probability;

/**
* Constructs a new AliasMethod to sample from a discrete distribution and
* hand back outcomes based on the probability distribution.
* <p>
* Given as input a list of probabilities corresponding to outcomes 0, 1,
* ..., n - 1, along with the random number generator that should be used
* as the underlying generator, this constructor creates the probability
* and alias tables needed to efficiently sample from this distribution.
*
* @param probabilities The list of probabilities.
* @param random The random number generator
*/
public AliasMethod(List<Double> probabilities, MersenneTwisterFast random) {
/* Begin by doing basic structural checks on the inputs. */
if (probabilities == null || random == null)
throw new NullPointerException();
if (probabilities.size() == 0)
throw new IllegalArgumentException("Probability vector must be nonempty.");

/* Allocate space for the probability and alias tables. */
probability = new double[probabilities.size()];
alias = new int[probabilities.size()];

/* Store the underlying generator. */
this.random = random;

/* Compute the average probability and cache it for later use. */
final double average = 1.0 / probabilities.size();

/* Make a copy of the probabilities list, since we will be making
* changes to it.
*/
probabilities = new ArrayList<Double>(probabilities);

/* Create two stacks to act as worklists as we populate the tables. */
Deque<Integer> small = new ArrayDeque<Integer>();
Deque<Integer> large = new ArrayDeque<Integer>();

/* Populate the stacks with the input probabilities. */
for (int i = 0; i < probabilities.size(); ++i) {
/* If the probability is below the average probability, then we add
* it to the small list; otherwise we add it to the large list.
*/
if (probabilities.get(i) >= average)
large.add(i);
else
small.add(i);
}

/* As a note: in the mathematical specification of the algorithm, we
* will always exhaust the small list before the big list. However,
* due to floating point inaccuracies, this is not necessarily true.
* Consequently, this inner loop (which tries to pair small and large
* elements) will have to check that both lists aren't empty.
*/
while (!small.isEmpty() && !large.isEmpty()) {
/* Get the index of the small and the large probabilities. */
int less = small.removeLast();
int more = large.removeLast();

/* These probabilities have not yet been scaled up to be such that
* 1/n is given weight 1.0. We do this here instead.
*/
probability[less] = probabilities.get(less) * probabilities.size();
alias[less] = more;

/* Decrease the probability of the larger one by the appropriate
* amount.
*/
probabilities.set(more,
(probabilities.get(more) + probabilities.get(less)) - average);

/* If the new probability is less than the average, add it into the
* small list; otherwise add it to the large list.
*/
if (probabilities.get(more) >= 1.0 / probabilities.size())
large.add(more);
else
small.add(more);
}

/* At this point, everything is in one list, which means that the
* remaining probabilities should all be 1/n. Based on this, set them
* appropriately. Due to numerical issues, we can't be sure which
* stack will hold the entries, so we empty both.
*/
while (!small.isEmpty())
probability[small.removeLast()] = 1.0;
while (!large.isEmpty())
probability[large.removeLast()] = 1.0;
}

/**
* Samples a value from the underlying distribution.
*
* @return A random value sampled from the underlying distribution.
*/
public int next() {
/* Generate a fair die roll to determine which column to inspect. */
int column = random.nextInt(probability.length);

/* Generate a biased coin toss to determine which option to pick. */
boolean coinToss = random.nextDouble() < probability[column];

/* Based on the outcome, return either the column or its alias. */
return coinToss ? column : alias[column];
}
}
Loading

0 comments on commit 42768e9

Please sign in to comment.