Skip to content

Commit

Permalink
feat: Update jaqpotpy-inference
Browse files Browse the repository at this point in the history
With this commit we can take predictions from models of SklearnModel class.
  • Loading branch information
vassilismin committed Jul 16, 2024
1 parent 5a1c03d commit c7e4c12
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 31 deletions.
31 changes: 9 additions & 22 deletions src/handlers/predict.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,19 @@
from ..entities.prediction_request import PredictionRequestPydantic
from ..helpers import model_decoder, json_to_predreq


def model_post_handler(request: PredictionRequestPydantic):
model = model_decoder.decode(request.model['rawModel'])
data_entry_all = json_to_predreq.decode(request)
_ = model(data_entry_all)

if isinstance(model.prediction[0], list):
results = {model.Y[i]: [item[i] for item in model.prediction] for i in range(len(model.prediction[0]))}
elif isinstance(model.prediction, list):
if isinstance(model.Y, list):
results = {model.Y[0]: [item for item in model.prediction]}
else:
results = {model.Y: [item for item in model.prediction]}
data_entry_all = json_to_predreq.decode(request, model)
prediction = model.predict_onnx(data_entry_all)
if model.task == 'classification':
probabilities = model.predict_proba_onnx(data_entry_all)
else:
results = {model.Y: [item for item in model.prediction]}
probabilities = [None for _ in range(len(prediction))]

if model.doa:
results['AD'] = model.doa.IN
else:
results['AD'] = [None for _ in range(len(model.prediction))]

if model.probability:
results['Probabilities'] = [list(prob) for prob in model.probability]
else:
results['Probabilities'] = [[] for _ in range(len(model.prediction))]
results = {model.dependentFeatures[0]['name']: [str(item) for item in prediction]}
results['Probabilities'] = [str(prob) for prob in probabilities]
results['AD'] = [None for _ in range(len(prediction))]

final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]}

return final_all
return final_all
17 changes: 8 additions & 9 deletions src/helpers/json_to_predreq.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
def decode(request):
dataset = request.dataset
model = request.model

keys = [feature['key'] for feature in model['independentFeatures']]
transformed_values = [[data[key] for key in keys] for data in dataset['input']]

# TODO fix to support multiple rows
return transformed_values[0]
def decode(request, model):
df = pd.DataFrame(request.dataset['input'])
independentFeatures = request.model['independentFeatures']
smiles_cols = [feature['key'] for feature in independentFeatures if feature['featureType'] == 'SMILES']
x_cols = [feature['key'] for feature in independentFeatures if feature['featureType'] != 'SMILES']
dataset = JaqpotpyDataset(df=df, smiles_cols=smiles_cols, x_cols=x_cols,
task=model.task, featurizer=model.featurizer)
return dataset

0 comments on commit c7e4c12

Please sign in to comment.