Skip to content

Commit

Permalink
feat: dataframe,probabilities,jaqpot_row_label (#36)
Browse files Browse the repository at this point in the history
* dataframe_and_probabilities

* added_row_label

* small_bug

* remove_duplicate

---------

Co-authored-by: Alex Arvanitidis <[email protected]>
  • Loading branch information
johnsaveus and alarv authored Oct 23, 2024
1 parent 650bca8 commit e47076c
Showing 1 changed file with 34 additions and 24 deletions.
58 changes: 34 additions & 24 deletions src/handlers/predict_pyg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,26 @@


def graph_post_handler(request: PredictionRequestPydantic):

feat_config = request.extraConfig["torchConfig"]["featurizerConfig"]
featurizer = _load_featurizer(feat_config)
target_name = request.model["dependentFeatures"][0]["name"]
model_task = request.model["task"]
smiles = request.dataset["input"][0]["SMILES"]
data = featurizer.featurize(smiles)
user_input = request.dataset["input"]
raw_model = request.model["rawModel"]
preds = []
if request.model["type"] == "TORCH_ONNX":
model_output = onnx_post_handler(raw_model, data)
return check_model_task(model_task, target_name, model_output)
for inp in user_input:
model_output = onnx_post_handler(
raw_model, featurizer.featurize(inp["SMILES"])
)
preds.append(check_model_task(model_task, target_name, model_output, inp))
elif request.model["type"] == "TORCHSCRIPT":
model_output = torchscript_post_handler(raw_model, data)
return check_model_task(model_task, target_name, model_output)
for inp in user_input:
model_output = torchscript_post_handler(
raw_model, featurizer.featurize(inp["SMILES"])
)
preds.append(check_model_task(model_task, target_name, model_output, inp))
return {"predictions": preds}


def onnx_post_handler(raw_model, data):
Expand Down Expand Up @@ -60,38 +66,42 @@ def _to_numpy(tensor):


def _load_featurizer(config):

featurizer = SmilesGraphFeaturizer()
featurizer.load_dict(config)
featurizer.sort_allowable_sets()
return featurizer


def graph_regression(target_name, output):
preds = [output.squeeze().tolist()]
def graph_regression(target_name, output, inp):
pred = [output.squeeze().tolist()]
results = {}
results[target_name] = [str(pred) for pred in preds]
final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]}
return final_all
results["jaqpotMetadata"] = {"jaqpotRowId": inp["jaqpotRowId"]}
if "jaqpotRowLabel" in inp:
results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"]
results[target_name] = pred
return results


def graph_binary_classification(target_name, output):
probs = [F.sigmoid(output).squeeze().tolist()]
preds = [int(prob > 0.5) for prob in probs]
def graph_binary_classification(target_name, output, inp):
proba = F.sigmoid(output).squeeze().tolist()
pred = int(proba > 0.5)
# UI Results
results = {}
results["Probabilities"] = [str(prob) for prob in probs]
results[target_name] = [str(pred) for pred in preds]
final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]}
return final_all

results["jaqpotMetadata"] = {
"probabilities": [round((1 - proba), 3), round(proba, 3)],
"jaqpotRowId": inp["jaqpotRowId"],
}
if "jaqpotRowLabel" in inp:
results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"]
results[target_name] = pred
return results

def check_model_task(model_task, target_name, out):

def check_model_task(model_task, target_name, out, row_id):
if model_task == "BINARY_CLASSIFICATION":
return graph_binary_classification(target_name, out)
return graph_binary_classification(target_name, out, row_id)
elif model_task == "REGRESSION":
return graph_regression(target_name, out)
return graph_regression(target_name, out, row_id)
else:
raise ValueError(
"Only BINARY_CLASSIFICATION and REGRESSION tasks are supported"
Expand Down

0 comments on commit e47076c

Please sign in to comment.