Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(JAQPOT-386): onnx_sequence #52

Merged
merged 6 commits into from
Dec 5, 2024
Merged

feat(JAQPOT-386): onnx_sequence #52

merged 6 commits into from
Dec 5, 2024

Conversation

johnsaveus
Copy link
Contributor

No description provided.

@johnsaveus johnsaveus changed the title onnx_sequence feat: onnx_sequence Dec 2, 2024
Copy link
Member

@alarv alarv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add the jira ticket in the title pls

model_output = onnx_post_handler(
raw_model, featurizer.transform(featurizer.transform([inp["SMILES"]]))
)
predictions.append(check_model_task(model_task, target_name, model_output, inp))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check_model_task method implies that it's a boolean function that returns true or false. In here the method returns the predictions based on the task, so maybe rename the method to generate_prediction_response so it's clearer to the reader what the method does?

def torch_sequence_post_handler(request: PredictionRequest) -> PredictionResponse:
feat_config = request.model.torch_config
featurizer = _load_featurizer(feat_config)
target_name = request.model.dependent_features[0].name
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it hard to allow multiple y target names instead of always hardcoding the 1 target_name by getting the dependent_features[0] here?



def check_model_task(model_task, target_name, out, row_id):
if model_task == "BINARY_CLASSIFICATION":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_task has concrete types, in this case model_task == ModelTask.BINARY_CLASSIFICATION is better, cause you get the compiler to help you in case this changes in the future, the build will fail

def check_model_task(model_task, target_name, out, row_id):
if model_task == "BINARY_CLASSIFICATION":
return sequence_classification(target_name, out, row_id)
elif model_task == "REGRESSION":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here model_task == ModelTask.REGRESSION

return sequence_regression(target_name, out, row_id)
else:
raise ValueError(
"Only BINARY_CLASSIFICATION and REGRESSION tasks are supported"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this text may get obsolete soon, if we introduce a third task type someone has to update it. Better rename it to: Unsupported model task

@alarv alarv changed the title feat: onnx_sequence feat(JAQPOT-386): onnx_sequence Dec 5, 2024
@johnsaveus johnsaveus merged commit b49b4c9 into main Dec 5, 2024
5 checks passed
@johnsaveus johnsaveus deleted the onnx_sequence branch December 5, 2024 12:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants