-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #404 from kaituo/log4.1
Add Python Wrapper for RCF and Fix Error Message
- Loading branch information
Showing
7 changed files
with
257 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Random Cut Forest (RCF) in Python | ||
|
||
RCF (Random Cut Forest) is implemented in Java and Rust. To use it in Python, follow these steps: | ||
|
||
## Step 1: Install JPype | ||
|
||
Install JPype to enable the interaction between Python and Java. You can find the installation instructions at [JPype Installation](https://jpype.readthedocs.io/en/latest/install.html). | ||
|
||
## Step 2: Import and Use TRCF from `python_rcf_wrapper` | ||
|
||
You need to import `TRCF` from the `python_rcf_wrapper` and call its `process` method. Below is an example Python script to demonstrate this: | ||
|
||
```python | ||
from python_rcf_wrapper.trcf_model import TRandomCutForestModel as TRCF | ||
import numpy as np | ||
|
||
# Parameters for the RCF model | ||
shingle_size = 8 | ||
dimensions = 2 | ||
num_trees = 50 | ||
output_after = 32 | ||
sample_size = 256 | ||
|
||
# Initialize the RCF model | ||
model = TRCF( | ||
rcf_dimensions=shingle_size * dimensions, | ||
shingle_size=shingle_size, | ||
num_trees=num_trees, | ||
output_after=output_after, | ||
anomaly_rate=0.001, | ||
z_factor=3, | ||
score_differencing=0.5, | ||
sample_size=sample_size | ||
) | ||
|
||
# Generate test data | ||
TEST_DATA = np.random.normal(size=(300, 2)) | ||
|
||
# Process each data point and print the RCF score and anomaly grade | ||
for point in TEST_DATA: | ||
descriptor = model.process(point) | ||
print("RCF score: {}, Anomaly grade: {}".format(descriptor.getRCFScore(), descriptor.getAnomalyGrade())) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from pathlib import Path | ||
|
||
import logging | ||
|
||
# Import JPype Module for Java imports | ||
import jpype.imports | ||
from jpype.types import * | ||
|
||
import os | ||
|
||
java_home = os.environ.get("JAVA_HOME", None) | ||
|
||
DEFAULT_JAVA_PATH = Path(__file__).parent / "lib" | ||
|
||
|
||
java_path = str(Path(os.environ.get("JAVA_LIB", DEFAULT_JAVA_PATH)) / "*") | ||
|
||
jpype.addClassPath(java_path) | ||
|
||
# Launch the JVM | ||
jpype.startJVM(convertStrings=False) | ||
|
||
logging.info("availableProcess {}".format(jpype.java.lang.Runtime.getRuntime().availableProcessors())) |
Binary file not shown.
Binary file added
BIN
+105 KB
python_rcf_wrapper/lib/randomcutforest-parkservices-4.0.0-SNAPSHOT.jar
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# Java imports | ||
from typing import List, Optional, Tuple, Any | ||
|
||
import numpy as np | ||
import logging | ||
from com.amazon.randomcutforest import RandomCutForest | ||
import jpype | ||
|
||
class RandomCutForestModel: | ||
""" | ||
Random Cut Forest Python Binding around the AWS Random Cut Forest Official Java version: | ||
https://github.com/aws/random-cut-forest-by-aws | ||
""" | ||
|
||
def __init__(self, forest: RandomCutForest = None, shingle_size: int = 8, | ||
num_trees: int = 100, random_seed: int = None, | ||
sample_size: int = 256, parallel_execution_enabled: bool = True, | ||
thread_pool_size: Optional[int] = None, lam: float=0.0001, | ||
output_after: int=256): | ||
if forest is not None: | ||
self.forest = forest | ||
else: | ||
builder = RandomCutForest.builder().numberOfTrees(num_trees). \ | ||
sampleSize(sample_size). \ | ||
dimensions(shingle_size). \ | ||
storeSequenceIndexesEnabled(True). \ | ||
centerOfMassEnabled(True). \ | ||
parallelExecutionEnabled(parallel_execution_enabled). \ | ||
timeDecay(lam). \ | ||
outputAfter(output_after) | ||
if thread_pool_size is not None: | ||
builder.threadPoolSize(thread_pool_size) | ||
|
||
if random_seed is not None: | ||
builder = builder.randomSeed(random_seed) | ||
|
||
self.forest = builder.build() | ||
|
||
def score(self, point: List[float]) -> float: | ||
""" | ||
Compute an anomaly score for the given point. | ||
Parameters | ||
---------- | ||
point: List[float] | ||
A data point with shingle size | ||
Returns | ||
------- | ||
float | ||
The anomaly score for the given point | ||
""" | ||
return self.forest.getAnomalyScore(point) | ||
|
||
def update(self, point: List[float]): | ||
""" | ||
Update the model with the data point. | ||
Parameters | ||
---------- | ||
point: List[float] | ||
Point with shingle size | ||
""" | ||
self.forest.update(point) | ||
|
||
|
||
def impute(self, point: List[float]) -> List[float]: | ||
""" | ||
Given a point with missing values, return a new point with the missing values imputed. Each tree in the forest | ||
individual produces an imputed value. For 1-dimensional points, the median imputed value is returned. For | ||
points with more than 1 dimension, the imputed point with the 25th percentile anomaly score is returned. | ||
Parameters | ||
---------- | ||
point: List[float] | ||
The point with shingle size | ||
Returns | ||
------- | ||
List[float] | ||
The imputed point. | ||
""" | ||
|
||
num_missing = np.isnan(point).sum() | ||
if num_missing == 0: | ||
return point | ||
missing_index = np.argwhere(np.isnan(point)).flatten() | ||
imputed_shingle = list(self.forest.imputeMissingValues(point, num_missing, missing_index)) | ||
return imputed_shingle | ||
|
||
def forecast(self, point: List[float]) -> float: | ||
""" | ||
Given one shingled data point, return one step forecast containing the next value. | ||
Parameters | ||
---------- | ||
point: List[float] | ||
The point with shingle size | ||
Returns | ||
------- | ||
float | ||
Forecast value of next timestamp. | ||
""" | ||
val = list(self.forest.extrapolateBasic(point, 1, 1, False, 0))[0] | ||
return val | ||
|
||
@property | ||
def shingle_size(self) -> int: | ||
""" | ||
Returns | ||
------- | ||
int | ||
Shingle size of random cut trees. | ||
""" | ||
return self.forest.getDimensions() | ||
|
||
def get_attribution(self, point: List[float]) -> Tuple[List[float], List[float]]: | ||
try: | ||
attribution_di_vec: Any = self.forest.getAnomalyAttribution(point) | ||
low: List[float] = list(attribution_di_vec.low) | ||
high: List[float] = list(attribution_di_vec.high) | ||
return low, high | ||
except jpype.JException as exception: | ||
logging.info("Error when loading the model: %s", exception.message()) | ||
logging.info("Stack track: %s", exception.stacktrace()) | ||
# Throw it back | ||
raise exception | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Java imports | ||
from typing import List, Optional, Tuple, Any | ||
|
||
import numpy as np | ||
import logging | ||
from com.amazon.randomcutforest.parkservices import ThresholdedRandomCutForest | ||
from com.amazon.randomcutforest.config import Precision | ||
from com.amazon.randomcutforest.parkservices import AnomalyDescriptor | ||
from com.amazon.randomcutforest.config import TransformMethod | ||
import jpype | ||
|
||
class TRandomCutForestModel: | ||
""" | ||
Random Cut Forest Python Binding around the AWS Random Cut Forest Official Java version: | ||
https://github.com/aws/random-cut-forest-by-aws | ||
""" | ||
|
||
def __init__(self, rcf_dimensions, shingle_size, num_trees: int = 30, output_after: int=256, anomaly_rate=0.005, | ||
z_factor=2.5, score_differencing=0.5, ignore_delta_threshold=0, sample_size=256): | ||
self.forest = (ThresholdedRandomCutForest | ||
.builder() | ||
.dimensions(rcf_dimensions) | ||
.sampleSize(sample_size) | ||
.numberOfTrees(num_trees) | ||
.timeDecay(0.0001) | ||
.initialAcceptFraction(output_after*1.0/sample_size) | ||
.parallelExecutionEnabled(True) | ||
.compact(True) | ||
.precision(Precision.FLOAT_32) | ||
.boundingBoxCacheFraction(1) | ||
.shingleSize(shingle_size) | ||
.anomalyRate(anomaly_rate) | ||
.outputAfter(output_after) | ||
.internalShinglingEnabled(True) | ||
.transformMethod(TransformMethod.NORMALIZE) | ||
.alertOnce(True) | ||
.autoAdjust(True) | ||
.build()) | ||
self.forest.setZfactor(z_factor) | ||
|
||
def process(self, point: List[float]) -> AnomalyDescriptor: | ||
""" | ||
a single call that prepreprocesses data, compute score/grade and updates | ||
state. | ||
Parameters | ||
---------- | ||
point: List[float] | ||
A data point with shingle size | ||
Returns | ||
------- | ||
AnomalyDescriptor | ||
Encapsulate detailed information about anomalies detected by RCF model. This class stores various attributes | ||
related to an anomaly, such as confidence levels, attribution scores, and expected values. | ||
""" | ||
return self.forest.process(point, 0) |