Skip to content

Commit

Permalink
feat(JAQPOT-186): Update jaqpotpy-inference (#9)
Browse files Browse the repository at this point in the history
* feat: Update jaqpotpy-inference

With this commit we can take predictions from models of SklearnModel class.

* fix: declare imports

* feat: support models with multiple output variables

* chore: simplify the code of model_post_handler
  • Loading branch information
vassilismin authored Jul 22, 2024
1 parent 5a1c03d commit 468dc67
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 29 deletions.
37 changes: 16 additions & 21 deletions src/handlers/predict.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,27 @@
from ..entities.prediction_request import PredictionRequestPydantic
from ..helpers import model_decoder, json_to_predreq


def model_post_handler(request: PredictionRequestPydantic):
model = model_decoder.decode(request.model['rawModel'])
data_entry_all = json_to_predreq.decode(request)
_ = model(data_entry_all)

if isinstance(model.prediction[0], list):
results = {model.Y[i]: [item[i] for item in model.prediction] for i in range(len(model.prediction[0]))}
elif isinstance(model.prediction, list):
if isinstance(model.Y, list):
results = {model.Y[0]: [item for item in model.prediction]}
else:
results = {model.Y: [item for item in model.prediction]}
data_entry_all = json_to_predreq.decode(request, model)
prediction = model.predict_onnx(data_entry_all)
if model.task == 'classification':
probabilities = model.predict_proba_onnx(data_entry_all)
else:
results = {model.Y: [item for item in model.prediction]}
probabilities = [None for _ in range(len(prediction))]

if model.doa:
results['AD'] = model.doa.IN
else:
results['AD'] = [None for _ in range(len(model.prediction))]
results = {}
for i, feature in enumerate(model.dependentFeatures):
key = feature['key']
if len(model.dependentFeatures) == 1:
values = [str(item) for item in prediction]
else:
values = [str(item) for item in prediction[:, i]]
results[key] = values

if model.probability:
results['Probabilities'] = [list(prob) for prob in model.probability]
else:
results['Probabilities'] = [[] for _ in range(len(model.prediction))]
results['Probabilities'] = [str(prob) for prob in probabilities]
results['AD'] = [None for _ in range(len(prediction))]

final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]}

return final_all
return final_all
18 changes: 10 additions & 8 deletions src/helpers/json_to_predreq.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
def decode(request):
dataset = request.dataset
model = request.model
from jaqpotpy.datasets import JaqpotpyDataset
import pandas as pd

keys = [feature['key'] for feature in model['independentFeatures']]
transformed_values = [[data[key] for key in keys] for data in dataset['input']]

# TODO fix to support multiple rows
return transformed_values[0]
def decode(request, model):
df = pd.DataFrame(request.dataset['input'])
independentFeatures = request.model['independentFeatures']
smiles_cols = [feature['key'] for feature in independentFeatures if feature['featureType'] == 'SMILES']
x_cols = [feature['key'] for feature in independentFeatures if feature['featureType'] != 'SMILES']
dataset = JaqpotpyDataset(df=df, smiles_cols=smiles_cols, x_cols=x_cols,
task=model.task, featurizer=model.featurizer)
return dataset

0 comments on commit 468dc67

Please sign in to comment.