Skip to content

Commit

Permalink
test tree_trav
Browse files Browse the repository at this point in the history
  • Loading branch information
interesaaat committed Aug 18, 2020
1 parent c5aa8ae commit ade75de
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 283 deletions.
8 changes: 4 additions & 4 deletions hummingbird/ml/_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def convert(topology, backend, device, extra_config={}):
# For the moment only tree_trav is enabled for pytorch <= 1.6.0
# if vers < allowed_min:
extra_config[constants.TREE_IMPLEMENTATION] = "tree_trav"
elif backend == tvm_backend:
# The TVM frontend for PyTorch currently don't support index_select
# https://github.com/apache/incubator-tvm/issues/6282
extra_config[constants.TREE_IMPLEMENTATION] = "gemm"
# elif backend == tvm_backend:
# The TVM frontend for PyTorch currently don't support index_select
# https://github.com/apache/incubator-tvm/issues/6282
# extra_config[constants.TREE_IMPLEMENTATION] = "gemm"

operator_map[operator.full_name] = converter(operator, device, extra_config)
except ValueError:
Expand Down
3 changes: 2 additions & 1 deletion hummingbird/ml/operator_converters/_tree_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def forward(self, x):
lefts = torch.index_select(self.lefts, 0, indexes).view(-1, self.num_trees)
rights = torch.index_select(self.rights, 0, indexes).view(-1, self.num_trees)

indexes = torch.where(torch.ge(feature_values, thresholds), rights, lefts).long()
indexes = torch.where(torch.ge(feature_values, thresholds), rights, lefts)
indexes = indexes.type(torch.LongTensor)
indexes = indexes + self.nodes_offset
indexes = indexes.view(-1)

Expand Down
Loading

0 comments on commit ade75de

Please sign in to comment.