Skip to content

Commit

Permalink
fix: bug in recreating featurizers (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
vassilismin authored Oct 8, 2024
1 parent 367cd6c commit fa455e8
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
4 changes: 0 additions & 4 deletions src/helpers/predict_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ def predict_onnx(model, dataset: JaqpotpyDataset, request):
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(
independent_feature.type.tensor_type.elem_type
)
# if np_dtype == "float64":
# np_dtype = "float32"
# elif np_dtype == ["int64", "uint64"]:
# np_dtype = "int32"

if len(model.graph.input) == 1:
input_feed[independent_feature.name] = dataset.X.values.astype(np_dtype)
Expand Down
18 changes: 10 additions & 8 deletions src/helpers/recreate_featurizer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import numpy as np

def recreate_featurizer(featurizer_name, featurizer_config):
featurizer_class = getattr(__import__(
'jaqpotpy.descriptors.molecular',
fromlist=[featurizer_name]), featurizer_name)

def recreate_featurizer(featurizer_name, featurizer_config):
featurizer_class = getattr(
__import__("jaqpotpy.descriptors.molecular", fromlist=[featurizer_name]),
featurizer_name,
)
featurizer = featurizer_class()
for attr, value in featurizer_config.items():
if attr !='class': # skip the class attribute
if attr != "class": # skip the class attribute
if isinstance(value, list):
value = np.array(value['value'])
setattr(featurizer, attr, value['value'])
return featurizer
value = np.array(value["value"])
setattr(featurizer, attr, value)
return featurizer

0 comments on commit fa455e8

Please sign in to comment.