Skip to content

Commit

Permalink
Merge branch 'main' into selected_features_bug
Browse files Browse the repository at this point in the history
  • Loading branch information
johnsaveus authored Dec 4, 2024
2 parents 12f57fd + 2105594 commit 428acf7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ fastapi==0.111.0
pydantic==2.7.1
uvicorn==0.29.0
starlette~=0.37.2
jaqpotpy==6.17.0
jaqpotpy==6.18.2
pre-commit==4.0.1
ruff==0.6.3
27 changes: 25 additions & 2 deletions src/helpers/predict_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
from onnxruntime import InferenceSession
from jaqpotpy.datasets import JaqpotpyDataset
from src.helpers.recreate_preprocessor import recreate_preprocessor
from jaqpotpy.doa import Leverage, BoundingBox, MeanVar
from jaqpotpy.doa import (
Leverage,
BoundingBox,
MeanVar,
Mahalanobis,
KernelBased,
CityBlock,
)


def calculate_doas(input_feed, request):
Expand All @@ -18,7 +25,6 @@ def calculate_doas(input_feed, request):
data instance. The keys in the dictionary are the names of the DoA methods used, and
the values are the corresponding DoA predictions.
"""

doas_results = []
input_df = pd.DataFrame(input_feed["input"])
for _, data_instance in input_df.iterrows():
Expand All @@ -34,6 +40,23 @@ def calculate_doas(input_feed, request):
elif doa_data.method == "MEAN_VAR":
doa_method = MeanVar()
doa_method.bounds = doa_data.data["bounds"]
elif doa_data.method == "MAHALANOBIS":
doa_method = Mahalanobis()
doa_method._mean_vector = doa_data.data["meanVector"]
doa_method._inv_cov_matrix = doa_data.data["invCovMatrix"]
doa_method._threshold = doa_data.data["threshold"]
elif doa_data.method == "KERNEL_BASED":
doa_method = KernelBased()
doa_method._sigma = doa_data.data["sigma"]
doa_method._gamma = doa_data.data.get("gamma", None)
doa_method._threshold = doa_data.data["threshold"]
doa_method._kernel_type = doa_data.data["kernelType"]
doa_method._data = doa_data.data["dataPoints"]
elif doa_data.method == "CITY_BLOCK":
doa_method = CityBlock()
doa_method._mean_vector = doa_data.data["meanVector"]
doa_method._threshold = doa_data.data["threshold"]

doa_instance_prediction[doa_method.__name__] = doa_method.predict(
pd.DataFrame(data_instance.values.reshape(1, -1))
)[0]
Expand Down

0 comments on commit 428acf7

Please sign in to comment.