Skip to content

Commit

Permalink
Correctly implement the decision_function_shape parameter.
Browse files Browse the repository at this point in the history
  • Loading branch information
breyerml committed Nov 7, 2023
1 parent f7f56bf commit ac1ebee
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions bindings/Python/sklearn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ struct svc {

std::optional<real_type> epsilon{};
std::optional<long long> max_iter{};
plssvm::classification_type classification{ plssvm::classification_type::oao };

std::unique_ptr<plssvm::csvm> svm_{ plssvm::make_csvm() };
std::unique_ptr<data_set_type> data_{};
Expand All @@ -51,7 +52,7 @@ struct svc {

void parse_provided_params(svc &self, const py::kwargs &args) {
// check keyword arguments
check_kwargs_for_correctness(args, { "C", "kernel", "degree", "gamma", "coef0", "shrinking", "probability", "tol", "cache_size", "class_weight", "verbose", "max_iter", "decision_function_shape", "break_ties", "random_state" });
check_kwargs_for_correctness(args, { "C", "kernel", "degree", "gamma", "coef0", "shrinking", "probability", "tol", "cache_size", "class_weight", "verbose", "max_iter", "decision_function_shape", "break_ties", "random_state", "classification" });

if (args.contains("C")) {
self.svm_->set_params(plssvm::cost = args["C"].cast<typename svc::real_type>());
Expand Down Expand Up @@ -101,7 +102,14 @@ void parse_provided_params(svc &self, const py::kwargs &args) {
self.max_iter = args["max_iter"].cast<long long>();
}
if (args.contains("decision_function_shape")) {
throw py::attribute_error{ "The 'decision_function_shape' parameter for a call to the 'SVC' constructor is not implemented yet!" };
const std::string &dfs = args["decision_function_shape"].cast<std::string>();
if (dfs == "ovo") {
self.classification = plssvm::classification_type::oao;
} else if (dfs == "ovr") {
self.classification = plssvm::classification_type::oaa;
} else {
throw py::value_error{ fmt::format("decision_function_shape must be either 'ovr' or 'ovo', got {}.", dfs) };
}
}
if (args.contains("break_ties")) {
throw py::attribute_error{ "The 'break_ties' parameter for a call to the 'SVC' constructor is not implemented yet!" };
Expand All @@ -115,20 +123,20 @@ void fit(svc &self) {
// fit the model using potentially provided keyword arguments
if (self.epsilon.has_value() && self.max_iter.has_value()) {
self.model_ = std::make_unique<typename svc::model_type>(self.svm_->fit(*self.data_,
plssvm::classification = plssvm::classification_type::oao,
plssvm::classification = self.classification,
plssvm::epsilon = self.epsilon.value(),
plssvm::max_iter = self.max_iter.value()));
} else if (self.epsilon.has_value()) {
self.model_ = std::make_unique<typename svc::model_type>(self.svm_->fit(*self.data_,
plssvm::classification = plssvm::classification_type::oao,
plssvm::classification = self.classification,
plssvm::epsilon = self.epsilon.value()));
} else if (self.max_iter.has_value()) {
self.model_ = std::make_unique<typename svc::model_type>(self.svm_->fit(*self.data_,
plssvm::classification = plssvm::classification_type::oao,
plssvm::classification = self.classification,
plssvm::max_iter = self.max_iter.value()));
} else {
self.model_ = std::make_unique<typename svc::model_type>(self.svm_->fit(*self.data_,
plssvm::classification = plssvm::classification_type::oao));
plssvm::classification = self.classification));
}
}

Expand Down

0 comments on commit ac1ebee

Please sign in to comment.