diff --git a/Java/benchmark/pom.xml b/Java/benchmark/pom.xml
index 2760ae02..a5fdc4ca 100644
--- a/Java/benchmark/pom.xml
+++ b/Java/benchmark/pom.xml
@@ -6,7 +6,7 @@
software.amazon.randomcutforest
randomcutforest-parent
- 3.5.1
+ 3.6.0
randomcutforest-benchmark
diff --git a/Java/benchmark/src/main/java/com/amazon/randomcutforest/RandomCutForestBenchmark.java b/Java/benchmark/src/main/java/com/amazon/randomcutforest/RandomCutForestBenchmark.java
index ca959781..4ee272a0 100644
--- a/Java/benchmark/src/main/java/com/amazon/randomcutforest/RandomCutForestBenchmark.java
+++ b/Java/benchmark/src/main/java/com/amazon/randomcutforest/RandomCutForestBenchmark.java
@@ -18,7 +18,6 @@
import java.util.List;
import java.util.Random;
-import org.github.jamm.MemoryMeter;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
@@ -131,10 +130,6 @@ public RandomCutForest scoreAndUpdate(BenchmarkState state, Blackhole blackhole)
}
blackhole.consume(score);
- if (!forest.parallelExecutionEnabled) {
- MemoryMeter meter = new MemoryMeter();
- System.out.println(" forest size " + meter.measureDeep(forest));
- }
return forest;
}
diff --git a/Java/benchmark/src/main/java/com/amazon/randomcutforest/RandomCutForestShingledBenchmark.java b/Java/benchmark/src/main/java/com/amazon/randomcutforest/RandomCutForestShingledBenchmark.java
index 76faf4b0..cafec955 100644
--- a/Java/benchmark/src/main/java/com/amazon/randomcutforest/RandomCutForestShingledBenchmark.java
+++ b/Java/benchmark/src/main/java/com/amazon/randomcutforest/RandomCutForestShingledBenchmark.java
@@ -18,7 +18,6 @@
import java.util.List;
import java.util.Random;
-import org.github.jamm.MemoryMeter;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
@@ -130,10 +129,6 @@ public RandomCutForest scoreAndUpdate(BenchmarkState state, Blackhole blackhole)
}
blackhole.consume(score);
- if (!forest.parallelExecutionEnabled) {
- MemoryMeter meter = new MemoryMeter();
- System.out.println(" forest size " + meter.measureDeep(forest));
- }
return forest;
}
diff --git a/Java/core/pom.xml b/Java/core/pom.xml
index d801923d..39e321a7 100644
--- a/Java/core/pom.xml
+++ b/Java/core/pom.xml
@@ -6,7 +6,7 @@
software.amazon.randomcutforest
randomcutforest-parent
- 3.5.1
+ 3.6.0
randomcutforest-core
diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/CommonUtils.java b/Java/core/src/main/java/com/amazon/randomcutforest/CommonUtils.java
index f475d181..6cd6f188 100644
--- a/Java/core/src/main/java/com/amazon/randomcutforest/CommonUtils.java
+++ b/Java/core/src/main/java/com/amazon/randomcutforest/CommonUtils.java
@@ -16,6 +16,7 @@
package com.amazon.randomcutforest;
import java.util.Objects;
+import java.util.function.Supplier;
import com.amazon.randomcutforest.tree.IBoundingBoxView;
@@ -38,11 +39,19 @@ private CommonUtils() {
* @throws IllegalArgumentException if {@code condition} is false.
*/
public static void checkArgument(boolean condition, String message) {
+
if (!condition) {
throw new IllegalArgumentException(message);
}
}
+ // a lazy equivalent of the above, which avoids parameter evaluation
+ public static void checkArgument(boolean condition, Supplier messageSupplier) {
+ if (!condition) {
+ throw new IllegalArgumentException(messageSupplier.get());
+ }
+ }
+
/**
* Throws an {@link IllegalStateException} with the specified message if the
* specified input is false.
diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/state/RandomCutForestMapper.java b/Java/core/src/main/java/com/amazon/randomcutforest/state/RandomCutForestMapper.java
index a28e4a0b..9b36e65a 100644
--- a/Java/core/src/main/java/com/amazon/randomcutforest/state/RandomCutForestMapper.java
+++ b/Java/core/src/main/java/com/amazon/randomcutforest/state/RandomCutForestMapper.java
@@ -336,8 +336,9 @@ public RandomCutForest singlePrecisionForest(RandomCutForest.Builder> builder,
tree = extTrees.get(i);
} else if (treeStates != null) {
tree = treeMapper.toModel(treeStates.get(i), context, random.nextLong());
- sampler.getSample().forEach(s -> tree.addPoint(s.getValue(), s.getSequenceIndex()));
+ sampler.getSample().forEach(s -> tree.addPointToPartialTree(s.getValue(), s.getSequenceIndex()));
tree.setConfig(Config.BOUNDING_BOX_CACHE_FRACTION, treeStates.get(i).getBoundingBoxCacheFraction());
+ tree.validateAndReconstruct();
} else {
// using boundingBoxCahce for the new tree
tree = new RandomCutTree.Builder().capacity(state.getSampleSize()).randomSeed(random.nextLong())
diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/AbstractNodeStoreMapper.java b/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/AbstractNodeStoreMapper.java
index e2b4caf6..7d77d26b 100644
--- a/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/AbstractNodeStoreMapper.java
+++ b/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/AbstractNodeStoreMapper.java
@@ -59,9 +59,8 @@ public AbstractNodeStore toModel(NodeStoreState state, CompactRandomCutTreeConte
}
// note boundingBoxCache is not set deliberately
return AbstractNodeStore.builder().capacity(capacity).useRoot(root).leftIndex(leftIndex).rightIndex(rightIndex)
- .cutDimension(cutDimension).cutValues(cutValue)
- .dimensions(compactRandomCutTreeContext.getPointStore().getDimensions())
- .pointStoreView(compactRandomCutTreeContext.getPointStore()).build();
+ .cutDimension(cutDimension).cutValues(cutValue).dimension(compactRandomCutTreeContext.getDimension())
+ .build();
}
@Override
diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/CompactRandomCutTreeContext.java b/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/CompactRandomCutTreeContext.java
index 1306475b..a17ff9aa 100644
--- a/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/CompactRandomCutTreeContext.java
+++ b/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/CompactRandomCutTreeContext.java
@@ -23,6 +23,7 @@
@Data
public class CompactRandomCutTreeContext {
private int maxSize;
+ private int dimension;
private IPointStore> pointStore;
private Precision precision;
}
diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/RandomCutTreeMapper.java b/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/RandomCutTreeMapper.java
index d79ad247..49caf14b 100644
--- a/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/RandomCutTreeMapper.java
+++ b/Java/core/src/main/java/com/amazon/randomcutforest/state/tree/RandomCutTreeMapper.java
@@ -33,11 +33,12 @@ public class RandomCutTreeMapper
@Override
public RandomCutTree toModel(CompactRandomCutTreeState state, CompactRandomCutTreeContext context, long seed) {
+ int dimension = (state.getDimensions() != 0) ? state.getDimensions() : context.getPointStore().getDimensions();
+ context.setDimension(dimension);
AbstractNodeStoreMapper nodeStoreMapper = new AbstractNodeStoreMapper();
nodeStoreMapper.setRoot(state.getRoot());
AbstractNodeStore nodeStore = nodeStoreMapper.toModel(state.getNodeStoreState(), context);
- int dimension = (state.getDimensions() != 0) ? state.getDimensions() : context.getPointStore().getDimensions();
// boundingBoxcache is not set deliberately;
// it should be set after the partial tree is complete
// likewise all the leaves, including the root, should be set to
diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/tree/AbstractNodeStore.java b/Java/core/src/main/java/com/amazon/randomcutforest/tree/AbstractNodeStore.java
index 04ddffb1..c87217ba 100644
--- a/Java/core/src/main/java/com/amazon/randomcutforest/tree/AbstractNodeStore.java
+++ b/Java/core/src/main/java/com/amazon/randomcutforest/tree/AbstractNodeStore.java
@@ -17,14 +17,8 @@
import static com.amazon.randomcutforest.CommonUtils.checkArgument;
-import java.util.Arrays;
-import java.util.HashMap;
import java.util.Stack;
-import java.util.function.Function;
-import com.amazon.randomcutforest.MultiVisitor;
-import com.amazon.randomcutforest.Visitor;
-import com.amazon.randomcutforest.store.IPointStoreView;
import com.amazon.randomcutforest.store.IndexIntervalManager;
/**
@@ -44,8 +38,6 @@
*/
public abstract class AbstractNodeStore {
- public static double SWITCH_FRACTION = 0.499;
-
public static int Null = -1;
public static boolean DEFAULT_STORE_PARENT = false;
@@ -57,74 +49,19 @@ public abstract class AbstractNodeStore {
* number_of_leaves + X
*/
protected final int capacity;
- protected final int dimensions;
protected final float[] cutValue;
- protected double boundingboxCacheFraction;
protected IndexIntervalManager freeNodeManager;
- protected double[] rangeSumData;
- protected float[] boundingBoxData;
- protected final IPointStoreView pointStoreView;
- protected final HashMap leafMass;
- protected boolean centerOfMassEnabled;
- protected boolean storeSequenceIndexesEnabled;
- protected float[] pointSum;
- protected HashMap> sequenceMap;
public AbstractNodeStore(AbstractNodeStore.Builder> builder) {
this.capacity = builder.capacity;
- this.dimensions = builder.dimensions;
if ((builder.leftIndex == null)) {
freeNodeManager = new IndexIntervalManager(capacity);
}
- this.boundingboxCacheFraction = builder.boundingBoxCacheFraction;
cutValue = (builder.cutValues != null) ? builder.cutValues : new float[capacity];
- leafMass = new HashMap<>();
- int cache_limit = (int) Math.floor(boundingboxCacheFraction * capacity);
- rangeSumData = new double[cache_limit];
- boundingBoxData = new float[2 * dimensions * cache_limit];
- this.pointStoreView = builder.pointStoreView;
- this.centerOfMassEnabled = builder.centerOfMassEnabled;
- this.storeSequenceIndexesEnabled = builder.storeSequencesEnabled;
- if (this.centerOfMassEnabled) {
- pointSum = new float[(capacity) * dimensions];
- }
- if (this.storeSequenceIndexesEnabled) {
- sequenceMap = new HashMap<>();
- }
}
protected abstract int addNode(Stack pathToRoot, float[] point, long sendex, int pointIndex, int childIndex,
- int cutDimension, float cutValue, BoundingBox box);
-
- protected int addLeaf(int pointIndex, long sequenceIndex) {
- if (storeSequenceIndexesEnabled) {
- HashMap leafMap = sequenceMap.remove(pointIndex);
- if (leafMap == null) {
- leafMap = new HashMap<>();
- }
- Integer count = leafMap.remove(sequenceIndex);
- if (count != null) {
- leafMap.put(sequenceIndex, count + 1);
- } else {
- leafMap.put(sequenceIndex, 1);
- }
- sequenceMap.put(pointIndex, leafMap);
- }
- return pointIndex + capacity + 1;
- }
-
- public void removeLeaf(int leafPointIndex, long sequenceIndex) {
- HashMap leafMap = sequenceMap.remove(leafPointIndex);
- checkArgument(leafMap != null, " leaf index not found in tree");
- Integer count = leafMap.remove(sequenceIndex);
- checkArgument(count != null, " sequence index not found in leaf");
- if (count > 1) {
- leafMap.put(sequenceIndex, count - 1);
- sequenceMap.put(leafPointIndex, leafMap);
- } else if (leafMap.size() > 0) {
- sequenceMap.put(leafPointIndex, leafMap);
- }
- }
+ int childMassIfLeaf, int cutDimension, float cutValue, BoundingBox box);
public boolean isLeaf(int index) {
return index > capacity;
@@ -134,7 +71,7 @@ public boolean isInternal(int index) {
return index < capacity && index >= 0;
}
- public abstract void addToPartialTree(Stack pathToRoot, float[] point, int pointIndex);
+ public abstract void assignInPartialTree(int savedParent, float[] point, int childReference);
public abstract int getLeftIndex(int index);
@@ -142,251 +79,14 @@ public boolean isInternal(int index) {
public abstract void setRoot(int index);
- public float[] getPointSum(int index) {
- checkArgument(centerOfMassEnabled, " enable center of mass");
- return (isLeaf(index)) ? pointStoreView.getScaledPoint(getPointIndex(index), getMass(index))
- : Arrays.copyOfRange(pointSum, index * dimensions, (index + 1) * dimensions);
- }
-
- public void invalidatePointSum(int index) {
- for (int i = 0; i < dimensions; i++) {
- pointSum[index * dimensions + i] = 0;
- }
- }
-
- public void recomputePointSum(int index) {
- float[] left = getPointSum(getLeftIndex(index));
- float[] right = getPointSum(getRightIndex(index));
- for (int i = 0; i < dimensions; i++) {
- pointSum[index * dimensions + i] = left[i] + right[i];
- }
- }
-
- public void increaseLeafMass(int index) {
- int y = (index - capacity - 1);
- leafMass.merge(y, 1, Integer::sum);
- }
-
- public int decreaseLeafMass(int index) {
- int y = (index - capacity - 1);
- Integer value = leafMass.remove(y);
- if (value != null) {
- if (value > 1) {
- leafMass.put(y, (value - 1));
- return value;
- } else {
- return 1;
- }
- } else {
- return 0;
- }
- }
-
- public void resizeCache(double fraction) {
- if (fraction == 0) {
- rangeSumData = null;
- boundingBoxData = null;
- } else {
- int limit = (int) Math.floor(fraction * capacity);
- rangeSumData = (rangeSumData == null) ? new double[limit] : Arrays.copyOf(rangeSumData, limit);
- boundingBoxData = (boundingBoxData == null) ? new float[limit * 2 * dimensions]
- : Arrays.copyOf(boundingBoxData, limit * 2 * dimensions);
- }
- boundingboxCacheFraction = fraction;
- }
-
- public int translate(int index) {
- if (rangeSumData.length <= index) {
- return Integer.MAX_VALUE;
- } else {
- return index;
- }
- }
-
- void copyBoxToData(int idx, BoundingBox box) {
- int base = 2 * idx * dimensions;
- int mid = base + dimensions;
- System.arraycopy(box.getMinValues(), 0, boundingBoxData, base, dimensions);
- System.arraycopy(box.getMaxValues(), 0, boundingBoxData, mid, dimensions);
- rangeSumData[idx] = box.getRangeSum();
- }
-
- public boolean checkContainsAndAddPoint(int index, float[] point) {
- int idx = translate(index);
- if (idx != Integer.MAX_VALUE && rangeSumData[idx] != 0) {
- int base = 2 * idx * dimensions;
- int mid = base + dimensions;
- double rangeSum = 0;
- for (int i = 0; i < dimensions; i++) {
- boundingBoxData[base + i] = Math.min(boundingBoxData[base + i], point[i]);
- }
- for (int i = 0; i < dimensions; i++) {
- boundingBoxData[mid + i] = Math.max(boundingBoxData[mid + i], point[i]);
- }
- for (int i = 0; i < dimensions; i++) {
- rangeSum += boundingBoxData[mid + i] - boundingBoxData[base + i];
- }
- boolean answer = (rangeSumData[idx] == rangeSum);
- rangeSumData[idx] = rangeSum;
- return answer;
- }
- return false;
- }
-
- public BoundingBox getBox(int index) {
- if (isLeaf(index)) {
- float[] point = pointStoreView.get(getPointIndex(index));
- return new BoundingBox(point, point);
- } else {
- checkArgument(isInternal(index), " incomplete state");
- int idx = translate(index);
- if (idx != Integer.MAX_VALUE) {
- if (rangeSumData[idx] != 0) {
- // return non-trivial boxes
- return getBoxFromData(idx);
- } else {
- BoundingBox box = reconstructBox(index, pointStoreView);
- copyBoxToData(idx, box);
- return box;
- }
- }
- return reconstructBox(index, pointStoreView);
- }
- }
-
- public BoundingBox reconstructBox(int index, IPointStoreView pointStoreView) {
- BoundingBox mutatedBoundingBox = getBox(getLeftIndex(index));
- growNodeBox(mutatedBoundingBox, pointStoreView, index, getRightIndex(index));
- return mutatedBoundingBox;
- }
-
- boolean checkStrictlyContains(int index, float[] point) {
- int idx = translate(index);
- if (idx != Integer.MAX_VALUE) {
- int base = 2 * idx * dimensions;
- int mid = base + dimensions;
- boolean isInside = true;
- for (int i = 0; i < dimensions && isInside; i++) {
- if (point[i] >= boundingBoxData[mid + i] || boundingBoxData[base + i] >= point[i]) {
- isInside = false;
- }
- }
- return isInside;
- }
- return false;
- }
-
- public boolean checkContainsAndRebuildBox(int index, float[] point, IPointStoreView pointStoreView) {
- int idx = translate(index);
- if (idx != Integer.MAX_VALUE && rangeSumData[idx] != 0) {
- if (!checkStrictlyContains(index, point)) {
- BoundingBox mutatedBoundingBox = reconstructBox(index, pointStoreView);
- copyBoxToData(idx, mutatedBoundingBox);
- return false;
- }
- return true;
- }
- return false;
- }
-
- public BoundingBox getBoxFromData(int idx) {
- int base = 2 * idx * dimensions;
- int mid = base + dimensions;
-
- return new BoundingBox(Arrays.copyOfRange(boundingBoxData, base, base + dimensions),
- Arrays.copyOfRange(boundingBoxData, mid, mid + dimensions));
- }
-
- protected void addBox(int index, float[] point, BoundingBox box) {
- if (isInternal(index)) {
- int idx = translate(index);
- if (idx != Integer.MAX_VALUE) { // always add irrespective of rangesum
- copyBoxToData(idx, box);
- checkContainsAndAddPoint(index, point);
- }
- }
- }
-
- public void growNodeBox(BoundingBox box, IPointStoreView pointStoreView, int node, int sibling) {
- if (isLeaf(sibling)) {
- float[] point = pointStoreView.get(getPointIndex(sibling));
- box.addPoint(point);
- } else {
- checkArgument(isInternal(sibling), " incomplete state " + sibling);
- int siblingIdx = translate(sibling);
- if (siblingIdx != Integer.MAX_VALUE) {
- if (rangeSumData[siblingIdx] != 0) {
- box.addBox(getBoxFromData(siblingIdx));
- } else {
- BoundingBox newBox = getBox(siblingIdx);
- copyBoxToData(siblingIdx, newBox);
- box.addBox(newBox);
- }
- return;
- }
- growNodeBox(box, pointStoreView, sibling, getLeftIndex(sibling));
- growNodeBox(box, pointStoreView, sibling, getRightIndex(sibling));
- return;
- }
- }
-
- public double probabilityOfCut(int node, float[] point, IPointStoreView pointStoreView,
- BoundingBox otherBox) {
- int nodeIdx = translate(node);
- if (nodeIdx != Integer.MAX_VALUE && rangeSumData[nodeIdx] != 0) {
- int base = 2 * nodeIdx * dimensions;
- int mid = base + dimensions;
- double minsum = 0;
- double maxsum = 0;
- for (int i = 0; i < dimensions; i++) {
- minsum += Math.max(boundingBoxData[base + i] - point[i], 0);
- }
- for (int i = 0; i < dimensions; i++) {
- maxsum += Math.max(point[i] - boundingBoxData[mid + i], 0);
- }
- double sum = maxsum + minsum;
-
- if (sum == 0.0) {
- return 0.0;
- }
- return sum / (rangeSumData[nodeIdx] + sum);
- } else if (otherBox != null) {
- return otherBox.probabilityOfCut(point);
- } else {
- BoundingBox box = getBox(node);
- return box.probabilityOfCut(point);
- }
- }
-
protected abstract void decreaseMassOfInternalNode(int node);
protected abstract void increaseMassOfInternalNode(int node);
- protected void manageAncestorsAdd(Stack path, float[] point, IPointStoreView pointStoreview) {
+ protected void manageInternalNodesPartial(Stack path) {
while (!path.isEmpty()) {
int index = path.pop()[0];
increaseMassOfInternalNode(index);
- if (pointSum != null) {
- recomputePointSum(index);
- }
- if (boundingboxCacheFraction > 0.0) {
- checkContainsAndRebuildBox(index, point, pointStoreview);
- checkContainsAndAddPoint(index, point);
- }
- }
- }
-
- protected void manageAncestorsDelete(Stack path, float[] point, IPointStoreView pointStoreview) {
- boolean resolved = false;
- while (!path.isEmpty()) {
- int index = path.pop()[0];
- decreaseMassOfInternalNode(index);
- if (pointSum != null) {
- recomputePointSum(index);
- }
- if (boundingboxCacheFraction > 0.0 && !resolved) {
- resolved = checkContainsAndRebuildBox(index, point, pointStoreview);
- }
}
}
@@ -410,22 +110,8 @@ public Stack getPath(int root, float[] point, boolean verbose) {
public abstract void deleteInternalNode(int index);
- public int getLeafMass(int index) {
- int y = (index - capacity - 1);
- Integer value = leafMass.get(y);
- if (value != null) {
- return value + 1;
- } else {
- return 1;
- }
- }
-
public abstract int getMass(int index);
- public int getPointIndex(int index) {
- return index - capacity - 1;
- }
-
protected boolean leftOf(float cutValue, int cutDimension, float[] point) {
return point[cutDimension] <= cutValue;
}
@@ -447,88 +133,16 @@ public int getSibling(int node, int parent) {
public abstract void replaceParentBySibling(int grandParent, int parent, int node);
- public HashMap> getSequenceMap() {
- return sequenceMap;
- }
-
public abstract int getCutDimension(int index);
public double getCutValue(int index) {
return cutValue[index];
}
- public double getBoundingboxCacheFraction() {
- return boundingboxCacheFraction;
- }
-
- protected void traversePathToLeafAndVisitNodes(float[] point, Visitor visitor, int root,
- IPointStoreView pointStoreView, Function projectToTree) {
- NodeView currentNodeView = new NodeView(this, pointStoreView, root);
- traversePathToLeafAndVisitNodes(point, visitor, currentNodeView, root, 0);
- }
-
protected boolean toLeft(float[] point, int currentNodeOffset) {
return point[getCutDimension(currentNodeOffset)] <= cutValue[currentNodeOffset];
}
- BoundingBox getLeftBox(int index) {
- return getBox(getLeftIndex(index));
- }
-
- BoundingBox getRightBox(int index) {
- return getBox(getRightIndex(index));
- }
-
- protected void traversePathToLeafAndVisitNodes(float[] point, Visitor visitor, NodeView currentNodeView,
- int node, int depthOfNode) {
- if (isLeaf(node)) {
- currentNodeView.setCurrentNode(node, getPointIndex(node), true);
- visitor.acceptLeaf(currentNodeView, depthOfNode);
- } else {
- checkArgument(isInternal(node), " incomplete state " + node + " " + depthOfNode);
- if (toLeft(point, node)) {
- traversePathToLeafAndVisitNodes(point, visitor, currentNodeView, getLeftIndex(node), depthOfNode + 1);
- currentNodeView.updateToParent(node, getRightIndex(node), !visitor.isConverged());
- } else {
- traversePathToLeafAndVisitNodes(point, visitor, currentNodeView, getRightIndex(node), depthOfNode + 1);
- currentNodeView.updateToParent(node, getLeftIndex(node), !visitor.isConverged());
- }
- visitor.accept(currentNodeView, depthOfNode);
- }
- }
-
- protected void traverseTreeMulti(float[] point, MultiVisitor visitor, int root,
- IPointStoreView pointStoreView, Function liftToTree) {
- NodeView currentNodeView = new NodeView(this, pointStoreView, root);
- traverseTreeMulti(point, visitor, currentNodeView, root, 0);
- }
-
- protected void traverseTreeMulti(float[] point, MultiVisitor visitor, NodeView currentNodeView, int node,
- int depthOfNode) {
- if (isLeaf(node)) {
- currentNodeView.setCurrentNode(node, getPointIndex(node), false);
- visitor.acceptLeaf(currentNodeView, depthOfNode);
- } else {
- checkArgument(isInternal(node), " incomplete state");
- currentNodeView.setCurrentNodeOnly(node);
- if (visitor.trigger(currentNodeView)) {
- traverseTreeMulti(point, visitor, currentNodeView, getLeftIndex(node), depthOfNode + 1);
- MultiVisitor newVisitor = visitor.newCopy();
- currentNodeView.setCurrentNodeOnly(getRightIndex(node));
- traverseTreeMulti(point, newVisitor, currentNodeView, getRightIndex(node), depthOfNode + 1);
- currentNodeView.updateToParent(node, getLeftIndex(node), false);
- visitor.combine(newVisitor);
- } else if (toLeft(point, node)) {
- traverseTreeMulti(point, visitor, currentNodeView, getLeftIndex(node), depthOfNode + 1);
- currentNodeView.updateToParent(node, getRightIndex(node), false);
- } else {
- traverseTreeMulti(point, visitor, currentNodeView, getRightIndex(node), depthOfNode + 1);
- currentNodeView.updateToParent(node, getLeftIndex(node), false);
- }
- visitor.accept(currentNodeView, depthOfNode);
- }
- }
-
public abstract int[] getCutDimension();
public abstract int[] getRightIndex();
@@ -552,24 +166,14 @@ public int size() {
*/
public static class Builder> {
- protected int dimensions;
protected int capacity;
protected int[] leftIndex;
protected int[] rightIndex;
protected int[] cutDimension;
protected float[] cutValues;
- protected int root;
- protected double boundingBoxCacheFraction;
- protected boolean centerOfMassEnabled;
- protected boolean storeSequencesEnabled;
protected boolean storeParent = DEFAULT_STORE_PARENT;
- protected IPointStoreView pointStoreView;
-
- // dimension of the points being stored
- public T dimensions(int dimensions) {
- this.dimensions = dimensions;
- return (T) this;
- }
+ protected int dimension;
+ protected int root;
// maximum number of points in the store
public T capacity(int capacity) {
@@ -577,6 +181,11 @@ public T capacity(int capacity) {
return (T) this;
}
+ public T dimension(int dimension) {
+ this.dimension = dimension;
+ return (T) this;
+ }
+
public T useRoot(int root) {
this.root = root;
return (T) this;
@@ -602,33 +211,12 @@ public T cutValues(float[] cutValues) {
return (T) this;
}
- public T pointStoreView(IPointStoreView pointStoreView) {
- this.pointStoreView = pointStoreView;
- return (T) this;
- }
-
- public T boundingBoxCacheFraction(double boundingBoxCacheFraction) {
- this.boundingBoxCacheFraction = boundingBoxCacheFraction;
- return (T) this;
- }
-
- public T centerOfMassEnabled(boolean centerOfMassEnabled) {
- this.centerOfMassEnabled = centerOfMassEnabled;
- return (T) this;
- }
-
public T storeParent(boolean storeParent) {
this.storeParent = storeParent;
return (T) this;
}
- public T storeSequencesEnabled(boolean storeSequencesEnabled) {
- this.storeSequencesEnabled = storeSequencesEnabled;
- return (T) this;
- }
-
public AbstractNodeStore build() {
- checkArgument(pointStoreView != null, " a point store view is required ");
if (leftIndex == null) {
checkArgument(rightIndex == null, " incorrect option of right indices");
checkArgument(cutValues == null, "incorrect option of cut values");
@@ -640,9 +228,9 @@ public AbstractNodeStore build() {
}
// capacity is numbner of internal nodes
- if (capacity < 256 && pointStoreView.getDimensions() <= 256) {
+ if (capacity < 256 && dimension <= 256) {
return new NodeStoreSmall(this);
- } else if (capacity < Character.MAX_VALUE && pointStoreView.getDimensions() <= Character.MAX_VALUE) {
+ } else if (capacity < Character.MAX_VALUE && dimension <= Character.MAX_VALUE) {
return new NodeStoreMedium(this);
} else {
return new NodeStoreLarge(this);
diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/tree/HyperTree.java b/Java/core/src/main/java/com/amazon/randomcutforest/tree/HyperTree.java
index 14b7e277..97266385 100644
--- a/Java/core/src/main/java/com/amazon/randomcutforest/tree/HyperTree.java
+++ b/Java/core/src/main/java/com/amazon/randomcutforest/tree/HyperTree.java
@@ -51,12 +51,11 @@ public void makeTree(List list, int seed) {
int[] cutDimension = new int[numberOfLeaves - 1];
float[] cutValue = new float[numberOfLeaves - 1];
root = makeTreeInt(list, seed, 0, this.gVecBuild, leftIndex, rightIndex, cutDimension, cutValue);
- nodeStore = AbstractNodeStore.builder().storeSequencesEnabled(false).pointStoreView(pointStoreView)
- .dimensions(dimension).capacity(numberOfLeaves - 1).leftIndex(leftIndex).rightIndex(rightIndex)
- .cutDimension(cutDimension).cutValues(cutValue).build();
+ nodeStore = AbstractNodeStore.builder().dimension(dimension).capacity(numberOfLeaves - 1)
+ .leftIndex(leftIndex).rightIndex(rightIndex).cutDimension(cutDimension).cutValues(cutValue).build();
// the cuts are specififed; now build tree
for (int i = 0; i < list.size(); i++) {
- addPoint(list.get(i), 0L);
+ addPointToPartialTree(list.get(i), 0L);
}
} else {
root = Null;
diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/tree/ITree.java b/Java/core/src/main/java/com/amazon/randomcutforest/tree/ITree.java
index 097786e6..2b6d9c0d 100644
--- a/Java/core/src/main/java/com/amazon/randomcutforest/tree/ITree.java
+++ b/Java/core/src/main/java/com/amazon/randomcutforest/tree/ITree.java
@@ -34,10 +34,14 @@ public interface ITree extends ITraversable, IDynamicConf
double[] liftFromTree(double[] result);
- public int[] projectMissingIndices(int[] list);
+ int[] projectMissingIndices(int[] list);
PointReference addPoint(PointReference point, long sequenceIndex);
+ void addPointToPartialTree(PointReference point, long sequenceIndex);
+
+ void validateAndReconstruct();
+
PointReference deletePoint(PointReference point, long sequenceIndex);
default long getRandomSeed() {
diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreLarge.java b/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreLarge.java
index d6a40a03..2e3c386b 100644
--- a/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreLarge.java
+++ b/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreLarge.java
@@ -90,7 +90,7 @@ public NodeStoreLarge(AbstractNodeStore.Builder builder) {
@Override
public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, int pointIndex, int childIndex,
- int cutDimension, float cutValue, BoundingBox box) {
+ int childMassIfLeaf, int cutDimension, float cutValue, BoundingBox box) {
int index = freeNodeManager.takeIndex();
this.cutValue[index] = cutValue;
this.cutDimension[index] = (byte) cutDimension;
@@ -101,8 +101,8 @@ public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, i
this.rightIndex[index] = (pointIndex + capacity + 1);
this.leftIndex[index] = childIndex;
}
- this.mass[index] = (getMass(childIndex) + 1) % (capacity + 1);
- addLeaf(pointIndex, sequenceIndex);
+ this.mass[index] = (((childMassIfLeaf > 0) ? childMassIfLeaf : getMass(childIndex)) + 1) % (capacity + 1);
+
int parentIndex = (pathToRoot.size() == 0) ? Null : pathToRoot.lastElement()[0];
if (this.parentIndex != null) {
this.parentIndex[index] = parentIndex;
@@ -110,13 +110,8 @@ public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, i
this.parentIndex[childIndex] = (index);
}
}
- addBox(index, point, box);
if (parentIndex != Null) {
spliceEdge(parentIndex, childIndex, index);
- manageAncestorsAdd(pathToRoot, point, pointStoreView);
- }
- if (pointSum != null) {
- recomputePointSum(index);
}
return index;
}
@@ -151,18 +146,20 @@ public void deleteInternalNode(int index) {
if (parentIndex != null) {
parentIndex[index] = capacity;
}
- if (pointSum != null) {
- invalidatePointSum(index);
- }
- int idx = translate(index);
- if (idx != Integer.MAX_VALUE) {
- rangeSumData[idx] = 0.0;
- }
freeNodeManager.releaseIndex(index);
}
public int getMass(int index) {
- return (isLeaf(index)) ? getLeafMass(index) : mass[index] != 0 ? mass[index] : (capacity + 1);
+ return mass[index] != 0 ? mass[index] : (capacity + 1);
+ }
+
+ @Override
+ public void assignInPartialTree(int node, float[] point, int childReference) {
+ if (leftOf(node, point)) {
+ leftIndex[node] = childReference;
+ } else {
+ rightIndex[node] = childReference;
+ }
}
public void spliceEdge(int parent, int node, int newNode) {
@@ -205,15 +202,4 @@ public int[] getRightIndex() {
return Arrays.copyOf(rightIndex, rightIndex.length);
}
- @Override
- public void addToPartialTree(Stack pathToRoot, float[] point, int pointIndex) {
- int node = pathToRoot.lastElement()[0];
- if (leftOf(node, point)) {
- leftIndex[node] = (pointIndex + capacity + 1);
- } else {
- rightIndex[node] = (pointIndex + capacity + 1);
- }
- manageAncestorsAdd(pathToRoot, point, pointStoreView);
- }
-
}
diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreMedium.java b/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreMedium.java
index b2a4ae40..b2556416 100644
--- a/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreMedium.java
+++ b/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreMedium.java
@@ -93,7 +93,7 @@ public NodeStoreMedium(AbstractNodeStore.Builder builder) {
@Override
public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, int pointIndex, int childIndex,
- int cutDimension, float cutValue, BoundingBox box) {
+ int childMassIfLeaf, int cutDimension, float cutValue, BoundingBox box) {
int index = freeNodeManager.takeIndex();
this.cutValue[index] = cutValue;
this.cutDimension[index] = (char) cutDimension;
@@ -104,8 +104,8 @@ public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, i
this.rightIndex[index] = (pointIndex + capacity + 1);
this.leftIndex[index] = childIndex;
}
- this.mass[index] = (char) ((getMass(childIndex) + 1) % (capacity + 1));
- addLeaf(pointIndex, sequenceIndex);
+ this.mass[index] = (char) ((((childMassIfLeaf > 0) ? childMassIfLeaf : getMass(childIndex)) + 1)
+ % (capacity + 1));
int parentIndex = (pathToRoot.size() == 0) ? Null : pathToRoot.lastElement()[0];
if (this.parentIndex != null) {
this.parentIndex[index] = (char) parentIndex;
@@ -113,17 +113,21 @@ public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, i
this.parentIndex[childIndex] = (char) (index);
}
}
- addBox(index, point, box);
if (parentIndex != Null) {
spliceEdge(parentIndex, childIndex, index);
- manageAncestorsAdd(pathToRoot, point, pointStoreView);
- }
- if (pointSum != null) {
- recomputePointSum(index);
}
return index;
}
+ @Override
+ public void assignInPartialTree(int node, float[] point, int childReference) {
+ if (leftOf(node, point)) {
+ leftIndex[node] = childReference;
+ } else {
+ rightIndex[node] = childReference;
+ }
+ }
+
public int getLeftIndex(int index) {
return leftIndex[index];
}
@@ -155,18 +159,11 @@ public void deleteInternalNode(int index) {
if (parentIndex != null) {
parentIndex[index] = (char) capacity;
}
- if (pointSum != null) {
- invalidatePointSum(index);
- }
- int idx = translate(index);
- if (idx != Integer.MAX_VALUE) {
- rangeSumData[idx] = 0.0;
- }
freeNodeManager.releaseIndex(index);
}
public int getMass(int index) {
- return (isLeaf(index)) ? getLeafMass(index) : mass[index] != 0 ? mass[index] : (capacity + 1);
+ return mass[index] != 0 ? mass[index] : (capacity + 1);
}
public void spliceEdge(int parent, int node, int newNode) {
@@ -209,14 +206,4 @@ public int[] getRightIndex() {
return Arrays.copyOf(rightIndex, rightIndex.length);
}
- @Override
- public void addToPartialTree(Stack pathToRoot, float[] point, int pointIndex) {
- int node = pathToRoot.lastElement()[0];
- if (leftOf(node, point)) {
- leftIndex[node] = (pointIndex + capacity + 1);
- } else {
- rightIndex[node] = (pointIndex + capacity + 1);
- }
- manageAncestorsAdd(pathToRoot, point, pointStoreView);
- }
}
diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreSmall.java b/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreSmall.java
index 39f7ce07..c0f8d5c0 100644
--- a/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreSmall.java
+++ b/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeStoreSmall.java
@@ -98,7 +98,7 @@ public NodeStoreSmall(AbstractNodeStore.Builder builder) {
@Override
public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, int pointIndex, int childIndex,
- int cutDimension, float cutValue, BoundingBox box) {
+ int childMassIfLeaf, int cutDimension, float cutValue, BoundingBox box) {
int index = freeNodeManager.takeIndex();
this.cutValue[index] = cutValue;
this.cutDimension[index] = (byte) cutDimension;
@@ -109,8 +109,8 @@ public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, i
this.rightIndex[index] = (char) (pointIndex + capacity + 1);
this.leftIndex[index] = (char) childIndex;
}
- this.mass[index] = (byte) (((byte) getMass(childIndex) + 1) % (capacity + 1));
- addLeaf(pointIndex, sequenceIndex);
+ this.mass[index] = (byte) ((((childMassIfLeaf > 0) ? childMassIfLeaf : getMass(childIndex)) + 1)
+ % (capacity + 1));
int parentIndex = (pathToRoot.size() == 0) ? Null : pathToRoot.lastElement()[0];
if (this.parentIndex != null) {
this.parentIndex[index] = (byte) parentIndex;
@@ -118,26 +118,19 @@ public int addNode(Stack pathToRoot, float[] point, long sequenceIndex, i
this.parentIndex[childIndex] = (byte) (index);
}
}
- addBox(index, point, box);
if (parentIndex != Null) {
spliceEdge(parentIndex, childIndex, index);
- manageAncestorsAdd(pathToRoot, point, pointStoreView);
- }
- if (pointSum != null) {
- recomputePointSum(index);
}
return index;
}
@Override
- public void addToPartialTree(Stack pathToRoot, float[] point, int pointIndex) {
- int node = pathToRoot.lastElement()[0];
+ public void assignInPartialTree(int node, float[] point, int childReference) {
if (leftOf(node, point)) {
- leftIndex[node] = (char) (pointIndex + capacity + 1);
+ leftIndex[node] = (char) childReference;
} else {
- rightIndex[node] = (char) (pointIndex + capacity + 1);
+ rightIndex[node] = (char) childReference;
}
- manageAncestorsAdd(pathToRoot, point, pointStoreView);
}
public int getLeftIndex(int index) {
@@ -171,18 +164,11 @@ public void deleteInternalNode(int index) {
if (parentIndex != null) {
parentIndex[index] = (byte) capacity;
}
- if (pointSum != null) {
- invalidatePointSum(index);
- }
- int idx = translate(index);
- if (idx != Integer.MAX_VALUE) {
- rangeSumData[idx] = 0.0;
- }
freeNodeManager.releaseIndex(index);
}
public int getMass(int index) {
- return (isLeaf(index)) ? getLeafMass(index) : mass[index] != 0 ? (mass[index] & 0xff) : (capacity + 1);
+ return mass[index] != 0 ? (mass[index] & 0xff) : (capacity + 1);
}
public void spliceEdge(int parent, int node, int newNode) {
diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeView.java b/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeView.java
index 7481c7a9..6f9ca01a 100644
--- a/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeView.java
+++ b/Java/core/src/main/java/com/amazon/randomcutforest/tree/NodeView.java
@@ -22,40 +22,42 @@
import com.amazon.randomcutforest.store.IPointStoreView;
public class NodeView implements INodeView {
- AbstractNodeStore nodeStore;
+
+ public static double SWITCH_FRACTION = 0.499;
+
+ RandomCutTree tree;
int currentNodeOffset;
float[] leafPoint;
- IPointStoreView pointStoreView;
BoundingBox currentBox;
- public NodeView(AbstractNodeStore nodeStore, IPointStoreView pointStoreView, int root) {
+ public NodeView(RandomCutTree tree, IPointStoreView pointStoreView, int root) {
this.currentNodeOffset = root;
- this.pointStoreView = pointStoreView;
- this.nodeStore = nodeStore;
+ this.tree = tree;
}
public int getMass() {
- return nodeStore.getMass(currentNodeOffset);
+ return tree.getMass(currentNodeOffset);
}
public IBoundingBoxView getBoundingBox() {
if (currentBox == null) {
- return nodeStore.getBox(currentNodeOffset);
+ return tree.getBox(currentNodeOffset);
}
return currentBox;
}
public IBoundingBoxView getSiblingBoundingBox(float[] point) {
- return (toLeft(point)) ? nodeStore.getRightBox(currentNodeOffset) : nodeStore.getLeftBox(currentNodeOffset);
+ return (toLeft(point)) ? tree.getBox(tree.nodeStore.getRightIndex(currentNodeOffset))
+ : tree.getBox(tree.nodeStore.getLeftIndex(currentNodeOffset));
}
public int getCutDimension() {
- return nodeStore.getCutDimension(currentNodeOffset);
+ return tree.nodeStore.getCutDimension(currentNodeOffset);
}
@Override
public double getCutValue() {
- return nodeStore.getCutValue(currentNodeOffset);
+ return tree.nodeStore.getCutValue(currentNodeOffset);
}
public float[] getLeafPoint() {
@@ -64,8 +66,8 @@ public float[] getLeafPoint() {
public HashMap getSequenceIndexes() {
checkState(isLeaf(), "can only be invoked for a leaf");
- if (nodeStore.storeSequenceIndexesEnabled) {
- return nodeStore.sequenceMap.get(nodeStore.getPointIndex(currentNodeOffset));
+ if (tree.storeSequenceIndexesEnabled) {
+ return tree.getSequenceMap(tree.getPointIndex(currentNodeOffset));
} else {
return new HashMap<>();
}
@@ -73,23 +75,22 @@ public HashMap getSequenceIndexes() {
@Override
public double probailityOfSeparation(float[] point) {
- return nodeStore.probabilityOfCut(currentNodeOffset, point, pointStoreView, currentBox);
+ return tree.probabilityOfCut(currentNodeOffset, point, currentBox);
}
@Override
public int getLeafPointIndex() {
- checkState(isLeaf(), "cannot invoke 'getLeafPointIndex' from a non-leaf node");
- return nodeStore.getPointIndex(currentNodeOffset);
+ return tree.getPointIndex(currentNodeOffset);
}
public boolean isLeaf() {
- return nodeStore.isLeaf(currentNodeOffset);
+ return tree.nodeStore.isLeaf(currentNodeOffset);
}
protected void setCurrentNode(int newNode, int index, boolean setBox) {
currentNodeOffset = newNode;
- leafPoint = pointStoreView.get(index);
- if (setBox && nodeStore.boundingboxCacheFraction < AbstractNodeStore.SWITCH_FRACTION) {
+ leafPoint = tree.pointStoreView.get(index);
+ if (setBox && tree.boundingBoxCacheFraction < SWITCH_FRACTION) {
currentBox = new BoundingBox(leafPoint, leafPoint);
}
}
@@ -100,14 +101,15 @@ protected void setCurrentNodeOnly(int newNode) {
public void updateToParent(int parent, int currentSibling, boolean updateBox) {
currentNodeOffset = parent;
- if (updateBox && nodeStore.boundingboxCacheFraction < AbstractNodeStore.SWITCH_FRACTION) {
- nodeStore.growNodeBox(currentBox, pointStoreView, parent, currentSibling);
+ if (updateBox && tree.boundingBoxCacheFraction < SWITCH_FRACTION) {
+ tree.growNodeBox(currentBox, tree.pointStoreView, parent, currentSibling);
}
}
// this function exists for matching the behavior of RCF2.0 and will be replaced
// this function explicitly uses the encoding of the new nodestore
protected boolean toLeft(float[] point) {
- return point[nodeStore.getCutDimension(currentNodeOffset)] <= nodeStore.getCutValue(currentNodeOffset);
+ return point[tree.nodeStore.getCutDimension(currentNodeOffset)] <= tree.nodeStore
+ .getCutValue(currentNodeOffset);
}
}
diff --git a/Java/core/src/main/java/com/amazon/randomcutforest/tree/RandomCutTree.java b/Java/core/src/main/java/com/amazon/randomcutforest/tree/RandomCutTree.java
index dd9789be..540db178 100644
--- a/Java/core/src/main/java/com/amazon/randomcutforest/tree/RandomCutTree.java
+++ b/Java/core/src/main/java/com/amazon/randomcutforest/tree/RandomCutTree.java
@@ -20,7 +20,10 @@
import static com.amazon.randomcutforest.CommonUtils.checkState;
import static com.amazon.randomcutforest.tree.AbstractNodeStore.Null;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.Stack;
@@ -67,6 +70,11 @@ public class RandomCutTree implements ITree {
protected double boundingBoxCacheFraction;
protected int outputAfter;
protected int dimension;
+ protected final HashMap leafMass;
+ protected double[] rangeSumData;
+ protected float[] boundingBoxData;
+ protected float[] pointSum;
+ protected HashMap> sequenceMap;
protected RandomCutTree(Builder> builder) {
pointStoreView = builder.pointStoreView;
@@ -76,23 +84,28 @@ protected RandomCutTree(Builder> builder) {
outputAfter = builder.outputAfter.orElse(numberOfLeaves / 4);
dimension = (builder.dimension != 0) ? builder.dimension : pointStoreView.getDimensions();
nodeStore = (builder.nodeStore != null) ? builder.nodeStore
- : AbstractNodeStore.builder().capacity(numberOfLeaves - 1).dimensions(dimension)
- .boundingBoxCacheFraction(builder.boundingBoxCacheFraction).pointStoreView(pointStoreView)
- .centerOfMassEnabled(builder.centerOfMassEnabled)
- .storeSequencesEnabled(builder.storeSequenceIndexesEnabled).build();
- // note the number of internal nodes is one less than sampleSize
- // the RCF V2_0 states used this notion
+ : AbstractNodeStore.builder().capacity(numberOfLeaves - 1).dimension(dimension).build();
this.boundingBoxCacheFraction = builder.boundingBoxCacheFraction;
this.storeSequenceIndexesEnabled = builder.storeSequenceIndexesEnabled;
this.centerOfMassEnabled = builder.centerOfMassEnabled;
this.root = builder.root;
+ leafMass = new HashMap<>();
+ int cache_limit = (int) Math.floor(boundingBoxCacheFraction * (numberOfLeaves - 1));
+ rangeSumData = new double[cache_limit];
+ boundingBoxData = new float[2 * dimension * cache_limit];
+ if (this.centerOfMassEnabled) {
+ pointSum = new float[(numberOfLeaves - 1) * dimension];
+ }
+ if (this.storeSequenceIndexesEnabled) {
+ sequenceMap = new HashMap<>();
+ }
}
@Override
public void setConfig(String name, T value, Class clazz) {
if (Config.BOUNDING_BOX_CACHE_FRACTION.equals(name)) {
checkArgument(Double.class.isAssignableFrom(clazz),
- String.format("Setting '%s' must be a double value", name));
+ () -> String.format("Setting '%s' must be a double value", name));
setBoundingBoxCacheFraction((Double) value);
} else {
throw new IllegalArgumentException("Unsupported configuration setting: " + name);
@@ -104,7 +117,7 @@ public T getConfig(String name, Class clazz) {
checkNotNull(clazz, "clazz must not be null");
if (Config.BOUNDING_BOX_CACHE_FRACTION.equals(name)) {
checkArgument(clazz.isAssignableFrom(Double.class),
- String.format("Setting '%s' must be a double value", name));
+ () -> String.format("Setting '%s' must be a double value", name));
return clazz.cast(boundingBoxCacheFraction);
} else {
throw new IllegalArgumentException("Unsupported configuration setting: " + name);
@@ -118,7 +131,7 @@ public T getConfig(String name, Class clazz) {
public void setBoundingBoxCacheFraction(double fraction) {
checkArgument(0 <= fraction && fraction <= 1, "incorrect parameter");
boundingBoxCacheFraction = fraction;
- nodeStore.resizeCache(fraction);
+ resizeCache(fraction);
}
/**
@@ -134,7 +147,7 @@ public void setBoundingBoxCacheFraction(double fraction) {
* @param box A bounding box that we want to find a random cut for.
* @return A new Cut corresponding to a random cut in the bounding box.
*/
- protected static Cut randomCut(double factor, float[] point, BoundingBox box) {
+ protected Cut randomCut(double factor, float[] point, BoundingBox box) {
double range = 0.0;
for (int i = 0; i < point.length; i++) {
@@ -148,7 +161,7 @@ protected static Cut randomCut(double factor, float[] point, BoundingBox box) {
range += maxValue - minValue;
}
- checkArgument(range > 0, " the union is a single point " + Arrays.toString(point)
+ checkArgument(range > 0, () -> " the union is a single point " + Arrays.toString(point)
+ "or the box is inappropriate, box" + box.toString() + "factor =" + factor);
double breakPoint = factor * range;
@@ -223,40 +236,42 @@ protected static Cut randomCut(double factor, float[] point, BoundingBox box) {
}
+ /**
+ * the following function adds a point to the tree
+ *
+ * @param pointIndex the number corresponding to the point
+ * @param sequenceIndex sequence index of the point
+ * @return the value of the point index where the point was added; this is
+ * pointIndex if there are no duplicates; otherwise it is the value of
+ * the point being duplicated.
+ */
public Integer addPoint(Integer pointIndex, long sequenceIndex) {
if (root == Null) {
- root = nodeStore.addLeaf(pointIndex, sequenceIndex);
+ root = convertToLeaf(pointIndex);
+ addLeaf(pointIndex, sequenceIndex);
return pointIndex;
} else {
- float[] point = pointStoreView.get(pointIndex);
+ float[] point = projectToTree(pointStoreView.get(pointIndex));
+ checkArgument(point.length == dimension, () -> " mismatch in dimensions for " + pointIndex);
Stack pathToRoot = nodeStore.getPath(root, point, false);
int[] first = pathToRoot.pop();
int leafNode = first[0];
int savedParent = (pathToRoot.size() == 0) ? Null : pathToRoot.lastElement()[0];
- if (!nodeStore.isLeaf(leafNode)) {
- // this corresponds to rebuilding a partial tree
- if (savedParent == Null) {
- root = pointIndex + numberOfLeaves; // note this capacity is nodestore.capacity + 1
- } else {
- nodeStore.addToPartialTree(pathToRoot, point, pointIndex);
- nodeStore.manageAncestorsAdd(pathToRoot, point, pointStoreView);
- nodeStore.addLeaf(pointIndex, sequenceIndex);
- }
- return pointIndex;
- }
int leafSavedSibling = first[1];
int sibling = leafSavedSibling;
- int leafPointIndex = nodeStore.getPointIndex(leafNode);
- float[] oldPoint = pointStoreView.get(leafPointIndex);
+ int leafPointIndex = getPointIndex(leafNode);
+ float[] oldPoint = projectToTree(pointStoreView.get(leafPointIndex));
+ checkArgument(oldPoint.length == dimension, () -> " mismatch in dimensions for " + pointIndex);
+
Stack parentPath = new Stack<>();
if (Arrays.equals(point, oldPoint)) {
- nodeStore.increaseLeafMass(leafNode);
+ increaseLeafMass(leafNode);
checkArgument(!nodeStore.freeNodeManager.isEmpty(), "incorrect/impossible state");
- nodeStore.manageAncestorsAdd(pathToRoot, point, pointStoreView);
- nodeStore.addLeaf(leafPointIndex, sequenceIndex);
+ manageAncestorsAdd(pathToRoot, point);
+ addLeaf(leafPointIndex, sequenceIndex);
return leafPointIndex;
} else {
int node = leafNode;
@@ -293,14 +308,11 @@ public Integer addPoint(Integer pointIndex, long sequenceIndex) {
parentPath.push(new int[] { node, sibling });
}
- if (savedDim == Integer.MAX_VALUE) {
- randomCut(factor, point, currentBox);
- throw new IllegalStateException(" cut failed ");
- }
+ checkArgument(savedDim != Integer.MAX_VALUE, () -> " cut failed at index " + pointIndex);
if (currentBox.contains(point) || parent == Null) {
break;
} else {
- nodeStore.growNodeBox(currentBox, pointStoreView, parent, sibling);
+ growNodeBox(currentBox, pointStoreView, parent, sibling);
int[] next = pathToRoot.pop();
node = next[0];
sibling = next[1];
@@ -318,8 +330,15 @@ public Integer addPoint(Integer pointIndex, long sequenceIndex) {
assert (pathToRoot.lastElement()[0] == savedParent);
}
- int mergedNode = nodeStore.addNode(pathToRoot, point, sequenceIndex, pointIndex, savedNode, savedDim,
- savedCutValue, savedBox);
+ int childMassIfLeaf = isLeaf(savedNode) ? getLeafMass(savedNode) : 0;
+ int mergedNode = nodeStore.addNode(pathToRoot, point, sequenceIndex, pointIndex, savedNode,
+ childMassIfLeaf, savedDim, savedCutValue, savedBox);
+ addLeaf(pointIndex, sequenceIndex);
+ addBox(mergedNode, point, savedBox);
+ manageAncestorsAdd(pathToRoot, point);
+ if (pointSum != null) {
+ recomputePointSum(mergedNode);
+ }
if (savedParent == Null) {
root = mergedNode;
}
@@ -328,25 +347,77 @@ public Integer addPoint(Integer pointIndex, long sequenceIndex) {
}
}
- public Integer deletePoint(Integer pointIndex, long sequenceIndex) {
+ protected void manageAncestorsAdd(Stack path, float[] point) {
+ while (!path.isEmpty()) {
+ int index = path.pop()[0];
+ nodeStore.increaseMassOfInternalNode(index);
+ if (pointSum != null) {
+ recomputePointSum(index);
+ }
+ if (boundingBoxCacheFraction > 0.0) {
+ checkContainsAndRebuildBox(index, point, pointStoreView);
+ checkContainsAndAddPoint(index, point);
+ }
+ }
+ }
- if (root == Null) {
- throw new IllegalStateException(" deleting from an empty tree");
+ /**
+ * the following is the same as in addPoint() except this function is used to
+ * rebuild the tree structure. This function does not create auxiliary arrays,
+ * which should be performed using validateAndReconstruct()
+ *
+ * @param pointIndex index of point (in point store)
+ * @param sequenceIndex sequence index (stored in sampler)
+ */
+ public void addPointToPartialTree(Integer pointIndex, long sequenceIndex) {
+
+ checkArgument(root != Null, " a null root is not a partial tree");
+ float[] point = projectToTree(pointStoreView.get(pointIndex));
+ checkArgument(point.length == dimension, () -> " incorrect projection at index " + pointIndex);
+
+ Stack pathToRoot = nodeStore.getPath(root, point, false);
+ int[] first = pathToRoot.pop();
+ int leafNode = first[0];
+ int savedParent = (pathToRoot.size() == 0) ? Null : pathToRoot.lastElement()[0];
+ if (!nodeStore.isLeaf(leafNode)) {
+ if (savedParent == Null) {
+ root = convertToLeaf(pointIndex);
+ } else {
+ nodeStore.assignInPartialTree(savedParent, point, convertToLeaf(pointIndex));
+ nodeStore.manageInternalNodesPartial(pathToRoot);
+ addLeaf(pointIndex, sequenceIndex);
+ }
+ return;
}
- float[] point = pointStoreView.get(pointIndex);
+ int leafPointIndex = getPointIndex(leafNode);
+ float[] oldPoint = projectToTree(pointStoreView.get(leafPointIndex));
+
+ checkArgument(oldPoint.length == dimension && Arrays.equals(point, oldPoint),
+ () -> "incorrect state on adding " + pointIndex);
+ increaseLeafMass(leafNode);
+ checkArgument(!nodeStore.freeNodeManager.isEmpty(), "incorrect/impossible state");
+ nodeStore.manageInternalNodesPartial(pathToRoot);
+ addLeaf(leafPointIndex, sequenceIndex);
+ return;
+ }
+
+ public Integer deletePoint(Integer pointIndex, long sequenceIndex) {
+
+ checkArgument(root != Null, " deleting from an empty tree");
+ float[] point = projectToTree(pointStoreView.get(pointIndex));
+ checkArgument(point.length == dimension, () -> " incorrect projection at index " + pointIndex);
Stack pathToRoot = nodeStore.getPath(root, point, false);
int[] first = pathToRoot.pop();
int leafSavedSibling = first[1];
int leafNode = first[0];
- int leafPointIndex = nodeStore.getPointIndex(leafNode);
+ int leafPointIndex = getPointIndex(leafNode);
- if (leafPointIndex != pointIndex && !pointStoreView.pointEquals(leafPointIndex, point)) {
- throw new IllegalStateException(" deleting wrong node " + leafPointIndex + " instead of " + pointIndex);
- } else if (storeSequenceIndexesEnabled) {
- nodeStore.removeLeaf(leafPointIndex, sequenceIndex);
- }
+ checkArgument(leafPointIndex == pointIndex,
+ () -> " deleting wrong node " + leafPointIndex + " instead of " + pointIndex);
+
+ removeLeaf(leafPointIndex, sequenceIndex);
- if (nodeStore.decreaseLeafMass(leafNode) == 0) {
+ if (decreaseLeafMass(leafNode) == 0) {
if (pathToRoot.size() == 0) {
root = Null;
} else {
@@ -357,16 +428,406 @@ public Integer deletePoint(Integer pointIndex, long sequenceIndex) {
} else {
int grandParent = pathToRoot.lastElement()[0];
nodeStore.replaceParentBySibling(grandParent, parent, leafNode);
- nodeStore.manageAncestorsDelete(pathToRoot, point, pointStoreView);
+ manageAncestorsDelete(pathToRoot, point);
}
nodeStore.deleteInternalNode(parent);
+ if (pointSum != null) {
+ invalidatePointSum(parent);
+ }
+ int idx = translate(parent);
+ if (idx != Integer.MAX_VALUE) {
+ rangeSumData[idx] = 0.0;
+ }
}
} else {
- nodeStore.manageAncestorsDelete(pathToRoot, point, pointStoreView);
+ manageAncestorsDelete(pathToRoot, point);
}
return leafPointIndex;
}
+ protected void manageAncestorsDelete(Stack path, float[] point) {
+ boolean resolved = false;
+ while (!path.isEmpty()) {
+ int index = path.pop()[0];
+ nodeStore.decreaseMassOfInternalNode(index);
+ if (pointSum != null) {
+ recomputePointSum(index);
+ }
+ if (boundingBoxCacheFraction > 0.0 && !resolved) {
+ resolved = checkContainsAndRebuildBox(index, point, pointStoreView);
+ }
+ }
+ }
+
+ //// leaf, nonleaf representations
+
+ public boolean isLeaf(int index) {
+ // note that numberOfLeaves - 1 corresponds to an unspefied leaf in partial tree
+ // 0 .. numberOfLeaves - 2 corresponds to internal nodes
+ return index >= numberOfLeaves;
+ }
+
+ public boolean isInternal(int index) {
+ // note that numberOfLeaves - 1 corresponds to an unspefied leaf in partial tree
+ // 0 .. numberOfLeaves - 2 corresponds to internal nodes
+ return index < numberOfLeaves - 1;
+ }
+
+ public int convertToLeaf(int pointIndex) {
+ return pointIndex + numberOfLeaves;
+ }
+
+ public int getPointIndex(int index) {
+ checkArgument(index >= numberOfLeaves, () -> " does not have a point associated " + index);
+ return index - numberOfLeaves;
+ }
+
+ public int getLeftChild(int index) {
+ checkArgument(isInternal(index), () -> "incorrect call to get left Index " + index);
+ return nodeStore.getLeftIndex(index);
+ }
+
+ public int getRightChild(int index) {
+ checkArgument(isInternal(index), () -> "incorrect call to get right child " + index);
+ return nodeStore.getRightIndex(index);
+ }
+
+ public int getCutDimension(int index) {
+ checkArgument(isInternal(index), () -> "incorrect call to get cut dimension " + index);
+ return nodeStore.getCutDimension(index);
+ }
+
+ public double getCutValue(int index) {
+ checkArgument(isInternal(index), () -> "incorrect call to get cut value " + index);
+ return nodeStore.getCutValue(index);
+ }
+
+ ///// mass assignments; separating leafs and internal nodes
+
+ protected int getMass(int index) {
+ return (isLeaf(index)) ? getLeafMass(index) : nodeStore.getMass(index);
+ }
+
+ protected int getLeafMass(int index) {
+ int y = (index - numberOfLeaves);
+ Integer value = leafMass.get(y);
+ return (value != null) ? value + 1 : 1;
+ }
+
+ protected void increaseLeafMass(int index) {
+ int y = (index - numberOfLeaves);
+ leafMass.merge(y, 1, Integer::sum);
+ }
+
+ protected int decreaseLeafMass(int index) {
+ int y = (index - numberOfLeaves);
+ Integer value = leafMass.remove(y);
+ if (value != null) {
+ if (value > 1) {
+ leafMass.put(y, (value - 1));
+ return value;
+ } else {
+ return 1;
+ }
+ } else {
+ return 0;
+ }
+ }
+
+ @Override
+ public int getMass() {
+ return root == Null ? 0 : isLeaf(root) ? getLeafMass(root) : nodeStore.getMass(root);
+ }
+
+ /////// Bounding box
+
+ public void resizeCache(double fraction) {
+ if (fraction == 0) {
+ rangeSumData = null;
+ boundingBoxData = null;
+ } else {
+ int limit = (int) Math.floor(fraction * (numberOfLeaves - 1));
+ rangeSumData = (rangeSumData == null) ? new double[limit] : Arrays.copyOf(rangeSumData, limit);
+ boundingBoxData = (boundingBoxData == null) ? new float[limit * 2 * dimension]
+ : Arrays.copyOf(boundingBoxData, limit * 2 * dimension);
+ }
+ boundingBoxCacheFraction = fraction;
+ }
+
+ protected int translate(int index) {
+ if (rangeSumData == null || rangeSumData.length <= index) {
+ return Integer.MAX_VALUE;
+ } else {
+ return index;
+ }
+ }
+
+ void copyBoxToData(int idx, BoundingBox box) {
+ int base = 2 * idx * dimension;
+ int mid = base + dimension;
+ System.arraycopy(box.getMinValues(), 0, boundingBoxData, base, dimension);
+ System.arraycopy(box.getMaxValues(), 0, boundingBoxData, mid, dimension);
+ rangeSumData[idx] = box.getRangeSum();
+ }
+
+ boolean checkContainsAndAddPoint(int index, float[] point) {
+ int idx = translate(index);
+ if (idx != Integer.MAX_VALUE && rangeSumData[idx] != 0) {
+ int base = 2 * idx * dimension;
+ int mid = base + dimension;
+ double rangeSum = 0;
+ for (int i = 0; i < dimension; i++) {
+ boundingBoxData[base + i] = Math.min(boundingBoxData[base + i], point[i]);
+ }
+ for (int i = 0; i < dimension; i++) {
+ boundingBoxData[mid + i] = Math.max(boundingBoxData[mid + i], point[i]);
+ }
+ for (int i = 0; i < dimension; i++) {
+ rangeSum += boundingBoxData[mid + i] - boundingBoxData[base + i];
+ }
+ boolean answer = (rangeSumData[idx] == rangeSum);
+ rangeSumData[idx] = rangeSum;
+ return answer;
+ }
+ return false;
+ }
+
+ public BoundingBox getBox(int index) {
+ if (isLeaf(index)) {
+ float[] point = projectToTree(pointStoreView.get(getPointIndex(index)));
+ checkArgument(point.length == dimension, () -> "failure in projection at index " + index);
+ return new BoundingBox(point, point);
+ } else {
+ checkArgument(isInternal(index), " incomplete state");
+ int idx = translate(index);
+ if (idx != Integer.MAX_VALUE) {
+ if (rangeSumData[idx] != 0) {
+ // return non-trivial boxes
+ return getBoxFromData(idx);
+ } else {
+ BoundingBox box = reconstructBox(index, pointStoreView);
+ copyBoxToData(idx, box);
+ return box;
+ }
+ }
+ return reconstructBox(index, pointStoreView);
+ }
+ }
+
+ BoundingBox reconstructBox(int index, IPointStoreView pointStoreView) {
+ BoundingBox mutatedBoundingBox = getBox(nodeStore.getLeftIndex(index));
+ growNodeBox(mutatedBoundingBox, pointStoreView, index, nodeStore.getRightIndex(index));
+ return mutatedBoundingBox;
+ }
+
+ boolean checkStrictlyContains(int index, float[] point) {
+ int idx = translate(index);
+ if (idx != Integer.MAX_VALUE) {
+ int base = 2 * idx * dimension;
+ int mid = base + dimension;
+ boolean isInside = true;
+ for (int i = 0; i < dimension && isInside; i++) {
+ if (point[i] >= boundingBoxData[mid + i] || boundingBoxData[base + i] >= point[i]) {
+ isInside = false;
+ }
+ }
+ return isInside;
+ }
+ return false;
+ }
+
+ boolean checkContainsAndRebuildBox(int index, float[] point, IPointStoreView pointStoreView) {
+ int idx = translate(index);
+ if (idx != Integer.MAX_VALUE && rangeSumData[idx] != 0) {
+ if (!checkStrictlyContains(index, point)) {
+ BoundingBox mutatedBoundingBox = reconstructBox(index, pointStoreView);
+ copyBoxToData(idx, mutatedBoundingBox);
+ return false;
+ }
+ return true;
+ }
+ return false;
+ }
+
+ BoundingBox getBoxFromData(int idx) {
+ int base = 2 * idx * dimension;
+ int mid = base + dimension;
+
+ return new BoundingBox(Arrays.copyOfRange(boundingBoxData, base, base + dimension),
+ Arrays.copyOfRange(boundingBoxData, mid, mid + dimension));
+ }
+
+ void addBox(int index, float[] point, BoundingBox box) {
+ if (isInternal(index)) {
+ int idx = translate(index);
+ if (idx != Integer.MAX_VALUE) { // always add irrespective of rangesum
+ copyBoxToData(idx, box);
+ checkContainsAndAddPoint(index, point);
+ }
+ }
+ }
+
+ void growNodeBox(BoundingBox box, IPointStoreView pointStoreView, int node, int sibling) {
+ if (isLeaf(sibling)) {
+ float[] point = projectToTree(pointStoreView.get(getPointIndex(sibling)));
+ checkArgument(point.length == dimension, () -> " incorrect projection at index " + sibling);
+ box.addPoint(point);
+ } else {
+ if (!isInternal(sibling)) {
+ throw new IllegalStateException(" incomplete state " + sibling);
+ }
+ int siblingIdx = translate(sibling);
+ if (siblingIdx != Integer.MAX_VALUE) {
+ if (rangeSumData[siblingIdx] != 0) {
+ box.addBox(getBoxFromData(siblingIdx));
+ } else {
+ BoundingBox newBox = getBox(siblingIdx);
+ copyBoxToData(siblingIdx, newBox);
+ box.addBox(newBox);
+ }
+ return;
+ }
+ growNodeBox(box, pointStoreView, sibling, nodeStore.getLeftIndex(sibling));
+ growNodeBox(box, pointStoreView, sibling, nodeStore.getRightIndex(sibling));
+ return;
+ }
+ }
+
+ public double probabilityOfCut(int node, float[] point, BoundingBox otherBox) {
+ int nodeIdx = translate(node);
+ if (nodeIdx != Integer.MAX_VALUE && rangeSumData[nodeIdx] != 0) {
+ int base = 2 * nodeIdx * dimension;
+ int mid = base + dimension;
+ double minsum = 0;
+ double maxsum = 0;
+ for (int i = 0; i < dimension; i++) {
+ minsum += Math.max(boundingBoxData[base + i] - point[i], 0);
+ }
+ for (int i = 0; i < dimension; i++) {
+ maxsum += Math.max(point[i] - boundingBoxData[mid + i], 0);
+ }
+ double sum = maxsum + minsum;
+
+ if (sum == 0.0) {
+ return 0.0;
+ }
+ return sum / (rangeSumData[nodeIdx] + sum);
+ } else if (otherBox != null) {
+ return otherBox.probabilityOfCut(point);
+ } else {
+ BoundingBox box = getBox(node);
+ return box.probabilityOfCut(point);
+ }
+ }
+
+ /// additional information at nodes
+
+ public float[] getPointSum(int index) {
+ checkArgument(centerOfMassEnabled, " enable center of mass");
+ if (isLeaf(index)) {
+ float[] point = projectToTree(pointStoreView.get(getPointIndex(index)));
+ checkArgument(point.length == dimension, () -> " incorrect projection");
+ int mass = getMass(index);
+ for (int i = 0; i < point.length; i++) {
+ point[i] *= mass;
+ }
+ return point;
+ } else {
+ return Arrays.copyOfRange(pointSum, index * dimension, (index + 1) * dimension);
+ }
+ }
+
+ public void invalidatePointSum(int index) {
+ for (int i = 0; i < dimension; i++) {
+ pointSum[index * dimension + i] = 0;
+ }
+ }
+
+ public void recomputePointSum(int index) {
+ float[] left = getPointSum(nodeStore.getLeftIndex(index));
+ float[] right = getPointSum(nodeStore.getRightIndex(index));
+ for (int i = 0; i < dimension; i++) {
+ pointSum[index * dimension + i] = left[i] + right[i];
+ }
+ }
+
+ public HashMap getSequenceMap(int index) {
+ HashMap hashMap = new HashMap<>();
+ List list = getSequenceList(index);
+ for (Long e : list) {
+ hashMap.merge(e, 1, Integer::sum);
+ }
+ return hashMap;
+ }
+
+ public List getSequenceList(int index) {
+ return sequenceMap.get(index);
+ }
+
+ protected void addLeaf(int pointIndex, long sequenceIndex) {
+ if (storeSequenceIndexesEnabled) {
+ List leafList = sequenceMap.remove(pointIndex);
+ if (leafList == null) {
+ leafList = new ArrayList<>(1);
+ }
+ leafList.add(sequenceIndex);
+ sequenceMap.put(pointIndex, leafList);
+ }
+ }
+
+ public void removeLeaf(int leafPointIndex, long sequenceIndex) {
+ if (storeSequenceIndexesEnabled) {
+ List leafList = sequenceMap.remove(leafPointIndex);
+ checkArgument(leafList != null, " leaf index not found in tree");
+ checkArgument(leafList.remove(sequenceIndex), " sequence index not found in leaf");
+ if (!leafList.isEmpty()) {
+ sequenceMap.put(leafPointIndex, leafList);
+ }
+ }
+ }
+
+ //// validations
+
+ public void validateAndReconstruct() {
+ if (root != Null) {
+ validateAndReconstruct(root);
+ }
+ }
+
+ /**
+ * This function is supposed to validate the integrity of the tree and rebuild
+ * internal data structures. At this moment the only internal structure is the
+ * pointsum.
+ *
+ * @param index the node of a tree
+ * @return a bounding box of the points
+ */
+ public BoundingBox validateAndReconstruct(int index) {
+ if (isLeaf(index)) {
+ return getBox(index);
+ } else {
+ BoundingBox leftBox = validateAndReconstruct(getLeftChild(index));
+ BoundingBox rightBox = validateAndReconstruct(getRightChild(index));
+ if (leftBox.maxValues[getCutDimension(index)] > getCutValue(index)
+ || rightBox.minValues[getCutDimension(index)] <= getCutValue(index)) {
+ throw new IllegalStateException(" incorrect bounding state at index " + index + " cut value "
+ + getCutValue(index) + "cut dimension " + getCutDimension(index) + " left Box "
+ + leftBox.toString() + " right box " + rightBox.toString());
+ }
+ if (centerOfMassEnabled) {
+ recomputePointSum(index);
+ }
+ rightBox.addBox(leftBox);
+ int idx = translate(index);
+ if (idx != Integer.MAX_VALUE) { // always add irrespective of rangesum
+ copyBoxToData(idx, rightBox);
+ }
+ return rightBox;
+ }
+ }
+
+ //// traversals
+
/**
* Starting from the root, traverse the canonical path to a leaf node and visit
* the nodes along the path. The canonical path is determined by the input
@@ -390,11 +851,31 @@ public Integer deletePoint(Integer pointIndex, long sequenceIndex) {
public R traverse(float[] point, IVisitorFactory visitorFactory) {
checkState(root != Null, "this tree doesn't contain any nodes");
Visitor visitor = visitorFactory.newVisitor(this, point);
- nodeStore.traversePathToLeafAndVisitNodes(projectToTree(point), visitor, root, pointStoreView,
- this::liftFromTree);
+ NodeView currentNodeView = new NodeView(this, pointStoreView, root);
+ traversePathToLeafAndVisitNodes(point, visitor, currentNodeView, root, 0);
return visitorFactory.liftResult(this, visitor.getResult());
}
+ protected void traversePathToLeafAndVisitNodes(float[] point, Visitor visitor, NodeView currentNodeView,
+ int node, int depthOfNode) {
+ if (isLeaf(node)) {
+ currentNodeView.setCurrentNode(node, getPointIndex(node), true);
+ visitor.acceptLeaf(currentNodeView, depthOfNode);
+ } else {
+ checkArgument(isInternal(node), () -> " incomplete state " + node + " " + depthOfNode);
+ if (nodeStore.toLeft(point, node)) {
+ traversePathToLeafAndVisitNodes(point, visitor, currentNodeView, nodeStore.getLeftIndex(node),
+ depthOfNode + 1);
+ currentNodeView.updateToParent(node, nodeStore.getRightIndex(node), !visitor.isConverged());
+ } else {
+ traversePathToLeafAndVisitNodes(point, visitor, currentNodeView, nodeStore.getRightIndex(node),
+ depthOfNode + 1);
+ currentNodeView.updateToParent(node, nodeStore.getLeftIndex(node), !visitor.isConverged());
+ }
+ visitor.accept(currentNodeView, depthOfNode);
+ }
+ }
+
/**
* This is a traversal method which follows the standard traversal path (defined
* in {@link #traverse(float[], IVisitorFactory)}) but at Node in checks to see
@@ -416,17 +897,35 @@ public R traverseMulti(float[] point, IMultiVisitorFactory visitorFactory
checkNotNull(visitorFactory, "visitor must not be null");
checkState(root != Null, "this tree doesn't contain any nodes");
MultiVisitor visitor = visitorFactory.newVisitor(this, point);
- nodeStore.traverseTreeMulti(projectToTree(point), visitor, root, pointStoreView, this::liftFromTree);
+ NodeView currentNodeView = new NodeView(this, pointStoreView, root);
+ traverseTreeMulti(point, visitor, currentNodeView, root, 0);
return visitorFactory.liftResult(this, visitor.getResult());
}
- /**
- *
- * @return the mass of the tree
- */
- @Override
- public int getMass() {
- return root == Null ? 0 : nodeStore.getMass(root);
+ protected void traverseTreeMulti(float[] point, MultiVisitor visitor, NodeView currentNodeView, int node,
+ int depthOfNode) {
+ if (nodeStore.isLeaf(node)) {
+ currentNodeView.setCurrentNode(node, getPointIndex(node), false);
+ visitor.acceptLeaf(currentNodeView, depthOfNode);
+ } else {
+ checkArgument(nodeStore.isInternal(node), " incomplete state");
+ currentNodeView.setCurrentNodeOnly(node);
+ if (visitor.trigger(currentNodeView)) {
+ traverseTreeMulti(point, visitor, currentNodeView, nodeStore.getLeftIndex(node), depthOfNode + 1);
+ MultiVisitor newVisitor = visitor.newCopy();
+ currentNodeView.setCurrentNodeOnly(nodeStore.getRightIndex(node));
+ traverseTreeMulti(point, newVisitor, currentNodeView, nodeStore.getRightIndex(node), depthOfNode + 1);
+ currentNodeView.updateToParent(node, nodeStore.getLeftIndex(node), false);
+ visitor.combine(newVisitor);
+ } else if (nodeStore.toLeft(point, node)) {
+ traverseTreeMulti(point, visitor, currentNodeView, nodeStore.getLeftIndex(node), depthOfNode + 1);
+ currentNodeView.updateToParent(node, nodeStore.getRightIndex(node), false);
+ } else {
+ traverseTreeMulti(point, visitor, currentNodeView, nodeStore.getRightIndex(node), depthOfNode + 1);
+ currentNodeView.updateToParent(node, nodeStore.getLeftIndex(node), false);
+ }
+ visitor.accept(currentNodeView, depthOfNode);
+ }
}
public int getNumberOfLeaves() {
diff --git a/Java/core/src/test/java/com/amazon/randomcutforest/state/RandomCutForestMapperTest.java b/Java/core/src/test/java/com/amazon/randomcutforest/state/RandomCutForestMapperTest.java
index dd814892..6a1b7915 100644
--- a/Java/core/src/test/java/com/amazon/randomcutforest/state/RandomCutForestMapperTest.java
+++ b/Java/core/src/test/java/com/amazon/randomcutforest/state/RandomCutForestMapperTest.java
@@ -152,4 +152,30 @@ public void testRoundTripForSingleNodeForest() {
}
}
+ private static float[] generate(int input) {
+ return new float[] { (float) (20 * Math.sin(input / 10.0)), (float) (20 * Math.cos(input / 10.0)) };
+ }
+
+ @Test
+ void benchmarkMappers() {
+ long seed = new Random().nextLong();
+ System.out.println(" Seed " + seed);
+ Random random = new Random(seed);
+
+ RandomCutForest rcf = RandomCutForest.builder().dimensions(2 * 10).shingleSize(10).sampleSize(628)
+ .internalShinglingEnabled(true).randomSeed(random.nextLong()).build();
+ for (int i = 0; i < 10000; i++) {
+ rcf.update(generate(i));
+ }
+ RandomCutForestMapper mapper = new RandomCutForestMapper();
+ mapper.setSaveExecutorContextEnabled(true);
+ mapper.setSaveTreeStateEnabled(true);
+ for (int j = 0; j < 1000; j++) {
+ RandomCutForest newRCF = mapper.toModel(mapper.toState(rcf));
+ float[] test = generate(10000 + j);
+ assertEquals(newRCF.getAnomalyScore(test), rcf.getAnomalyScore(test), 1e-6);
+ rcf.update(test);
+ }
+ }
+
}
diff --git a/Java/core/src/test/java/com/amazon/randomcutforest/tree/RandomCutTreeTest.java b/Java/core/src/test/java/com/amazon/randomcutforest/tree/RandomCutTreeTest.java
index 90bb2564..505da148 100644
--- a/Java/core/src/test/java/com/amazon/randomcutforest/tree/RandomCutTreeTest.java
+++ b/Java/core/src/test/java/com/amazon/randomcutforest/tree/RandomCutTreeTest.java
@@ -94,7 +94,7 @@ public void setUp() {
assertEquals(pointStoreFloat.add(new float[] { 0, 1 }, 5), 4);
assertEquals(pointStoreFloat.add(new float[] { 0, 0 }, 6), 5);
- assertThrows(IllegalStateException.class, () -> tree.deletePoint(0, 1));
+ assertThrows(IllegalArgumentException.class, () -> tree.deletePoint(0, 1));
tree.addPoint(0, 1);
when(rng.nextDouble()).thenReturn(0.625);
@@ -118,57 +118,48 @@ public void testInitialTreeState() {
int node = tree.getRoot();
// the second double[] is intentional
IBoundingBoxView expectedBox = new BoundingBox(new float[] { -1, -1 }).getMergedBox(new float[] { 1, 1 });
- assertThat(tree.nodeStore.getBox(node), is(expectedBox));
- assertThat(tree.nodeStore.getCutDimension(node), is(1));
- assertThat(tree.nodeStore.getCutValue(node), closeTo(-0.5, EPSILON));
+ assertThat(tree.getBox(node), is(expectedBox));
+ assertThat(tree.getCutDimension(node), is(1));
+ assertThat(tree.getCutValue(node), closeTo(-0.5, EPSILON));
assertThat(tree.getMass(), is(5));
- assertArrayEquals(new double[] { -1, 2 }, toDoubleArray(tree.nodeStore.getPointSum(node)), EPSILON);
- assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true));
- assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))),
- is(new float[] { -1, -1 }));
- assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(1));
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(1L), 1);
-
- node = tree.nodeStore.getRightIndex(node);
+ assertArrayEquals(new double[] { -1, 2 }, toDoubleArray(tree.getPointSum(node)), EPSILON);
+ assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true));
+ assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { -1, -1 }));
+ assertThat(tree.getMass(tree.getLeftChild(node)), is(1));
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(1L), 1);
+
+ node = tree.getRightChild(node);
expectedBox = new BoundingBox(new float[] { -1, 0 }).getMergedBox(new BoundingBox(new float[] { 1, 1 }));
- assertThat(tree.nodeStore.getBox(node), is(expectedBox));
- assertThat(tree.nodeStore.getCutDimension(node), is(0));
- assertThat(tree.nodeStore.getCutValue(node), closeTo(0.5, EPSILON));
- assertThat(tree.nodeStore.getMass(node), is(4));
- assertArrayEquals(new double[] { 0.0, 3.0 }, toDoubleArray(tree.nodeStore.getPointSum(node)), EPSILON);
-
- assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getRightIndex(node)), is(true));
- assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))),
- is(new float[] { 1, 1 }));
- assertThat(tree.nodeStore.getMass(tree.nodeStore.getRightIndex(node)), is(1));
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))).get(2L), 1);
-
- node = tree.nodeStore.getLeftIndex(node);
+ assertThat(tree.getBox(node), is(expectedBox));
+ assertThat(tree.getCutDimension(node), is(0));
+ assertThat(tree.getCutValue(node), closeTo(0.5, EPSILON));
+ assertThat(tree.getMass(node), is(4));
+ assertArrayEquals(new double[] { 0.0, 3.0 }, toDoubleArray(tree.getPointSum(node)), EPSILON);
+
+ assertThat(tree.isLeaf(tree.getRightChild(node)), is(true));
+ assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getRightChild(node))), is(new float[] { 1, 1 }));
+ assertThat(tree.getMass(tree.getRightChild(node)), is(1));
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getRightChild(node))).get(2L), 1);
+
+ node = tree.getLeftChild(node);
expectedBox = new BoundingBox(new float[] { -1, 0 }).getMergedBox(new float[] { 0, 1 });
- assertThat(tree.nodeStore.getBox(node), is(expectedBox));
-
- assertThat(tree.nodeStore.getCutDimension(node), is(0));
- assertThat(tree.nodeStore.getCutValue(node), closeTo(-0.5, EPSILON));
- assertThat(tree.nodeStore.getMass(node), is(3));
- assertArrayEquals(new double[] { -1.0, 2.0 }, toDoubleArray(tree.nodeStore.getPointSum(node)), EPSILON);
- assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true));
- assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))),
- is(new float[] { -1, 0 }));
- assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(1));
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(3L), 1);
-
- assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getRightIndex(node)), is(true));
- assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))),
- is(new float[] { 0, 1 }));
- assertThat(tree.nodeStore.getMass(tree.nodeStore.getRightIndex(node)), is(2));
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))).get(4L), 1);
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))).get(5L), 1);
- assertThrows(IllegalStateException.class, () -> tree.deletePoint(5, 6));
+ assertThat(tree.getBox(node), is(expectedBox));
+
+ assertThat(tree.getCutDimension(node), is(0));
+ assertThat(tree.getCutValue(node), closeTo(-0.5, EPSILON));
+ assertThat(tree.getMass(node), is(3));
+ assertArrayEquals(new double[] { -1.0, 2.0 }, toDoubleArray(tree.getPointSum(node)), EPSILON);
+ assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true));
+ assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { -1, 0 }));
+ assertThat(tree.getMass(tree.getLeftChild(node)), is(1));
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(3L), 1);
+
+ assertThat(tree.isLeaf(tree.getRightChild(node)), is(true));
+ assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getRightChild(node))), is(new float[] { 0, 1 }));
+ assertThat(tree.getMass(tree.getRightChild(node)), is(2));
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getRightChild(node))).get(4L), 1);
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getRightChild(node))).get(5L), 1);
+ assertThrows(IllegalArgumentException.class, () -> tree.deletePoint(5, 6));
}
@Test
@@ -180,44 +171,37 @@ public void testDeletePointWithLeafSibling() {
int node = tree.getRoot();
IBoundingBoxView expectedBox = new BoundingBox(new float[] { -1, -1 }).getMergedBox(new float[] { 1, 1 });
- assertThat(tree.nodeStore.getBox(node), is(expectedBox));
- assertThat(tree.nodeStore.getCutDimension(node), is(1));
- assertThat(tree.nodeStore.getCutValue(node), closeTo(-0.5, EPSILON));
+ assertThat(tree.getBox(node), is(expectedBox));
+ assertThat(tree.getCutDimension(node), is(1));
+ assertThat(tree.getCutValue(node), closeTo(-0.5, EPSILON));
assertThat(tree.getMass(), is(4));
- assertArrayEquals(new double[] { 0.0, 2.0 }, toDoubleArray(tree.nodeStore.getPointSum(node)), EPSILON);
+ assertArrayEquals(new double[] { 0.0, 2.0 }, toDoubleArray(tree.getPointSum(node)), EPSILON);
- assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true));
- assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))),
- is(new float[] { -1, -1 }));
- assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(1));
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(1L), 1);
+ assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true));
+ assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { -1, -1 }));
+ assertThat(tree.getMass(tree.getLeftChild(node)), is(1));
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(1L), 1);
// sibling node moves up and bounding box recomputed
- node = tree.nodeStore.getRightIndex(node);
+ node = tree.getRightChild(node);
expectedBox = new BoundingBox(new float[] { 0, 1 }).getMergedBox(new float[] { 1, 1 });
- assertThat(tree.nodeStore.getBox(node), is(expectedBox));
- assertThat(tree.nodeStore.getCutDimension(node), is(0));
- assertThat(tree.nodeStore.getCutValue(node), closeTo(0.5, EPSILON));
- assertThat(tree.nodeStore.getMass(node), is(3));
- assertArrayEquals(new double[] { 1.0, 3.0 }, toDoubleArray(tree.nodeStore.getPointSum(node)), EPSILON);
-
- assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true));
- assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))),
- is(new float[] { 0, 1 }));
- assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(2));
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(4L), 1);
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(5L), 1);
-
- assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getRightIndex(node)), is(true));
- assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))),
- is(new float[] { 1, 1 }));
- assertThat(tree.nodeStore.getMass(tree.nodeStore.getRightIndex(node)), is(1));
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))).get(2L), 1);
+ assertThat(tree.getBox(node), is(expectedBox));
+ assertThat(tree.getCutDimension(node), is(0));
+ assertThat(tree.getCutValue(node), closeTo(0.5, EPSILON));
+ assertThat(tree.getMass(node), is(3));
+ assertArrayEquals(new double[] { 1.0, 3.0 }, toDoubleArray(tree.getPointSum(node)), EPSILON);
+
+ assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true));
+ assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { 0, 1 }));
+ assertThat(tree.getMass(tree.getLeftChild(node)), is(2));
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(4L), 1);
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(5L), 1);
+
+ assertThat(tree.isLeaf(tree.getRightChild(node)), is(true));
+ assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getRightChild(node))), is(new float[] { 1, 1 }));
+ assertThat(tree.getMass(tree.getRightChild(node)), is(1));
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getRightChild(node))).get(2L), 1);
}
@Test
@@ -228,41 +212,34 @@ public void testDeletePointWithNonLeafSibling() {
int node = tree.getRoot();
IBoundingBoxView expectedBox = new BoundingBox(new float[] { -1, -1 }).getMergedBox(new float[] { 0, 1 });
- assertThat(tree.nodeStore.getBox(node), is(expectedBox));
- assertThat(tree.nodeStore.getCutDimension(node), is(1));
- assertThat(tree.nodeStore.getCutValue(node), closeTo(-0.5, EPSILON));
+ assertThat(tree.getBox(node), is(expectedBox));
+ assertThat(tree.getCutDimension(node), is(1));
+ assertThat(tree.getCutValue(node), closeTo(-0.5, EPSILON));
assertThat(tree.getMass(), is(4));
- assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true));
- assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))),
- is(new float[] { -1, -1 }));
- assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(1));
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(1L), 1);
+ assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true));
+ assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { -1, -1 }));
+ assertThat(tree.getMass(tree.getLeftChild(node)), is(1));
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(1L), 1);
// sibling node moves up and bounding box stays the same
- node = tree.nodeStore.getRightIndex(node);
+ node = tree.getRightChild(node);
expectedBox = new BoundingBox(new float[] { -1, 0 }).getMergedBox(new float[] { 0, 1 });
- assertThat(tree.nodeStore.getBox(node), is(expectedBox));
- assertThat(tree.nodeStore.getCutDimension(node), is(0));
- assertThat(tree.nodeStore.getCutValue(node), closeTo(-0.5, EPSILON));
-
- assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true));
- assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))),
- is(new float[] { -1, 0 }));
- assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(1));
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(3L), 1);
-
- assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getRightIndex(node)), is(true));
- assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))),
- is(new float[] { 0, 1 }));
- assertThat(tree.nodeStore.getMass(tree.nodeStore.getRightIndex(node)), is(2));
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))).get(4L), 1);
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))).get(5L), 1);
+ assertThat(tree.getBox(node), is(expectedBox));
+ assertThat(tree.getCutDimension(node), is(0));
+ assertThat(tree.getCutValue(node), closeTo(-0.5, EPSILON));
+
+ assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true));
+ assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { -1, 0 }));
+ assertThat(tree.getMass(tree.getLeftChild(node)), is(1));
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(3L), 1);
+
+ assertThat(tree.isLeaf(tree.getRightChild(node)), is(true));
+ assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getRightChild(node))), is(new float[] { 0, 1 }));
+ assertThat(tree.getMass(tree.getRightChild(node)), is(2));
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getRightChild(node))).get(4L), 1);
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getRightChild(node))).get(5L), 1);
}
@Test
@@ -273,65 +250,55 @@ public void testDeletePointWithMassGreaterThan1() {
int node = tree.getRoot();
IBoundingBoxView expectedBox = new BoundingBox(new float[] { -1, -1 }).getMergedBox(new float[] { 1, 1 });
- assertThat(tree.nodeStore.getBox(node), is(expectedBox));
- assertThat(tree.nodeStore.getCutDimension(node), is(1));
- assertThat(tree.nodeStore.getCutValue(node), closeTo(-0.5, EPSILON));
+ assertThat(tree.getBox(node), is(expectedBox));
+ assertThat(tree.getCutDimension(node), is(1));
+ assertThat(tree.getCutValue(node), closeTo(-0.5, EPSILON));
assertThat(tree.getMass(), is(4));
- assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true));
- assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))),
- is(new float[] { -1, -1 }));
- assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(1));
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(1L), 1);
- assertArrayEquals(new double[] { -1.0, 1.0 }, toDoubleArray(tree.nodeStore.getPointSum(node)), EPSILON);
-
- assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true));
- assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))),
- is(new float[] { -1, -1 }));
- assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(1));
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(1L), 1);
-
- node = tree.nodeStore.getRightIndex(node);
+ assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true));
+ assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { -1, -1 }));
+ assertThat(tree.getMass(tree.getLeftChild(node)), is(1));
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(1L), 1);
+ assertArrayEquals(new double[] { -1.0, 1.0 }, toDoubleArray(tree.getPointSum(node)), EPSILON);
+
+ assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true));
+ assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { -1, -1 }));
+ assertThat(tree.getMass(tree.getLeftChild(node)), is(1));
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(1L), 1);
+
+ node = tree.getRightChild(node);
expectedBox = new BoundingBox(new float[] { -1, 0 }).getMergedBox(new float[] { 1, 1 });
- assertThat(tree.nodeStore.getBox(node), is(expectedBox));
- assertThat(tree.nodeStore.getCutDimension(node), is(0));
- assertThat(tree.nodeStore.getCutValue(node), closeTo(0.5, EPSILON));
- assertThat(tree.nodeStore.getMass(node), is(3));
+ assertThat(tree.getBox(node), is(expectedBox));
+ assertThat(tree.getCutDimension(node), is(0));
+ assertThat(tree.getCutValue(node), closeTo(0.5, EPSILON));
+ assertThat(tree.getMass(node), is(3));
- assertArrayEquals(new double[] { 0.0, 2.0 }, toDoubleArray(tree.nodeStore.getPointSum(node)), EPSILON);
+ assertArrayEquals(new double[] { 0.0, 2.0 }, toDoubleArray(tree.getPointSum(node)), EPSILON);
- assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getRightIndex(node)), is(true));
- assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))),
- is(new float[] { 1, 1 }));
- assertThat(tree.nodeStore.getMass(tree.nodeStore.getRightIndex(node)), is(1));
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))).get(2L), 1);
+ assertThat(tree.isLeaf(tree.getRightChild(node)), is(true));
+ assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getRightChild(node))), is(new float[] { 1, 1 }));
+ assertThat(tree.getMass(tree.getRightChild(node)), is(1));
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getRightChild(node))).get(2L), 1);
- node = tree.nodeStore.getLeftIndex(node);
+ node = tree.getLeftChild(node);
expectedBox = new BoundingBox(new float[] { -1, 0 }).getMergedBox(new float[] { 0, 1 });
- assertThat(tree.nodeStore.getBox(node), is(expectedBox));
- assertEquals(expectedBox.toString(), tree.nodeStore.getBox(node).toString());
- assertThat(tree.nodeStore.getCutDimension(node), is(0));
- assertThat(tree.nodeStore.getCutValue(node), closeTo(-0.5, EPSILON));
+ assertThat(tree.getBox(node), is(expectedBox));
+ assertEquals(expectedBox.toString(), tree.getBox(node).toString());
+ assertThat(tree.getCutDimension(node), is(0));
+ assertThat(tree.getCutValue(node), closeTo(-0.5, EPSILON));
assertThat(tree.getMass(), is(4));
- assertArrayEquals(new double[] { -1.0, 1.0 }, toDoubleArray(tree.nodeStore.getPointSum(node)), EPSILON);
-
- assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getLeftIndex(node)), is(true));
- assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))),
- is(new float[] { -1, 0 }));
- assertThat(tree.nodeStore.getMass(tree.nodeStore.getLeftIndex(node)), is(1));
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getLeftIndex(node))).get(3L), 1);
-
- assertThat(tree.nodeStore.isLeaf(tree.nodeStore.getRightIndex(node)), is(true));
- assertThat(tree.pointStoreView.get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))),
- is(new float[] { 0, 1 }));
- assertThat(tree.nodeStore.getMass(tree.nodeStore.getRightIndex(node)), is(1));
- assertEquals(tree.nodeStore.getSequenceMap()
- .get(tree.nodeStore.getPointIndex(tree.nodeStore.getRightIndex(node))).get(5L), 1);
+ assertArrayEquals(new double[] { -1.0, 1.0 }, toDoubleArray(tree.getPointSum(node)), EPSILON);
+
+ assertThat(tree.isLeaf(tree.getLeftChild(node)), is(true));
+ assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getLeftChild(node))), is(new float[] { -1, 0 }));
+ assertThat(tree.getMass(tree.getLeftChild(node)), is(1));
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getLeftChild(node))).get(3L), 1);
+
+ assertThat(tree.isLeaf(tree.getRightChild(node)), is(true));
+ assertThat(tree.pointStoreView.get(tree.getPointIndex(tree.getRightChild(node))), is(new float[] { 0, 1 }));
+ assertThat(tree.getMass(tree.getRightChild(node)), is(1));
+ assertEquals(tree.getSequenceMap(tree.getPointIndex(tree.getRightChild(node))).get(5L), 1);
}
@Test
@@ -392,6 +359,7 @@ public void testfloat() {
System.out.println("rangesum " + box.getRangeSum());
double factor = 1.0 - 1e-16;
System.out.println(factor);
- Cut cut = RandomCutTree.randomCut(factor, possible, box);
+ RandomCutTree tree = RandomCutTree.builder().dimension(trials).build();
+ Cut cut = tree.randomCut(factor, possible, box);
}
}
diff --git a/Java/examples/pom.xml b/Java/examples/pom.xml
index 94058604..96766042 100644
--- a/Java/examples/pom.xml
+++ b/Java/examples/pom.xml
@@ -7,7 +7,7 @@
software.amazon.randomcutforest
randomcutforest-parent
- 3.5.1
+ 3.6.0
randomcutforest-examples
diff --git a/Java/parkservices/pom.xml b/Java/parkservices/pom.xml
index 85f590b7..6977c6ef 100644
--- a/Java/parkservices/pom.xml
+++ b/Java/parkservices/pom.xml
@@ -6,7 +6,7 @@
software.amazon.randomcutforest
randomcutforest-parent
- 3.5.1
+ 3.6.0
randomcutforest-parkservices
diff --git a/Java/pom.xml b/Java/pom.xml
index bad56e22..5540e139 100644
--- a/Java/pom.xml
+++ b/Java/pom.xml
@@ -4,7 +4,7 @@
software.amazon.randomcutforest
randomcutforest-parent
- 3.5.1
+ 3.6.0
pom
software.amazon.randomcutforest:randomcutforest
diff --git a/Java/serialization/pom.xml b/Java/serialization/pom.xml
index 20c37c99..1f2a33db 100644
--- a/Java/serialization/pom.xml
+++ b/Java/serialization/pom.xml
@@ -7,7 +7,7 @@
software.amazon.randomcutforest
randomcutforest-parent
- 3.5.1
+ 3.6.0
randomcutforest-serialization
diff --git a/Java/testutils/pom.xml b/Java/testutils/pom.xml
index 467f925d..9d464baf 100644
--- a/Java/testutils/pom.xml
+++ b/Java/testutils/pom.xml
@@ -4,7 +4,7 @@
randomcutforest-parent
software.amazon.randomcutforest
- 3.5.1
+ 3.6.0
randomcutforest-testutils