diff --git a/README.md b/README.md index bd6621c7..1f0c4bc6 100644 --- a/README.md +++ b/README.md @@ -181,12 +181,12 @@ You can also find more examples in our tutorials and documentation. | Attack | Poisoning | [History Attack](https://arxiv.org/abs/2203.08669), [Label Flip](https://arxiv.org/abs/2203.08669), [MAPF](https://arxiv.org/abs/2203.08669), [SVM Poisoning](https://arxiv.org/abs/1206.6389) | | Attack | Backdoor | [DBA](https://openreview.net/forum?id=rkgyS0VFvr) | | Attack | Free-Rider | [Delta-Weight](https://arxiv.org/pdf/1911.12560.pdf) | -| Attack | Evasion | [Gradient-Descent Attack](https://arxiv.org/abs/1708.06131) | +| Attack | Evasion | [Gradient-Descent Attack](https://arxiv.org/abs/1708.06131), [FGSM](https://arxiv.org/abs/1412.6572) | | Attack | Membership Inference | [Shaddow Attack](https://arxiv.org/abs/1610.05820) | | Defense | Homomorphic Encryption | [Paiilier](https://link.springer.com/chapter/10.1007/3-540-48910-X_16) | | Defense | Differential Privacy | [DPSGD](https://arxiv.org/abs/1607.00133), [AdaDPS](https://arxiv.org/pdf/2202.05963.pdf) | | Defense | Anonymization | [Mondrian](https://ieeexplore.ieee.org/document/1617393) | -| Defense | Certified Robustness | [PixelDP](https://arxiv.org/abs/1802.03471v4) | +| Defense | Robust Training | [PixelDP](https://arxiv.org/abs/1802.03471v4), [Cost-Aware Robust Tree Ensemble](https://arxiv.org/abs/1912.01149) | | Defense | Debugging | [Model Assertions](https://cs.stanford.edu/~matei/papers/2019/debugml_model_assertions.pdf), [Rain](https://arxiv.org/abs/2004.05722), [Neuron Coverage](https://dl.acm.org/doi/abs/10.1145/3132747.3132785) | | Defense | Others | [Soteria](https://openaccess.thecvf.com/content/CVPR2021/papers/Sun_Soteria_Provable_Defense_Against_Privacy_Leakage_in_Federated_Learning_From_CVPR_2021_paper.pdf), [FoolsGold](https://arxiv.org/abs/1808.04866), [MID](https://arxiv.org/abs/2009.05241), [Sparse Gradient](https://aclanthology.org/D17-1045/) | diff --git a/docs/source/aijack.attack.evasion.rst b/docs/source/aijack.attack.evasion.rst index 62fc430d..5c0406a5 100644 --- a/docs/source/aijack.attack.evasion.rst +++ b/docs/source/aijack.attack.evasion.rst @@ -12,6 +12,14 @@ aijack.attack.evasion.evasion\_attack module :undoc-members: :show-inheritance: +aijack.attack.evasion.fgsm module +--------------------------------- + +.. automodule:: aijack.attack.evasion.fgsm + :members: + :undoc-members: + :show-inheritance: + Module contents --------------- diff --git a/docs/source/aijack.defense.crobustness.rst b/docs/source/aijack.defense.crobustness.rst new file mode 100644 index 00000000..6f847f37 --- /dev/null +++ b/docs/source/aijack.defense.crobustness.rst @@ -0,0 +1,21 @@ +aijack.defense.crobustness package +================================== + +Submodules +---------- + +aijack.defense.crobustness.pixeldp module +----------------------------------------- + +.. automodule:: aijack.defense.crobustness.pixeldp + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: aijack.defense.crobustness + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/aijack.defense.rst b/docs/source/aijack.defense.rst index 6f48850e..b6ea6459 100644 --- a/docs/source/aijack.defense.rst +++ b/docs/source/aijack.defense.rst @@ -7,6 +7,7 @@ Subpackages .. toctree:: :maxdepth: 4 + aijack.defense.crobustness aijack.defense.debugging aijack.defense.dp aijack.defense.foolsgold diff --git a/docs/source/notebooks/aijack_transferbility_and_robustness.ipynb b/docs/source/notebooks/aijack_transferbility_and_robustness.ipynb new file mode 100644 index 00000000..456fc9e7 --- /dev/null +++ b/docs/source/notebooks/aijack_transferbility_and_robustness.ipynb @@ -0,0 +1,611 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Exploring Adversarial Example Transferability and Robust Tree Models\n", + "\n", + "Welcome to this tutorial, where we delve into the intriguing concept of adversarial example transferability. This phenomenon highlights that adversarial examples crafted to exploit one model's vulnerabilities can surprisingly fool other models as well. Our experiments on the MNIST dataset will uncover that adversarial examples ([1]) designed to attack a neural network can also confound a tree ensemble model.\n", + "\n", + "Furthermore, we will present compelling evidence that Cost-Aware Robust Tree Ensemble ([2]) can effectively counteract such evasion attacks by incorporating domain-specific knowledge during training.\n", + "\n", + "With the assistance of AIJack, you now have the convenient opportunity to evaluate these cutting-edge techniques effortlessly.\n", + "\n", + "```\n", + "[1] Goodfellow, Ian J., Jonathon Shlens, and Christian Szegedy. \"Explaining and harnessing adversarial examples.\" arXiv preprint arXiv:1412.6572 (2014).\n", + "[2] Chen, Yizheng, et al. \"{Cost-Aware} Robust Tree Ensembles for Security Applications.\" 30th USENIX Security Symposium (USENIX Security 21). 2021.\n", + "```" + ], + "metadata": { + "id": "pUKTRr8zekCE" + } + }, + { + "cell_type": "code", + "source": [ + "import cv2\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "import torchvision\n", + "import torchvision.transforms as transforms\n", + "from torch.utils.data import TensorDataset\n", + "from matplotlib import pyplot as plt\n", + "from sklearn.metrics import accuracy_score\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "from aijack.attack.evasion import FGSMAttacker\n", + "from aijack.collaborative.tree import (\n", + " XGBoostClassifierAPI,\n", + " XGBoostClient,\n", + ")\n", + "from aijack.utils import NumpyDataset\n", + "\n", + "BASE = \"data/\"\n", + "torch.manual_seed(42)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Uu73jVt5iePg", + "outputId": "74db5d62-4d75-45eb-d4b3-815b28d79384" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "" + ] + }, + "metadata": {}, + "execution_count": 2 + } + ] + }, + { + "cell_type": "code", + "source": [ + "mnist_dataset_train = torchvision.datasets.MNIST(root=\"\", train=True, download=True)\n", + "\n", + "transform = transforms.Compose(\n", + " [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]\n", + ")\n", + "\n", + "X = mnist_dataset_train.train_data.numpy()\n", + "y = mnist_dataset_train.train_labels.numpy()\n", + "X_train, X_test, y_train, y_test = train_test_split(\n", + " X, y, test_size=0.33, random_state=42, shuffle=True\n", + ")\n", + "\n", + "X_train = X_train[:2000]\n", + "y_train = y_train[:2000]\n", + "X_test = X_test[:1000]\n", + "y_test = y_test[:1000]\n", + "\n", + "train_dataset = NumpyDataset(\n", + " X_train,\n", + " y_train,\n", + " transform=transform,\n", + ")\n", + "train_dataloader = torch.utils.data.DataLoader(\n", + " train_dataset, batch_size=16, shuffle=True, num_workers=2\n", + ")\n", + "\n", + "test_dataset = NumpyDataset(\n", + " X_test,\n", + " y_test,\n", + " transform=transform,\n", + ")\n", + "test_dataloader = torch.utils.data.DataLoader(\n", + " test_dataset, batch_size=16, shuffle=True, num_workers=2\n", + ")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "d2E7oZx_KuZU", + "outputId": "8a82fbea-404e-40c0-f303-e37a26b30f60" + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST/raw/train-images-idx3-ubyte.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 9912422/9912422 [00:00<00:00, 66438329.26it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Extracting MNIST/raw/train-images-idx3-ubyte.gz to MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST/raw/train-labels-idx1-ubyte.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 28881/28881 [00:00<00:00, 64331223.49it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Extracting MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST/raw/t10k-images-idx3-ubyte.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 1648877/1648877 [00:00<00:00, 22731323.10it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Extracting MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST/raw/t10k-labels-idx1-ubyte.gz\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 4542/4542 [00:00<00:00, 5683331.97it/s]" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Extracting MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/raw\n", + "\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\n", + "/usr/local/lib/python3.10/dist-packages/torchvision/datasets/mnist.py:75: UserWarning: train_data has been renamed data\n", + " warnings.warn(\"train_data has been renamed data\")\n", + "/usr/local/lib/python3.10/dist-packages/torchvision/datasets/mnist.py:65: UserWarning: train_labels has been renamed targets\n", + " warnings.warn(\"train_labels has been renamed targets\")\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "class Net(nn.Module):\n", + " def __init__(self):\n", + " super(Net, self).__init__()\n", + " self.fla = nn.Flatten()\n", + " self.fc1 = nn.Linear(28 * 28, 10)\n", + " # self.fc2 = nn.Linear(100, 10)\n", + "\n", + " def forward(self, x):\n", + " x = self.fla(x)\n", + " x = self.fc1(x)\n", + " # x = torch.relu(x)\n", + " # x = self.fc2(x)\n", + " # x = F.softmax(x, dim=1)\n", + " return x" + ], + "metadata": { + "id": "T8lJ0X7SjHFB" + }, + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "net = Net()\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.SGD(net.parameters(), lr=0.003, momentum=0.9)" + ], + "metadata": { + "id": "Z-MoJqPXjHqq" + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "for epoch in range(30): # loop over the dataset multiple times\n", + " running_loss = 0\n", + " data_size = 0\n", + " for i, data in enumerate(train_dataloader, 0):\n", + " # get the inputs; data is a list of [inputs, labels]\n", + " inputs, labels = data\n", + "\n", + " # zero the parameter gradients\n", + " optimizer.zero_grad()\n", + "\n", + " # forward + backward + optimize\n", + " outputs = net(inputs)\n", + " loss = criterion(outputs, labels.to(torch.int64))\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " running_loss += loss.item()\n", + " data_size += inputs.shape[0]\n", + "\n", + " print(f\"epoch {epoch}: loss is {running_loss/data_size}\")\n", + "\n", + "\n", + "in_preds = []\n", + "in_label = []\n", + "with torch.no_grad():\n", + " for data in test_dataloader:\n", + " inputs, labels = data\n", + " outputs = net(inputs)\n", + " in_preds.append(outputs)\n", + " in_label.append(labels)\n", + " in_preds = torch.cat(in_preds)\n", + " in_label = torch.cat(in_label)\n", + "print(\n", + " \"\\nTest Accuracy is: \",\n", + " accuracy_score(np.array(torch.argmax(in_preds, axis=1)), np.array(in_label)),\n", + ")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "k1Qi44IdjRiU", + "outputId": "e2f841d0-69d3-4821-8692-6420afe1e64e" + }, + "execution_count": 6, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 0: loss is 0.06138062975555658\n", + "epoch 1: loss is 0.03086455798149109\n", + "epoch 2: loss is 0.026004610773175955\n", + "epoch 3: loss is 0.024114671636372806\n", + "epoch 4: loss is 0.02078699535317719\n", + "epoch 5: loss is 0.019672507184557618\n", + "epoch 6: loss is 0.018815230432897807\n", + "epoch 7: loss is 0.01698792629688978\n", + "epoch 8: loss is 0.01645607683621347\n", + "epoch 9: loss is 0.01524037384428084\n", + "epoch 10: loss is 0.014240541556850075\n", + "epoch 11: loss is 0.013692389758303761\n", + "epoch 12: loss is 0.012920912820845842\n", + "epoch 13: loss is 0.012520179092884064\n", + "epoch 14: loss is 0.012519657954573632\n", + "epoch 15: loss is 0.011793930067680775\n", + "epoch 16: loss is 0.011448755952529609\n", + "epoch 17: loss is 0.010951191697269679\n", + "epoch 18: loss is 0.010661566779017449\n", + "epoch 19: loss is 0.010179236607626081\n", + "epoch 20: loss is 0.009936179189942777\n", + "epoch 21: loss is 0.009754820148460568\n", + "epoch 22: loss is 0.00917115265596658\n", + "epoch 23: loss is 0.009030795649625362\n", + "epoch 24: loss is 0.008823388266377151\n", + "epoch 25: loss is 0.008829819331876933\n", + "epoch 26: loss is 0.008454289820045233\n", + "epoch 27: loss is 0.008023065636865794\n", + "epoch 28: loss is 0.007618186932988465\n", + "epoch 29: loss is 0.007679891352541744\n", + "\n", + "Test Accuracy is: 0.873\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## FGSM Attack against NN" + ], + "metadata": { + "id": "eBRDCoZ-hOeW" + } + }, + { + "cell_type": "code", + "source": [ + "x_origin = inputs[[0]]\n", + "y_origin = labels[[0]]\n", + "\n", + "attacker = FGSMAttacker(\n", + " net, criterion, eps=0.3, grad_lower_bound=-0.15, grad_upper_bound=0.15\n", + ")\n", + "perturbed_x = attacker.attack((x_origin, y_origin.to(torch.int64)))" + ], + "metadata": { + "id": "e6pbzHb7jSEG" + }, + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "fig = plt.figure()\n", + "fig.add_subplot(121)\n", + "plt.imshow(x_origin[0][0].detach().numpy(), cmap=\"gray\", vmin=-1.0, vmax=1.0)\n", + "plt.title(f\"Predicted Label: {net(x_origin).argmax().item()}\")\n", + "plt.axis(\"off\")\n", + "fig.add_subplot(122)\n", + "plt.imshow(perturbed_x[0][0].detach().numpy(), cmap=\"gray\", vmin=-1.0, vmax=1.0)\n", + "plt.title(f\"Predicted Label: {net(perturbed_x).argmax().item()}\")\n", + "plt.axis(\"off\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 302 + }, + "id": "ZO3_lvc2mCPm", + "outputId": "545bc5e3-0149-4e7c-e4c5-01a5e5ea222f" + }, + "execution_count": 8, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "(-0.5, 27.5, 27.5, -0.5)" + ] + }, + "metadata": {}, + "execution_count": 8 + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAELCAYAAABEYIWnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAYnElEQVR4nO3dfXBU1f3H8e+S56dKShIghUkjGIQIptChAmqgIIwRsNOhqG0dwFojBTE4bbRMpSBpVURAIoJoCzOUWpIp6IzSMqQGhHT6AEQLqZiUkoyUqRjNA2AgJjm/Pyj5EQLnbry52c1+368ZZsx+7z337O7dMx9Pcs76jDFGAACAWn0C3QEAABBYhAEAAJQjDAAAoBxhAAAA5QgDAAAoRxgAAEA5wgAAAMoRBgAAUI4wAACAcoSBbvLVr35V5s6d2/7z3r17xefzyd69ewPWpytd2ceeMHHiRLnpppu6tc1APA/AC4wbV8e40fNCIgxs2bJFfD5f+7/o6GjJyMiQhQsXykcffRTo7nXJrl27ZNmyZQHtg8/nk4ULFwa0D17617/+JbNmzZLExESJjY2VW2+9VUpLSwPdLfQwxo3uFerjxuW2bdsmPp9P4uPjA92VbhMe6A50p6eeekrS09Pl/PnzcuDAAdmwYYPs2rVLjh49KrGxsT3al9tvv12ampokMjKyS+ft2rVL1q9fH/APdqj68MMPZdy4cRIWFiY/+clPJC4uTjZv3ixTp06VP/3pT3L77bcHuovoYYwb6IqzZ89Kfn6+xMXFBbor3SqkwsCdd94pX//610VE5MEHH5R+/frJ6tWr5Y033pD77rvvquecO3fOkze1T58+Eh0d3e3twp1nnnlG6uvr5ejRozJs2DAREfnhD38oN954oyxevFgOHToU4B6ipzFuoCsKCgokISFBJk2aJK+//nqgu9NtQuLXBNfyzW9+U0RETpw4ISIic+fOlfj4eDl+/Ljk5ORIQkKCfO973xMRkba2Nlm7dq1kZmZKdHS09O/fX3Jzc6Wurq5Dm8YYKSgokEGDBklsbKxMmjRJKioqOl37Wr/7++tf/yo5OTmSmJgocXFxMmrUKHnhhRfa+7d+/XoRkQ7Tl5d0dx/deOONN+Suu+6S1NRUiYqKkiFDhsiKFSuktbX1qscfOnRIxo8fLzExMZKeni4bN27sdMyFCxfk5z//uQwdOlSioqJk8ODBkp+fLxcuXHDsz/Hjx+X48eOOx+3fv1++9rWvtQcBEZHY2FiZOXOmHD58WKqqqhzbQGhj3GDcuJaqqipZs2aNrF69WsLDQ+r/pUNrZuBKl97kfv36tT/W0tIi06ZNk1tvvVVWrVrVPg2Ym5srW7ZskXnz5smiRYvkxIkT8uKLL0p5ebmUlZVJRESEiIgsXbpUCgoKJCcnR3JycuTw4cMydepUaW5uduzPnj17ZPr06TJw4EB59NFHZcCAAfL+++/Lm2++KY8++qjk5ubKqVOnZM+ePbJ169ZO5/dEH/21ZcsWiY+Pl8cee0zi4+Pl7bfflqVLl0pjY6M899xzHY6tq6uTnJwcmT17ttx3331SVFQk8+fPl8jISHnggQdE5OKANXPmTDlw4IA89NBDMnz4cDly5IisWbNGKisrHRP45MmTRUSkurraetyFCxckMTGx0+OX7oNDhw7JDTfc4OergFDEuMG4cS15eXkyadIkycnJkaKioi4//6BmQsDmzZuNiJiSkhLz8ccfmw8//ND87ne/M/369TMxMTHm5MmTxhhj5syZY0TEPPHEEx3O379/vxERs23btg6P//GPf+zw+OnTp01kZKS56667TFtbW/txS5YsMSJi5syZ0/5YaWmpERFTWlpqjDGmpaXFpKenm7S0NFNXV9fhOpe3tWDBAnO1t8WLPl6LiJgFCxZYj/nss886PZabm2tiY2PN+fPn2x/Lzs42ImKef/759scuXLhgsrKyTEpKimlubjbGGLN161bTp08fs3///g5tbty40YiIKSsra38sLS2t0/NIS0szaWlpjs9txowZpm/fvqaxsbHD4+PGjTMiYlatWuXYBkID4wbjhr/jhjHGvPnmmyY8PNxUVFQYYy7eF3FxcX6d2xuE1K8JpkyZIsnJyTJ48GC59957JT4+Xnbu3Clf+cpXOhw3f/78Dj8XFxfLddddJ3fccYfU1ta2/xszZozEx8e3/6V5SUmJNDc3yyOPPNJhGi4vL8+xb+Xl5XLixAnJy8uTvn37dqhd3ta19EQfuyImJqb9v8+cOSO1tbVy2223yWeffSbHjh3rcGx4eLjk5ua2/xwZGSm5ubly+vTp9t/RFxcXy/Dhw+XGG2/s8PwuTdk6/bV/dXW1X+l+/vz5Ul9fL/fcc4+Ul5dLZWWl5OXlycGDB0VEpKmpya/nj9DBuMG44aS5uVkWL14sDz/8sIwYMcLfp9urhNSvCdavXy8ZGRkSHh4u/fv3l2HDhkmfPh3zTnh4uAwaNKjDY1VVVdLQ0CApKSlXbff06dMiIlJTUyMi0mkaOTk5+apTz5e7NPX4RdfO9kQfu6KiokJ+9rOfydtvvy2NjY0dag0NDR1+Tk1N7fTHVhkZGSJy8cN4yy23SFVVlbz//vuSnJx81etden5u3XnnnVJYWChPPPGEjB49WkREhg4dKr/4xS8kPz8/pJYKwT+MG4wbTtasWSO1tbWyfPnybmkvGIVUGBg7dmz7XwVfS1RUVKcPeltbm6SkpMi2bduues61brSeFEx9rK+vl+zsbPnSl74kTz31lAwZMkSio6Pl8OHD8vjjj0tbW1uX22xra5ORI0fK6tWrr1ofPHiw2263W7hwocybN0/+8Y9/SGRkpGRlZcmvfvUrEfn/wQZ6MG70jN46bjQ0NEhBQYH86Ec/ksbGxvYQc/bsWTHGSHV1tcTGxl4zcPUWIRUGvqghQ4ZISUmJTJgwocM01pXS0tJE5GLavv7669sf//jjjzv9Ze7VriEicvToUZkyZco1j7vW1F9P9NFfe/fulU8++UR27NjRYV3+pb++vtKpU6c6LcWqrKwUkYu7golcfH7vvfeeTJ482a/pT7fi4uJk3Lhx7T+XlJRITEyMTJgwwfNrIzQwbnRNbx036urq5OzZs7Jy5UpZuXJlp3p6errcfffdvX6ZYUj9zcAXNXv2bGltbZUVK1Z0qrW0tEh9fb2IXPzdYkREhBQWFooxpv2YtWvXOl5j9OjRkp6eLmvXrm1v75LL27p04195TE/00V9hYWGd+t3c3CwvvfTSVY9vaWmRl19+ucOxL7/8siQnJ8uYMWNE5OLz+89//iOvvPJKp/Obmprk3Llz1j51dYnQ5f785z/Ljh075Ac/+IFcd911X6gN6MO40TW9ddxISUmRnTt3dvo3adIkiY6Olp07d8pPf/pTaxu9ATMDIpKdnS25ubny9NNPy7vvvitTp06ViIgIqaqqkuLiYnnhhRdk1qxZkpycLD/+8Y/l6aeflunTp0tOTo6Ul5fLH/7wB0lKSrJeo0+fPrJhwwaZMWOGZGVlybx582TgwIFy7NgxqaiokN27d4uItN/kixYtkmnTpklYWJjce++9PdLHyx08eFAKCgo6PT5x4kQZP368JCYmypw5c2TRokXi8/lk69atHT7kl0tNTZVnn31WqqurJSMjQ7Zv3y7vvvuubNq0qX1Z0/333y9FRUXy8MMPS2lpqUyYMEFaW1vl2LFjUlRUJLt377ZO5fq7RKimpkZmz54tM2fOlAEDBkhFRYVs3LhRRo0aJb/85S/9fHUAxo2rCcVxIzY2Vr71rW91evz111+Xv/3tb1et9UoBWsXQrS4tEfr73/9uPc5pKcimTZvMmDFjTExMjElISDAjR440+fn55tSpU+3HtLa2muXLl5uBAweamJgYM3HiRHP06NFOy1auXCJ0yYEDB8wdd9xhEhISTFxcnBk1apQpLCxsr7e0tJhHHnnEJCcnG5/P12m5UHf28VpE5Jr/VqxYYYwxpqyszNxyyy0mJibGpKammvz8fLN79+5Ozzk7O9tkZmaagwcPmnHjxpno6GiTlpZmXnzxxU7XbW5uNs8++6zJzMw0UVFRJjEx0YwZM8YsX77cNDQ0tB/nZonQp59+au6++24zYMAAExkZadLT083jjz/eaakhQh/jBuNGV5YWXinUlhb6jLlGLAMAACrwNwMAAChHGAAAQDnCAAAAyhEGAABQjjAAAIByhAEAAJQjDAAAoJzfOxD2xH7xAOx647YgTmNHv379PL3+J5984mn7XvffH8H+HJ3659R+oM/3h9v3wOvX2GnsYGYAAADlCAMAAChHGAAAQDnCAAAAyhEGAABQjjAAAIByfi8tBIBA8HrJVk8sO3PidR/cvgZeC/TSQX+ef29fAuuEmQEAAJQjDAAAoBxhAAAA5QgDAAAoRxgAAEA5wgAAAMoRBgAAUI59BgB4yuv12V6vwXfSHevDA72OPhi+htmNQN8D3SHQ7yEzAwAAKEcYAABAOcIAAADKEQYAAFCOMAAAgHKEAQAAlCMMAACgnM8YY/w60Ofzui8AHPj5cQ0qTmNHd3zXvJv2nQT6e+bRO3h9n7n9nDiNHcwMAACgHGEAAADlCAMAAChHGAAAQDnCAAAAyhEGAABQjjAAAIBy4YHuAIDQ5vX3tHu9D0BPfNd9sO9lEOj3wElPvEduBft9zswAAADKEQYAAFCOMAAAgHKEAQAAlCMMAACgHGEAAADlCAMAACjnM35+QbrTd5ID8J6fH9eg4vXY0RvWmIc6t2vo4czta1xbW2utMzMAAIByhAEAAJQjDAAAoBxhAAAA5QgDAAAoRxgAAEA5wgAAAMqFB7oD+H8RERHWelpamrU+Z84ca/2GG26w1rOysqz14uJia72wsNBaP336tLUOXA1r1J05rUFPTEy01gM9dqxbt85ab2trs9ZD4R5xu4+A0/lOmBkAAEA5wgAAAMoRBgAAUI4wAACAcoQBAACUIwwAAKAcYQAAAOV8xs8vSPf6O8lDQVhYmLU+d+5ca33JkiXWenp6urXe2tpqrTc1NVnrTmJjY6317du3W+v333+/te60lhgifn5cg0pSUpK17nZ9dG9YY+70HFNSUqz18ePHW+uhPnYsWrTI1fW7g9f3mdf7DDiNHcwMAACgHGEAAADlCAMAAChHGAAAQDnCAAAAyhEGAABQjjAAAIByhAEAAJRj06EuSEtLs9ZXrFhhrX//+9+31s+cOWOtb9682Vrft2+ftb5z505r3cmRI0es9czMTGs9NzfXWn/llVe63CdteuOmQ27HDrebvQTDpkbnz5+31ouKiqz1sWPHWutOY0dZWZm1/uqrr1rrTmOLE6fzR4wYYa0/9NBD1vo777zT5T5dzp97JNCbDrnFpkMAAMCKMAAAgHKEAQAAlCMMAACgHGEAAADlCAMAAChHGAAAQLnwQHcgWDjtISAisnv3bmv9+uuvt9bXrVtnra9Zs8Zar6mpsda9du7cOVfn5+XlWevbt2+31hsbG11dHzo5rQ93u77bn7GjsLDQWnc7drz11lvW+uHDh611r6Wmpro6f/HixdZ6eXm5tV5dXW2te72HgIj3+124bZ+ZAQAAlCMMAACgHGEAAADlCAMAAChHGAAAQDnCAAAAyhEGAABQjn0G/ufJJ590PCYjI8Naf+aZZ6z1JUuWdKlPwWbt2rXW+m9/+1trffjw4dZ6bGystc4+A72T2/XRTnW37btdY94dY8eOHTus9Zdeesla9/o5ul3D7jR2LFu2zFq/7bbbrPW4uLgu9qgjf56f1+v8A42ZAQAAlCMMAACgHGEAAADlCAMAAChHGAAAQDnCAAAAyhEGAABQTs0+A0OGDLHWs7OzHdv46KOPrPVNmzZ1qU/a7Nu3z1pvaGjooZ6gJ3m9Bt4tp/6NHTvWWp84caLjNU6ePGmtux07Av0aOqmsrHR1/pEjR6z1iooKV+37I9j3EXB7DzAzAACAcoQBAACUIwwAAKAcYQAAAOUIAwAAKEcYAABAOcIAAADKqdln4LHHHrPWnfYhEBE5e/astf6Nb3zDWo+Pj7fWa2pqrPUzZ85Y61779re/7er8qqoqa72pqclV+whObtc/B3qfgunTp1vrffv2dWyjvr7eWk9MTLTWtY8db731Vjf1JHCc7lOn+9zrfQ6YGQAAQDnCAAAAyhEGAABQjjAAAIByhAEAAJQjDAAAoBxhAAAA5dTsM/Daa69Z67Nnz3Zsw2mdqNM1nPzlL3+x1v/73/+6at+tyZMnuzo/KSnJWo+IiLDWP//8c1fXR2AE+/fAOzl06JC1/sEHHzi24TR27Nmzx1qvra211j/99FNr3WnsqKystNbdevDBB611p+fnNHb0BoHeL8MJMwMAAChHGAAAQDnCAAAAyhEGAABQjjAAAIByhAEAAJQjDAAAoJzPGGP8OtDn87ovAbVs2TLHY5YuXep9RxRLTU211gO9z0Iw8PPjGlSc1oi73Ycg0Ouzv/vd7zoes27duh7oyRcX7Ov8nfo3cOBAa72lpcVaD/Q9JOL9fhxOYwczAwAAKEcYAABAOcIAAADKEQYAAFCOMAAAgHKEAQAAlCMMAACgHPsM/E9MTIzjMRMmTHB1je985zvW+rBhw1y1n5iYaK2PHDnSVftey8/Pt9ZXrVrVQz0JXr1xnwGvx45ArxHvibHj5ptvttZnzJjhqv0vf/nL1npkZKS17nYfAqd9BJy8+uqr1npvGDuc9hlwe587vcbMDAAAoBxhAAAA5QgDAAAoRxgAAEA5wgAAAMoRBgAAUI4wAACAcuGB7kCwaGpqcjympKTE1TXcnu/E630G+vfvb61v3LjRWnfqH9Ab9cTYsX37dmv9+eefd9V+oMeODRs2WOtu9yHoDZz2EfB6HwJmBgAAUI4wAACAcoQBAACUIwwAAKAcYQAAAOUIAwAAKEcYAABAOfYZCCF1dXXW+jvvvOOq/SeffNJad7uPwG9+8xtX5yM0uV1/7bb97uD1GnG3r4ETr8cOJ0lJSda609jRHa9PoPcB8BozAwAAKEcYAABAOcIAAADKEQYAAFCOMAAAgHKEAQAAlCMMAACgnM8YY/w60Ofzui8IckeOHLHWMzMzrfWamhprffTo0da60z4KGvj5cQ0qTmvEnXi9hj7Y13+L9P7XoKKiwloPCwuz1qOjo631hISELvept3H7HtXW1lrrzAwAAKAcYQAAAOUIAwAAKEcYAABAOcIAAADKEQYAAFCOMAAAgHLhge4AgseUKVOs9aFDh7pqf9asWdY6+wjo5HYNvdv1116v4RfpHXsZ2Di9Rvfcc4+1/vnnn1vr/fv3t9anTZtmrWvg9B64vceYGQAAQDnCAAAAyhEGAABQjjAAAIByhAEAAJQjDAAAoBxhAAAA5dhnQJFBgwZZ6ytXrrTWo6KirPV///vf1np1dbW1jtDk9Tr+ntgnINC83kvBqX7zzTdb688995y17jT2NDQ0WOvBMHZ4/R4Eei8KZgYAAFCOMAAAgHKEAQAAlCMMAACgHGEAAADlCAMAAChHGAAAQDn2GVBk3rx51npWVpar9svKyqx1DevB0Vmg109z37l/D5zGjpiYGFft//rXv7bWg2GNvts+eN1Ht/c5MwMAAChHGAAAQDnCAAAAyhEGAABQjjAAAIByhAEAAJQjDAAAoBxhAAAA5dh0KIQMHjzYWnfaOMTn83VndwARcb9Zi9vNVNy2H+hNk3qC09jxwAMPWOsJCQnd2R0EADMDAAAoRxgAAEA5wgAAAMoRBgAAUI4wAACAcoQBAACUIwwAAKCczxhj/DqQNehBb9iwYdZ6aWmptT5gwABX129ubrbWhw4daq2fPHnS1fU18PPjGlScxo5g3wcg0NcPBnv37rXWb7rpJk+vn5WVZa27HTvc7mXhD6/303DiNHYwMwAAgHKEAQAAlCMMAACgHGEAAADlCAMAAChHGAAAQDnCAAAAyoUHugPoPh988IG1vmDBAmv997//vavrO51/6tQpV+2jdwr1dfj+rA93+xp4vdeBU/tOY8e+ffus9draWmv9tddes9bfe+89az3Qa/h7gtfPkZkBAACUIwwAAKAcYQAAAOUIAwAAKEcYAABAOcIAAADKEQYAAFCOfQYU+ec//2mtFxcXW+sjRoyw1nfs2GGtt7W1WesITW7XyHu9hr4neN0Hr9uvqKiw1jds2GCtZ2ZmWutOY0cwCPa9Ityez8wAAADKEQYAAFCOMAAAgHKEAQAAlCMMAACgHGEAAADlCAMAACjnM8YYvw70+bzuCwAHfn5cg0pSUpK1Hgz7ACC4uV1D3x283ifA7efA6fq1tbXWOjMDAAAoRxgAAEA5wgAAAMoRBgAAUI4wAACAcoQBAACUIwwAAKAc+wwAvUgo7jPgxOt9CJzWZ3u9vtyfawS7QO8D4PY98uf1D/R94vYecRo7mBkAAEA5wgAAAMoRBgAAUI4wAACAcoQBAACUIwwAAKAcYQAAAOX83mcAAACEJmYGAABQjjAAAIByhAEAAJQjDAAAoBxhAAAA5QgDAAAoRxgAAEA5wgAAAMoRBgAAUO7/AEYqFjqwVNwqAAAAAElFTkSuQmCC\n" + }, + "metadata": {} + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## XGBoost without Defense\n", + "\n", + "The adversarial example crafted above can also deceive the XGBoost model." + ], + "metadata": { + "id": "iTS_qx0Ug_Ry" + } + }, + { + "cell_type": "code", + "source": [ + "min_leaf = 1\n", + "depth = 6\n", + "learning_rate = 0.3\n", + "boosting_rounds = 10\n", + "lam = 1.0\n", + "gamma = 0.0\n", + "eps = 1.0\n", + "min_child_weight = -1 * float(\"inf\")\n", + "subsample_cols = 0.8\n", + "\n", + "X_train_normalized = ((X_train / 255) * 2 - 1).reshape(-1, 28 * 28).tolist()\n", + "X_test_normalized = ((X_test / 255) * 2 - 1).reshape(-1, 28 * 28).tolist()\n", + "\n", + "p0 = XGBoostClient(\n", + " X_train_normalized,\n", + " 10,\n", + " list(range(28 * 28)),\n", + " 0,\n", + " min_leaf,\n", + " subsample_cols,\n", + " 32,\n", + " False,\n", + " 0,\n", + ")\n", + "parties = [p0]\n", + "\n", + "clf = XGBoostClassifierAPI(\n", + " 10,\n", + " subsample_cols,\n", + " min_child_weight,\n", + " depth,\n", + " min_leaf,\n", + " learning_rate,\n", + " boosting_rounds,\n", + " lam,\n", + " gamma,\n", + " eps,\n", + " -1,\n", + " 0,\n", + " 1.0,\n", + " 1,\n", + " True,\n", + " False,\n", + ")\n", + "clf.fit(parties, y_train.tolist())\n", + "\n", + "predicted_proba = clf.predict_proba(X_train_normalized)\n", + "print(\n", + " \"Train Accuracy: \",\n", + " accuracy_score(np.array(predicted_proba).argmax(axis=1), y_train),\n", + ")\n", + "predicted_proba = clf.predict_proba(X_test_normalized)\n", + "print(\n", + " \"Test Accuracy: \", accuracy_score(np.array(predicted_proba).argmax(axis=1), y_test)\n", + ")\n", + "\n", + "print(\n", + " \"Predicted Label without Attack: \",\n", + " np.array(\n", + " clf.predict_proba(x_origin[0][0].detach().numpy().reshape(1, -1).tolist())\n", + " ).argmax(1),\n", + ")\n", + "print(\n", + " \"Predicted Label with Attack: \",\n", + " np.array(\n", + " clf.predict_proba(perturbed_x[0][0].detach().numpy().reshape(1, -1).tolist())\n", + " ).argmax(1),\n", + ")" + ], + "metadata": { + "id": "bqEk1VA-m0Ih", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "85646211-41ce-4bb4-9f08-bbda04155447" + }, + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Train Accuracy: 0.998\n", + "Test Accuracy: 0.836\n", + "Predicted Label without Attack: [9]\n", + "Predicted Label with Attack: [5]\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## XGBoost using Attack-Cost Constraints" + ], + "metadata": { + "id": "C_fdBRIChCKQ" + } + }, + { + "cell_type": "code", + "source": [ + "p0 = XGBoostClient(\n", + " X_train_normalized,\n", + " 10,\n", + " list(range(28 * 28)),\n", + " 0,\n", + " min_leaf,\n", + " subsample_cols,\n", + " 32,\n", + " False,\n", + " 0,\n", + ")\n", + "# You can set the attack-cost constraint to each feature\n", + "p0.set_cost_constraint_map([(-0.2, 0.2)] * (28 * 28))\n", + "parties = [p0]\n", + "\n", + "clf = XGBoostClassifierAPI(\n", + " 10,\n", + " subsample_cols,\n", + " min_child_weight,\n", + " depth,\n", + " min_leaf,\n", + " learning_rate,\n", + " boosting_rounds,\n", + " lam,\n", + " gamma,\n", + " eps,\n", + " -1,\n", + " 0,\n", + " 1.0,\n", + " 1,\n", + " True,\n", + " True,\n", + ")\n", + "clf.fit(parties, y_train.tolist())\n", + "\n", + "predicted_proba = clf.predict_proba(X_train_normalized)\n", + "print(\n", + " \"Train Accuracy: \",\n", + " accuracy_score(np.array(predicted_proba).argmax(axis=1), y_train),\n", + ")\n", + "predicted_proba = clf.predict_proba(X_test_normalized)\n", + "print(\n", + " \"Test Accuracy: \", accuracy_score(np.array(predicted_proba).argmax(axis=1), y_test)\n", + ")\n", + "\n", + "print(\n", + " \"Predicted Label without Attack: \",\n", + " np.array(\n", + " clf.predict_proba(x_origin[0][0].detach().numpy().reshape(1, -1).tolist())\n", + " ).argmax(1),\n", + ")\n", + "print(\n", + " \"Predicted Label with Attack: \",\n", + " np.array(\n", + " clf.predict_proba(perturbed_x[0][0].detach().numpy().reshape(1, -1).tolist())\n", + " ).argmax(1),\n", + ")" + ], + "metadata": { + "id": "1vxloHfQq06E", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "5ea8f965-0767-4fd7-bcde-55b8afd8b841" + }, + "execution_count": 10, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Train Accuracy: 0.975\n", + "Test Accuracy: 0.84\n", + "Predicted Label without Attack: [9]\n", + "Predicted Label with Attack: [9]\n" + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/docs/source/notebooks/evasion.rst b/docs/source/notebooks/evasion.rst index 98ca1201..323bc038 100644 --- a/docs/source/notebooks/evasion.rst +++ b/docs/source/notebooks/evasion.rst @@ -4,4 +4,5 @@ Evasion Attack .. toctree:: :maxdepth: 1 - aijack_evasion_attack \ No newline at end of file + aijack_evasion_attack + aijack_transferbility_and_robustness diff --git a/src/aijack/attack/__init__.py b/src/aijack/attack/__init__.py index a99c9916..a65dd2f5 100644 --- a/src/aijack/attack/__init__.py +++ b/src/aijack/attack/__init__.py @@ -1,7 +1,7 @@ """Submodule for attack algorithms against machine learning. """ from .base_attack import BaseAttacker # noqa: F401 -from .evasion import Evasion_attack_sklearn # noqa: F401 +from .evasion import Evasion_attack_sklearn, FGSMAttacker # noqa: F401 from .inversion import ( # noqa: F401 MI_FACE, GANAttackClientManager, diff --git a/src/aijack/attack/evasion/__init__.py b/src/aijack/attack/evasion/__init__.py index a342004b..092744dd 100644 --- a/src/aijack/attack/evasion/__init__.py +++ b/src/aijack/attack/evasion/__init__.py @@ -2,5 +2,6 @@ machine learning model cannot correctly classify. """ from .evasion_attack import Evasion_attack_sklearn # noqa: F401 +from .fgsm import FGSMAttacker # noqa: F401 -__all__ = ["Evasion_attack_sklearn"] +__all__ = ["Evasion_attack_sklearn", "FGSMAttacker"] diff --git a/src/aijack/attack/evasion/fgsm.py b/src/aijack/attack/evasion/fgsm.py new file mode 100644 index 00000000..aadfa80c --- /dev/null +++ b/src/aijack/attack/evasion/fgsm.py @@ -0,0 +1,44 @@ +import torch + +from ..base_attack import BaseAttacker + + +class FGSMAttacker(BaseAttacker): + def __init__( + self, + target_model, + criterion, + eps=0.3, + grad_lower_bound=-0.1, + grad_upper_bound=0.1, + output_lower_bound=-1.0, + output_upper_bound=1.0, + ): + super().__init__(target_model) + + self.criterion = criterion + self.eps = eps + self.grad_lower_bound = grad_lower_bound + self.grad_upper_bound = grad_upper_bound + self.output_lower_bound = output_lower_bound + self.output_upper_bound = output_upper_bound + + def attack(self, data): + x, y = data + x.requires_grad = True + + self.target_model.zero_grad() + output = self.target_model(x) + loss = self.criterion(output, y) + loss.backward() + grad = x.grad.data + + sign_data_grad = grad.sign() + noise = torch.clamp( + self.eps * sign_data_grad, self.grad_lower_bound, self.grad_upper_bound + ) + perturbed_x = x + noise + perturbed_x = torch.clamp( + perturbed_x, self.output_lower_bound, self.output_upper_bound + ) + return perturbed_x diff --git a/src/aijack/collaborative/tree/core/model.h b/src/aijack/collaborative/tree/core/model.h index ba2a7b8e..0e8ed881 100644 --- a/src/aijack/collaborative/tree/core/model.h +++ b/src/aijack/collaborative/tree/core/model.h @@ -1,8 +1,8 @@ #pragma once -#include +#include #include #include -#include +#include using namespace std; /** @@ -10,24 +10,23 @@ using namespace std; * * @tparam PartyName */ -template -struct TreeModelBase -{ - TreeModelBase(){}; +template struct TreeModelBase { + // TreeModelBase(){}; - /** - * @brief Function to train the model given the parties and ground-truth labels. - * - * @param parties The vector of parties. - * @param y The vector of ground-truth vectors - */ - virtual void fit(vector &parties, vector &y) = 0; + /** + * @brief Function to train the model given the parties and ground-truth + * labels. + * + * @param parties The vector of parties. + * @param y The vector of ground-truth vectors + */ + virtual void fit(vector &parties, vector &y) = 0; - /** - * @brief Function to return the predicted scores of the given data. - * - * @param X The feature matrix. - * @return vector The vector of predicted raw scores. - */ - virtual vector> predict_raw(vector> &X) = 0; + /** + * @brief Function to return the predicted scores of the given data. + * + * @param X The feature matrix. + * @return vector The vector of predicted raw scores. + */ + virtual vector> predict_raw(vector> &X) = 0; }; diff --git a/src/aijack/collaborative/tree/core/node.h b/src/aijack/collaborative/tree/core/node.h index 632721bf..f5a4adfb 100644 --- a/src/aijack/collaborative/tree/core/node.h +++ b/src/aijack/collaborative/tree/core/node.h @@ -1,12 +1,12 @@ #pragma once -#include -#include +#include #include #include #include -#include +#include #include #include +#include using namespace std; /** @@ -14,110 +14,112 @@ using namespace std; * * @tparam PartyType Type of party. */ -template -struct Node -{ - vector parties; - vector y; - vector idxs; - - int num_classes; - int depth; - int active_party_id; - int n_job; - - int party_id, record_id; - int row_count, num_parties; - float score; - vector val; - - int best_party_id = -1; - int best_col_id = -1; - int best_threshold_id = -1; - - float best_score = -1 * numeric_limits::infinity(); - int is_leaf_flag = -1; // -1:not calculated yer, 0: is not leaf, 1: is leaf - - Node(){}; - - /** - * @brief Get the idxs object - * - * @return vector - */ - virtual vector get_idxs() = 0; - - /** - * @brief Get the party id object - * - * @return int - */ - virtual int get_party_id() = 0; - - /** - * @brief Get the record id object - * - * @return int - */ - virtual int get_record_id() = 0; - - /** - * @brief Get the value assigned to this node. - * - * @return float - */ - virtual vector get_val() = 0; - - /** - * @brief Get the evaluation score of this node. - * - * @return float - */ - virtual float get_score() = 0; - - /** - * @brief Get the num of parties used for this node. - * - * @return int - */ - virtual int get_num_parties() = 0; - - /** - * @brief Compute the weight (val) of this node. - * - * @return vector - */ - virtual vector compute_weight() = 0; - - /** - * @brief Find the best split which gives the best score (gain). - * - * @return tuple - */ - virtual tuple find_split() = 0; - - /** - * @brief Generate the children nodes. - * - * @param best_party_id The index of the best party. - * @param best_col_id The index of the best feature. - * @param best_threshold_id The index of the best threshold. - */ - virtual void make_children_nodes(int best_party_id, int best_col_id, int best_threshold_id) = 0; - - /** - * @brief Return true if this node is a leaf. - * - * @return true - * @return false - */ - virtual bool is_leaf() = 0; - - /** - * @brief Return true if the node is pure; the assigned labels to this node consist of a unique label. - * - * @return true - * @return false - */ - virtual bool is_pure() = 0; +template struct Node { + vector &parties; + vector &y; + vector idxs; + + int num_classes; + int depth; + int active_party_id; + int n_job; + + int party_id, record_id; + int row_count, num_parties; + float score; + vector val; + + int best_party_id = -1; + int best_col_id = -1; + int best_threshold_id = -1; + + float best_score = -1 * numeric_limits::infinity(); + int is_leaf_flag = -1; // -1:not calculated yer, 0: is not leaf, 1: is leaf + + // Node(){}; + Node(vector &parties_, vector &idxs_, vector &y_) + : parties(parties_), idxs(idxs_), y(y_) {} + + /** + * @brief Get the idxs object + * + * @return vector + */ + virtual vector get_idxs() = 0; + + /** + * @brief Get the party id object + * + * @return int + */ + virtual int get_party_id() = 0; + + /** + * @brief Get the record id object + * + * @return int + */ + virtual int get_record_id() = 0; + + /** + * @brief Get the value assigned to this node. + * + * @return float + */ + virtual vector get_val() = 0; + + /** + * @brief Get the evaluation score of this node. + * + * @return float + */ + virtual float get_score() = 0; + + /** + * @brief Get the num of parties used for this node. + * + * @return int + */ + virtual int get_num_parties() = 0; + + /** + * @brief Compute the weight (val) of this node. + * + * @return vector + */ + virtual vector compute_weight() = 0; + + /** + * @brief Find the best split which gives the best score (gain). + * + * @return tuple + */ + virtual tuple find_split() = 0; + + /** + * @brief Generate the children nodes. + * + * @param best_party_id The index of the best party. + * @param best_col_id The index of the best feature. + * @param best_threshold_id The index of the best threshold. + */ + virtual void make_children_nodes(int best_party_id, int best_col_id, + int best_threshold_id) = 0; + + /** + * @brief Return true if this node is a leaf. + * + * @return true + * @return false + */ + virtual bool is_leaf() = 0; + + /** + * @brief Return true if the node is pure; the assigned labels to this node + * consist of a unique label. + * + * @return true + * @return false + */ + virtual bool is_pure() = 0; }; diff --git a/src/aijack/collaborative/tree/core/nodeapi.h b/src/aijack/collaborative/tree/core/nodeapi.h index 33467409..1b5d50ae 100644 --- a/src/aijack/collaborative/tree/core/nodeapi.h +++ b/src/aijack/collaborative/tree/core/nodeapi.h @@ -1,248 +1,195 @@ #pragma once -#include -#include +#include #include #include #include -#include +#include +#include #include #include -#include +#include using namespace std; -template -struct NodeAPI -{ - - NodeAPI(){}; - - float get_leaf_purity(NodeType *node, int tot_cnt) - { - float leaf_purity = 0; - if (node->is_leaf()) - { - int cnt_idxs = node->idxs.size(); - if (cnt_idxs == 0) - { - leaf_purity = 0.0; - } - else - { - int cnt_zero = 0; - for (int i = 0; i < node->idxs.size(); i++) - { - if (node->y[node->idxs[i]] == 0) - { - cnt_zero += 1; - } - } - leaf_purity = max(float(cnt_zero) / float(cnt_idxs), - 1 - float(cnt_zero) / float(cnt_idxs)); - leaf_purity = leaf_purity * (float(cnt_idxs) / float(tot_cnt)); - } - } - else - { - leaf_purity = get_leaf_purity(node->left, tot_cnt) + get_leaf_purity(node->right, tot_cnt); +template struct NodeAPI { + + NodeAPI(){}; + + float get_leaf_purity(NodeType *node, int tot_cnt) { + float leaf_purity = 0; + if (node->is_leaf()) { + int cnt_idxs = node->idxs.size(); + if (cnt_idxs == 0) { + leaf_purity = 0.0; + } else { + int cnt_zero = 0; + for (int i = 0; i < node->idxs.size(); i++) { + if (node->y[node->idxs[i]] == 0) { + cnt_zero += 1; + } } - return leaf_purity; + leaf_purity = max(float(cnt_zero) / float(cnt_idxs), + 1 - float(cnt_zero) / float(cnt_idxs)); + leaf_purity = leaf_purity * (float(cnt_idxs) / float(tot_cnt)); + } + } else { + leaf_purity = get_leaf_purity(node->left, tot_cnt) + + get_leaf_purity(node->right, tot_cnt); } - - string print(NodeType *node, bool show_purity = false, bool binary_color = true, int target_party_id = -1) - { - pair result = recursive_print(node, "", false, show_purity, binary_color, target_party_id); - if (result.second) - { - return ""; - } - else - { - return result.first; - } + return leaf_purity; + } + + string print(NodeType *node, bool show_purity = false, + bool binary_color = true, int target_party_id = -1) { + pair result = recursive_print(node, "", false, show_purity, + binary_color, target_party_id); + if (result.second) { + return ""; + } else { + return result.first; } - - string print_leaf(NodeType *node, bool show_purity, bool binary_color) - { - string node_info = to_string(node->get_val()[0]); - if (show_purity) - { - int cnt_idxs = node->idxs.size(); - if (cnt_idxs == 0) - { - node_info += ", null"; - } - else - { - int cnt_zero = 0; - for (int i = 0; i < node->idxs.size(); i++) - { - if (node->y[node->idxs[i]] == 0) - { - cnt_zero += 1; - } - } - float purity = max(float(cnt_zero) / float(cnt_idxs), - 1 - float(cnt_zero) / float(cnt_idxs)); - node_info += ", "; - - if (binary_color) - { - if (purity < 0.7) - { - node_info += "\033[32m"; - } - else if (purity < 0.9) - { - node_info += "\033[33m"; - } - else - { - node_info += "\033[31m"; - } - node_info += to_string(purity); - node_info += " ("; - node_info += to_string(cnt_zero); - node_info += ", "; - node_info += to_string(cnt_idxs - cnt_zero); - node_info += ")"; - node_info += "\033[0m"; - } - else - { - node_info += to_string(purity); - } - } - } - else - { - node_info += ", ["; - int temp_id; - for (int i = 0; i < node->idxs.size(); i++) - { - temp_id = node->idxs[i]; - if (binary_color) - { - if (node->y[temp_id] == 0) - { - node_info += "\033[32m"; - node_info += to_string(temp_id); - node_info += "\033[0m"; - } - else - { - node_info += to_string(temp_id); - } - } - else - { - node_info += to_string(temp_id); - } - node_info += ", "; - } - node_info += "]"; + } + + string print_leaf(NodeType *node, bool show_purity, bool binary_color) { + string node_info = to_string(node->get_val()[0]); + if (show_purity) { + int cnt_idxs = node->idxs.size(); + if (cnt_idxs == 0) { + node_info += ", null"; + } else { + int cnt_zero = 0; + for (int i = 0; i < node->idxs.size(); i++) { + if (node->y[node->idxs[i]] == 0) { + cnt_zero += 1; + } } - - return node_info; - } - - pair recursive_print(NodeType *node, string prefix, bool isleft, bool show_purity, - bool binary_color, int target_party_id = -1) - { - string node_info; - bool skip_flag; - if (node->is_leaf()) - { - skip_flag = node->depth <= 0 && target_party_id != -1 && node->party_id != target_party_id; - if (skip_flag) - { - node_info = ""; - } - else - { - node_info = print_leaf(node, show_purity, binary_color); - } - node_info = prefix + "|-- " + node_info; - node_info += "\n"; + float purity = max(float(cnt_zero) / float(cnt_idxs), + 1 - float(cnt_zero) / float(cnt_idxs)); + node_info += ", "; + + if (binary_color) { + if (purity < 0.7) { + node_info += "\033[32m"; + } else if (purity < 0.9) { + node_info += "\033[33m"; + } else { + node_info += "\033[31m"; + } + node_info += to_string(purity); + node_info += " ("; + node_info += to_string(cnt_zero); + node_info += ", "; + node_info += to_string(cnt_idxs - cnt_zero); + node_info += ")"; + node_info += "\033[0m"; + } else { + node_info += to_string(purity); } - else - { - node_info += to_string(node->get_party_id()); - node_info += ", "; - node_info += to_string(node->get_record_id()); - node_info = prefix + "|-- " + node_info; - - string next_prefix = ""; - if (isleft) - { - next_prefix += "| "; - } - else - { - next_prefix += " "; - } - - pair left_node_info_and_skip_flag = recursive_print(node->left, prefix + next_prefix, true, - show_purity, binary_color, target_party_id); - pair right_node_info_and_skip_flag = recursive_print(node->right, prefix + next_prefix, false, - show_purity, binary_color, target_party_id); - if (left_node_info_and_skip_flag.second && right_node_info_and_skip_flag.second) - { - node_info += " -> " + print_leaf(node, show_purity, binary_color); - node_info += "\n"; - } - else - { - node_info += "\n"; - node_info += left_node_info_and_skip_flag.first; - node_info += right_node_info_and_skip_flag.first; - } - - skip_flag = false; + } + } else { + node_info += ", ["; + int temp_id; + for (int i = 0; i < node->idxs.size(); i++) { + temp_id = node->idxs[i]; + if (binary_color) { + if (node->y[temp_id] == 0) { + node_info += "\033[32m"; + node_info += to_string(temp_id); + node_info += "\033[0m"; + } else { + node_info += to_string(temp_id); + } + } else { + node_info += to_string(temp_id); } + node_info += ", "; + } + node_info += "]"; + } - return make_pair(node_info, skip_flag); + return node_info; + } + + pair recursive_print(NodeType *node, string prefix, bool isleft, + bool show_purity, bool binary_color, + int target_party_id = -1) { + string node_info; + bool skip_flag; + if (node->is_leaf()) { + skip_flag = node->depth <= 0 && target_party_id != -1 && + node->party_id != target_party_id; + if (skip_flag) { + node_info = ""; + } else { + node_info = print_leaf(node, show_purity, binary_color); + } + node_info = prefix + "|-- " + node_info; + node_info += "\n"; + } else { + node_info += to_string(node->get_party_id()); + node_info += ", "; + node_info += to_string(node->get_record_id()); + node_info = prefix + "|-- " + node_info; + + string next_prefix = ""; + if (isleft) { + next_prefix += "| "; + } else { + next_prefix += " "; + } + + pair left_node_info_and_skip_flag = + recursive_print(node->left, prefix + next_prefix, true, show_purity, + binary_color, target_party_id); + pair right_node_info_and_skip_flag = + recursive_print(node->right, prefix + next_prefix, false, show_purity, + binary_color, target_party_id); + if (left_node_info_and_skip_flag.second && + right_node_info_and_skip_flag.second) { + node_info += " -> " + print_leaf(node, show_purity, binary_color); + node_info += "\n"; + } else { + node_info += "\n"; + node_info += left_node_info_and_skip_flag.first; + node_info += right_node_info_and_skip_flag.first; + } + + skip_flag = false; } - vector predict_row(NodeType *node, vector &xi) - { - queue que; - que.push(node); - - NodeType *temp_node; - while (!que.empty()) - { - temp_node = que.front(); - que.pop(); - - if (temp_node->is_leaf()) - { - return temp_node->val; - } - else - { - if (node->parties[temp_node->party_id].is_left(temp_node->record_id, xi)) - { - que.push(temp_node->left); - } - else - { - que.push(temp_node->right); - } - } + return make_pair(node_info, skip_flag); + } + + vector predict_row(NodeType *node, vector &xi) { + queue que; + que.push(node); + + NodeType *temp_node; + while (!que.empty()) { + temp_node = que.front(); + que.pop(); + + if (temp_node->is_leaf()) { + return temp_node->val; + } else { + if (node->parties[temp_node->party_id].is_left(temp_node->record_id, + xi)) { + que.push(temp_node->left); + } else { + que.push(temp_node->right); } - - vector nan_vec(node->num_classes, nan("")); - return nan_vec; + } } - vector> predict(NodeType *node, vector> &x_new) - { - int x_new_size = x_new.size(); - vector> y_pred(x_new_size); - for (int i = 0; i < x_new_size; i++) - { - y_pred[i] = predict_row(node, x_new[i]); - } - return y_pred; + vector nan_vec(node->num_classes, 0); + return nan_vec; + } + + vector> predict(NodeType *node, vector> &x_new) { + int x_new_size = x_new.size(); + vector> y_pred(x_new_size); + for (int i = 0; i < x_new_size; i++) { + y_pred[i] = predict_row(node, x_new[i]); } + return y_pred; + } }; diff --git a/src/aijack/collaborative/tree/core/party.h b/src/aijack/collaborative/tree/core/party.h index 5c8e6997..1f7f554b 100644 --- a/src/aijack/collaborative/tree/core/party.h +++ b/src/aijack/collaborative/tree/core/party.h @@ -1,205 +1,175 @@ #pragma once +#include "../utils/utils.h" +#include #include -#include -#include +#include #include #include #include -#include -#include -#include -#include +#include #include -#include +#include +#include #include +#include +#include #include -#include -#include "../utils/utils.h" +#include using namespace std; -struct Party -{ - vector> x; // a feature vector of this party - vector feature_id; // id of the features - int party_id; // id of this party - int min_leaf; - float subsample_cols; // ratio of subsampled columuns - bool use_missing_value; - int seed; - - int col_count; // the number of columns - int subsample_col_count; - - int num_classes; - - unordered_map> lookup_table; // record_id: (feature_id, threshold, missing_value_dir) - vector temp_column_subsample; - vector> temp_thresholds; // feature_id->threshold - - Party() {} - Party(vector> x_, int num_classes_, vector feature_id_, int party_id_, - int min_leaf_, float subsample_cols_, - bool use_missing_value_ = false, int seed_ = 0) - { - validate_arguments(x_, feature_id_, party_id_, min_leaf_, subsample_cols_); - x = x_; - num_classes = num_classes_; - feature_id = feature_id_; - party_id = party_id_; - min_leaf = min_leaf_; - subsample_cols = subsample_cols_; - use_missing_value = use_missing_value_; - seed = seed_; - - col_count = x.at(0).size(); - subsample_col_count = max(1, int(subsample_cols * float(col_count))); - } - - void validate_arguments(vector> &x_, vector &feature_id_, int &party_id_, - int min_leaf_, float subsample_cols_) - { - try - { - if (x_.size() == 0) - { - throw invalid_argument("x is empty"); - } - } - catch (std::exception &e) - { - std::cerr << e.what() << std::endl; - } - - try - { - if (x_[0].size() != feature_id_.size()) - { - throw invalid_argument("the number of columns of x is different from the size of feature_id"); - } - } - catch (std::exception &e) - { - std::cerr << e.what() << std::endl; - } - - try - { - if (subsample_cols_ > 1 || subsample_cols_ < 0) - { - throw out_of_range("subsample_cols should be in [1, 0]"); - } - } - catch (std::exception &e) - { - std::cerr << e.what() << std::endl; - } +struct Party { + vector> x; // a feature vector of this party + vector feature_id; // id of the features + int party_id; // id of this party + int min_leaf; + float subsample_cols; // ratio of subsampled columuns + bool use_missing_value; + int seed; + + int col_count; // the number of columns + int subsample_col_count; + + int num_classes; + + unordered_map> + lookup_table; // record_id: (feature_id, threshold, missing_value_dir) + vector temp_column_subsample; + vector> temp_thresholds; // feature_id->threshold + + // Party() {} + Party(vector> x_, int num_classes_, vector feature_id_, + int party_id_, int min_leaf_, float subsample_cols_, + bool use_missing_value_ = false, int seed_ = 0) { + validate_arguments(x_, feature_id_, party_id_, min_leaf_, subsample_cols_); + x = x_; + num_classes = num_classes_; + feature_id = feature_id_; + party_id = party_id_; + min_leaf = min_leaf_; + subsample_cols = subsample_cols_; + use_missing_value = use_missing_value_; + seed = seed_; + + col_count = x.at(0).size(); + subsample_col_count = max(1, int(subsample_cols * float(col_count))); + } + + void validate_arguments(vector> &x_, vector &feature_id_, + int &party_id_, int min_leaf_, + float subsample_cols_) { + try { + if (x_.size() == 0) { + throw invalid_argument("x is empty"); + } + } catch (std::exception &e) { + std::cerr << e.what() << std::endl; } - unordered_map> get_lookup_table() - { - return lookup_table; + try { + if (x_[0].size() != feature_id_.size()) { + throw invalid_argument("the number of columns of x is different from " + "the size of feature_id"); + } + } catch (std::exception &e) { + std::cerr << e.what() << std::endl; } - vector get_threshold_candidates(vector &x_col) - { - vector x_col_wo_duplicates = remove_duplicates(x_col); - vector thresholds(x_col_wo_duplicates.size()); - copy(x_col_wo_duplicates.begin(), x_col_wo_duplicates.end(), thresholds.begin()); - sort(thresholds.begin(), thresholds.end()); - return thresholds; + try { + if (subsample_cols_ > 1 || subsample_cols_ < 0) { + throw out_of_range("subsample_cols should be in [1, 0]"); + } + } catch (std::exception &e) { + std::cerr << e.what() << std::endl; } - - bool is_left(int record_id, vector &xi) - { - bool flag; - float x_criterion = xi[feature_id[get<0>(lookup_table[record_id])]]; - if (isnan(x_criterion)) - { - try - { - if (!use_missing_value) - { - throw std::runtime_error("given data contains NaN, but use_missing_value is false"); - } - else - { - flag = get<2>(lookup_table[record_id]) == 0; - } - } - catch (std::exception &e) - { - std::cout << e.what() << std::endl; - } + } + + unordered_map> get_lookup_table() { + return lookup_table; + } + + vector get_threshold_candidates(vector &x_col) { + vector x_col_wo_duplicates = remove_duplicates(x_col); + vector thresholds(x_col_wo_duplicates.size()); + copy(x_col_wo_duplicates.begin(), x_col_wo_duplicates.end(), + thresholds.begin()); + sort(thresholds.begin(), thresholds.end()); + return thresholds; + } + + bool is_left(int record_id, vector &xi) { + bool flag; + float x_criterion = xi[feature_id[get<0>(lookup_table[record_id])]]; + if (isnan(x_criterion)) { + try { + if (!use_missing_value) { + throw std::runtime_error( + "given data contains NaN, but use_missing_value is false"); + } else { + flag = get<2>(lookup_table[record_id]) == 0; } - else - { - flag = x_criterion <= get<1>(lookup_table[record_id]); - } - return flag; + } catch (std::exception &e) { + std::cout << e.what() << std::endl; + } + } else { + flag = x_criterion <= get<1>(lookup_table[record_id]); } - - void subsample_columns() - { - temp_column_subsample.resize(col_count); - iota(temp_column_subsample.begin(), temp_column_subsample.end(), 0); - mt19937 engine(seed); - seed += 1; - shuffle(temp_column_subsample.begin(), temp_column_subsample.end(), engine); + return flag; + } + + void subsample_columns() { + temp_column_subsample.resize(col_count); + iota(temp_column_subsample.begin(), temp_column_subsample.end(), 0); + mt19937 engine(seed); + seed += 1; + shuffle(temp_column_subsample.begin(), temp_column_subsample.end(), engine); + } + + vector split_rows(vector &idxs, int feature_opt_pos, + int threshold_opt_pos) { + // feature_opt_idがthreshold_opt_id以下のindexを返す + int feature_opt_id, missing_dir; + feature_opt_id = + temp_column_subsample[feature_opt_pos % subsample_col_count]; + if (feature_opt_pos > subsample_col_count) { + missing_dir = 1; + } else { + missing_dir = 0; } - - vector split_rows(vector &idxs, int feature_opt_pos, int threshold_opt_pos) - { - // feature_opt_idがthreshold_opt_id以下のindexを返す - int feature_opt_id, missing_dir; - feature_opt_id = temp_column_subsample[feature_opt_pos % subsample_col_count]; - if (feature_opt_pos > subsample_col_count) - { - missing_dir = 1; - } - else - { - missing_dir = 0; - } - int row_count = idxs.size(); - vector x_col(row_count); - for (int r = 0; r < row_count; r++) - x_col[r] = x[idxs[r]][feature_opt_id]; - - vector left_idxs; - float threshold = temp_thresholds[feature_opt_pos][threshold_opt_pos]; - for (int r = 0; r < row_count; r++) - if (((!isnan(x_col[r])) && (x_col[r] <= threshold)) || - ((isnan(x_col[r])) && (missing_dir == 1))) - left_idxs.push_back(idxs[r]); - - return left_idxs; + int row_count = idxs.size(); + vector x_col(row_count); + for (int r = 0; r < row_count; r++) + x_col[r] = x[idxs[r]][feature_opt_id]; + + vector left_idxs; + float threshold = temp_thresholds[feature_opt_pos][threshold_opt_pos]; + for (int r = 0; r < row_count; r++) + if (((!isnan(x_col[r])) && (x_col[r] <= threshold)) || + ((isnan(x_col[r])) && (missing_dir == 1))) + left_idxs.push_back(idxs[r]); + + return left_idxs; + } + + int insert_lookup_table(int feature_opt_pos, int threshold_opt_pos) { + int feature_opt_id, missing_dir; + float threshold_opt; + feature_opt_id = + temp_column_subsample[feature_opt_pos % subsample_col_count]; + threshold_opt = temp_thresholds[feature_opt_pos][threshold_opt_pos]; + + if (use_missing_value) { + if (feature_opt_pos > subsample_col_count) { + missing_dir = 1; + } else { + missing_dir = 0; + } + } else { + missing_dir = -1; } - int insert_lookup_table(int feature_opt_pos, int threshold_opt_pos) - { - int feature_opt_id, missing_dir; - float threshold_opt; - feature_opt_id = temp_column_subsample[feature_opt_pos % subsample_col_count]; - threshold_opt = temp_thresholds[feature_opt_pos][threshold_opt_pos]; - - if (use_missing_value) - { - if (feature_opt_pos > subsample_col_count) - { - missing_dir = 1; - } - else - { - missing_dir = 0; - } - } - else - { - missing_dir = -1; - } - - lookup_table.emplace(lookup_table.size(), - make_tuple(feature_opt_id, threshold_opt, missing_dir)); - return lookup_table.size() - 1; - } + lookup_table.emplace( + lookup_table.size(), + make_tuple(feature_opt_id, threshold_opt, missing_dir)); + return lookup_table.size() - 1; + } }; diff --git a/src/aijack/collaborative/tree/core/tree.h b/src/aijack/collaborative/tree/core/tree.h index a158f68d..17c81e0d 100644 --- a/src/aijack/collaborative/tree/core/tree.h +++ b/src/aijack/collaborative/tree/core/tree.h @@ -1,107 +1,97 @@ #pragma once -#include +#include "../core/nodeapi.h" +#include #include #include -#include -#include "../core/nodeapi.h" +#include using namespace std; -template -struct Tree -{ - NodeType dtree; - NodeAPI nodeapi; +template struct Tree { + NodeType dtree; + NodeAPI nodeapi; - Tree() {} + Tree(NodeType dtree_) : dtree(dtree_) {} - /** - * @brief Get the root node object - * - * @return NodeType& - */ - NodeType &get_root_node() - { - return *dtree; - } + /** + * @brief Get the root node object + * + * @return NodeType& + */ + NodeType &get_root_node() { return *dtree; } - /** - * @brief Return the predicted value of the give new sample X - * - * @param X the new sample to be predicted - * @return vector> - */ - vector> predict(vector> &X) - { - return nodeapi.predict(&dtree, X); - } + /** + * @brief Return the predicted value of the give new sample X + * + * @param X the new sample to be predicted + * @return vector> + */ + vector> predict(vector> &X) { + return nodeapi.predict(&dtree, X); + } - /** - * @brief Recursively extract the vector of predictions of the training data from the specified node - * - * @param node target node - * @return vector, vector>>> - */ - vector, vector>>> extract_train_prediction_from_node(NodeType &node) - { - if (node.is_leaf()) - { - vector, vector>>> result; - result.push_back(make_pair(node.idxs, - vector>(node.idxs.size(), - node.val))); - return result; - } - else - { - vector, vector>>> left_result = - extract_train_prediction_from_node(*node.left); - vector, vector>>> right_result = - extract_train_prediction_from_node(*node.right); - left_result.insert(left_result.end(), right_result.begin(), right_result.end()); - return left_result; - } + /** + * @brief Recursively extract the vector of predictions of the training data + * from the specified node + * + * @param node target node + * @return vector, vector>>> + */ + vector, vector>>> + extract_train_prediction_from_node(NodeType *node) { + if (node->is_leaf()) { + vector, vector>>> result; + result.push_back(make_pair( + node->idxs, vector>(node->idxs.size(), node->val))); + return result; + } else { + vector, vector>>> left_result = + extract_train_prediction_from_node(node->left); + vector, vector>>> right_result = + extract_train_prediction_from_node(node->right); + left_result.insert(left_result.end(), right_result.begin(), + right_result.end()); + return left_result; } + } - /** - * @brief Recursively extract the vector of predictions of the training data - * - * @return vector> - */ - vector> get_train_prediction() - { - vector, vector>>> result = extract_train_prediction_from_node(dtree); - vector> y_train_pred(dtree.y.size()); - for (int i = 0; i < result.size(); i++) - { - for (int j = 0; j < result[i].first.size(); j++) - { - y_train_pred[result[i].first[j]] = result[i].second[j]; - } - } - - return y_train_pred; + /** + * @brief Recursively extract the vector of predictions of the training data + * + * @return vector> + */ + vector> get_train_prediction() { + vector, vector>>> result = + extract_train_prediction_from_node(&dtree); + vector> y_train_pred(dtree.y.size()); + for (int i = 0; i < result.size(); i++) { + for (int j = 0; j < result[i].first.size(); j++) { + y_train_pred[result[i].first[j]] = result[i].second[j]; + } } - /** - * @brief Printout the structure of this tree - * - * @param show_purity Show leaf purity of each leaf node, if true - * @param binary_color Color the leaf purity (red: >= 0.8, yellow: 0.8 ~ 0.7, green: 0.7 >) - * @param target_party_id The id of the active party - * @return string - */ - string print(bool show_purity = false, bool binary_color = true, int target_party_id = -1) - { - return nodeapi.print(&dtree, show_purity, binary_color, target_party_id); - } + return y_train_pred; + } - /** - * @brief Get the average of leaf purity - * - * @return float - */ - float get_leaf_purity() - { - return nodeapi.get_leaf_purity(&dtree, dtree.idxs.size()); - } + /** + * @brief Printout the structure of this tree + * + * @param show_purity Show leaf purity of each leaf node, if true + * @param binary_color Color the leaf purity (red: >= 0.8, yellow: 0.8 ~ 0.7, + * green: 0.7 >) + * @param target_party_id The id of the active party + * @return string + */ + string print(bool show_purity = false, bool binary_color = true, + int target_party_id = -1) { + return nodeapi.print(&dtree, show_purity, binary_color, target_party_id); + } + + /** + * @brief Get the average of leaf purity + * + * @return float + */ + float get_leaf_purity() { + return nodeapi.get_leaf_purity(&dtree, dtree.idxs.size()); + } }; diff --git a/src/aijack/collaborative/tree/secureboost/node.h b/src/aijack/collaborative/tree/secureboost/node.h index cb4a42ec..4ec51370 100644 --- a/src/aijack/collaborative/tree/secureboost/node.h +++ b/src/aijack/collaborative/tree/secureboost/node.h @@ -1,331 +1,296 @@ #pragma once -#include "../xgboost/node.h" #include "../../../defense/paillier/src/paillier.h" +#include "../xgboost/node.h" #include "party.h" using namespace std; -struct SecureBoostNode : Node -{ - vector parties; - vector> gradient, hessian; - vector> vanila_gradient, vanila_hessian; - float min_child_weight, lam, gamma, eps; - bool use_only_active_party; - SecureBoostNode *left, *right; - - int num_classes; - - SecureBoostNode() {} - SecureBoostNode(vector &parties_, vector &y_, - int num_classes_, - vector> &gradient_, - vector> &hessian_, - vector> &vanila_gradient_, - vector> &vanila_hessian_, - vector &idxs_, float min_child_weight_, float lam_, - float gamma_, float eps_, int depth_, int active_party_id_ = 0, - bool use_only_active_party_ = false, int n_job_ = 1) - { - parties = parties_; - y = y_; - num_classes = num_classes_; - gradient = gradient_; - hessian = hessian_; - vanila_gradient = vanila_gradient_; - vanila_hessian = vanila_hessian_; - idxs = idxs_; - min_child_weight = min_child_weight_; - lam = lam_; - gamma = gamma_; - eps = eps_; - depth = depth_; - active_party_id = active_party_id_; - use_only_active_party = use_only_active_party_; - n_job = n_job_; - - row_count = idxs.size(); - num_parties = parties.size(); - - val = compute_weight(); - // tuple best_split = find_split(); - - if (is_leaf()) - { - is_leaf_flag = 1; - } - else - { - is_leaf_flag = 0; - } - - if (is_leaf_flag == 0) - { - tuple best_split = find_split(); - party_id = get<0>(best_split); - if (party_id != -1) - { - record_id = parties[party_id].insert_lookup_table(get<1>(best_split), get<2>(best_split)); - make_children_nodes(get<0>(best_split), get<1>(best_split), get<2>(best_split)); - } - else - { - is_leaf_flag = 1; - } - } - } - - vector get_idxs() - { - return idxs; - } - - int get_party_id() - { - return party_id; - } - - int get_record_id() - { - return record_id; - } - - vector get_val() - { - return val; - } - - float get_score() - { - return score; - } - - SecureBoostNode get_left() - { - return *left; - } - - SecureBoostNode get_right() - { - return *right; +struct SecureBoostNode : Node { + vector> &gradient, &hessian; + vector> &vanila_gradient, &vanila_hessian; + float min_child_weight, lam, gamma, eps; + bool use_only_active_party; + SecureBoostNode *left, *right; + + int num_classes; + + // SecureBoostNode() {} + SecureBoostNode(vector &parties_, vector &y_, + int num_classes_, + vector> &gradient_, + vector> &hessian_, + vector> &vanila_gradient_, + vector> &vanila_hessian_, vector &idxs_, + float min_child_weight_, float lam_, float gamma_, float eps_, + int depth_, int active_party_id_ = 0, + bool use_only_active_party_ = false, int n_job_ = 1) + : gradient(gradient_), hessian(hessian_), + vanila_gradient(vanila_gradient_), + vanila_hessian(vanila_hessian_), Node(parties_, idxs_, + y_) { + num_classes = num_classes_; + min_child_weight = min_child_weight_; + lam = lam_; + gamma = gamma_; + eps = eps_; + depth = depth_; + active_party_id = active_party_id_; + use_only_active_party = use_only_active_party_; + n_job = n_job_; + + row_count = idxs.size(); + num_parties = parties.size(); + + val = compute_weight(); + // tuple best_split = find_split(); + + if (is_leaf()) { + is_leaf_flag = 1; + } else { + is_leaf_flag = 0; } - int get_num_parties() - { - return parties.size(); + if (is_leaf_flag == 0) { + tuple best_split = find_split(); + party_id = get<0>(best_split); + if (party_id != -1) { + record_id = parties[party_id].insert_lookup_table(get<1>(best_split), + get<2>(best_split)); + make_children_nodes(get<0>(best_split), get<1>(best_split), + get<2>(best_split)); + } else { + is_leaf_flag = 1; + } } - - vector compute_weight() - { - return xgboost_compute_weight(row_count, vanila_gradient, vanila_hessian, idxs, lam); - } - - float compute_gain(vector &left_grad, vector &right_grad, vector &left_hess, vector &right_hess) - { - return xgboost_compute_gain(left_grad, right_grad, left_hess, right_hess, gamma, lam); + } + + SecureBoostNode &operator=(const SecureBoostNode &other) { + if (this != &other) { + parties = other.parties; + y = other.y; + idxs = other.idxs; + val = other.val; + + num_classes = other.num_classes; + depth = other.depth; + active_party_id = other.active_party_id; + n_job = other.n_job; + + party_id = other.party_id; + record_id = other.record_id; + row_count = other.row_count; + num_parties = other.num_parties; + is_leaf_flag = other.is_leaf_flag; + + left = other.left; + right = other.right; } - - void find_split_per_party(int party_id_start, int temp_num_parties, vector &sum_grad, vector &sum_hess) - { - int grad_dim = sum_grad.size(); - - for (int temp_party_id = party_id_start; temp_party_id < party_id_start + temp_num_parties; temp_party_id++) - { - - vector, vector>>> search_results; - if (temp_party_id == active_party_id) - { - search_results = parties[temp_party_id].greedy_search_split(vanila_gradient, vanila_hessian, idxs); - } - else - { - vector, vector>>> encrypted_search_result = - parties[temp_party_id].greedy_search_split_encrypt(gradient, hessian, idxs); - int temp_result_size = encrypted_search_result.size(); - search_results.resize(temp_result_size); - int temp_vec_size; - for (int j = 0; j < temp_result_size; j++) - { - temp_vec_size = encrypted_search_result[j].size(); - search_results[j].resize(temp_vec_size); - for (int k = 0; k < temp_vec_size; k++) - { - vector temp_grad_decrypted, temp_hess_decrypted; - temp_grad_decrypted.resize(grad_dim); - temp_hess_decrypted.resize(grad_dim); - - for (int c = 0; c < grad_dim; c++) - { - temp_grad_decrypted[c] = parties[active_party_id] - .sk.decrypt( - encrypted_search_result[j][k].first[c]); - temp_hess_decrypted[c] = parties[active_party_id] - .sk.decrypt( - encrypted_search_result[j][k].second[c]); - } - search_results[j][k] = make_pair(temp_grad_decrypted, temp_hess_decrypted); - } - } - } - - float temp_score; - vector temp_left_grad(grad_dim, 0); - vector temp_left_hess(grad_dim, 0); - vector temp_right_grad(grad_dim, 0); - vector temp_right_hess(grad_dim, 0); - bool skip_flag = false; - - for (int j = 0; j < search_results.size(); j++) - { - temp_score = 0; - - for (int c = 0; c < grad_dim; c++) - { - temp_left_grad[c] = 0; - temp_left_hess[c] = 0; - } - - for (int k = 0; k < search_results[j].size(); k++) - { - for (int c = 0; c < grad_dim; c++) - { - temp_left_grad[c] += search_results[j][k].first[c]; - temp_left_hess[c] += search_results[j][k].second[c]; - } - - skip_flag = false; - for (int c = 0; c < grad_dim; c++) - { - if (temp_left_hess[c] < min_child_weight || - sum_hess[c] - temp_left_hess[c] < min_child_weight) - { - skip_flag = true; - } - } - if (skip_flag) - { - continue; - } - - for (int c = 0; c < grad_dim; c++) - { - temp_right_grad[c] = sum_grad[c] - temp_left_grad[c]; - temp_right_hess[c] = sum_hess[c] - temp_left_hess[c]; - } - - temp_score = compute_gain(temp_left_grad, temp_right_grad, - temp_left_hess, temp_right_hess); - - if (temp_score > best_score) - { - best_score = temp_score; - best_party_id = temp_party_id; - best_col_id = j; - best_threshold_id = k; - } - } + return *this; + } + + vector get_idxs() { return idxs; } + + int get_party_id() { return party_id; } + + int get_record_id() { return record_id; } + + vector get_val() { return val; } + + float get_score() { return score; } + + SecureBoostNode get_left() { return *left; } + + SecureBoostNode get_right() { return *right; } + + int get_num_parties() { return parties.size(); } + + vector compute_weight() { + return xgboost_compute_weight(row_count, vanila_gradient, vanila_hessian, + idxs, lam); + } + + float compute_gain(vector &left_grad, vector &right_grad, + vector &left_hess, vector &right_hess) { + return xgboost_compute_gain(left_grad, right_grad, left_hess, right_hess, + gamma, lam); + } + + void find_split_per_party(int party_id_start, int temp_num_parties, + vector &sum_grad, vector &sum_hess) { + int grad_dim = sum_grad.size(); + + for (int temp_party_id = party_id_start; + temp_party_id < party_id_start + temp_num_parties; temp_party_id++) { + + vector, vector>>> search_results; + if (temp_party_id == active_party_id) { + search_results = parties[temp_party_id].greedy_search_split( + vanila_gradient, vanila_hessian, idxs); + } else { + vector, vector>>> + encrypted_search_result = + parties[temp_party_id].greedy_search_split_encrypt( + gradient, hessian, idxs); + int temp_result_size = encrypted_search_result.size(); + search_results.resize(temp_result_size); + int temp_vec_size; + for (int j = 0; j < temp_result_size; j++) { + temp_vec_size = encrypted_search_result[j].size(); + search_results[j].resize(temp_vec_size); + for (int k = 0; k < temp_vec_size; k++) { + vector temp_grad_decrypted, temp_hess_decrypted; + temp_grad_decrypted.resize(grad_dim); + temp_hess_decrypted.resize(grad_dim); + + for (int c = 0; c < grad_dim; c++) { + temp_grad_decrypted[c] = + parties[active_party_id].sk.decrypt( + encrypted_search_result[j][k].first[c]); + temp_hess_decrypted[c] = + parties[active_party_id].sk.decrypt( + encrypted_search_result[j][k].second[c]); } + search_results[j][k] = + make_pair(temp_grad_decrypted, temp_hess_decrypted); + } } - } + } - tuple find_split() - { - vector sum_grad(gradient[0].size(), 0); - vector sum_hess(hessian[0].size(), 0); - for (int i = 0; i < row_count; i++) - { - for (int c = 0; c < sum_grad.size(); c++) - { - sum_grad[c] += vanila_gradient[idxs[i]][c]; - sum_hess[c] += vanila_hessian[idxs[i]][c]; - } - } + float temp_score; + vector temp_left_grad(grad_dim, 0); + vector temp_left_hess(grad_dim, 0); + vector temp_right_grad(grad_dim, 0); + vector temp_right_hess(grad_dim, 0); + bool skip_flag = false; - float temp_score, temp_left_grad, temp_left_hess; + for (int j = 0; j < search_results.size(); j++) { + temp_score = 0; - if (use_only_active_party) - { - find_split_per_party(active_party_id, 1, sum_grad, sum_hess); + for (int c = 0; c < grad_dim; c++) { + temp_left_grad[c] = 0; + temp_left_hess[c] = 0; } - else - { - if (n_job == 1) - { - find_split_per_party(0, num_parties, sum_grad, sum_hess); - } - else - { - vector num_parties_per_thread = get_num_parties_per_process(n_job, num_parties); - int cnt_parties = 0; - vector threads_parties; - for (int i = 0; i < n_job; i++) - { - int local_num_parties = num_parties_per_thread[i]; - thread temp_th([this, cnt_parties, local_num_parties, &sum_grad, &sum_hess] - { this->find_split_per_party(cnt_parties, local_num_parties, sum_grad, sum_hess); }); - threads_parties.push_back(move(temp_th)); - cnt_parties += num_parties_per_thread[i]; - } - for (int i = 0; i < num_parties; i++) - { - threads_parties[i].join(); - } + + for (int k = 0; k < search_results[j].size(); k++) { + for (int c = 0; c < grad_dim; c++) { + temp_left_grad[c] += search_results[j][k].first[c]; + temp_left_hess[c] += search_results[j][k].second[c]; + } + + skip_flag = false; + for (int c = 0; c < grad_dim; c++) { + if (temp_left_hess[c] < min_child_weight || + sum_hess[c] - temp_left_hess[c] < min_child_weight) { + skip_flag = true; } + } + if (skip_flag) { + continue; + } + + for (int c = 0; c < grad_dim; c++) { + temp_right_grad[c] = sum_grad[c] - temp_left_grad[c]; + temp_right_hess[c] = sum_hess[c] - temp_left_hess[c]; + } + + temp_score = compute_gain(temp_left_grad, temp_right_grad, + temp_left_hess, temp_right_hess); + + if (temp_score > best_score) { + best_score = temp_score; + best_party_id = temp_party_id; + best_col_id = j; + best_threshold_id = k; + } } - score = best_score; - return make_tuple(best_party_id, best_col_id, best_threshold_id); + } } - - void make_children_nodes(int best_party_id, int best_col_id, int best_threshold_id) - { - // TODO: remove idx with nan values from right_idxs; - vector left_idxs = parties[best_party_id].split_rows(idxs, best_col_id, best_threshold_id); - vector right_idxs; - for (int i = 0; i < row_count; i++) - if (!any_of(left_idxs.begin(), left_idxs.end(), [&](float x) - { return x == idxs[i]; })) - right_idxs.push_back(idxs[i]); - - left = new SecureBoostNode(parties, y, num_classes, gradient, hessian, - vanila_gradient, vanila_hessian, - left_idxs, min_child_weight, - lam, gamma, eps, depth - 1, active_party_id, use_only_active_party); - if (left->is_leaf_flag == 1) - { - left->party_id = party_id; - } - right = new SecureBoostNode(parties, y, num_classes, gradient, hessian, - vanila_gradient, vanila_hessian, - right_idxs, min_child_weight, - lam, gamma, eps, depth - 1, active_party_id, use_only_active_party); - if (right->is_leaf_flag == 1) - { - right->party_id = party_id; - } + } + + tuple find_split() { + vector sum_grad(gradient[0].size(), 0); + vector sum_hess(hessian[0].size(), 0); + for (int i = 0; i < row_count; i++) { + for (int c = 0; c < sum_grad.size(); c++) { + sum_grad[c] += vanila_gradient[idxs[i]][c]; + sum_hess[c] += vanila_hessian[idxs[i]][c]; + } } - bool is_leaf() - { - if (is_leaf_flag == -1) - { - return is_pure() || std::isinf(score) || depth <= 0; + float temp_score, temp_left_grad, temp_left_hess; + + if (use_only_active_party) { + find_split_per_party(active_party_id, 1, sum_grad, sum_hess); + } else { + if (n_job == 1) { + find_split_per_party(0, num_parties, sum_grad, sum_hess); + } else { + vector num_parties_per_thread = + get_num_parties_per_process(n_job, num_parties); + int cnt_parties = 0; + vector threads_parties; + for (int i = 0; i < n_job; i++) { + int local_num_parties = num_parties_per_thread[i]; + thread temp_th( + [this, cnt_parties, local_num_parties, &sum_grad, &sum_hess] { + this->find_split_per_party(cnt_parties, local_num_parties, + sum_grad, sum_hess); + }); + threads_parties.push_back(move(temp_th)); + cnt_parties += num_parties_per_thread[i]; } - else - { - return is_leaf_flag; + for (int i = 0; i < num_parties; i++) { + threads_parties[i].join(); } + } } + score = best_score; + return make_tuple(best_party_id, best_col_id, best_threshold_id); + } + + void make_children_nodes(int best_party_id, int best_col_id, + int best_threshold_id) { + // TODO: remove idx with nan values from right_idxs; + vector left_idxs = + parties[best_party_id].split_rows(idxs, best_col_id, best_threshold_id); + vector right_idxs; + for (int i = 0; i < row_count; i++) + if (!any_of(left_idxs.begin(), left_idxs.end(), + [&](float x) { return x == idxs[i]; })) + right_idxs.push_back(idxs[i]); + + left = new SecureBoostNode(parties, y, num_classes, gradient, hessian, + vanila_gradient, vanila_hessian, left_idxs, + min_child_weight, lam, gamma, eps, depth - 1, + active_party_id, use_only_active_party); + if (left->is_leaf_flag == 1) { + left->party_id = party_id; + } + right = new SecureBoostNode(parties, y, num_classes, gradient, hessian, + vanila_gradient, vanila_hessian, right_idxs, + min_child_weight, lam, gamma, eps, depth - 1, + active_party_id, use_only_active_party); + if (right->is_leaf_flag == 1) { + right->party_id = party_id; + } + } - bool is_pure() - { - set s{}; - for (int i = 0; i < row_count; i++) - { - if (s.insert(y[idxs[i]]).second) - { - if (s.size() == 2) - return false; - } - } - return true; + bool is_leaf() { + if (is_leaf_flag == -1) { + return is_pure() || std::isinf(score) || depth <= 0; + } else { + return is_leaf_flag; + } + } + + bool is_pure() { + set s{}; + for (int i = 0; i < row_count; i++) { + if (s.insert(y[idxs[i]]).second) { + if (s.size() == 2) + return false; + } } + return true; + } }; diff --git a/src/aijack/collaborative/tree/secureboost/party.h b/src/aijack/collaborative/tree/secureboost/party.h index 3c5bf29a..a10abf4a 100644 --- a/src/aijack/collaborative/tree/secureboost/party.h +++ b/src/aijack/collaborative/tree/secureboost/party.h @@ -1,299 +1,251 @@ #pragma once -#include "../xgboost/party.h" #include "../../../defense/paillier/src/paillier.h" +#include "../xgboost/party.h" using namespace std; -struct SecureBoostParty : XGBoostParty -{ - PaillierPublicKey pk; - PaillierSecretKey sk; - - SecureBoostParty() {} - SecureBoostParty(vector> x_, int num_classes_, - vector feature_id_, int party_id_, - int min_leaf_, float subsample_cols_, - int num_precentile_bin_ = 256, - bool use_missing_value_ = false, - int seed_ = 0) : XGBoostParty(x_, num_classes_, feature_id_, party_id_, - min_leaf_, subsample_cols_, - num_precentile_bin_, - use_missing_value_, seed_) {} - - void set_publickey(PaillierPublicKey pk_) - { - pk = pk_; - } - - void set_secretkey(PaillierSecretKey sk_) - { - sk = sk_; - } - - vector, vector>>> greedy_search_split(vector> &gradient, - vector> &hessian, - vector &idxs) - { - // feature_id -> [(grad hess)] - // the threshold of split_candidates_grad_hess[i][j] = temp_thresholds[i][j] - int num_thresholds; - if (use_missing_value) - num_thresholds = subsample_col_count * 2; - else - num_thresholds = subsample_col_count; - vector, vector>>> split_candidates_grad_hess(num_thresholds); - temp_thresholds = vector>(num_thresholds); - - int row_count = idxs.size(); - int recoed_id = 0; - - int grad_dim = gradient[0].size(); - - for (int i = 0; i < subsample_col_count; i++) - { - // extract the necessary data - int k = temp_column_subsample[i]; - vector x_col(row_count); - - int not_missing_values_count = 0; - int missing_values_count = 0; - for (int r = 0; r < row_count; r++) - { - if (!isnan(x[idxs[r]][k])) - { - x_col[not_missing_values_count] = x[idxs[r]][k]; - not_missing_values_count += 1; - } - else - { - missing_values_count += 1; - } - } - x_col.resize(not_missing_values_count); - - vector x_col_idxs(not_missing_values_count); - iota(x_col_idxs.begin(), x_col_idxs.end(), 0); - sort(x_col_idxs.begin(), x_col_idxs.end(), [&x_col](size_t i1, size_t i2) - { return x_col[i1] < x_col[i2]; }); - - sort(x_col.begin(), x_col.end()); - - // get percentiles of x_col - vector percentiles = get_threshold_candidates(x_col); - - // enumerate all threshold value (missing value goto right) - int current_min_idx = 0; - int cumulative_left_size = 0; - for (int p = 0; p < percentiles.size(); p++) - { - vector temp_grad(grad_dim, 0); - vector temp_hess(grad_dim, 0); - int temp_left_size = 0; - - for (int r = current_min_idx; r < not_missing_values_count; r++) - { - if (x_col[r] <= percentiles[p]) - { - for (int c = 0; c < grad_dim; c++) - { - temp_grad[c] += gradient[idxs[x_col_idxs[r]]][c]; - temp_hess[c] += hessian[idxs[x_col_idxs[r]]][c]; - } - cumulative_left_size += 1; - } - else - { - current_min_idx = r; - break; - } - } - - if (cumulative_left_size >= min_leaf && - row_count - cumulative_left_size >= min_leaf) - { - split_candidates_grad_hess[i].push_back(make_pair(temp_grad, temp_hess)); - temp_thresholds[i].push_back(percentiles[p]); - } +struct SecureBoostParty : public XGBoostParty { + PaillierPublicKey pk; + PaillierSecretKey sk; + + // SecureBoostParty() {} + SecureBoostParty(vector> x_, int num_classes_, + vector feature_id_, int party_id_, int min_leaf_, + float subsample_cols_, int num_precentile_bin_ = 256, + bool use_missing_value_ = false, int seed_ = 0) + : XGBoostParty(x_, num_classes_, feature_id_, party_id_, min_leaf_, + subsample_cols_, num_precentile_bin_, use_missing_value_, + seed_) {} + + void set_publickey(PaillierPublicKey pk_) { pk = pk_; } + + void set_secretkey(PaillierSecretKey sk_) { sk = sk_; } + + vector, vector>>> + greedy_search_split(vector> &gradient, + vector> &hessian, vector &idxs) { + // feature_id -> [(grad hess)] + // the threshold of split_candidates_grad_hess[i][j] = temp_thresholds[i][j] + int num_thresholds; + if (use_missing_value) + num_thresholds = subsample_col_count * 2; + else + num_thresholds = subsample_col_count; + vector, vector>>> + split_candidates_grad_hess(num_thresholds); + temp_thresholds = vector>(num_thresholds); + + int row_count = idxs.size(); + int recoed_id = 0; + + int grad_dim = gradient[0].size(); + + for (int i = 0; i < subsample_col_count; i++) { + // extract the necessary data + int k = temp_column_subsample[i]; + vector x_col(row_count); + + int not_missing_values_count = 0; + int missing_values_count = 0; + for (int r = 0; r < row_count; r++) { + if (!isnan(x[idxs[r]][k])) { + x_col[not_missing_values_count] = x[idxs[r]][k]; + not_missing_values_count += 1; + } else { + missing_values_count += 1; + } + } + x_col.resize(not_missing_values_count); + + vector x_col_idxs(not_missing_values_count); + iota(x_col_idxs.begin(), x_col_idxs.end(), 0); + sort(x_col_idxs.begin(), x_col_idxs.end(), + [&x_col](size_t i1, size_t i2) { return x_col[i1] < x_col[i2]; }); + + sort(x_col.begin(), x_col.end()); + + // get percentiles of x_col + vector percentiles = get_threshold_candidates(x_col); + + // enumerate all threshold value (missing value goto right) + int current_min_idx = 0; + int cumulative_left_size = 0; + for (int p = 0; p < percentiles.size(); p++) { + vector temp_grad(grad_dim, 0); + vector temp_hess(grad_dim, 0); + int temp_left_size = 0; + + for (int r = current_min_idx; r < not_missing_values_count; r++) { + if (x_col[r] <= percentiles[p]) { + for (int c = 0; c < grad_dim; c++) { + temp_grad[c] += gradient[idxs[x_col_idxs[r]]][c]; + temp_hess[c] += hessian[idxs[x_col_idxs[r]]][c]; } + cumulative_left_size += 1; + } else { + current_min_idx = r; + break; + } + } - // enumerate missing value goto left - if (use_missing_value) - { - int current_max_idx = not_missing_values_count - 1; - int cumulative_right_size = 0; - for (int p = percentiles.size() - 1; p >= 0; p--) - { - vector temp_grad(grad_dim, 0); - vector temp_hess(grad_dim, 0); - int temp_left_size = 0; - - for (int r = current_max_idx; r >= 0; r--) - { - if (x_col[r] >= percentiles[p]) - { - for (int c = 0; c < grad_dim; c++) - { - temp_grad[c] += gradient[idxs[x_col_idxs[r]]][c]; - temp_hess[c] += hessian[idxs[x_col_idxs[r]]][c]; - } - cumulative_right_size += 1; - } - else - { - current_max_idx = r; - break; - } - } - - if (cumulative_right_size >= min_leaf && - row_count - cumulative_right_size >= min_leaf) - { - split_candidates_grad_hess[i + subsample_col_count].push_back(make_pair(temp_grad, - temp_hess)); - temp_thresholds[i + subsample_col_count].push_back(percentiles[p]); - } - } + if (cumulative_left_size >= min_leaf && + row_count - cumulative_left_size >= min_leaf) { + split_candidates_grad_hess[i].push_back( + make_pair(temp_grad, temp_hess)); + temp_thresholds[i].push_back(percentiles[p]); + } + } + + // enumerate missing value goto left + if (use_missing_value) { + int current_max_idx = not_missing_values_count - 1; + int cumulative_right_size = 0; + for (int p = percentiles.size() - 1; p >= 0; p--) { + vector temp_grad(grad_dim, 0); + vector temp_hess(grad_dim, 0); + int temp_left_size = 0; + + for (int r = current_max_idx; r >= 0; r--) { + if (x_col[r] >= percentiles[p]) { + for (int c = 0; c < grad_dim; c++) { + temp_grad[c] += gradient[idxs[x_col_idxs[r]]][c]; + temp_hess[c] += hessian[idxs[x_col_idxs[r]]][c]; + } + cumulative_right_size += 1; + } else { + current_max_idx = r; + break; } + } + + if (cumulative_right_size >= min_leaf && + row_count - cumulative_right_size >= min_leaf) { + split_candidates_grad_hess[i + subsample_col_count].push_back( + make_pair(temp_grad, temp_hess)); + temp_thresholds[i + subsample_col_count].push_back(percentiles[p]); + } } - - return split_candidates_grad_hess; + } } - vector, vector>>> greedy_search_split_encrypt(vector> &gradient, - vector> &hessian, - vector &idxs) - { - // feature_id -> [(grad hess)] - // the threshold of split_candidates_grad_hess[i][j] = temp_thresholds[i][j] - int num_thresholds; - if (use_missing_value) - num_thresholds = subsample_col_count * 2; - else - num_thresholds = subsample_col_count; - vector, vector>>> split_candidates_grad_hess(num_thresholds); - temp_thresholds = vector>(num_thresholds); - - int row_count = idxs.size(); - int recoed_id = 0; - - int grad_dim = gradient[0].size(); - - for (int i = 0; i < subsample_col_count; i++) - { - // extract the necessary data - int k = temp_column_subsample[i]; - vector x_col(row_count); - - int not_missing_values_count = 0; - int missing_values_count = 0; - for (int r = 0; r < row_count; r++) - { - if (!isnan(x[idxs[r]][k])) - { - x_col[not_missing_values_count] = x[idxs[r]][k]; - not_missing_values_count += 1; - } - else - { - missing_values_count += 1; - } - } - x_col.resize(not_missing_values_count); - - vector x_col_idxs(not_missing_values_count); - iota(x_col_idxs.begin(), x_col_idxs.end(), 0); - sort(x_col_idxs.begin(), x_col_idxs.end(), [&x_col](size_t i1, size_t i2) - { return x_col[i1] < x_col[i2]; }); - - sort(x_col.begin(), x_col.end()); - - // get percentiles of x_col - vector percentiles = get_threshold_candidates(x_col); - - // enumerate all threshold value (missing value goto right) - int current_min_idx = 0; - int cumulative_left_size = 0; - for (int p = 0; p < percentiles.size(); p++) - { - vector temp_grad(grad_dim); - vector temp_hess(grad_dim); - for (int c = 0; c < grad_dim; c++) - { - temp_grad[c] = pk.encrypt(0); - temp_hess[c] = pk.encrypt(0); - } - int temp_left_size = 0; - - for (int r = current_min_idx; r < not_missing_values_count; r++) - { - if (x_col[r] <= percentiles[p]) - { - for (int c = 0; c < grad_dim; c++) - { - temp_grad[c] = temp_grad[c] + gradient[idxs[x_col_idxs[r]]][c]; - temp_hess[c] = temp_hess[c] + hessian[idxs[x_col_idxs[r]]][c]; - } - cumulative_left_size += 1; - } - else - { - current_min_idx = r; - break; - } - } - - if (cumulative_left_size >= min_leaf && - row_count - cumulative_left_size >= min_leaf) - { - split_candidates_grad_hess[i].push_back(make_pair(temp_grad, temp_hess)); - temp_thresholds[i].push_back(percentiles[p]); - } - } + return split_candidates_grad_hess; + } + + vector, vector>>> + greedy_search_split_encrypt(vector> &gradient, + vector> &hessian, + vector &idxs) { + // feature_id -> [(grad hess)] + // the threshold of split_candidates_grad_hess[i][j] = temp_thresholds[i][j] + int num_thresholds; + if (use_missing_value) + num_thresholds = subsample_col_count * 2; + else + num_thresholds = subsample_col_count; + vector, vector>>> + split_candidates_grad_hess(num_thresholds); + temp_thresholds = vector>(num_thresholds); + + int row_count = idxs.size(); + int recoed_id = 0; + + int grad_dim = gradient[0].size(); + + for (int i = 0; i < subsample_col_count; i++) { + // extract the necessary data + int k = temp_column_subsample[i]; + vector x_col(row_count); + + int not_missing_values_count = 0; + int missing_values_count = 0; + for (int r = 0; r < row_count; r++) { + if (!isnan(x[idxs[r]][k])) { + x_col[not_missing_values_count] = x[idxs[r]][k]; + not_missing_values_count += 1; + } else { + missing_values_count += 1; + } + } + x_col.resize(not_missing_values_count); + + vector x_col_idxs(not_missing_values_count); + iota(x_col_idxs.begin(), x_col_idxs.end(), 0); + sort(x_col_idxs.begin(), x_col_idxs.end(), + [&x_col](size_t i1, size_t i2) { return x_col[i1] < x_col[i2]; }); + + sort(x_col.begin(), x_col.end()); + + // get percentiles of x_col + vector percentiles = get_threshold_candidates(x_col); + + // enumerate all threshold value (missing value goto right) + int current_min_idx = 0; + int cumulative_left_size = 0; + for (int p = 0; p < percentiles.size(); p++) { + vector temp_grad(grad_dim); + vector temp_hess(grad_dim); + for (int c = 0; c < grad_dim; c++) { + temp_grad[c] = pk.encrypt(0); + temp_hess[c] = pk.encrypt(0); + } + int temp_left_size = 0; - // enumerate missing value goto left - if (use_missing_value) - { - int current_max_idx = not_missing_values_count - 1; - int cumulative_right_size = 0; - for (int p = percentiles.size() - 1; p >= 0; p--) - { - vector temp_grad(grad_dim); - vector temp_hess(grad_dim); - for (int c = 0; c < grad_dim; c++) - { - temp_grad[c] = pk.encrypt(0); - temp_hess[c] = pk.encrypt(0); - } - int temp_left_size = 0; - - for (int r = current_max_idx; r >= 0; r--) - { - if (x_col[r] >= percentiles[p]) - { - for (int c = 0; c < grad_dim; c++) - { - temp_grad[c] = temp_grad[c] + gradient[idxs[x_col_idxs[r]]][c]; - temp_hess[c] = temp_hess[c] + hessian[idxs[x_col_idxs[r]]][c]; - } - cumulative_right_size += 1; - } - else - { - current_max_idx = r; - break; - } - } - - if (cumulative_right_size >= min_leaf && - row_count - cumulative_right_size >= min_leaf) - { - split_candidates_grad_hess[i + subsample_col_count].push_back(make_pair(temp_grad, - temp_hess)); - temp_thresholds[i + subsample_col_count].push_back(percentiles[p]); - } - } + for (int r = current_min_idx; r < not_missing_values_count; r++) { + if (x_col[r] <= percentiles[p]) { + for (int c = 0; c < grad_dim; c++) { + temp_grad[c] = temp_grad[c] + gradient[idxs[x_col_idxs[r]]][c]; + temp_hess[c] = temp_hess[c] + hessian[idxs[x_col_idxs[r]]][c]; } + cumulative_left_size += 1; + } else { + current_min_idx = r; + break; + } } - return split_candidates_grad_hess; + if (cumulative_left_size >= min_leaf && + row_count - cumulative_left_size >= min_leaf) { + split_candidates_grad_hess[i].push_back( + make_pair(temp_grad, temp_hess)); + temp_thresholds[i].push_back(percentiles[p]); + } + } + + // enumerate missing value goto left + if (use_missing_value) { + int current_max_idx = not_missing_values_count - 1; + int cumulative_right_size = 0; + for (int p = percentiles.size() - 1; p >= 0; p--) { + vector temp_grad(grad_dim); + vector temp_hess(grad_dim); + for (int c = 0; c < grad_dim; c++) { + temp_grad[c] = pk.encrypt(0); + temp_hess[c] = pk.encrypt(0); + } + int temp_left_size = 0; + + for (int r = current_max_idx; r >= 0; r--) { + if (x_col[r] >= percentiles[p]) { + for (int c = 0; c < grad_dim; c++) { + temp_grad[c] = temp_grad[c] + gradient[idxs[x_col_idxs[r]]][c]; + temp_hess[c] = temp_hess[c] + hessian[idxs[x_col_idxs[r]]][c]; + } + cumulative_right_size += 1; + } else { + current_max_idx = r; + break; + } + } + + if (cumulative_right_size >= min_leaf && + row_count - cumulative_right_size >= min_leaf) { + split_candidates_grad_hess[i + subsample_col_count].push_back( + make_pair(temp_grad, temp_hess)); + temp_thresholds[i + subsample_col_count].push_back(percentiles[p]); + } + } + } } + + return split_candidates_grad_hess; + } }; diff --git a/src/aijack/collaborative/tree/secureboost/secureboost.h b/src/aijack/collaborative/tree/secureboost/secureboost.h index 65653831..b0829c25 100644 --- a/src/aijack/collaborative/tree/secureboost/secureboost.h +++ b/src/aijack/collaborative/tree/secureboost/secureboost.h @@ -1,223 +1,203 @@ #pragma once -#include -#include -#include -#include -#include #include "../../../defense/paillier/src/paillier.h" #include "../core/model.h" #include "../xgboost/loss.h" +#include "party.h" #include "tree.h" +#include +#include +#include +#include +#include using namespace std; -struct SecureBoostBase : TreeModelBase -{ - float subsample_cols; - float min_child_weight; - int depth; - int min_leaf; - float learning_rate; - int boosting_rounds; - float lam; - float gamma; - float eps; - int active_party_id; - int completelly_secure_round; - float init_value; - int n_job; - bool save_loss; - int num_classes; - - LossFunc *lossfunc_obj; - - vector> init_pred; - vector estimators; - vector logging_loss; - - SecureBoostBase(int num_classes_, float subsample_cols_ = 0.8, - float min_child_weight_ = -1 * numeric_limits::infinity(), - int depth_ = 5, int min_leaf_ = 5, - float learning_rate_ = 0.4, int boosting_rounds_ = 5, - float lam_ = 1.5, float gamma_ = 1, float eps_ = 0.1, - int active_party_id_ = -1, int completelly_secure_round_ = 0, - float init_value_ = 1.0, int n_job_ = 1, bool save_loss_ = true) - { - num_classes = num_classes_; - subsample_cols = subsample_cols_; - min_child_weight = min_child_weight_; - depth = depth_; - min_leaf = min_leaf_; - learning_rate = learning_rate_; - boosting_rounds = boosting_rounds_; - lam = lam_; - gamma = gamma_; - eps = eps_; - active_party_id = active_party_id_; - completelly_secure_round = completelly_secure_round_; - init_value = init_value_; - n_job = n_job_; - save_loss = save_loss_; - - if (num_classes == 2) - { - lossfunc_obj = new BCELoss(); - } - else - { - lossfunc_obj = new CELoss(num_classes); - } +struct SecureBoostBase : TreeModelBase { + float subsample_cols; + float min_child_weight; + int depth; + int min_leaf; + float learning_rate; + int boosting_rounds; + float lam; + float gamma; + float eps; + int active_party_id; + int completelly_secure_round; + float init_value; + int n_job; + bool save_loss; + int num_classes; + + LossFunc *lossfunc_obj; + + vector> init_pred; + vector estimators; + vector logging_loss; + + vector parties_cp; + + SecureBoostBase(int num_classes_, float subsample_cols_ = 0.8, + float min_child_weight_ = -1 * + numeric_limits::infinity(), + int depth_ = 5, int min_leaf_ = 5, float learning_rate_ = 0.4, + int boosting_rounds_ = 5, float lam_ = 1.5, float gamma_ = 1, + float eps_ = 0.1, int active_party_id_ = -1, + int completelly_secure_round_ = 0, float init_value_ = 1.0, + int n_job_ = 1, bool save_loss_ = true) { + num_classes = num_classes_; + subsample_cols = subsample_cols_; + min_child_weight = min_child_weight_; + depth = depth_; + min_leaf = min_leaf_; + learning_rate = learning_rate_; + boosting_rounds = boosting_rounds_; + lam = lam_; + gamma = gamma_; + eps = eps_; + active_party_id = active_party_id_; + completelly_secure_round = completelly_secure_round_; + init_value = init_value_; + n_job = n_job_; + save_loss = save_loss_; + + if (num_classes == 2) { + lossfunc_obj = new BCELoss(); + } else { + lossfunc_obj = new CELoss(num_classes); } + } - virtual vector> get_init_pred(vector &y) = 0; + virtual vector> get_init_pred(vector &y) = 0; - void load_estimators(vector &_estimators) - { - estimators = _estimators; - } + void load_estimators(vector &_estimators) { + estimators = _estimators; + } + + void clear() { + estimators.clear(); + logging_loss.clear(); + } + + vector get_estimators() { return estimators; } + vector get_parties() { return parties_cp; } - void clear() - { - estimators.clear(); - logging_loss.clear(); + void fit(vector &parties, vector &y) { + try { + if ((active_party_id < 0) || (active_party_id > parties.size())) { + throw invalid_argument("invalid active_party_id"); + } + } catch (std::exception &e) { + std::cout << e.what() << std::endl; } - vector get_estimators() - { - return estimators; + parties_cp = parties; + + int row_count = y.size(); + vector> base_pred; + if (estimators.size() == 0) { + init_pred = get_init_pred(y); + copy(init_pred.begin(), init_pred.end(), back_inserter(base_pred)); + } else { + base_pred.resize(row_count); + for (int j = 0; j < row_count; j++) + for (int c = 0; c < num_classes; c++) + base_pred[j][c] = 0; + + for (int i = 0; i < estimators.size(); i++) { + vector> pred_temp = estimators[i].get_train_prediction(); + for (int j = 0; j < row_count; j++) + for (int c = 0; c < num_classes; c++) + base_pred[j][c] += learning_rate * pred_temp[j][c]; + } } - void fit(vector &parties, vector &y) - { - try - { - if ((active_party_id < 0) || (active_party_id > parties.size())) - { - throw invalid_argument("invalid active_party_id"); - } - } - catch (std::exception &e) - { - std::cout << e.what() << std::endl; + for (int i = 0; i < boosting_rounds; i++) { + vector> vanila_grad = lossfunc_obj->get_grad(base_pred, y); + vector> vanila_hess = lossfunc_obj->get_hess(base_pred, y); + int grad_dim = vanila_grad[0].size(); + + vector> grad( + row_count, vector(grad_dim)); + vector> hess( + row_count, vector(grad_dim)); + + for (int j = 0; j < row_count; j++) { + for (int c = 0; c < grad_dim; c++) { + grad[j][c] = + parties[active_party_id].pk.encrypt(vanila_grad[j][c]); + hess[j][c] = + parties[active_party_id].pk.encrypt(vanila_hess[j][c]); } + } - int row_count = y.size(); - vector> base_pred; - if (estimators.size() == 0) - { - init_pred = get_init_pred(y); - copy(init_pred.begin(), init_pred.end(), back_inserter(base_pred)); - } - else - { - base_pred.resize(row_count); - for (int j = 0; j < row_count; j++) - for (int c = 0; c < num_classes; c++) - base_pred[j][c] = 0; - - for (int i = 0; i < estimators.size(); i++) - { - vector> pred_temp = estimators[i].get_train_prediction(); - for (int j = 0; j < row_count; j++) - for (int c = 0; c < num_classes; c++) - base_pred[j][c] += learning_rate * pred_temp[j][c]; - } - } + SecureBoostTree boosting_tree( + parties_cp, y, num_classes, grad, hess, vanila_grad, vanila_hess, + min_child_weight, lam, gamma, eps, min_leaf, depth, active_party_id, + (completelly_secure_round > i), n_job); + vector> pred_temp = boosting_tree.get_train_prediction(); - for (int i = 0; i < boosting_rounds; i++) - { - vector> vanila_grad = lossfunc_obj->get_grad(base_pred, y); - vector> vanila_hess = lossfunc_obj->get_hess(base_pred, y); - int grad_dim = vanila_grad[0].size(); - - vector> grad(row_count, vector(grad_dim)); - vector> hess(row_count, vector(grad_dim)); - - for (int j = 0; j < row_count; j++) - { - for (int c = 0; c < grad_dim; c++) - { - grad[j][c] = parties[active_party_id].pk.encrypt(vanila_grad[j][c]); - hess[j][c] = parties[active_party_id].pk.encrypt(vanila_hess[j][c]); - } - } - - SecureBoostTree boosting_tree = SecureBoostTree(); - boosting_tree.fit(parties, y, num_classes, grad, hess, vanila_grad, vanila_hess, min_child_weight, - lam, gamma, eps, min_leaf, depth, active_party_id, (completelly_secure_round > i), n_job); - vector> pred_temp = boosting_tree.get_train_prediction(); - - for (int j = 0; j < row_count; j++) - for (int c = 0; c < num_classes; c++) - base_pred[j][c] += learning_rate * pred_temp[j][c]; - - estimators.push_back(boosting_tree); - - if (save_loss) - { - logging_loss.push_back(lossfunc_obj->get_loss(base_pred, y)); - } - } - } + for (int j = 0; j < row_count; j++) + for (int c = 0; c < num_classes; c++) + base_pred[j][c] += learning_rate * pred_temp[j][c]; - vector> predict_raw(vector> &X) - { - int pred_dim; - if (num_classes == 2) - { - pred_dim = 1; - } - else - { - pred_dim = num_classes; - } + estimators.push_back(boosting_tree); - int row_count = X.size(); - vector> y_pred(row_count, vector(pred_dim, init_value)); - // copy(init_pred.begin(), init_pred.end(), back_inserter(y_pred)); - int estimators_num = estimators.size(); - for (int i = 0; i < estimators_num; i++) - { - vector> y_pred_temp = estimators[i].predict(X); - for (int j = 0; j < row_count; j++) - { - for (int c = 0; c < pred_dim; c++) - { - y_pred[j][c] += learning_rate * y_pred_temp[j][c]; - } - } - } + if (save_loss) { + logging_loss.push_back(lossfunc_obj->get_loss(base_pred, y)); + } - return y_pred; + parties = parties_cp; } -}; - -struct SecureBoostClassifier : public SecureBoostBase -{ - using SecureBoostBase::SecureBoostBase; - - vector> get_init_pred(vector &y) - { - vector> init_pred(y.size(), vector(num_classes, init_value)); - return init_pred; + } + + vector> predict_raw(vector> &X) { + int pred_dim; + if (num_classes == 2) { + pred_dim = 1; + } else { + pred_dim = num_classes; } - vector> predict_proba(vector> &x) - { - vector> raw_score = predict_raw(x); - int row_count = x.size(); - vector> predicted_probas(row_count, vector(num_classes, 0)); - for (int i = 0; i < row_count; i++) - { - if (num_classes == 2) - { - predicted_probas[i][1] = sigmoid(raw_score[i][0]); - predicted_probas[i][0] = 1 - predicted_probas[i][1]; - } - else - { - predicted_probas[i] = softmax(raw_score[i]); - } + int row_count = X.size(); + vector> y_pred(row_count, + vector(pred_dim, init_value)); + // copy(init_pred.begin(), init_pred.end(), back_inserter(y_pred)); + int estimators_num = estimators.size(); + for (int i = 0; i < estimators_num; i++) { + vector> y_pred_temp = estimators[i].predict(X); + for (int j = 0; j < row_count; j++) { + for (int c = 0; c < pred_dim; c++) { + y_pred[j][c] += learning_rate * y_pred_temp[j][c]; } - return predicted_probas; + } + } + + return y_pred; + } +}; + +struct SecureBoostClassifier : public SecureBoostBase { + using SecureBoostBase::SecureBoostBase; + + vector> get_init_pred(vector &y) { + vector> init_pred(y.size(), + vector(num_classes, init_value)); + return init_pred; + } + + vector> predict_proba(vector> &x) { + vector> raw_score = predict_raw(x); + int row_count = x.size(); + vector> predicted_probas(row_count, + vector(num_classes, 0)); + for (int i = 0; i < row_count; i++) { + if (num_classes == 2) { + predicted_probas[i][1] = sigmoid(raw_score[i][0]); + predicted_probas[i][0] = 1 - predicted_probas[i][1]; + } else { + predicted_probas[i] = softmax(raw_score[i]); + } } + return predicted_probas; + } }; diff --git a/src/aijack/collaborative/tree/secureboost/tree.h b/src/aijack/collaborative/tree/secureboost/tree.h index 29f6a59b..32b9b24c 100644 --- a/src/aijack/collaborative/tree/secureboost/tree.h +++ b/src/aijack/collaborative/tree/secureboost/tree.h @@ -1,32 +1,44 @@ #pragma once -#include -#include -#include -#include #include "../core/tree.h" #include "node.h" +#include +#include +#include +#include + +inline SecureBoostNode +make_root(vector &parties, vector y, int num_classes, + vector> &gradient, + vector> &hessian, + vector> &vanila_gradient, + vector> &vanila_hessian, float min_child_weight, + float lam, float gamma, float eps, int min_leaf, int depth, + int active_party_id = 0, bool use_only_active_party = false, + int n_job = 1) { + vector idxs(y.size()); + iota(idxs.begin(), idxs.end(), 0); + for (int i = 0; i < parties.size(); i++) { + parties[i].subsample_columns(); + } + return SecureBoostNode(parties, y, num_classes, gradient, hessian, + vanila_gradient, vanila_hessian, idxs, + min_child_weight, lam, gamma, eps, depth, + active_party_id, use_only_active_party, n_job); +} -struct SecureBoostTree : Tree -{ - SecureBoostTree() {} - void fit(vector &parties, vector y, - int num_classes, - vector> &gradient, - vector> &hessian, - vector> &vanila_gradient, vector> &vanila_hessian, - float min_child_weight, float lam, float gamma, float eps, - int min_leaf, int depth, int active_party_id = 0, - bool use_only_active_party = false, int n_job = 1) - { - vector idxs(y.size()); - iota(idxs.begin(), idxs.end(), 0); - for (int i = 0; i < parties.size(); i++) - { - parties[i].subsample_columns(); - } - dtree = SecureBoostNode(parties, y, num_classes, gradient, hessian, vanila_gradient, - vanila_hessian, idxs, min_child_weight, - lam, gamma, eps, depth, active_party_id, - use_only_active_party, n_job); - } +struct SecureBoostTree : public Tree { + // SecureBoostNode dtree; + // SecureBoostTree() {} + SecureBoostTree(vector &parties, vector y, + int num_classes, vector> &gradient, + vector> &hessian, + vector> &vanila_gradient, + vector> &vanila_hessian, float min_child_weight, + float lam, float gamma, float eps, int min_leaf, int depth, + int active_party_id = 0, bool use_only_active_party = false, + int n_job = 1) + : Tree(make_root( + parties, y, num_classes, gradient, hessian, vanila_gradient, + vanila_hessian, min_child_weight, lam, gamma, eps, min_leaf, depth, + active_party_id, use_only_active_party, n_job)) {} }; diff --git a/src/aijack/collaborative/tree/xgboost/node.h b/src/aijack/collaborative/tree/xgboost/node.h index 3758e1b1..8d5c3ef4 100644 --- a/src/aijack/collaborative/tree/xgboost/node.h +++ b/src/aijack/collaborative/tree/xgboost/node.h @@ -1,359 +1,363 @@ #pragma once +#include "../core/node.h" +#include "../utils/metric.h" +#include "../utils/utils.h" +#include "party.h" +#include "utils.h" +#include #include -#include -#include +#include #include #include -#include -#include -#include -#include +#include +#include #include -#include +#include +#include #include -#include +#include +#include #include -#include -#include "party.h" -#include "utils.h" -#include "../core/node.h" -#include "../utils/metric.h" -#include "../utils/utils.h" +#include using namespace std; -struct XGBoostNode : Node -{ - vector parties; - vector> gradient, hessian; - float min_child_weight, lam, gamma, eps; - float best_entropy; - bool use_only_active_party; - XGBoostNode *left, *right; - - int num_classes; - - float entire_datasetsize = 0; - vector entire_class_cnt; - - XGBoostNode() {} - XGBoostNode(vector &parties_, vector &y_, int num_classes_, - vector> &gradient_, - vector> &hessian_, vector &idxs_, - float min_child_weight_, float lam_, float gamma_, float eps_, int depth_, - int active_party_id_ = -1, bool use_only_active_party_ = false, int n_job_ = 1) - { - parties = parties_; - y = y_; - num_classes = num_classes_; - gradient = gradient_; - hessian = hessian_; - idxs = idxs_; - min_child_weight = min_child_weight_; - lam = lam_; - gamma = gamma_; - eps = eps_; - depth = depth_; - active_party_id = active_party_id_; - use_only_active_party = use_only_active_party_; - n_job = n_job_; - - row_count = idxs.size(); - num_parties = parties.size(); - - entire_class_cnt.resize(num_classes, 0); - entire_datasetsize = y.size(); - for (int i = 0; i < entire_datasetsize; i++) - { - entire_class_cnt[int(y[i])] += 1.0; - } - - try - { - if (use_only_active_party && active_party_id > parties.size()) - { - throw invalid_argument("invalid active_party_id"); - } - } - catch (std::exception &e) - { - std::cerr << e.what() << std::endl; - } +struct XGBoostNode : public Node { + // vector &parties; + vector> &gradient, &hessian; + float min_child_weight, lam, gamma, eps; + float best_entropy; + bool use_only_active_party, is_robust; + XGBoostNode *left, *right; + + int num_classes; + + float entire_datasetsize = 0; + vector entire_class_cnt; + + // XGBoostNode() {} + XGBoostNode(vector &parties_, vector &y_, + int num_classes_, vector> &gradient_, + vector> &hessian_, vector &idxs_, + float min_child_weight_, float lam_, float gamma_, float eps_, + int depth_, int active_party_id_ = -1, + bool use_only_active_party_ = false, int n_job_ = 1, + bool is_robust_ = false) + : gradient(gradient_), + hessian(hessian_), Node(parties_, idxs_, y_) { + num_classes = num_classes_; + min_child_weight = min_child_weight_; + lam = lam_; + gamma = gamma_; + eps = eps_; + depth = depth_; + active_party_id = active_party_id_; + use_only_active_party = use_only_active_party_; + n_job = n_job_; + is_robust = is_robust_; + + row_count = idxs.size(); + num_parties = parties.size(); + + entire_class_cnt.resize(num_classes, 0); + entire_datasetsize = y.size(); + for (int i = 0; i < entire_datasetsize; i++) { + entire_class_cnt[int(y[i])] += 1.0; + } - val = compute_weight(); + try { + if (use_only_active_party && active_party_id > parties.size()) { + throw invalid_argument("invalid active_party_id"); + } + } catch (std::exception &e) { + std::cerr << e.what() << std::endl; + } - if (is_leaf()) - { - is_leaf_flag = 1; - } - else - { - is_leaf_flag = 0; - } + val = compute_weight(); - if (is_leaf_flag == 0) - { - tuple best_split = find_split(); - party_id = get<0>(best_split); - if (party_id != -1) - { - record_id = parties[party_id].insert_lookup_table(get<1>(best_split), get<2>(best_split)); - make_children_nodes(get<0>(best_split), get<1>(best_split), get<2>(best_split)); - } - else - { - is_leaf_flag = 1; - } - } + if (is_leaf()) { + is_leaf_flag = 1; + } else { + is_leaf_flag = 0; } - vector get_idxs() - { - return idxs; + if (is_leaf_flag == 0) { + tuple best_split = find_split(); + party_id = get<0>(best_split); + if (party_id != -1) { + record_id = parties[party_id].insert_lookup_table(get<1>(best_split), + get<2>(best_split)); + make_children_nodes(get<0>(best_split), get<1>(best_split), + get<2>(best_split)); + } else { + is_leaf_flag = 1; + } } - - int get_party_id() - { - return party_id; + } + + XGBoostNode &operator=(const XGBoostNode &other) { + if (this != &other) { + parties = other.parties; + for (int i = 0; i < other.parties.size(); i++) { + parties[i].lookup_table = other.parties[i].lookup_table; + } + y = other.y; + idxs = other.idxs; + val = other.val; + + num_classes = other.num_classes; + depth = other.depth; + active_party_id = other.active_party_id; + n_job = other.n_job; + + party_id = other.party_id; + record_id = other.record_id; + row_count = other.row_count; + num_parties = other.num_parties; + is_leaf_flag = other.is_leaf_flag; + + left = other.left; + right = other.right; } + return *this; + } - int get_record_id() - { - return record_id; - } + vector get_idxs() { return idxs; } - vector get_val() - { - return val; - } + int get_party_id() { return party_id; } - float get_score() - { - return score; - } + int get_record_id() { return record_id; } - XGBoostNode get_left() - { - return *left; - } + vector get_val() { return val; } - XGBoostNode get_right() - { - return *right; - } + float get_score() { return score; } - int get_num_parties() - { - return parties.size(); - } + XGBoostNode get_left() { return *left; } - vector compute_weight() - { - return xgboost_compute_weight(row_count, gradient, hessian, idxs, lam); - } + XGBoostNode get_right() { return *right; } - float compute_gain(vector &left_grad, vector &right_grad, vector &left_hess, vector &right_hess) - { - return xgboost_compute_gain(left_grad, right_grad, left_hess, right_hess, gamma, lam); - } + int get_num_parties() { return parties.size(); } - void find_split_per_party(int party_id_start, int temp_num_parties, vector &sum_grad, vector &sum_hess, float tot_cnt, vector &temp_y_class_cnt) - { - - vector temp_left_class_cnt, temp_right_class_cnt; - temp_left_class_cnt.resize(num_classes, 0); - temp_right_class_cnt.resize(num_classes, 0); - - int grad_dim = sum_grad.size(); - - for (int temp_party_id = party_id_start; temp_party_id < party_id_start + temp_num_parties; temp_party_id++) - { - - vector, vector, float, vector>>> search_results = - parties[temp_party_id].greedy_search_split(gradient, hessian, y, idxs); - - float temp_score, temp_entropy; - vector temp_left_grad(grad_dim, 0); - vector temp_left_hess(grad_dim, 0); - vector temp_right_grad(grad_dim, 0); - vector temp_right_hess(grad_dim, 0); - float temp_left_size, temp_right_size; - bool skip_flag = false; - - for (int j = 0; j < search_results.size(); j++) - { - temp_score = 0; - temp_entropy = 0; - temp_left_size = 0; - temp_right_size = 0; - - for (int c = 0; c < grad_dim; c++) - { - temp_left_grad[c] = 0; - temp_left_hess[c] = 0; - } - - for (int c = 0; c < num_classes; c++) - { - temp_left_class_cnt[c] = 0; - temp_right_class_cnt[c] = 0; - } - - for (int k = 0; k < search_results[j].size(); k++) - { - for (int c = 0; c < grad_dim; c++) - { - temp_left_grad[c] += get<0>(search_results[j][k])[c]; - temp_left_hess[c] += get<1>(search_results[j][k])[c]; - } - temp_left_size += get<2>(search_results[j][k]); - temp_right_size = tot_cnt - temp_left_size; - - for (int c = 0; c < num_classes; c++) - { - temp_left_class_cnt[c] += get<3>(search_results[j][k])[c]; - temp_right_class_cnt[c] = temp_y_class_cnt[c] - temp_left_class_cnt[c]; - } - - skip_flag = false; - for (int c = 0; c < grad_dim; c++) - { - if (temp_left_hess[c] < min_child_weight || - sum_hess[c] - temp_left_hess[c] < min_child_weight) - { - skip_flag = true; - } - } - if (skip_flag) - { - continue; - } - - for (int c = 0; c < grad_dim; c++) - { - temp_right_grad[c] = sum_grad[c] - temp_left_grad[c]; - temp_right_hess[c] = sum_hess[c] - temp_left_hess[c]; - } - - temp_score = compute_gain(temp_left_grad, temp_right_grad, - temp_left_hess, temp_right_hess); - - if (temp_score > best_score) - { - best_score = temp_score; - best_entropy = temp_entropy; - best_party_id = temp_party_id; - best_col_id = j; - best_threshold_id = k; - } - } - } - } - } + vector compute_weight() { + return xgboost_compute_weight(row_count, gradient, hessian, idxs, lam); + } - tuple find_split() - { - vector sum_grad(gradient[0].size(), 0); - vector sum_hess(hessian[0].size(), 0); - for (int i = 0; i < row_count; i++) - { - for (int c = 0; c < sum_grad.size(); c++) - { - sum_grad[c] += gradient[idxs[i]][c]; - sum_hess[c] += hessian[idxs[i]][c]; - } - } + float compute_gain(vector &left_grad, vector &right_grad, + vector &left_hess, vector &right_hess) { + return xgboost_compute_gain(left_grad, right_grad, left_hess, right_hess, + gamma, lam); + } - float tot_cnt = row_count; - vector temp_y_class_cnt(num_classes, 0); - for (int r = 0; r < row_count; r++) - { - temp_y_class_cnt[int(y[idxs[r]])] += 1; - } + void find_split_per_party(int party_id_start, int temp_num_parties, + vector &sum_grad, vector &sum_hess, + float tot_cnt, vector &temp_y_class_cnt) { + + vector temp_left_class_cnt, temp_right_class_cnt; + temp_left_class_cnt.resize(num_classes, 0); + temp_right_class_cnt.resize(num_classes, 0); - float temp_score, temp_left_grad, temp_left_hess; + int grad_dim = sum_grad.size(); - if (use_only_active_party) - { - find_split_per_party(active_party_id, 1, sum_grad, sum_hess, tot_cnt, temp_y_class_cnt); + for (int temp_party_id = party_id_start; + temp_party_id < party_id_start + temp_num_parties; temp_party_id++) { + + vector, vector, float, vector>>> + search_results; + bool robust_flag = + is_robust && (parties[temp_party_id].cost_constraint_map.size() == + parties[temp_party_id].feature_id.size()); + if (!robust_flag) { + search_results = parties[temp_party_id].greedy_search_split( + gradient, hessian, y, idxs); + } else { + search_results = parties[temp_party_id].robust_greedy_search_split( + gradient, hessian, y, idxs, gamma, lam); + } + + float temp_score, temp_entropy; + vector temp_left_grad(grad_dim, 0); + vector temp_left_hess(grad_dim, 0); + vector temp_right_grad(grad_dim, 0); + vector temp_right_hess(grad_dim, 0); + float temp_left_size, temp_right_size; + bool skip_flag = false; + + for (int j = 0; j < search_results.size(); j++) { + temp_score = 0; + temp_entropy = 0; + temp_left_size = 0; + temp_right_size = 0; + + for (int c = 0; c < grad_dim; c++) { + temp_left_grad[c] = 0; + temp_left_hess[c] = 0; } - else - { - if (n_job == 1) - { - find_split_per_party(0, num_parties, sum_grad, sum_hess, tot_cnt, temp_y_class_cnt); + + for (int c = 0; c < num_classes; c++) { + temp_left_class_cnt[c] = 0; + temp_right_class_cnt[c] = 0; + } + + for (int k = 0; k < search_results[j].size(); k++) { + for (int c = 0; c < grad_dim; c++) { + if (!robust_flag) { + temp_left_grad[c] += get<0>(search_results[j][k])[c]; + temp_left_hess[c] += get<1>(search_results[j][k])[c]; + } else { + temp_left_grad[c] = get<0>(search_results[j][k])[c]; + temp_left_hess[c] = get<1>(search_results[j][k])[c]; } - else - { - vector num_parties_per_thread = get_num_parties_per_process(n_job, num_parties); - - int cnt_parties = 0; - vector threads_parties; - for (int i = 0; i < n_job; i++) - { - int local_num_parties = num_parties_per_thread[i]; - thread temp_th([this, cnt_parties, local_num_parties, &sum_grad, &sum_hess, tot_cnt, &temp_y_class_cnt] - { this->find_split_per_party(cnt_parties, local_num_parties, sum_grad, sum_hess, tot_cnt, temp_y_class_cnt); }); - threads_parties.push_back(move(temp_th)); - cnt_parties += num_parties_per_thread[i]; - } - for (int i = 0; i < num_parties; i++) - { - threads_parties[i].join(); - } + } + if (!robust_flag) { + temp_left_size += get<2>(search_results[j][k]); + } else { + temp_left_size = get<2>(search_results[j][k]); + } + temp_right_size = tot_cnt - temp_left_size; + + for (int c = 0; c < num_classes; c++) { + if (!robust_flag) { + temp_left_class_cnt[c] += get<3>(search_results[j][k])[c]; + } else { + temp_left_class_cnt[c] = get<3>(search_results[j][k])[c]; } + temp_right_class_cnt[c] = + temp_y_class_cnt[c] - temp_left_class_cnt[c]; + } + + skip_flag = false; + for (int c = 0; c < grad_dim; c++) { + if (temp_left_hess[c] < min_child_weight || + sum_hess[c] - temp_left_hess[c] < min_child_weight) { + skip_flag = true; + } + } + if (skip_flag) { + continue; + } + + for (int c = 0; c < grad_dim; c++) { + temp_right_grad[c] = sum_grad[c] - temp_left_grad[c]; + temp_right_hess[c] = sum_hess[c] - temp_left_hess[c]; + } + + temp_score = compute_gain(temp_left_grad, temp_right_grad, + temp_left_hess, temp_right_hess); + + if (temp_score > best_score) { + best_score = temp_score; + best_entropy = temp_entropy; + best_party_id = temp_party_id; + best_col_id = j; + best_threshold_id = k; + } } + } + } + } + + tuple find_split() { + vector sum_grad(gradient[0].size(), 0); + vector sum_hess(hessian[0].size(), 0); + for (int i = 0; i < row_count; i++) { + for (int c = 0; c < sum_grad.size(); c++) { + sum_grad[c] += gradient[idxs[i]][c]; + sum_hess[c] += hessian[idxs[i]][c]; + } + } - score = best_score; - return make_tuple(best_party_id, best_col_id, best_threshold_id); + float tot_cnt = row_count; + vector temp_y_class_cnt(num_classes, 0); + for (int r = 0; r < row_count; r++) { + temp_y_class_cnt[int(y[idxs[r]])] += 1; } - void make_children_nodes(int best_party_id, int best_col_id, int best_threshold_id) - { - // TODO: remove idx with nan values from right_idxs; - vector left_idxs = parties[best_party_id].split_rows(idxs, best_col_id, best_threshold_id); - vector right_idxs; - for (int i = 0; i < row_count; i++) - if (!any_of(left_idxs.begin(), left_idxs.end(), [&](float x) - { return x == idxs[i]; })) - right_idxs.push_back(idxs[i]); - - left = new XGBoostNode(parties, y, num_classes, gradient, hessian, left_idxs, min_child_weight, - lam, gamma, eps, depth - 1, active_party_id, use_only_active_party, n_job); - if (left->is_leaf_flag == 1) - { - left->party_id = party_id; + float temp_score, temp_left_grad, temp_left_hess; + + if (use_only_active_party) { + find_split_per_party(active_party_id, 1, sum_grad, sum_hess, tot_cnt, + temp_y_class_cnt); + } else { + if (n_job == 1) { + find_split_per_party(0, num_parties, sum_grad, sum_hess, tot_cnt, + temp_y_class_cnt); + } else { + vector num_parties_per_thread = + get_num_parties_per_process(n_job, num_parties); + + int cnt_parties = 0; + vector threads_parties; + for (int i = 0; i < n_job; i++) { + int local_num_parties = num_parties_per_thread[i]; + thread temp_th([this, cnt_parties, local_num_parties, &sum_grad, + &sum_hess, tot_cnt, &temp_y_class_cnt] { + this->find_split_per_party(cnt_parties, local_num_parties, sum_grad, + sum_hess, tot_cnt, temp_y_class_cnt); + }); + threads_parties.push_back(move(temp_th)); + cnt_parties += num_parties_per_thread[i]; } - right = new XGBoostNode(parties, y, num_classes, gradient, hessian, right_idxs, min_child_weight, - lam, gamma, eps, depth - 1, active_party_id, use_only_active_party, n_job); - if (right->is_leaf_flag == 1) - { - right->party_id = party_id; + for (int i = 0; i < num_parties; i++) { + threads_parties[i].join(); } + } } - bool is_leaf() - { - if (is_leaf_flag == -1) - { - return is_pure() || std::isinf(score) || depth <= 0; - } - else - { - return is_leaf_flag; - } + score = best_score; + return make_tuple(best_party_id, best_col_id, best_threshold_id); + } + + void make_children_nodes(int best_party_id, int best_col_id, + int best_threshold_id) { + vector left_idxs; + + if ((!is_robust) || (parties[best_party_id].cost_constraint_map.size() != + parties[party_id].feature_id.size())) { + left_idxs = parties[best_party_id].split_rows(idxs, best_col_id, + best_threshold_id); + } else { + left_idxs = parties[best_party_id].robust_split_rows( + idxs, best_col_id, best_threshold_id, gradient, hessian, y, gamma, + lam); } - bool is_pure() - { - set s{}; - for (int i = 0; i < row_count; i++) - { - if (s.insert(y[idxs[i]]).second) - { - if (s.size() == 2) - return false; - } - } - return true; + vector right_idxs; + for (int i = 0; i < row_count; i++) + if (!any_of(left_idxs.begin(), left_idxs.end(), + [&](float x) { return x == idxs[i]; })) + right_idxs.push_back(idxs[i]); + + left = new XGBoostNode(parties, y, num_classes, gradient, hessian, + left_idxs, min_child_weight, lam, gamma, eps, + depth - 1, active_party_id, use_only_active_party, + n_job, is_robust); + if (left->is_leaf_flag == 1) { + left->party_id = party_id; + } + right = new XGBoostNode(parties, y, num_classes, gradient, hessian, + right_idxs, min_child_weight, lam, gamma, eps, + depth - 1, active_party_id, use_only_active_party, + n_job, is_robust); + if (right->is_leaf_flag == 1) { + right->party_id = party_id; + } + } + + bool is_leaf() { + if (is_leaf_flag == -1) { + return is_pure() || std::isinf(score) || depth <= 0; + } else { + return is_leaf_flag; + } + } + + bool is_pure() { + set s{}; + for (int i = 0; i < row_count; i++) { + if (s.insert(y[idxs[i]]).second) { + if (s.size() == 2) + return false; + } } + return true; + } }; diff --git a/src/aijack/collaborative/tree/xgboost/party.h b/src/aijack/collaborative/tree/xgboost/party.h index 5d08b5bc..ef54db5f 100644 --- a/src/aijack/collaborative/tree/xgboost/party.h +++ b/src/aijack/collaborative/tree/xgboost/party.h @@ -1,176 +1,372 @@ #pragma once -#include -#include "../core/party.h" #include "../../../defense/paillier/src/paillier.h" +#include "../core/party.h" +#include "utils.h" +#include +#include using namespace std; -struct XGBoostParty : Party -{ - int num_percentile_bin; - - XGBoostParty() {} - XGBoostParty(vector> x_, int num_classes_, vector feature_id_, int party_id_, - int min_leaf_, float subsample_cols_, int num_precentile_bin_ = 256, - bool use_missing_value_ = false, int seed_ = 0) : Party(x_, num_classes_, feature_id_, party_id_, - min_leaf_, subsample_cols_, - use_missing_value_, seed_) - { - num_percentile_bin = num_precentile_bin_; +struct XGBoostParty : public Party { + int num_percentile_bin; + vector> cost_constraint_map; + + // XGBoostParty() {} + XGBoostParty(vector> x_, int num_classes_, + vector feature_id_, int party_id_, int min_leaf_, + float subsample_cols_, int num_precentile_bin_ = 256, + bool use_missing_value_ = false, int seed_ = 0) + : Party(x_, num_classes_, feature_id_, party_id_, min_leaf_, + subsample_cols_, use_missing_value_, seed_) { + num_percentile_bin = num_precentile_bin_; + } + + void + set_cost_constraint_map(vector> cost_constraint_map_) { + cost_constraint_map = cost_constraint_map_; + } + + vector get_threshold_candidates(vector &x_col) { + if (x_col.size() > num_percentile_bin) { + vector probs(num_percentile_bin); + for (int i = 1; i <= num_percentile_bin; i++) + probs[i] = float(i) / float(num_percentile_bin); + vector percentiles_candidate = Quantile(x_col, probs); + vector percentiles = + remove_duplicates(percentiles_candidate); + return percentiles; + } else { + vector x_col_wo_duplicates = remove_duplicates(x_col); + vector percentiles(x_col_wo_duplicates.size()); + copy(x_col_wo_duplicates.begin(), x_col_wo_duplicates.end(), + percentiles.begin()); + sort(percentiles.begin(), percentiles.end()); + return percentiles; } + } + + vector, vector, float, vector>>> + greedy_search_split(const vector> &gradient, + const vector> &hessian, vector &y, + vector &idxs) { + // feature_id -> [(grad hess)] + // the threshold of split_candidates_grad_hess[i][j] = temp_thresholds[i][j] + int num_thresholds; + if (use_missing_value) + num_thresholds = subsample_col_count * 2; + else + num_thresholds = subsample_col_count; + vector, vector, float, vector>>> + split_candidates_grad_hess(num_thresholds); + temp_thresholds = vector>(num_thresholds); - vector get_threshold_candidates(vector &x_col) - { - if (x_col.size() > num_percentile_bin) - { - vector probs(num_percentile_bin); - for (int i = 1; i <= num_percentile_bin; i++) - probs[i] = float(i) / float(num_percentile_bin); - vector percentiles_candidate = Quantile(x_col, probs); - vector percentiles = remove_duplicates(percentiles_candidate); - return percentiles; + int row_count = idxs.size(); + int recoed_id = 0; + + int grad_dim = gradient[0].size(); + + for (int i = 0; i < subsample_col_count; i++) { + // extract the necessary data + int k = temp_column_subsample[i]; + vector x_col(row_count); + + int not_missing_values_count = 0; + int missing_values_count = 0; + for (int r = 0; r < row_count; r++) { + if (!isnan(x[idxs[r]][k])) { + x_col[not_missing_values_count] = x[idxs[r]][k]; + not_missing_values_count += 1; + } else { + missing_values_count += 1; } - else - { - vector x_col_wo_duplicates = remove_duplicates(x_col); - vector percentiles(x_col_wo_duplicates.size()); - copy(x_col_wo_duplicates.begin(), x_col_wo_duplicates.end(), percentiles.begin()); - sort(percentiles.begin(), percentiles.end()); - return percentiles; + } + x_col.resize(not_missing_values_count); + + vector x_col_idxs(not_missing_values_count); + iota(x_col_idxs.begin(), x_col_idxs.end(), 0); + sort(x_col_idxs.begin(), x_col_idxs.end(), + [&x_col](size_t i1, size_t i2) { return x_col[i1] < x_col[i2]; }); + + sort(x_col.begin(), x_col.end()); + + // get percentiles of x_col + vector percentiles = get_threshold_candidates(x_col); + + // enumerate all threshold value (missing value goto right) + int current_min_idx = 0; + int cumulative_left_size = 0; + for (int p = 0; p < percentiles.size(); p++) { + vector temp_grad(grad_dim, 0); + vector temp_hess(grad_dim, 0); + float temp_left_size = 0; + vector temp_left_y_class_cnt(num_classes, 0); + + for (int r = current_min_idx; r < not_missing_values_count; r++) { + if (x_col[r] <= percentiles[p]) { + for (int c = 0; c < grad_dim; c++) { + temp_grad[c] += gradient[idxs[x_col_idxs[r]]][c]; + temp_hess[c] += hessian[idxs[x_col_idxs[r]]][c]; + } + temp_left_size += 1.0; + temp_left_y_class_cnt[int(y[idxs[x_col_idxs[r]]])] += 1.0; + cumulative_left_size += 1; + } else { + current_min_idx = r; + break; + } + } + + if (cumulative_left_size >= min_leaf && + row_count - cumulative_left_size >= min_leaf) { + split_candidates_grad_hess[i].push_back(make_tuple( + temp_grad, temp_hess, temp_left_size, temp_left_y_class_cnt)); + temp_thresholds[i].push_back(percentiles[p]); + } + } + + // enumerate missing value goto left + if (use_missing_value) { + int current_max_idx = not_missing_values_count - 1; + int cumulative_right_size = 0; + for (int p = percentiles.size() - 1; p >= 0; p--) { + vector temp_grad(grad_dim, 0); + vector temp_hess(grad_dim, 0); + float temp_left_size = 0; + vector temp_left_y_class_cnt(num_classes, 0); + + for (int r = current_max_idx; r >= 0; r--) { + if (x_col[r] >= percentiles[p]) { + for (int c = 0; c < grad_dim; c++) { + temp_grad[c] += gradient[idxs[x_col_idxs[r]]][c]; + temp_hess[c] += hessian[idxs[x_col_idxs[r]]][c]; + } + temp_left_size += 1.0; + temp_left_y_class_cnt[int(y[idxs[x_col_idxs[r]]])] += 1.0; + cumulative_right_size += 1; + } else { + current_max_idx = r; + break; + } + } + + if (cumulative_right_size >= min_leaf && + row_count - cumulative_right_size >= min_leaf) { + split_candidates_grad_hess[i + subsample_col_count].push_back( + make_tuple(temp_grad, temp_hess, temp_left_size, + temp_left_y_class_cnt)); + temp_thresholds[i + subsample_col_count].push_back(percentiles[p]); + } } + } } - vector, vector, float, vector>>> greedy_search_split(vector> &gradient, - vector> &hessian, - vector &y, - vector &idxs) - { - // feature_id -> [(grad hess)] - // the threshold of split_candidates_grad_hess[i][j] = temp_thresholds[i][j] - int num_thresholds; - if (use_missing_value) - num_thresholds = subsample_col_count * 2; - else - num_thresholds = subsample_col_count; - vector, vector, float, vector>>> split_candidates_grad_hess(num_thresholds); - temp_thresholds = vector>(num_thresholds); - - int row_count = idxs.size(); - int recoed_id = 0; - - int grad_dim = gradient[0].size(); - - for (int i = 0; i < subsample_col_count; i++) - { - // extract the necessary data - int k = temp_column_subsample[i]; - vector x_col(row_count); - - int not_missing_values_count = 0; - int missing_values_count = 0; - for (int r = 0; r < row_count; r++) - { - if (!isnan(x[idxs[r]][k])) - { - x_col[not_missing_values_count] = x[idxs[r]][k]; - not_missing_values_count += 1; - } - else - { - missing_values_count += 1; - } + return split_candidates_grad_hess; + } + + vector robust_split_rows(vector &idxs, int feature_opt_pos, + int threshold_opt_pos, + const vector> &gradient, + const vector> &hessian, + vector &y, float gam, float lam) { + vector left_idxs; + int num_thresholds = subsample_col_count; + int feature_opt_id = + temp_column_subsample[feature_opt_pos % subsample_col_count]; + + int row_count = idxs.size(); + int recoed_id = 0; + + int grad_dim = gradient[0].size(); + + vector x_col(row_count); + for (int r = 0; r < row_count; r++) + x_col[r] = x[idxs[r]][feature_opt_id]; + + float threshold = temp_thresholds[feature_opt_pos][threshold_opt_pos]; + + vector left_grad_confident(grad_dim, 0); + vector right_grad_confident(grad_dim, 0); + vector left_hess_confident(grad_dim, 0); + vector right_hess_confident(grad_dim, 0); + vector uncertain_idxs; + + for (int r = 0; r < row_count; r++) { + if ((!isnan(x_col[r])) && (x_col[r] <= threshold)) { + + if (x_col[r] >= threshold + cost_constraint_map[feature_opt_id].first) { + uncertain_idxs.push_back(idxs[r]); + } else { + left_idxs.push_back(idxs[r]); + for (int c = 0; c < grad_dim; c++) { + left_grad_confident[c] += gradient[idxs[r]][c]; + left_hess_confident[c] += hessian[idxs[r]][c]; + } + } + } else { + if (x_col[r] < threshold + cost_constraint_map[feature_opt_id].second) { + uncertain_idxs.push_back(idxs[r]); + } else { + for (int c = 0; c < grad_dim; c++) { + right_grad_confident[c] += gradient[idxs[r]][c]; + right_hess_confident[c] += hessian[idxs[r]][c]; + } + } + } + } + + for (int r : uncertain_idxs) { + vector left_tmp_grad(left_grad_confident); + vector right_tmp_grad(right_grad_confident); + vector left_tmp_hess(left_hess_confident); + vector right_tmp_hess(right_hess_confident); + + for (int c = 0; c < grad_dim; c++) { + left_tmp_grad[c] += gradient[r][c]; + left_tmp_hess[c] += hessian[r][c]; + right_tmp_grad[c] += gradient[r][c]; + right_tmp_hess[c] += hessian[r][c]; + } + + float left_gain = + xgboost_compute_gain(left_tmp_grad, right_grad_confident, + left_tmp_hess, right_hess_confident, gam, lam); + float right_gain = + xgboost_compute_gain(left_grad_confident, right_tmp_grad, + left_hess_confident, right_tmp_hess, gam, lam); + + if (left_gain < right_gain) { + left_idxs.push_back(r); + left_grad_confident = left_tmp_grad; + left_hess_confident = left_tmp_hess; + } else { + right_grad_confident = right_tmp_grad; + right_hess_confident = right_tmp_hess; + } + } + return left_idxs; + } + + vector, vector, float, vector>>> + robust_greedy_search_split(const vector> &gradient, + const vector> &hessian, + vector &y, vector &idxs, float gam, + float lam) { + + int num_thresholds; + if (use_missing_value) + num_thresholds = subsample_col_count * 2; + else + num_thresholds = subsample_col_count; + vector, vector, float, vector>>> + split_candidates_grad_hess(num_thresholds); + temp_thresholds = vector>(num_thresholds); + + int row_count = idxs.size(); + int recoed_id = 0; + + int grad_dim = gradient[0].size(); + + for (int i = 0; i < subsample_col_count; i++) { + // extract the necessary data + int k = temp_column_subsample[i]; + vector x_col(row_count); + + int not_missing_values_count = 0; + int missing_values_count = 0; + for (int r = 0; r < row_count; r++) { + if (!isnan(x[idxs[r]][k])) { + x_col[not_missing_values_count] = x[idxs[r]][k]; + not_missing_values_count += 1; + } else { + missing_values_count += 1; + } + } + x_col.resize(not_missing_values_count); + + vector x_col_idxs(not_missing_values_count); + iota(x_col_idxs.begin(), x_col_idxs.end(), 0); + sort(x_col_idxs.begin(), x_col_idxs.end(), + [&x_col](size_t i1, size_t i2) { return x_col[i1] < x_col[i2]; }); + + sort(x_col.begin(), x_col.end()); + + // get percentiles of x_col + vector percentiles = get_threshold_candidates(x_col); + + for (int p = 0; p < percentiles.size(); p++) { + float temp_left_size = 0; + vector temp_left_y_class_cnt(num_classes, 0); + + float threshold_val = percentiles[p]; + vector left_grad_confident(grad_dim, 0); + vector right_grad_confident(grad_dim, 0); + vector left_hess_confident(grad_dim, 0); + vector right_hess_confident(grad_dim, 0); + vector uncertain_idxs; + + for (int r = 0; r < not_missing_values_count; r++) { + if (x_col[r] <= percentiles[p]) { + temp_left_size += 1.0; + temp_left_y_class_cnt[int(y[idxs[x_col_idxs[r]]])] += 1.0; + + if (x_col[r] >= percentiles[p] + cost_constraint_map[k].first) { + uncertain_idxs.push_back(idxs[x_col_idxs[r]]); + } else { + for (int c = 0; c < grad_dim; c++) { + left_grad_confident[c] += gradient[idxs[x_col_idxs[r]]][c]; + left_hess_confident[c] += hessian[idxs[x_col_idxs[r]]][c]; + } } - x_col.resize(not_missing_values_count); - - vector x_col_idxs(not_missing_values_count); - iota(x_col_idxs.begin(), x_col_idxs.end(), 0); - sort(x_col_idxs.begin(), x_col_idxs.end(), [&x_col](size_t i1, size_t i2) - { return x_col[i1] < x_col[i2]; }); - - sort(x_col.begin(), x_col.end()); - - // get percentiles of x_col - vector percentiles = get_threshold_candidates(x_col); - - // enumerate all threshold value (missing value goto right) - int current_min_idx = 0; - int cumulative_left_size = 0; - for (int p = 0; p < percentiles.size(); p++) - { - vector temp_grad(grad_dim, 0); - vector temp_hess(grad_dim, 0); - float temp_left_size = 0; - vector temp_left_y_class_cnt(num_classes, 0); - - for (int r = current_min_idx; r < not_missing_values_count; r++) - { - if (x_col[r] <= percentiles[p]) - { - for (int c = 0; c < grad_dim; c++) - { - temp_grad[c] += gradient[idxs[x_col_idxs[r]]][c]; - temp_hess[c] += hessian[idxs[x_col_idxs[r]]][c]; - } - temp_left_size += 1.0; - temp_left_y_class_cnt[int(y[idxs[x_col_idxs[r]]])] += 1.0; - cumulative_left_size += 1; - } - else - { - current_min_idx = r; - break; - } - } - - if (cumulative_left_size >= min_leaf && - row_count - cumulative_left_size >= min_leaf) - { - split_candidates_grad_hess[i].push_back(make_tuple(temp_grad, temp_hess, temp_left_size, temp_left_y_class_cnt)); - temp_thresholds[i].push_back(percentiles[p]); - } + } else { + if (x_col[r] < percentiles[p] + cost_constraint_map[k].second) { + uncertain_idxs.push_back(idxs[x_col_idxs[r]]); + } else { + for (int c = 0; c < grad_dim; c++) { + right_grad_confident[c] += gradient[idxs[x_col_idxs[r]]][c]; + right_hess_confident[c] += hessian[idxs[x_col_idxs[r]]][c]; + } } + } + } - // enumerate missing value goto left - if (use_missing_value) - { - int current_max_idx = not_missing_values_count - 1; - int cumulative_right_size = 0; - for (int p = percentiles.size() - 1; p >= 0; p--) - { - vector temp_grad(grad_dim, 0); - vector temp_hess(grad_dim, 0); - float temp_left_size = 0; - vector temp_left_y_class_cnt(num_classes, 0); - - for (int r = current_max_idx; r >= 0; r--) - { - if (x_col[r] >= percentiles[p]) - { - for (int c = 0; c < grad_dim; c++) - { - temp_grad[c] += gradient[idxs[x_col_idxs[r]]][c]; - temp_hess[c] += hessian[idxs[x_col_idxs[r]]][c]; - } - temp_left_size += 1.0; - temp_left_y_class_cnt[int(y[idxs[x_col_idxs[r]]])] += 1.0; - cumulative_right_size += 1; - } - else - { - current_max_idx = r; - break; - } - } - - if (cumulative_right_size >= min_leaf && - row_count - cumulative_right_size >= min_leaf) - { - split_candidates_grad_hess[i + subsample_col_count].push_back(make_tuple(temp_grad, temp_hess, temp_left_size, temp_left_y_class_cnt)); - temp_thresholds[i + subsample_col_count].push_back(percentiles[p]); - } - } - } + for (int r : uncertain_idxs) { + vector left_tmp_grad(left_grad_confident); + vector right_tmp_grad(right_grad_confident); + vector left_tmp_hess(left_hess_confident); + vector right_tmp_hess(right_hess_confident); + + for (int c = 0; c < grad_dim; c++) { + left_tmp_grad[c] += gradient[r][c]; + left_tmp_hess[c] += hessian[r][c]; + right_tmp_grad[c] += gradient[r][c]; + right_tmp_hess[c] += hessian[r][c]; + } + + float left_gain = xgboost_compute_gain( + left_tmp_grad, right_grad_confident, left_tmp_hess, + right_hess_confident, gam, lam); + float right_gain = xgboost_compute_gain( + left_grad_confident, right_tmp_grad, left_hess_confident, + right_tmp_hess, gam, lam); + + if (left_gain < right_gain) { + left_grad_confident = left_tmp_grad; + left_hess_confident = left_tmp_hess; + } else { + right_grad_confident = right_tmp_grad; + right_hess_confident = right_tmp_hess; + } } - return split_candidates_grad_hess; + if (temp_left_size >= min_leaf && + row_count - temp_left_size >= min_leaf) { + split_candidates_grad_hess[i].push_back( + make_tuple(left_grad_confident, left_hess_confident, + temp_left_size, temp_left_y_class_cnt)); + temp_thresholds[i].push_back(percentiles[p]); + } + } } + + return split_candidates_grad_hess; + } }; diff --git a/src/aijack/collaborative/tree/xgboost/tree.h b/src/aijack/collaborative/tree/xgboost/tree.h index 972c6a36..d8b8d26f 100644 --- a/src/aijack/collaborative/tree/xgboost/tree.h +++ b/src/aijack/collaborative/tree/xgboost/tree.h @@ -1,39 +1,45 @@ #pragma once -#include -#include -#include -#include #include "../core/tree.h" #include "node.h" +#include +#include +#include +#include + +inline XGBoostNode +make_root_node(vector &parties, vector &y, int num_classes, + vector> &gradient, vector> &hessian, + float min_child_weight, float lam, float gamma, float eps, + int min_leaf, int depth, int active_party_id = -1, + bool use_only_active_party = false, int n_job = 1, + bool is_robust = false) { + vector idxs(y.size()); + iota(idxs.begin(), idxs.end(), 0); + for (int i = 0; i < parties.size(); i++) { + parties[i].subsample_columns(); + } -struct XGBoostTree : Tree -{ - XGBoostTree() {} - void fit(vector &parties, vector &y, int num_classes, - vector> &gradient, vector> &hessian, - float min_child_weight, float lam, - float gamma, float eps, int min_leaf, int depth, - int active_party_id = -1, bool use_only_active_party = false, int n_job = 1) - { - vector idxs(y.size()); - iota(idxs.begin(), idxs.end(), 0); - for (int i = 0; i < parties.size(); i++) - { - parties[i].subsample_columns(); - } + return XGBoostNode(parties, y, num_classes, gradient, hessian, idxs, + min_child_weight, lam, gamma, eps, depth, active_party_id, + use_only_active_party, n_job, is_robust); +} - dtree = XGBoostNode(parties, y, num_classes, gradient, hessian, idxs, - min_child_weight, lam, gamma, eps, depth, - active_party_id, use_only_active_party, n_job); - } +struct XGBoostTree : public Tree { + XGBoostTree(vector &parties, vector &y, int num_classes, + vector> &gradient, vector> &hessian, + float min_child_weight, float lam, float gamma, float eps, + int min_leaf, int depth, int active_party_id = -1, + bool use_only_active_party = false, int n_job = 1, + bool is_robust = false) + : Tree(make_root_node( + parties, y, num_classes, gradient, hessian, min_child_weight, lam, + gamma, eps, min_leaf, depth, active_party_id, use_only_active_party, + n_job, is_robust)) {} - XGBoostNode get_root_xgboost_node() - { - return dtree; - } + XGBoostNode &get_root_xgboost_node() { return dtree; } - // vector> predict(vector> X) - //{ - // return dtree.predict(X); - // } + // vector> predict(vector> X) + //{ + // return dtree.predict(X); + // } }; diff --git a/src/aijack/collaborative/tree/xgboost/utils.h b/src/aijack/collaborative/tree/xgboost/utils.h index d794070a..b1d92469 100644 --- a/src/aijack/collaborative/tree/xgboost/utils.h +++ b/src/aijack/collaborative/tree/xgboost/utils.h @@ -1,50 +1,47 @@ #pragma once -#include +#include +#include #include #include -#include -#include +#include using namespace std; -float inline xgboost_compute_gain(vector left_grad, vector right_grad, - vector left_hess, vector right_hess, - float gam, float lam) -{ - float left_gain = 0; - float right_gain = 0; - float base_gain = 0; +float inline xgboost_compute_gain(vector &left_grad, + vector &right_grad, + vector &left_hess, + vector &right_hess, float gam, + float lam) { + float left_gain = 0; + float right_gain = 0; + float base_gain = 0; - for (int c = 0; c < left_grad.size(); c++) - { - left_gain += (left_grad[c] * left_grad[c]) / (left_hess[c] + lam); - right_gain += (right_grad[c] * right_grad[c]) / (right_hess[c] + lam); - base_gain += ((left_grad[c] + right_grad[c]) * - (left_grad[c] + right_grad[c]) / (left_hess[c] + right_hess[c] + lam)); - } + for (int c = 0; c < left_grad.size(); c++) { + left_gain += (left_grad[c] * left_grad[c]) / (left_hess[c] + lam); + right_gain += (right_grad[c] * right_grad[c]) / (right_hess[c] + lam); + base_gain += + ((left_grad[c] + right_grad[c]) * (left_grad[c] + right_grad[c]) / + (left_hess[c] + right_hess[c] + lam)); + } - return 0.5 * (left_gain + right_gain - base_gain) - gam; + return 0.5 * (left_gain + right_gain - base_gain) - gam; } -vector inline xgboost_compute_weight(int row_count, - vector> &gradient, vector> &hessian, - vector &idxs, float lam) -{ - int grad_dim = gradient[0].size(); - vector sum_grad(grad_dim, 0); - vector sum_hess(grad_dim, 0); - vector node_weigths(grad_dim, 0); - for (int i = 0; i < row_count; i++) - { - for (int c = 0; c < grad_dim; c++) - { - sum_grad[c] += gradient[idxs[i]][c]; - sum_hess[c] += hessian[idxs[i]][c]; - } +vector inline xgboost_compute_weight( + int row_count, const vector> &gradient, + const vector> &hessian, vector &idxs, float lam) { + int grad_dim = gradient[0].size(); + vector sum_grad(grad_dim, 0); + vector sum_hess(grad_dim, 0); + vector node_weigths(grad_dim, 0); + for (int i = 0; i < row_count; i++) { + for (int c = 0; c < grad_dim; c++) { + sum_grad[c] += gradient[idxs[i]][c]; + sum_hess[c] += hessian[idxs[i]][c]; } + } - for (int c = 0; c < grad_dim; c++) - { - node_weigths[c] = -1 * (sum_grad[c] / (sum_hess[c] + lam)); - } - return node_weigths; + for (int c = 0; c < grad_dim; c++) { + node_weigths[c] = -1 * (sum_grad[c] / (sum_hess[c] + lam)); + } + return node_weigths; } diff --git a/src/aijack/collaborative/tree/xgboost/xgboost.h b/src/aijack/collaborative/tree/xgboost/xgboost.h index d326edc9..a3f485bd 100644 --- a/src/aijack/collaborative/tree/xgboost/xgboost.h +++ b/src/aijack/collaborative/tree/xgboost/xgboost.h @@ -1,200 +1,179 @@ #pragma once -#include -#include -#include -#include -#include #include "../core/model.h" -#include "tree.h" #include "loss.h" +#include "party.h" +#include "tree.h" +#include +#include +#include +#include +#include using namespace std; -struct XGBoostBase : TreeModelBase -{ - float subsample_cols; - float min_child_weight; - int depth; - int min_leaf; - float learning_rate; - int boosting_rounds; - float lam; - float gamma; - float eps; - int active_party_id; - int completelly_secure_round; - float init_value; - int n_job; - bool save_loss; - int num_classes; - - float upsilon_Y; - - LossFunc *lossfunc_obj; - - vector> init_pred; - vector estimators; - vector logging_loss; - - XGBoostBase(int num_classes_, float subsample_cols_ = 0.8, - float min_child_weight_ = -1 * numeric_limits::infinity(), - int depth_ = 5, int min_leaf_ = 5, - float learning_rate_ = 0.4, int boosting_rounds_ = 5, - float lam_ = 1.5, float gamma_ = 1, float eps_ = 0.1, - int active_party_id_ = -1, int completelly_secure_round_ = 0, - float init_value_ = 1.0, int n_job_ = 1, bool save_loss_ = true) - { - num_classes = num_classes_; - subsample_cols = subsample_cols_; - min_child_weight = min_child_weight_; - depth = depth_; - min_leaf = min_leaf_; - learning_rate = learning_rate_; - boosting_rounds = boosting_rounds_; - lam = lam_; - gamma = gamma_; - eps = eps_; - active_party_id = active_party_id_; - completelly_secure_round = completelly_secure_round_; - init_value = init_value_; - n_job = n_job_; - save_loss = save_loss_; - - if (num_classes == 2) - { - lossfunc_obj = new BCELoss(); - } - else - { - lossfunc_obj = new CELoss(num_classes); - } - } - - virtual vector> get_init_pred(vector &y) = 0; - - void load_estimators(vector &_estimators) - { - estimators = _estimators; +struct XGBoostBase : TreeModelBase { + float subsample_cols; + float min_child_weight; + int depth; + int min_leaf; + float learning_rate; + int boosting_rounds; + float lam; + float gamma; + float eps; + int active_party_id; + int completelly_secure_round; + float init_value; + int n_job; + bool save_loss; + int num_classes; + bool is_robust; + + float upsilon_Y; + + LossFunc *lossfunc_obj; + + vector> init_pred; + vector estimators; + vector logging_loss; + vector parties_cp; + + XGBoostBase(int num_classes_, float subsample_cols_ = 0.8, + float min_child_weight_ = -1 * numeric_limits::infinity(), + int depth_ = 5, int min_leaf_ = 5, float learning_rate_ = 0.4, + int boosting_rounds_ = 5, float lam_ = 1.5, float gamma_ = 1, + float eps_ = 0.1, int active_party_id_ = -1, + int completelly_secure_round_ = 0, float init_value_ = 1.0, + int n_job_ = 1, bool save_loss_ = true, bool is_robust_ = false) { + num_classes = num_classes_; + subsample_cols = subsample_cols_; + min_child_weight = min_child_weight_; + depth = depth_; + min_leaf = min_leaf_; + learning_rate = learning_rate_; + boosting_rounds = boosting_rounds_; + lam = lam_; + gamma = gamma_; + eps = eps_; + active_party_id = active_party_id_; + completelly_secure_round = completelly_secure_round_; + init_value = init_value_; + n_job = n_job_; + save_loss = save_loss_; + is_robust = is_robust_; + + if (num_classes == 2) { + lossfunc_obj = new BCELoss(); + } else { + lossfunc_obj = new CELoss(num_classes); } - - void clear() - { - estimators.clear(); - logging_loss.clear(); + } + + virtual vector> get_init_pred(vector &y) = 0; + + void load_estimators(vector &_estimators) { + estimators = _estimators; + } + + void clear() { + estimators.clear(); + logging_loss.clear(); + } + + vector get_estimators() { return estimators; } + vector get_parties() { return parties_cp; } + + void fit(vector &parties, vector &y) { + int row_count = y.size(); + + parties_cp = parties; + + vector> base_pred; + if (estimators.size() == 0) { + init_pred = get_init_pred(y); + copy(init_pred.begin(), init_pred.end(), back_inserter(base_pred)); + } else { + base_pred.resize(row_count); + for (int j = 0; j < row_count; j++) + for (int c = 0; c < num_classes; c++) + base_pred[j][c] = 0; + + for (int i = 0; i < estimators.size(); i++) { + vector> pred_temp = estimators[i].get_train_prediction(); + for (int j = 0; j < row_count; j++) + for (int c = 0; c < num_classes; c++) + base_pred[j][c] += learning_rate * pred_temp[j][c]; + } } - vector get_estimators() - { - return estimators; - } + for (int i = 0; i < boosting_rounds; i++) { + vector> grad = lossfunc_obj->get_grad(base_pred, y); + vector> hess = lossfunc_obj->get_hess(base_pred, y); - void fit(vector &parties, vector &y) - { - int row_count = y.size(); + XGBoostTree boosting_tree( + parties_cp, y, num_classes, grad, hess, min_child_weight, lam, gamma, + eps, min_leaf, depth, active_party_id, (completelly_secure_round > i), + n_job, is_robust); + vector> pred_temp = boosting_tree.get_train_prediction(); + for (int j = 0; j < row_count; j++) + for (int c = 0; c < num_classes; c++) + base_pred[j][c] += learning_rate * pred_temp[j][c]; - vector> base_pred; - if (estimators.size() == 0) - { - init_pred = get_init_pred(y); - copy(init_pred.begin(), init_pred.end(), back_inserter(base_pred)); - } - else - { - base_pred.resize(row_count); - for (int j = 0; j < row_count; j++) - for (int c = 0; c < num_classes; c++) - base_pred[j][c] = 0; - - for (int i = 0; i < estimators.size(); i++) - { - vector> pred_temp = estimators[i].get_train_prediction(); - for (int j = 0; j < row_count; j++) - for (int c = 0; c < num_classes; c++) - base_pred[j][c] += learning_rate * pred_temp[j][c]; - } - } + estimators.push_back(boosting_tree); - for (int i = 0; i < boosting_rounds; i++) - { - vector> grad = lossfunc_obj->get_grad(base_pred, y); - vector> hess = lossfunc_obj->get_hess(base_pred, y); - - XGBoostTree boosting_tree = XGBoostTree(); - boosting_tree.fit(parties, y, num_classes, grad, hess, min_child_weight, - lam, gamma, eps, min_leaf, depth, - active_party_id, (completelly_secure_round > i), n_job); - vector> pred_temp = boosting_tree.get_train_prediction(); - for (int j = 0; j < row_count; j++) - for (int c = 0; c < num_classes; c++) - base_pred[j][c] += learning_rate * pred_temp[j][c]; - - estimators.push_back(boosting_tree); - - if (save_loss) - { - logging_loss.push_back(lossfunc_obj->get_loss(base_pred, y)); - } - } + if (save_loss) { + logging_loss.push_back(lossfunc_obj->get_loss(base_pred, y)); + } + } + } + + vector> predict_raw(vector> &X) { + int pred_dim; + if (num_classes == 2) { + pred_dim = 1; + } else { + pred_dim = num_classes; } - vector> predict_raw(vector> &X) - { - int pred_dim; - if (num_classes == 2) - { - pred_dim = 1; - } - else - { - pred_dim = num_classes; - } - - int row_count = X.size(); - vector> y_pred(row_count, vector(pred_dim, init_value)); - // copy(init_pred.begin(), init_pred.end(), back_inserter(y_pred)); - int estimators_num = estimators.size(); - for (int i = 0; i < estimators_num; i++) - { - vector> y_pred_temp = estimators[i].predict(X); - for (int j = 0; j < row_count; j++) - { - for (int c = 0; c < pred_dim; c++) - { - y_pred[j][c] += learning_rate * y_pred_temp[j][c]; - } - } + int row_count = X.size(); + vector> y_pred(row_count, + vector(pred_dim, init_value)); + // copy(init_pred.begin(), init_pred.end(), back_inserter(y_pred)); + int estimators_num = estimators.size(); + for (int i = 0; i < estimators_num; i++) { + vector> y_pred_temp = estimators[i].predict(X); + for (int j = 0; j < row_count; j++) { + for (int c = 0; c < pred_dim; c++) { + y_pred[j][c] += learning_rate * y_pred_temp[j][c]; } - - return y_pred; + } } -}; -struct XGBoostClassifier : public XGBoostBase -{ - using XGBoostBase::XGBoostBase; - - vector> get_init_pred(vector &y) - { - vector> init_pred(y.size(), vector(num_classes, init_value)); - return init_pred; - } + return y_pred; + } +}; - vector> predict_proba(vector> &x) - { - vector> raw_score = predict_raw(x); - int row_count = x.size(); - vector> predicted_probas(row_count, vector(num_classes, 0)); - for (int i = 0; i < row_count; i++) - { - if (num_classes == 2) - { - predicted_probas[i][1] = sigmoid(raw_score[i][0]); - predicted_probas[i][0] = 1 - predicted_probas[i][1]; - } - else - { - predicted_probas[i] = softmax(raw_score[i]); - } - } - return predicted_probas; +struct XGBoostClassifier : public XGBoostBase { + using XGBoostBase::XGBoostBase; + + vector> get_init_pred(vector &y) { + vector> init_pred(y.size(), + vector(num_classes, init_value)); + return init_pred; + } + + vector> predict_proba(vector> &x) { + vector> raw_score = predict_raw(x); + int row_count = x.size(); + vector> predicted_probas(row_count, + vector(num_classes, 0)); + for (int i = 0; i < row_count; i++) { + if (num_classes == 2) { + predicted_probas[i][1] = sigmoid(raw_score[i][0]); + predicted_probas[i][0] = 1 - predicted_probas[i][1]; + } else { + predicted_probas[i] = softmax(raw_score[i]); + } } + return predicted_probas; + } }; diff --git a/src/main.cpp b/src/main.cpp index b922322d..799f45de 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,17 +1,20 @@ -#include -#include -#include #include #include +#include +#include +#include #include +#include "aijack/collaborative/tree/secureboost/party.h" +#include "aijack/collaborative/tree/secureboost/secureboost.h" +#include "aijack/collaborative/tree/xgboost/node.h" +#include "aijack/collaborative/tree/xgboost/party.h" +#include "aijack/collaborative/tree/xgboost/xgboost.h" #include "aijack/defense/dp/core//rdp.cpp" #include "aijack/defense/dp/core//search.cpp" #include "aijack/defense/kanonymity/core/anonymizer.h" -#include "aijack/defense/paillier/src/paillier.h" #include "aijack/defense/paillier/src/keygenerator.h" -#include "aijack/collaborative/tree/xgboost/xgboost.h" -#include "aijack/collaborative/tree/secureboost/secureboost.h" +#include "aijack/defense/paillier/src/paillier.h" #define STRINGIFY(x) #x #define MACRO_STRINGIFY(x) STRINGIFY(x) @@ -22,150 +25,150 @@ namespace py = pybind11; template using overload_cast_ = pybind11::detail::overload_cast_impl; -PYBIND11_MODULE(aijack_cpp_core, m) -{ - m.doc() = R"pbdoc( +PYBIND11_MODULE(aijack_cpp_core, m) { + m.doc() = R"pbdoc( c++ backend for aijack )pbdoc"; - m.def("eps_gaussian", - &eps_gaussian, R"pbdoc(eps_gaussian)pbdoc"); - - m.def("eps_laplace", - &eps_laplace, R"pbdoc(eps_laplace)pbdoc"); - - m.def("eps_randresp", - &eps_randresp, R"pbdoc(eps_randresp)pbdoc"); - - m.def("culc_tightupperbound_lowerbound_of_rdp_with_theorem6and8_of_zhu_2019", - &culc_tightupperbound_lowerbound_of_rdp_with_theorem6and8_of_zhu_2019, - R"pbdoc(culc_tightupperbound_lowerbound_of_rdp_with_theorem6and8_of_zhu_2019)pbdoc"); - - m.def("culc_upperbound_of_rdp_with_Sampled_Gaussian_Mechanism", - &culc_upperbound_of_rdp_with_Sampled_Gaussian_Mechanism, - R"pbdoc(culc_upperbound_of_rdp_with_Sampled_Gaussian_Mechanism)pbdoc"); - - m.def("_ternary_search", - &_ternary_search, R"pbdoc(_ternary_search)pbdoc"); - - m.def("_ternary_search_int", - &_ternary_search_int, R"pbdoc(_ternary_search_int)pbdoc"); - - m.def("_greedy_search", - &_greedy_search, R"pbdoc(_greedy_search)pbdoc"); - - m.def("_greedy_search_frac", - &_greedy_search_frac, R"pbdoc(_greey_search_frac)pbdoc"); - - py::class_(m, "PaillierKeyGenerator") - .def(py::init()) - .def("generate_keypair", &PaillierKeyGenerator::generate_keypair); - - py::class_(m, "PaillierPublicKey") - .def("encrypt", &PaillierPublicKey::encrypt) - .def("encrypt", &PaillierPublicKey::encrypt) - .def("encrypt", &PaillierPublicKey::encrypt) - .def("encrypt", &PaillierPublicKey::encrypt) - .def("get_publickeyvalues", &PaillierPublicKey::get_publickeyvalues); - - py::class_(m, "PaillierCipherText") - .def("__add__", overload_cast_()(&PaillierCipherText::operator+)) - .def("__add__", overload_cast_()(&PaillierCipherText::operator+)) - .def("__add__", overload_cast_()(&PaillierCipherText::operator+)) - .def("__add__", overload_cast_()(&PaillierCipherText::operator+)) - .def("__add__", overload_cast_()(&PaillierCipherText::operator+)) - .def("__mul__", overload_cast_()(&PaillierCipherText::operator*)) - .def("__mul__", overload_cast_()(&PaillierCipherText::operator*)) - .def("__mul__", overload_cast_()(&PaillierCipherText::operator*)) - .def("__mul__", overload_cast_()(&PaillierCipherText::operator*)) - .def("get_value", &PaillierCipherText::get_value); - - py::class_(m, "PaillierSecretKey") - .def("decrypt2int", &PaillierSecretKey::decrypt2int) - .def("decrypt2long", &PaillierSecretKey::decrypt2long) - .def("decrypt2float", &PaillierSecretKey::decrypt2float) - .def("decrypt2double", &PaillierSecretKey::decrypt2double) - .def("get_publickeyvalues", &PaillierSecretKey::get_publickeyvalues) - .def("get_secretkeyvalues", &PaillierSecretKey::get_secretkeyvalues); - - py::class_(m, "XGBoostParty") - .def(py::init>, int, vector, int, - int, float, int, bool, int>()) - .def("get_lookup_table", &XGBoostParty::get_lookup_table); - - py::class_(m, "SecureBoostParty") - .def(py::init>, int, vector, int, - int, float, int, bool, int>()) - .def("get_lookup_table", &SecureBoostParty::get_lookup_table) - .def("set_publickey", &SecureBoostParty::set_publickey) - .def("set_secretkey", &SecureBoostParty::set_secretkey); - - py::class_(m, "XGBoostNode") - .def("get_idxs", &XGBoostNode::get_idxs) - .def("get_party_id", &XGBoostNode::get_party_id) - .def("get_record_id", &XGBoostNode::get_record_id) - .def("get_val", &XGBoostNode::get_val) - .def("get_score", &XGBoostNode::get_score) - .def("get_left", &XGBoostNode::get_left) - .def("get_right", &XGBoostNode::get_right) - .def("is_leaf", &XGBoostNode::is_leaf); - - py::class_(m, "SecureBoostNode") - .def("get_idxs", &SecureBoostNode::get_idxs) - .def("get_party_id", &SecureBoostNode::get_party_id) - .def("get_record_id", &SecureBoostNode::get_record_id) - .def("get_val", &SecureBoostNode::get_val) - .def("get_score", &SecureBoostNode::get_score) - .def("get_left", &SecureBoostNode::get_left) - .def("get_right", &SecureBoostNode::get_right) - .def("is_leaf", &SecureBoostNode::is_leaf); - - py::class_(m, "XGBoostTree") - .def("get_root_xgboost_node", &XGBoostTree::get_root_xgboost_node) - .def("print", &XGBoostTree::print) - .def("predict", &XGBoostTree::predict); - - py::class_(m, "SecureBoostTree") - .def("print", &SecureBoostTree::print) - .def("predict", &SecureBoostTree::predict); - - py::class_(m, "XGBoostClassifier") - .def(py::init()) - .def("fit", &XGBoostClassifier::fit) - .def("get_init_pred", &XGBoostClassifier::get_init_pred) - .def("load_estimators", &XGBoostClassifier::load_estimators) - .def("get_estimators", &XGBoostClassifier::get_estimators) - .def("predict_raw", &XGBoostClassifier::predict_raw) - .def("predict_proba", &XGBoostClassifier::predict_proba); - - py::class_(m, "SecureBoostClassifier") - .def(py::init()) - .def("fit", &SecureBoostClassifier::fit) - .def("get_init_pred", &SecureBoostClassifier::get_init_pred) - .def("load_estimators", &SecureBoostClassifier::load_estimators) - .def("get_estimators", &SecureBoostClassifier::get_estimators) - .def("predict_raw", &SecureBoostClassifier::predict_raw) - .def("predict_proba", &SecureBoostClassifier::predict_proba); - - py::class_(m, "DataFrame") - .def(py::init, map, int>()) - .def("insert_continuous", &DataFrame::insert_continuous) - .def("insert_categorical", &DataFrame::insert_categorical) - .def("insert_continuous_column", &DataFrame::insert_continuous_column) - .def("insert_categorical_column", &DataFrame::insert_categorical_column) - .def("get_data_continuous", &DataFrame::get_data_continuous) - .def("get_data_categorical", &DataFrame::get_data_categorical); - - py::class_(m, "Mondrian") - .def(py::init()) - .def("get_final_partitions", &Mondrian::get_final_partitions) - .def("anonymize", &Mondrian::anonymize); + m.def("eps_gaussian", &eps_gaussian, R"pbdoc(eps_gaussian)pbdoc"); + + m.def("eps_laplace", &eps_laplace, R"pbdoc(eps_laplace)pbdoc"); + + m.def("eps_randresp", &eps_randresp, R"pbdoc(eps_randresp)pbdoc"); + + m.def( + "culc_tightupperbound_lowerbound_of_rdp_with_theorem6and8_of_zhu_2019", + &culc_tightupperbound_lowerbound_of_rdp_with_theorem6and8_of_zhu_2019, + R"pbdoc(culc_tightupperbound_lowerbound_of_rdp_with_theorem6and8_of_zhu_2019)pbdoc"); + + m.def("culc_upperbound_of_rdp_with_Sampled_Gaussian_Mechanism", + &culc_upperbound_of_rdp_with_Sampled_Gaussian_Mechanism, + R"pbdoc(culc_upperbound_of_rdp_with_Sampled_Gaussian_Mechanism)pbdoc"); + + m.def("_ternary_search", &_ternary_search, R"pbdoc(_ternary_search)pbdoc"); + + m.def("_ternary_search_int", &_ternary_search_int, + R"pbdoc(_ternary_search_int)pbdoc"); + + m.def("_greedy_search", &_greedy_search, R"pbdoc(_greedy_search)pbdoc"); + + m.def("_greedy_search_frac", &_greedy_search_frac, + R"pbdoc(_greey_search_frac)pbdoc"); + + py::class_(m, "PaillierKeyGenerator") + .def(py::init()) + .def("generate_keypair", &PaillierKeyGenerator::generate_keypair); + + py::class_(m, "PaillierPublicKey") + .def("encrypt", &PaillierPublicKey::encrypt) + .def("encrypt", &PaillierPublicKey::encrypt) + .def("encrypt", &PaillierPublicKey::encrypt) + .def("encrypt", &PaillierPublicKey::encrypt) + .def("get_publickeyvalues", &PaillierPublicKey::get_publickeyvalues); + + py::class_(m, "PaillierCipherText") + .def("__add__", overload_cast_()(&PaillierCipherText::operator+)) + .def("__add__", overload_cast_()(&PaillierCipherText::operator+)) + .def("__add__", overload_cast_()(&PaillierCipherText::operator+)) + .def("__add__", overload_cast_()(&PaillierCipherText::operator+)) + .def("__add__", + overload_cast_()(&PaillierCipherText::operator+)) + .def("__mul__", overload_cast_()(&PaillierCipherText::operator*)) + .def("__mul__", overload_cast_()(&PaillierCipherText::operator*)) + .def("__mul__", overload_cast_()(&PaillierCipherText::operator*)) + .def("__mul__", overload_cast_()(&PaillierCipherText::operator*)) + .def("get_value", &PaillierCipherText::get_value); + + py::class_(m, "PaillierSecretKey") + .def("decrypt2int", &PaillierSecretKey::decrypt2int) + .def("decrypt2long", &PaillierSecretKey::decrypt2long) + .def("decrypt2float", &PaillierSecretKey::decrypt2float) + .def("decrypt2double", &PaillierSecretKey::decrypt2double) + .def("get_publickeyvalues", &PaillierSecretKey::get_publickeyvalues) + .def("get_secretkeyvalues", &PaillierSecretKey::get_secretkeyvalues); + + py::class_(m, "XGBoostParty") + .def(py::init>, int, vector, int, int, float, + int, bool, int>()) + .def("set_cost_constraint_map", &XGBoostParty::set_cost_constraint_map) + .def("get_lookup_table", &XGBoostParty::get_lookup_table); + + py::class_(m, "SecureBoostParty") + .def(py::init>, int, vector, int, int, float, + int, bool, int>()) + .def("get_lookup_table", &SecureBoostParty::get_lookup_table) + .def("set_publickey", &SecureBoostParty::set_publickey) + .def("set_secretkey", &SecureBoostParty::set_secretkey); + + py::class_(m, "XGBoostNode") + .def("get_idxs", &XGBoostNode::get_idxs) + .def("get_party_id", &XGBoostNode::get_party_id) + .def("get_record_id", &XGBoostNode::get_record_id) + .def("get_num_parties", &XGBoostNode::get_num_parties) + .def("get_val", &XGBoostNode::get_val) + .def("get_score", &XGBoostNode::get_score) + .def("get_left", &XGBoostNode::get_left) + .def("get_right", &XGBoostNode::get_right) + .def("is_leaf", &XGBoostNode::is_leaf); + + py::class_(m, "SecureBoostNode") + .def("get_idxs", &SecureBoostNode::get_idxs) + .def("get_party_id", &SecureBoostNode::get_party_id) + .def("get_record_id", &SecureBoostNode::get_record_id) + .def("get_val", &SecureBoostNode::get_val) + .def("get_score", &SecureBoostNode::get_score) + .def("get_left", &SecureBoostNode::get_left) + .def("get_right", &SecureBoostNode::get_right) + .def("is_leaf", &SecureBoostNode::is_leaf); + + py::class_(m, "XGBoostTree") + .def("get_root_xgboost_node", &XGBoostTree::get_root_xgboost_node) + .def("print", &XGBoostTree::print) + .def("predict", &XGBoostTree::predict); + + py::class_(m, "SecureBoostTree") + .def("print", &SecureBoostTree::print) + .def("predict", &SecureBoostTree::predict); + + py::class_(m, "XGBoostClassifier") + .def(py::init()) + .def("fit", &XGBoostClassifier::fit) + .def("get_init_pred", &XGBoostClassifier::get_init_pred) + .def("load_estimators", &XGBoostClassifier::load_estimators) + .def("get_estimators", &XGBoostClassifier::get_estimators) + .def("get_parties", &XGBoostClassifier::get_parties) + .def("predict_raw", &XGBoostClassifier::predict_raw) + .def("predict_proba", &XGBoostClassifier::predict_proba); + + py::class_(m, "SecureBoostClassifier") + .def(py::init()) + .def("fit", &SecureBoostClassifier::fit) + .def("get_init_pred", &SecureBoostClassifier::get_init_pred) + .def("load_estimators", &SecureBoostClassifier::load_estimators) + .def("get_estimators", &SecureBoostClassifier::get_estimators) + .def("get_parties", &SecureBoostClassifier::get_parties) + .def("predict_raw", &SecureBoostClassifier::predict_raw) + .def("predict_proba", &SecureBoostClassifier::predict_proba); + + py::class_(m, "DataFrame") + .def(py::init, map, int>()) + .def("insert_continuous", &DataFrame::insert_continuous) + .def("insert_categorical", &DataFrame::insert_categorical) + .def("insert_continuous_column", &DataFrame::insert_continuous_column) + .def("insert_categorical_column", &DataFrame::insert_categorical_column) + .def("get_data_continuous", &DataFrame::get_data_continuous) + .def("get_data_categorical", &DataFrame::get_data_categorical); + + py::class_(m, "Mondrian") + .def(py::init()) + .def("get_final_partitions", &Mondrian::get_final_partitions) + .def("anonymize", &Mondrian::anonymize); #ifdef VERSION_INFO - m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO); + m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO); #else - m.attr("__version__") = "dev"; + m.attr("__version__") = "dev"; #endif } diff --git a/test/collaborative/secureboost/test_secureboost.py b/test/collaborative/secureboost/test_secureboost.py index d9783d84..874395bb 100644 --- a/test/collaborative/secureboost/test_secureboost.py +++ b/test/collaborative/secureboost/test_secureboost.py @@ -62,8 +62,8 @@ def test_secureboost(): [0.3700332045555115, 0.6299667954444885], [0.20040041208267212, 0.7995995879173279], [0.44300776720046997, 0.55699223279953], - [0.3700332045555115, 0.6299667954444885], - [0.3700332045555115, 0.6299667954444885], + [0.2150152325630188, 0.7849847674369812], + [0.2150152325630188, 0.7849847674369812], [0.44300776720046997, 0.55699223279953], [0.20040041208267212, 0.7995995879173279], ]