Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
johnsaveus committed Dec 5, 2024
1 parent 5f3ee34 commit d85af6a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/handlers/predict_torch_geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import io
import numpy as np
from src.helpers.torch_utils import to_numpy, check_model_task
from src.helpers.torch_utils import to_numpy, generate_prediction_response
from jaqpotpy.api.openapi import ModelType, PredictionRequest, PredictionResponse
from jaqpotpy.descriptors.graph.graph_featurizer import SmilesGraphFeaturizer
from jaqpotpy.api.openapi.models.model_task import ModelTask
Expand All @@ -23,15 +23,15 @@ def torch_geometric_post_handler(request: PredictionRequest) -> PredictionRespon
raw_model, featurizer.featurize(inp["SMILES"])
)
predictions.append(
check_model_task(model_task, target_name, model_output, inp)
generate_prediction_response(model_task, target_name, model_output, inp)
)
elif request.model.type == ModelType.TORCHSCRIPT:
for inp in user_input:
model_output = torchscript_post_handler(
raw_model, featurizer.featurize(inp["SMILES"])
)
predictions.append(
check_model_task(model_task, target_name, model_output, inp)
generate_prediction_response(model_task, target_name, model_output, inp)
)
return PredictionResponse(predictions=predictions)

Expand Down
6 changes: 4 additions & 2 deletions src/handlers/predict_torch_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import io
import numpy as np
import torch.nn.functional as f
from src.helpers.torch_utils import to_numpy, check_model_task
from src.helpers.torch_utils import to_numpy, generate_prediction_response
from jaqpotpy.descriptors.tokenizer import SmilesVectorizer
from jaqpotpy.api.openapi import ModelType, PredictionRequest, PredictionResponse
from jaqpotpy.descriptors.graph.graph_featurizer import SmilesGraphFeaturizer
Expand All @@ -22,7 +22,9 @@ def torch_sequence_post_handler(request: PredictionRequest) -> PredictionRespons
model_output = onnx_post_handler(
raw_model, featurizer.transform(featurizer.transform([inp["SMILES"]]))
)
predictions.append(check_model_task(model_task, target_name, model_output, inp))
predictions.append(
generate_prediction_response(model_task, target_name, model_output, inp)
)
return PredictionResponse(predictions=predictions)


Expand Down
2 changes: 1 addition & 1 deletion src/helpers/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def torch_regression(target_name, output, inp):
return results


def check_model_task(model_task, target_name, out, row_id):
def generate_prediction_response(model_task, target_name, out, row_id):
if model_task == ModelTask.BINARY_CLASSIFICATION:
return torch_binary_classification(target_name, out, row_id)
elif model_task == ModelTask.REGRESSION:
Expand Down

0 comments on commit d85af6a

Please sign in to comment.