diff --git a/src/io/tree.cpp b/src/io/tree.cpp index 6cd0d50b4da7..39b5c23d4d1c 100644 --- a/src/io/tree.cpp +++ b/src/io/tree.cpp @@ -523,35 +523,32 @@ std::string Tree::NumericalDecisionIfElse(int node) const { str_buf << std::setprecision(std::numeric_limits::digits10 + 2); uint8_t missing_type = GetMissingType(decision_type_[node]); bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask); - if (missing_type == MissingType::None - || (missing_type == MissingType::Zero && default_left && kZeroThreshold < threshold_[node])) { - str_buf << "if (fval <= " << threshold_[node] << ") {"; - } else if (missing_type == MissingType::Zero) { + if (missing_type != MissingType::NaN) { + str_buf << "if (std::isnan(fval)) fval = 0.0;"; + } + if (missing_type == MissingType::Zero) { if (default_left) { - str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {"; + str_buf << "if (Tree::IsZero(fval)) {"; } else { - str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {"; + str_buf << "if (!Tree::IsZero(fval)) {"; } - } else { + } else if (missing_type == MissingType::NaN) { if (default_left) { - str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {"; + str_buf << "if (std::isnan(fval)) {"; } else { - str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {"; + str_buf << "if (!std::isnan(fval)) {"; } + } else { + str_buf << "if (fval <= " << threshold_[node] << ") {"; } return str_buf.str(); } std::string Tree::CategoricalDecisionIfElse(int node) const { - uint8_t missing_type = GetMissingType(decision_type_[node]); std::stringstream str_buf; Common::C_stringstream(str_buf); - if (missing_type == MissingType::NaN) { - str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast(fval); }"; - } else { - str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast(fval); }"; - } int cat_idx = static_cast(threshold_[node]); + str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast(fval); }"; str_buf << "if (int_fval >= 0 && int_fval < 32 * ("; str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx]; str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx];