diff --git a/src/handlers/predict_pyg.py b/src/handlers/predict_pyg.py index abe13c1..7683984 100644 --- a/src/handlers/predict_pyg.py +++ b/src/handlers/predict_pyg.py @@ -9,20 +9,26 @@ def graph_post_handler(request: PredictionRequestPydantic): - feat_config = request.extraConfig["torchConfig"]["featurizerConfig"] featurizer = _load_featurizer(feat_config) target_name = request.model["dependentFeatures"][0]["name"] model_task = request.model["task"] - smiles = request.dataset["input"][0]["SMILES"] - data = featurizer.featurize(smiles) + user_input = request.dataset["input"] raw_model = request.model["rawModel"] + preds = [] if request.model["type"] == "TORCH_ONNX": - model_output = onnx_post_handler(raw_model, data) - return check_model_task(model_task, target_name, model_output) + for inp in user_input: + model_output = onnx_post_handler( + raw_model, featurizer.featurize(inp["SMILES"]) + ) + preds.append(check_model_task(model_task, target_name, model_output, inp)) elif request.model["type"] == "TORCHSCRIPT": - model_output = torchscript_post_handler(raw_model, data) - return check_model_task(model_task, target_name, model_output) + for inp in user_input: + model_output = torchscript_post_handler( + raw_model, featurizer.featurize(inp["SMILES"]) + ) + preds.append(check_model_task(model_task, target_name, model_output, inp)) + return {"predictions": preds} def onnx_post_handler(raw_model, data): @@ -60,38 +66,42 @@ def _to_numpy(tensor): def _load_featurizer(config): - featurizer = SmilesGraphFeaturizer() featurizer.load_dict(config) featurizer.sort_allowable_sets() return featurizer -def graph_regression(target_name, output): - preds = [output.squeeze().tolist()] +def graph_regression(target_name, output, inp): + pred = [output.squeeze().tolist()] results = {} - results[target_name] = [str(pred) for pred in preds] - final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]} - return final_all + results["jaqpotMetadata"] = {"jaqpotRowId": inp["jaqpotRowId"]} + if "jaqpotRowLabel" in inp: + results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"] + results[target_name] = pred + return results -def graph_binary_classification(target_name, output): - probs = [F.sigmoid(output).squeeze().tolist()] - preds = [int(prob > 0.5) for prob in probs] +def graph_binary_classification(target_name, output, inp): + proba = F.sigmoid(output).squeeze().tolist() + pred = int(proba > 0.5) # UI Results results = {} - results["Probabilities"] = [str(prob) for prob in probs] - results[target_name] = [str(pred) for pred in preds] - final_all = {"predictions": [dict(zip(results, t)) for t in zip(*results.values())]} - return final_all - + results["jaqpotMetadata"] = { + "probabilities": [round((1 - proba), 3), round(proba, 3)], + "jaqpotRowId": inp["jaqpotRowId"], + } + if "jaqpotRowLabel" in inp: + results["jaqpotMetadata"]["jaqpotRowLabel"] = inp["jaqpotRowLabel"] + results[target_name] = pred + return results -def check_model_task(model_task, target_name, out): +def check_model_task(model_task, target_name, out, row_id): if model_task == "BINARY_CLASSIFICATION": - return graph_binary_classification(target_name, out) + return graph_binary_classification(target_name, out, row_id) elif model_task == "REGRESSION": - return graph_regression(target_name, out) + return graph_regression(target_name, out, row_id) else: raise ValueError( "Only BINARY_CLASSIFICATION and REGRESSION tasks are supported"