Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Koukyosyumei committed Aug 18, 2023
1 parent 2f70bee commit f7842b5
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 20 deletions.
3 changes: 1 addition & 2 deletions src/aijack/collaborative/tree/a.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "secureboost/secureboost.h"
#include "xgboost/xgboost.h"

int main() { // XGBoostClassifier clf(2);
}
int main() { XGBoostClassifier clf(2); }
Binary file modified src/aijack/collaborative/tree/a.out
Binary file not shown.
2 changes: 1 addition & 1 deletion src/aijack/collaborative/tree/core/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ template <typename PartyName> struct TreeModelBase {
* @param parties The vector of parties.
* @param y The vector of ground-truth vectors
*/
virtual void fit(vector<float> &y) = 0;
virtual void fit(vector<PartyName> &parties, vector<float> &y) = 0;

/**
* @brief Function to return the predicted scores of the given data.
Expand Down
2 changes: 1 addition & 1 deletion src/aijack/collaborative/tree/secureboost/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct SecureBoostNode : Node<SecureBoostParty> {
int num_classes;

// SecureBoostNode() {}
SecureBoostNode(vector<SecureBoostParty> &parties_, vector<float> &y_,
SecureBoostNode(vector<SecureBoostParty> parties_, vector<float> &y_,
int num_classes_,
vector<vector<PaillierCipherText>> &gradient_,
vector<vector<PaillierCipherText>> &hessian_,
Expand Down
11 changes: 3 additions & 8 deletions src/aijack/collaborative/tree/secureboost/secureboost.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,14 @@ struct SecureBoostBase : TreeModelBase<SecureBoostParty> {
vector<SecureBoostTree> estimators;
vector<float> logging_loss;

vector<SecureBoostParty> &parties;

SecureBoostBase(vector<SecureBoostParty> &parties_, int num_classes_,
float subsample_cols_ = 0.8,
SecureBoostBase(int num_classes_, float subsample_cols_ = 0.8,
float min_child_weight_ = -1 *
numeric_limits<float>::infinity(),
int depth_ = 5, int min_leaf_ = 5, float learning_rate_ = 0.4,
int boosting_rounds_ = 5, float lam_ = 1.5, float gamma_ = 1,
float eps_ = 0.1, int active_party_id_ = -1,
int completelly_secure_round_ = 0, float init_value_ = 1.0,
int n_job_ = 1, bool save_loss_ = true)
: parties(parties_) {
int n_job_ = 1, bool save_loss_ = true) {
num_classes = num_classes_;
subsample_cols = subsample_cols_;
min_child_weight = min_child_weight_;
Expand Down Expand Up @@ -81,9 +77,8 @@ struct SecureBoostBase : TreeModelBase<SecureBoostParty> {
}

vector<SecureBoostTree> get_estimators() { return estimators; }
vector<SecureBoostParty> get_parties() { return parties; }

void fit(vector<float> &y) {
void fit(vector<SecureBoostParty> &parties, vector<float> &y) {
try {
if ((active_party_id < 0) || (active_party_id > parties.size())) {
throw invalid_argument("invalid active_party_id");
Expand Down
2 changes: 1 addition & 1 deletion src/aijack/collaborative/tree/xgboost/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct XGBoostNode : public Node<XGBoostParty> {
vector<float> entire_class_cnt;

// XGBoostNode() {}
XGBoostNode(vector<XGBoostParty> &parties_, vector<float> &y_,
XGBoostNode(vector<XGBoostParty> parties_, vector<float> &y_,
int num_classes_, vector<vector<float>> &gradient_,
vector<vector<float>> &hessian_, vector<int> &idxs_,
float min_child_weight_, float lam_, float gamma_, float eps_,
Expand Down
10 changes: 3 additions & 7 deletions src/aijack/collaborative/tree/xgboost/xgboost.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,19 @@ struct XGBoostBase : TreeModelBase<XGBoostParty> {

float upsilon_Y;

vector<XGBoostParty> &parties;
LossFunc *lossfunc_obj;

vector<vector<float>> init_pred;
vector<XGBoostTree> estimators;
vector<float> logging_loss;

XGBoostBase(vector<XGBoostParty> &parties_, int num_classes_,
float subsample_cols_ = 0.8,
XGBoostBase(int num_classes_, float subsample_cols_ = 0.8,
float min_child_weight_ = -1 * numeric_limits<float>::infinity(),
int depth_ = 5, int min_leaf_ = 5, float learning_rate_ = 0.4,
int boosting_rounds_ = 5, float lam_ = 1.5, float gamma_ = 1,
float eps_ = 0.1, int active_party_id_ = -1,
int completelly_secure_round_ = 0, float init_value_ = 1.0,
int n_job_ = 1, bool save_loss_ = true)
: parties(parties_) {
int n_job_ = 1, bool save_loss_ = true) {
num_classes = num_classes_;
subsample_cols = subsample_cols_;
min_child_weight = min_child_weight_;
Expand Down Expand Up @@ -80,9 +77,8 @@ struct XGBoostBase : TreeModelBase<XGBoostParty> {
}

vector<XGBoostTree> get_estimators() { return estimators; }
vector<XGBoostParty> get_parties() { return parties; }

void fit(vector<float> &y) {
void fit(vector<XGBoostParty> &parties, vector<float> &y) {
int row_count = y.size();

vector<vector<float>> base_pred;
Expand Down

0 comments on commit f7842b5

Please sign in to comment.