From da72a2a9cb7cb12e342f10344c8db0dee74b34c8 Mon Sep 17 00:00:00 2001 From: Matteo Interlandi Date: Wed, 19 Aug 2020 11:23:33 -0700 Subject: [PATCH] use long() for casting --- .../onnx/onnx_operator.html | 36 ++++--------------- .../operator_converters/sklearn/skl_sv.html | 12 ++----- .../_tree_implementations.py | 3 +- 3 files changed, 9 insertions(+), 42 deletions(-) diff --git a/doc/html/hummingbird/ml/operator_converters/onnx/onnx_operator.html b/doc/html/hummingbird/ml/operator_converters/onnx/onnx_operator.html index a4d297863..5a387b069 100644 --- a/doc/html/hummingbird/ml/operator_converters/onnx/onnx_operator.html +++ b/doc/html/hummingbird/ml/operator_converters/onnx/onnx_operator.html @@ -327,18 +327,10 @@

Ancestors

Methods

-def forward(self, x) +def forward(self, x) -> Callable[..., Any]
-

Defines the computation performed at every call.

-

Should be overridden by all subclasses.

-
-

Note

-

Although the recipe for forward pass needs to be defined within -this function, one should call the :class:Module instance afterwards -instead of this since the former takes care of running the -registered hooks while the latter silently ignores them.

-
+
Expand source code @@ -377,18 +369,10 @@

Ancestors

Methods

-def forward(self, *x) +def forward(self, *x) -> Callable[..., Any]
-

Defines the computation performed at every call.

-

Should be overridden by all subclasses.

-
-

Note

-

Although the recipe for forward pass needs to be defined within -this function, one should call the :class:Module instance afterwards -instead of this since the former takes care of running the -registered hooks while the latter silently ignores them.

-
+
Expand source code @@ -429,18 +413,10 @@

Ancestors

Methods

-def forward(self, x) +def forward(self, x) -> Callable[..., Any]
-

Defines the computation performed at every call.

-

Should be overridden by all subclasses.

-
-

Note

-

Although the recipe for forward pass needs to be defined within -this function, one should call the :class:Module instance afterwards -instead of this since the former takes care of running the -registered hooks while the latter silently ignores them.

-
+
Expand source code diff --git a/doc/html/hummingbird/ml/operator_converters/sklearn/skl_sv.html b/doc/html/hummingbird/ml/operator_converters/sklearn/skl_sv.html index 647c79f7d..51e5ea058 100644 --- a/doc/html/hummingbird/ml/operator_converters/sklearn/skl_sv.html +++ b/doc/html/hummingbird/ml/operator_converters/sklearn/skl_sv.html @@ -304,18 +304,10 @@

Ancestors

Methods

-def forward(self, x) +def forward(self, x) -> Callable[..., Any]
-

Defines the computation performed at every call.

-

Should be overridden by all subclasses.

-
-

Note

-

Although the recipe for forward pass needs to be defined within -this function, one should call the :class:Module instance afterwards -instead of this since the former takes care of running the -registered hooks while the latter silently ignores them.

-
+
Expand source code diff --git a/hummingbird/ml/operator_converters/_tree_implementations.py b/hummingbird/ml/operator_converters/_tree_implementations.py index 8821c7667..3a42d5fb6 100644 --- a/hummingbird/ml/operator_converters/_tree_implementations.py +++ b/hummingbird/ml/operator_converters/_tree_implementations.py @@ -233,8 +233,7 @@ def forward(self, x): lefts = torch.index_select(self.lefts, 0, indexes).view(-1, self.num_trees) rights = torch.index_select(self.rights, 0, indexes).view(-1, self.num_trees) - indexes = torch.where(torch.ge(feature_values, thresholds), rights, lefts) - indexes = indexes.type(torch.LongTensor) + indexes = torch.where(torch.ge(feature_values, thresholds), rights, lefts).long() indexes = indexes + self.nodes_offset indexes = indexes.view(-1)