diff --git a/src/helpers/predict_methods.py b/src/helpers/predict_methods.py index 13ea311..beaba52 100644 --- a/src/helpers/predict_methods.py +++ b/src/helpers/predict_methods.py @@ -12,10 +12,6 @@ def predict_onnx(model, dataset: JaqpotpyDataset, request): np_dtype = onnx.helper.tensor_dtype_to_np_dtype( independent_feature.type.tensor_type.elem_type ) - # if np_dtype == "float64": - # np_dtype = "float32" - # elif np_dtype == ["int64", "uint64"]: - # np_dtype = "int32" if len(model.graph.input) == 1: input_feed[independent_feature.name] = dataset.X.values.astype(np_dtype) diff --git a/src/helpers/recreate_featurizer.py b/src/helpers/recreate_featurizer.py index d73194a..bf4f587 100644 --- a/src/helpers/recreate_featurizer.py +++ b/src/helpers/recreate_featurizer.py @@ -1,13 +1,15 @@ import numpy as np -def recreate_featurizer(featurizer_name, featurizer_config): - featurizer_class = getattr(__import__( - 'jaqpotpy.descriptors.molecular', - fromlist=[featurizer_name]), featurizer_name) + +def recreate_featurizer(featurizer_name, featurizer_config): + featurizer_class = getattr( + __import__("jaqpotpy.descriptors.molecular", fromlist=[featurizer_name]), + featurizer_name, + ) featurizer = featurizer_class() for attr, value in featurizer_config.items(): - if attr !='class': # skip the class attribute + if attr != "class": # skip the class attribute if isinstance(value, list): - value = np.array(value['value']) - setattr(featurizer, attr, value['value']) - return featurizer \ No newline at end of file + value = np.array(value["value"]) + setattr(featurizer, attr, value) + return featurizer