diff --git a/integration/xgboost/processor/CMakeLists.txt b/integration/xgboost/processor/CMakeLists.txt new file mode 100644 index 0000000000..d29b246377 --- /dev/null +++ b/integration/xgboost/processor/CMakeLists.txt @@ -0,0 +1,52 @@ +cmake_minimum_required(VERSION 3.19) +project(proc_nvflare LANGUAGES CXX C VERSION 1.0) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_BUILD_TYPE Debug) + +option(GOOGLE_TEST "Build google tests" OFF) + +file(GLOB_RECURSE LIB_SRC + "src/*.h" + "src/*.cc" + ) + +add_library(proc_nvflare SHARED ${LIB_SRC}) +set(XGB_SRC ${proc_nvflare_SOURCE_DIR}/../../../../xgboost) +target_include_directories(proc_nvflare PRIVATE ${proc_nvflare_SOURCE_DIR}/src/include + ${XGB_SRC}/src + ${XGB_SRC}/rabit/include + ${XGB_SRC}/include + ${XGB_SRC}/dmlc-core/include) + +link_directories(${XGB_SRC}/lib/) + +if (APPLE) + add_link_options("LINKER:-object_path_lto,$_lto.o") + add_link_options("LINKER:-cache_path_lto,${CMAKE_BINARY_DIR}/LTOCache") +endif () + +target_link_libraries(proc_nvflare ${XGB_SRC}/lib/libxgboost${CMAKE_SHARED_LIBRARY_SUFFIX}) + +#-- Unit Tests +if(GOOGLE_TEST) + find_package(GTest REQUIRED) + enable_testing() + add_executable(proc_test) + target_link_libraries(proc_test PRIVATE proc_nvflare) + + + target_include_directories(proc_test PRIVATE ${proc_nvflare_SOURCE_DIR}/src/include + ${XGB_SRC}/src + ${XGB_SRC}/rabit/include + ${XGB_SRC}/include + ${XGB_SRC}/dmlc-core/include + ${XGB_SRC}/tests) + + add_subdirectory(${proc_nvflare_SOURCE_DIR}/tests) + + add_test( + NAME TestProcessor + COMMAND proc_test + WORKING_DIRECTORY ${proc_nvflare_BINARY_DIR}) + +endif() \ No newline at end of file diff --git a/integration/xgboost/processor/README.md b/integration/xgboost/processor/README.md new file mode 100644 index 0000000000..08afc24e42 --- /dev/null +++ b/integration/xgboost/processor/README.md @@ -0,0 +1,15 @@ +# Build Instruction + +This plugin build requires xgboost source code, checkout xgboost source and build it with FEDERATED plugin, + +cd xgboost +mkdir build +cd build +cmake .. -DPLUGIN_FEDERATED=ON +make + +cd NVFlare/integration/xgboost/processor +mkdir build +cd build +cmake .. +make diff --git a/integration/xgboost/processor/src/README.md b/integration/xgboost/processor/src/README.md new file mode 100644 index 0000000000..a10dae75ed --- /dev/null +++ b/integration/xgboost/processor/src/README.md @@ -0,0 +1,11 @@ +# encoding-plugins +Processor Plugin for NVFlare + +This plugin is a companion for NVFlare based encryption, it processes the data so it can +be properly decoded by Python code running on NVFlare. + +All the encryption is happening on the local GRPC client/server so no encryption is needed +in this plugin. + + + diff --git a/integration/xgboost/processor/src/dam/README.md b/integration/xgboost/processor/src/dam/README.md new file mode 100644 index 0000000000..ba65423e65 --- /dev/null +++ b/integration/xgboost/processor/src/dam/README.md @@ -0,0 +1,12 @@ +# DAM (Direct-Accessible Marshaller) + +A simple serialization library that doesn't have dependencies, and the data +is directly accessible in C/C++ without copying. + +To make the data accessible in C, following rules must be followed, + +1. Numeric values must be stored in native byte-order. +2. Numeric values must start at the 64-bit boundaries (8-bytes) + + + diff --git a/integration/xgboost/processor/src/dam/dam.cc b/integration/xgboost/processor/src/dam/dam.cc new file mode 100644 index 0000000000..d768d497dd --- /dev/null +++ b/integration/xgboost/processor/src/dam/dam.cc @@ -0,0 +1,146 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include "dam.h" + +void print_buffer(uint8_t *buffer, int size) { + for (int i = 0; i < size; i++) { + auto c = buffer[i]; + std::cout << std::hex << (int) c << " "; + } + std::cout << std::endl << std::dec; +} + +// DamEncoder ====== +void DamEncoder::AddFloatArray(std::vector &value) { + if (encoded) { + std::cout << "Buffer is already encoded" << std::endl; + return; + } + auto buf_size = value.size()*8; + uint8_t *buffer = static_cast(malloc(buf_size)); + memcpy(buffer, value.data(), buf_size); + // print_buffer(reinterpret_cast(value.data()), value.size() * 8); + entries->push_back(new Entry(kDataTypeFloatArray, buffer, value.size())); +} + +void DamEncoder::AddIntArray(std::vector &value) { + std::cout << "AddIntArray called, size: " << value.size() << std::endl; + if (encoded) { + std::cout << "Buffer is already encoded" << std::endl; + return; + } + auto buf_size = value.size()*8; + std::cout << "Allocating " << buf_size << " bytes" << std::endl; + uint8_t *buffer = static_cast(malloc(buf_size)); + memcpy(buffer, value.data(), buf_size); + // print_buffer(buffer, buf_size); + entries->push_back(new Entry(kDataTypeIntArray, buffer, value.size())); +} + +std::uint8_t * DamEncoder::Finish(size_t &size) { + encoded = true; + + size = calculate_size(); + auto buf = static_cast(malloc(size)); + auto pointer = buf; + memcpy(pointer, kSignature, strlen(kSignature)); + memcpy(pointer+8, &size, 8); + memcpy(pointer+16, &data_set_id, 8); + + pointer += kPrefixLen; + for (auto entry : *entries) { + memcpy(pointer, &entry->data_type, 8); + pointer += 8; + memcpy(pointer, &entry->size, 8); + pointer += 8; + int len = 8*entry->size; + memcpy(pointer, entry->pointer, len); + free(entry->pointer); + pointer += len; + // print_buffer(entry->pointer, entry->size*8); + } + + if ((pointer - buf) != size) { + std::cout << "Invalid encoded size: " << (pointer - buf) << std::endl; + return nullptr; + } + + return buf; +} + +std::size_t DamEncoder::calculate_size() { + auto size = kPrefixLen; + + for (auto entry : *entries) { + size += 16; // The Type and Len + size += entry->size * 8; // All supported data types are 8 bytes + } + + return size; +} + + +// DamDecoder ====== + +DamDecoder::DamDecoder(std::uint8_t *buffer, std::size_t size) { + this->buffer = buffer; + this->buf_size = size; + this->pos = buffer + kPrefixLen; + if (size >= kPrefixLen) { + memcpy(&len, buffer + 8, 8); + memcpy(&data_set_id, buffer + 16, 8); + } else { + len = 0; + data_set_id = 0; + } +} + +bool DamDecoder::IsValid() { + return buf_size >= kPrefixLen && memcmp(buffer, kSignature, strlen(kSignature)) == 0; +} + +std::vector DamDecoder::DecodeIntArray() { + auto type = *reinterpret_cast(pos); + if (type != kDataTypeIntArray) { + std::cout << "Data type " << type << " doesn't match Int Array" << std::endl; + return std::vector(); + } + pos += 8; + + auto len = *reinterpret_cast(pos); + pos += 8; + auto ptr = reinterpret_cast(pos); + pos += 8*len; + return std::vector(ptr, ptr + len); +} + +std::vector DamDecoder::DecodeFloatArray() { + auto type = *reinterpret_cast(pos); + if (type != kDataTypeFloatArray) { + std::cout << "Data type " << type << " doesn't match Float Array" << std::endl; + return std::vector(); + } + pos += 8; + + auto len = *reinterpret_cast(pos); + pos += 8; + + auto ptr = reinterpret_cast(pos); + pos += 8*len; + return std::vector(ptr, ptr + len); +} diff --git a/integration/xgboost/processor/src/include/dam.h b/integration/xgboost/processor/src/include/dam.h new file mode 100644 index 0000000000..e6afd44299 --- /dev/null +++ b/integration/xgboost/processor/src/include/dam.h @@ -0,0 +1,93 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include + +const char kSignature[] = "NVDADAM1"; // DAM (Direct Accessible Marshalling) V1 +const int kPrefixLen = 24; + +const int kDataTypeInt = 1; +const int kDataTypeFloat = 2; +const int kDataTypeString = 3; +const int kDataTypeIntArray = 257; +const int kDataTypeFloatArray = 258; + +const int kDataTypeMap = 1025; + +class Entry { + public: + int64_t data_type; + uint8_t * pointer; + int64_t size; + + Entry(int64_t data_type, uint8_t *pointer, int64_t size) { + this->data_type = data_type; + this->pointer = pointer; + this->size = size; + } +}; + +class DamEncoder { + private: + bool encoded = false; + int64_t data_set_id; + std::vector *entries = new std::vector(); + + public: + explicit DamEncoder(int64_t data_set_id) { + this->data_set_id = data_set_id; + } + + void AddIntArray(std::vector &value); + + void AddFloatArray(std::vector &value); + + std::uint8_t * Finish(size_t &size); + + private: + std::size_t calculate_size(); +}; + +class DamDecoder { + private: + std::uint8_t *buffer = nullptr; + std::size_t buf_size = 0; + std::uint8_t *pos = nullptr; + std::size_t remaining = 0; + int64_t data_set_id = 0; + int64_t len = 0; + + public: + explicit DamDecoder(std::uint8_t *buffer, std::size_t size); + + size_t Size() { + return len; + } + + int64_t GetDataSetId() { + return data_set_id; + } + + bool IsValid(); + + std::vector DecodeIntArray(); + + std::vector DecodeFloatArray(); +}; + +void print_buffer(uint8_t *buffer, int size); diff --git a/integration/xgboost/processor/src/include/nvflare_processor.h b/integration/xgboost/processor/src/include/nvflare_processor.h new file mode 100644 index 0000000000..52cf42920f --- /dev/null +++ b/integration/xgboost/processor/src/include/nvflare_processor.h @@ -0,0 +1,71 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include +#include "processing/processor.h" + +const int kDataSetHGPairs = 1; +const int kDataSetAggregation = 2; +const int kDataSetAggregationWithFeatures = 3; +const int kDataSetAggregationResult = 4; + +class NVFlareProcessor: public processing::Processor { + private: + bool active_ = false; + const std::map *params_; + std::vector *gh_pairs_{nullptr}; + std::vector cuts_; + std::vector slots_; + bool feature_sent_ = false; + std::vector features_; + + public: + void Initialize(bool active, std::map params) override { + this->active_ = active; + this->params_ = ¶ms; + } + + void Shutdown() override { + this->gh_pairs_ = nullptr; + this->cuts_.clear(); + this->slots_.clear(); + } + + void FreeBuffer(void *buffer) override { + free(buffer); + } + + void* ProcessGHPairs(size_t &size, std::vector& pairs) override; + + void* HandleGHPairs(size_t &size, void *buffer, size_t buf_size) override; + + void InitAggregationContext(const std::vector &cuts, std::vector &slots) override { + if (this->slots_.empty()) { + this->cuts_ = std::vector(cuts); + this->slots_ = std::vector(slots); + } else { + std::cout << "Multiple calls to InitAggregationContext" << std::endl; + } + } + + void *ProcessAggregation(size_t &size, std::map> nodes) override; + + std::vector HandleAggregation(void *buffer, size_t buf_size) override; + +}; \ No newline at end of file diff --git a/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc b/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc new file mode 100644 index 0000000000..dce1701f7e --- /dev/null +++ b/integration/xgboost/processor/src/nvflare-plugin/nvflare_processor.cc @@ -0,0 +1,166 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "nvflare_processor.h" +#include "dam.h" + +const char kPluginName[] = "nvflare"; + +using std::vector; +using std::cout; +using std::endl; + +void* NVFlareProcessor::ProcessGHPairs(size_t &size, std::vector& pairs) { + cout << "ProcessGHPairs called with pairs size: " << pairs.size() << endl; + gh_pairs_ = new std::vector(pairs); + + DamEncoder encoder(kDataSetHGPairs); + encoder.AddFloatArray(pairs); + auto buffer = encoder.Finish(size); + + return buffer; +} + +void* NVFlareProcessor::HandleGHPairs(size_t &size, void *buffer, size_t buf_size) { + cout << "HandleGHPairs called with buffer size: " << buf_size << " Active: " << active_ << endl; + size = buf_size; + return buffer; +} + +void *NVFlareProcessor::ProcessAggregation(size_t &size, std::map> nodes) { + cout << "ProcessAggregation called with " << nodes.size() << " nodes" << endl; + + int64_t data_set_id; + if (!feature_sent_) { + data_set_id = kDataSetAggregationWithFeatures; + feature_sent_ = true; + } else { + data_set_id = kDataSetAggregation; + } + + DamEncoder encoder(data_set_id); + + // Add cuts pointers + vector cuts_vec; + for (auto value : cuts_) { + cuts_vec.push_back(value); + } + encoder.AddIntArray(cuts_vec); + + auto num_features = cuts_.size() - 1; + auto num_samples = slots_.size() / num_features; + cout << "Samples: " << num_samples << " Features: " << num_features << endl; + + if (data_set_id == kDataSetAggregationWithFeatures) { + if (features_.empty()) { + for (std::size_t f = 0; f < num_features; f++) { + auto slot = slots_[f]; + if (slot >= 0) { + features_.push_back(f); + } + } + } + cout << "Including feature size: " << features_.size() << endl; + encoder.AddIntArray(features_); + + vector bins; + for (int i = 0; i < num_samples; i++) { + for (auto f : features_) { + auto index = f + i * num_features; + if (index > slots_.size()) { + cout << "Index is out of range " << index << endl; + } + auto slot = slots_[index]; + bins.push_back(slot); + } + } + encoder.AddIntArray(bins); + } + + // Add nodes to build + vector node_vec; + for (const auto &kv : nodes) { + std::cout << "Node: " << kv.first << " Rows: " << kv.second.size() << std::endl; + node_vec.push_back(kv.first); + } + encoder.AddIntArray(node_vec); + + // For each node, get the row_id/slot pair + for (const auto &kv : nodes) { + vector rows; + for (auto row : kv.second) { + rows.push_back(row); + } + encoder.AddIntArray(rows); + } + + auto buffer = encoder.Finish(size); + return buffer; +} + +std::vector NVFlareProcessor::HandleAggregation(void *buffer, size_t buf_size) { + cout << "HandleAggregation called with buffer size: " << buf_size << endl; + auto remaining = buf_size; + char *pointer = reinterpret_cast(buffer); + + // The buffer is concatenated by AllGather. It may contain multiple DAM buffers + std::vector result; + auto max_slot = cuts_.back(); + auto array_size = 2 * max_slot * sizeof(double); + double *slots = static_cast(malloc(array_size)); + while (remaining > kPrefixLen) { + DamDecoder decoder(reinterpret_cast(pointer), remaining); + if (!decoder.IsValid()) { + cout << "Not DAM encoded buffer ignored at offset: " << (int)(pointer - (char *)buffer) << endl; + break; + } + auto size = decoder.Size(); + auto node_list = decoder.DecodeIntArray(); + for (auto node : node_list) { + memset(slots, 0, array_size); + auto feature_list = decoder.DecodeIntArray(); + // Convert per-feature histo to a flat one + for (auto f : feature_list) { + auto base = cuts_[f]; + auto bins = decoder.DecodeFloatArray(); + auto n = bins.size() / 2; + for (int i = 0; i < n; i++) { + auto index = base + i; + slots[2 * index] += bins[2 * i]; + slots[2 * index + 1] += bins[2 * i + 1]; + } + } + result.insert(result.end(), slots, slots + 2 * max_slot); + } + remaining -= size; + pointer += size; + } + free(slots); + + return result; +} + +extern "C" { + +processing::Processor *LoadProcessor(char *plugin_name) { + if (strcasecmp(plugin_name, kPluginName) != 0) { + cout << "Unknown plugin name: " << plugin_name << endl; + return nullptr; + } + + return new NVFlareProcessor(); +} +} diff --git a/integration/xgboost/processor/tests/CMakeLists.txt b/integration/xgboost/processor/tests/CMakeLists.txt new file mode 100644 index 0000000000..893d8738dc --- /dev/null +++ b/integration/xgboost/processor/tests/CMakeLists.txt @@ -0,0 +1,14 @@ +file(GLOB_RECURSE TEST_SOURCES "*.cc") + +target_sources(proc_test PRIVATE ${TEST_SOURCES}) + +target_include_directories(proc_test + PRIVATE + ${GTEST_INCLUDE_DIRS} + ${proc_nvflare_SOURCE_DIR/tests} + ${proc_nvflare_SOURCE_DIR}/src) + +message("Include Dir: ${GTEST_INCLUDE_DIRS}") +target_link_libraries(proc_test + PRIVATE + ${GTEST_LIBRARIES}) diff --git a/integration/xgboost/processor/tests/test_dam.cc b/integration/xgboost/processor/tests/test_dam.cc new file mode 100644 index 0000000000..292161adba --- /dev/null +++ b/integration/xgboost/processor/tests/test_dam.cc @@ -0,0 +1,39 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "gtest/gtest.h" +#include "dam.h" + +TEST(DamTest, TestEncodeDecode) { + double float_array[] = {1.1, 1.2, 1.3, 1.4}; + int64_t int_array[] = {123, 456, 789}; + + DamEncoder encoder(123); + encoder.AddFloatArray(std::vector(float_array, float_array + 4)); + encoder.AddIntArray(std::vector(int_array, int_array + 3)); + size_t size; + auto buf = encoder.Finish(size); + std::cout << "Encoded size is " << size << std::endl; + + DamDecoder decoder(buf, size); + EXPECT_EQ(decoder.IsValid(), true); + EXPECT_EQ(decoder.GetDataSetId(), 123); + + auto float_vec = decoder.DecodeFloatArray(); + EXPECT_EQ(0, memcmp(float_vec.data(), float_array, float_vec.size()*8)); + + auto int_vec = decoder.DecodeIntArray(); + EXPECT_EQ(0, memcmp(int_vec.data(), int_array, int_vec.size()*8)); +} diff --git a/integration/xgboost/processor/tests/test_main.cc b/integration/xgboost/processor/tests/test_main.cc new file mode 100644 index 0000000000..20612fe4e4 --- /dev/null +++ b/integration/xgboost/processor/tests/test_main.cc @@ -0,0 +1,21 @@ +/** + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "gtest/gtest.h" + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/nvflare/app_common/xgb/csv_loader.py b/nvflare/app_common/xgb/csv_loader.py new file mode 100644 index 0000000000..d61104bee1 --- /dev/null +++ b/nvflare/app_common/xgb/csv_loader.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import xgboost as xgb + +from nvflare.app_common.xgb.data_loader import XGBDataLoader + + +class CsvDataLoader(XGBDataLoader): + def __init__(self, rank: int, folder: str): + """Reads CSV dataset and return XGB data matrix. + + Args: + rank: Rank of the site + folder: Folder to find the CSV files + """ + self.rank = rank + self.folder = folder + + def load_data(self, client_id: str): + + train_path = f"{self.folder}/site-{self.rank + 1}/train.csv" + valid_path = f"{self.folder}/site-{self.rank + 1}/valid.csv" + + if self.rank == 0: + label = "&label_column=0" + else: + label = "" + + # for Vertical XGBoost, read from csv with label_column and set data_split_mode to 1 for column mode + train_data = xgb.DMatrix(train_path + f"?format=csv{label}", data_split_mode=2) + valid_data = xgb.DMatrix(valid_path + f"?format=csv{label}", data_split_mode=2) + + return train_data, valid_data diff --git a/nvflare/app_common/xgb/executor.py b/nvflare/app_common/xgb/executor.py index 11b9e2c54f..e91a1ed31e 100644 --- a/nvflare/app_common/xgb/executor.py +++ b/nvflare/app_common/xgb/executor.py @@ -32,7 +32,7 @@ def __init__( adaptor_component_id: str, configure_task_name=Constant.CONFIG_TASK_NAME, start_task_name=Constant.START_TASK_NAME, - req_timeout=10.0, + req_timeout=60.0, ): """Constructor diff --git a/nvflare/app_common/xgb/fed_executor.py b/nvflare/app_common/xgb/fed_executor.py index f64ecf5f32..636ca77cbb 100644 --- a/nvflare/app_common/xgb/fed_executor.py +++ b/nvflare/app_common/xgb/fed_executor.py @@ -27,7 +27,7 @@ def __init__( verbose_eval=False, use_gpus=False, int_server_grpc_options=None, - req_timeout=10.0, + req_timeout=60.0, model_file_name="model.json", in_process=True, ): diff --git a/nvflare/app_common/xgb/grpc_client.py b/nvflare/app_common/xgb/grpc_client.py index a3cbe62102..9f44cc278f 100644 --- a/nvflare/app_common/xgb/grpc_client.py +++ b/nvflare/app_common/xgb/grpc_client.py @@ -80,7 +80,9 @@ def send_allgather(self, seq_num, rank, data: bytes): send_buffer=data, ) + self.logger.info(f"Allgather is sending {len(data)} bytes Rank: {rank} Seq: {seq_num}") result = self.stub.Allgather(req) + if not isinstance(result, pb2.AllgatherReply): self.logger.error(f"expect reply to be pb2.AllgatherReply but got {type(result)}") return None diff --git a/nvflare/app_common/xgb/paillier/__init__.py b/nvflare/app_common/xgb/he/__init__.py similarity index 100% rename from nvflare/app_common/xgb/paillier/__init__.py rename to nvflare/app_common/xgb/he/__init__.py diff --git a/nvflare/app_common/xgb/paillier/adder.py b/nvflare/app_common/xgb/he/adder.py similarity index 74% rename from nvflare/app_common/xgb/paillier/adder.py rename to nvflare/app_common/xgb/he/adder.py index 085bba1a2f..8eebac8d4a 100644 --- a/nvflare/app_common/xgb/paillier/adder.py +++ b/nvflare/app_common/xgb/he/adder.py @@ -13,14 +13,17 @@ # limitations under the License. import concurrent.futures +import itertools -from nvflare.app_common.xgb.aggr import Aggregator +from nvflare.app_common.xgb.he.aggr import Aggregator +from .cipher import Cipher from .util import encode_encrypted_numbers_to_str class Adder: - def __init__(self, max_workers=10): + def __init__(self, cipher: Cipher, max_workers=10): + self.cipher = cipher self.exe = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) def add(self, encrypted_numbers, features, sample_groups=None, encode_sum=True): @@ -50,17 +53,17 @@ def add(self, encrypted_numbers, features, sample_groups=None, encode_sum=True): gid, sample_id_list = g items.append((encode_sum, fid, encrypted_numbers, mask, num_bins, gid, sample_id_list)) - results = self.exe.map(_do_add, items) + results = self.exe.map(_do_add, items, itertools.repeat(self.cipher)) rl = [] for r in results: rl.append(r) return rl -def _do_add(item): +def _do_add(item, cipher): encode_sum, fid, encrypted_numbers, mask, num_bins, gid, sample_id_list = item # bins = [0 for _ in range(num_bins)] - aggr = Aggregator() + aggr = Aggregator(cipher) bins = aggr.aggregate( gh_values=encrypted_numbers, @@ -68,24 +71,6 @@ def _do_add(item): num_bins=num_bins, sample_ids=sample_id_list, ) - # - # if not sample_id_list: - # # all samples - # for sample_id in range(len(encrypted_numbers)): - # bid = mask[sample_id] - # if bins[bid] == 0: - # # avoid plain_text + cypher_text, which could be slow! - # bins[bid] = encrypted_numbers[sample_id] - # else: - # bins[bid] += encrypted_numbers[sample_id] - # else: - # for sample_id in sample_id_list: - # bid = mask[sample_id] - # if bins[bid] == 0: - # # avoid plain_text + cypher_text, which could be slow! - # bins[bid] = encrypted_numbers[sample_id] - # else: - # bins[bid] += encrypted_numbers[sample_id] if encode_sum: sums = encode_encrypted_numbers_to_str(bins) diff --git a/nvflare/app_common/xgb/aggr.py b/nvflare/app_common/xgb/he/aggr.py similarity index 63% rename from nvflare/app_common/xgb/aggr.py rename to nvflare/app_common/xgb/he/aggr.py index 9da4088611..a6982a34cf 100644 --- a/nvflare/app_common/xgb/aggr.py +++ b/nvflare/app_common/xgb/he/aggr.py @@ -1,4 +1,16 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,15 +23,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvflare.app_common.xgb.he.cipher import Cipher class Aggregator: - def __init__(self, initial_value=0): + def __init__(self, cipher: Cipher, initial_value=0): + self.cipher = cipher self.initial_value = initial_value - def add(self, a, b): - return a + b - def _update_aggregation(self, gh_values, sample_bin_assignment, sample_id, aggr): bin_id = sample_bin_assignment[sample_id] sample_value = gh_values[sample_id] @@ -28,7 +39,7 @@ def _update_aggregation(self, gh_values, sample_bin_assignment, sample_id, aggr) # avoid add since sample_value may be cypher-text! aggr[bin_id] = sample_value else: - aggr[bin_id] = self.add(current_value, sample_value) + aggr[bin_id] = self.cipher.add(current_value, sample_value) def aggregate(self, gh_values: list, sample_bin_assignment, num_bins, sample_ids): aggr_result = [self.initial_value] * num_bins diff --git a/nvflare/app_common/xgb/he/cipher.py b/nvflare/app_common/xgb/he/cipher.py new file mode 100644 index 0000000000..7b928feeaa --- /dev/null +++ b/nvflare/app_common/xgb/he/cipher.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, Union + +ClearText = Union[float, int] + + +class Cipher(ABC): + """An abstract class for Homomorphic Encryption operations""" + @abstractmethod + def generate_keys(self, key_length: int): + pass + + @abstractmethod + def get_public_key_str(self) -> str: + pass + + @abstractmethod + def set_public_key_str(self, public_key_str: str): + pass + + @abstractmethod + def encode_cipher_text(self, cipher_text: Any) -> Any: + pass + + @abstractmethod + def decode_cipher_text(self, encoded_cipher_text : Any) -> Any: + pass + + @abstractmethod + def encrypt(self, clear_text: ClearText) -> Any: + pass + + @abstractmethod + def decrypt(self, cipher_text: Any) -> ClearText: + pass + + @abstractmethod + def add(self, a: Any, b: Any) -> Any: + pass diff --git a/nvflare/app_common/xgb/paillier/decrypter.py b/nvflare/app_common/xgb/he/decrypter.py similarity index 100% rename from nvflare/app_common/xgb/paillier/decrypter.py rename to nvflare/app_common/xgb/he/decrypter.py diff --git a/nvflare/app_common/xgb/paillier/encryptor.py b/nvflare/app_common/xgb/he/encryptor.py similarity index 97% rename from nvflare/app_common/xgb/paillier/encryptor.py rename to nvflare/app_common/xgb/he/encryptor.py index 108098a2b1..d506c2a107 100644 --- a/nvflare/app_common/xgb/paillier/encryptor.py +++ b/nvflare/app_common/xgb/he/encryptor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +Z# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nvflare/app_common/xgb/he/phe_cipher.py b/nvflare/app_common/xgb/he/phe_cipher.py new file mode 100644 index 0000000000..e90cdbd63a --- /dev/null +++ b/nvflare/app_common/xgb/he/phe_cipher.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import phe + +from nvflare.app_common.xgb.he.cipher import Cipher, ClearText + + +class PheCipher(Cipher): + + def __init__(self): + self.public_key = None + self.private_key = None + + def generate_keys(self, key_length: int): + self.public_key, self.private_key = phe.paillier.generate_paillier_keypair(n_length=key_length) + + def get_public_key_str(self) -> str: + return phe.util.int_to_base64(self.public_key.n) + + def set_public_key_str(self, public_key_str: str): + self.public_key = phe.paillier.PaillierPublicKey(n=phe.util.base64_to_int(public_key_str)) + + def encode_cipher_text(self, cipher_text: Any) -> Any: + if not isinstance(cipher_text, phe.paillier.EncryptedNumber): + raise TypeError(f"Invalid type {type(cipher_text)}") + + return phe.util.int_to_base64(cipher_text.ciphertext()), cipher_text.exponent + + def decode_cipher_text(self, encoded_cipher_text: Any) -> Any: + cipher_str, exp = encoded_cipher_text + return phe.paillier.EncryptedNumber( + self.public_key, ciphertext=phe.util.base64_to_int(cipher_str), exponent=exp) + + def encrypt(self, clear_text: ClearText) -> Any: + return self.public_key.encrypt(clear_text) + + def decrypt(self, cipher_text: Any) -> ClearText: + return self.private_key.decrpty(cipher_text) + + def add(self, a: Any, b: Any) -> Any: + return a + b diff --git a/nvflare/app_common/xgb/paillier/util.py b/nvflare/app_common/xgb/he/util.py similarity index 99% rename from nvflare/app_common/xgb/paillier/util.py rename to nvflare/app_common/xgb/he/util.py index 2e3c290836..6b70469a0d 100644 --- a/nvflare/app_common/xgb/paillier/util.py +++ b/nvflare/app_common/xgb/he/util.py @@ -14,8 +14,6 @@ import json -import phe - SCALE_FACTOR = 10000000000000 diff --git a/nvflare/app_common/xgb/mock/mock_data_converter.py b/nvflare/app_common/xgb/mock/mock_data_converter.py index b840f6c623..6f6f90b88d 100644 --- a/nvflare/app_common/xgb/mock/mock_data_converter.py +++ b/nvflare/app_common/xgb/mock/mock_data_converter.py @@ -17,7 +17,7 @@ from typing import Dict, List, Tuple from nvflare.apis.fl_context import FLContext -from nvflare.app_common.xgb.aggr import Aggregator +from nvflare.app_common.xgb.he.aggr import Aggregator from nvflare.app_common.xgb.defs import Constant from nvflare.app_common.xgb.sec.data_converter import ( AggregationContext, diff --git a/nvflare/app_common/xgb/proto/federated.proto b/nvflare/app_common/xgb/proto/federated.proto index f412204813..a37e63526b 100644 --- a/nvflare/app_common/xgb/proto/federated.proto +++ b/nvflare/app_common/xgb/proto/federated.proto @@ -1,9 +1,10 @@ /*! - * Copyright 2022-2023 XGBoost contributors + * Copyright 2022 XGBoost contributors + * This is federated.old.proto from XGBoost */ syntax = "proto3"; -package xgboost.collective.federated; +package xgboost.federated; service Federated { rpc Allgather(AllgatherRequest) returns (AllgatherReply) {} @@ -13,18 +14,14 @@ service Federated { } enum DataType { - HALF = 0; - FLOAT = 1; - DOUBLE = 2; - LONG_DOUBLE = 3; - INT8 = 4; - INT16 = 5; - INT32 = 6; - INT64 = 7; - UINT8 = 8; - UINT16 = 9; - UINT32 = 10; - UINT64 = 11; + INT8 = 0; + UINT8 = 1; + INT32 = 2; + UINT32 = 3; + INT64 = 4; + UINT64 = 5; + FLOAT = 6; + DOUBLE = 7; } enum ReduceOperation { @@ -82,4 +79,4 @@ message BroadcastRequest { message BroadcastReply { bytes receive_buffer = 1; -} \ No newline at end of file +} diff --git a/nvflare/app_common/xgb/proto/federated_pb2.py b/nvflare/app_common/xgb/proto/federated_pb2.py index ba80c1e5d6..45095b265e 100644 --- a/nvflare/app_common/xgb/proto/federated_pb2.py +++ b/nvflare/app_common/xgb/proto/federated_pb2.py @@ -1,25 +1,12 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: federated.proto +# Protobuf Python Version: 4.25.1 """Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -27,33 +14,33 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x66\x65\x64\x65rated.proto\x12\x1cxgboost.collective.federated\"N\n\x10\x41llgatherRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\"(\n\x0e\x41llgatherReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"O\n\x11\x41llgatherVRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\")\n\x0f\x41llgatherVReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\xd2\x01\n\x10\x41llreduceRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12\x39\n\tdata_type\x18\x04 \x01(\x0e\x32&.xgboost.collective.federated.DataType\x12G\n\x10reduce_operation\x18\x05 \x01(\x0e\x32-.xgboost.collective.federated.ReduceOperation\"(\n\x0e\x41llreduceReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\\\n\x10\x42roadcastRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12\x0c\n\x04root\x18\x04 \x01(\x05\"(\n\x0e\x42roadcastReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c*\x96\x01\n\x08\x44\x61taType\x12\x08\n\x04HALF\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06\x44OUBLE\x10\x02\x12\x0f\n\x0bLONG_DOUBLE\x10\x03\x12\x08\n\x04INT8\x10\x04\x12\t\n\x05INT16\x10\x05\x12\t\n\x05INT32\x10\x06\x12\t\n\x05INT64\x10\x07\x12\t\n\x05UINT8\x10\x08\x12\n\n\x06UINT16\x10\t\x12\n\n\x06UINT32\x10\n\x12\n\n\x06UINT64\x10\x0b*^\n\x0fReduceOperation\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03MIN\x10\x01\x12\x07\n\x03SUM\x10\x02\x12\x0f\n\x0b\x42ITWISE_AND\x10\x03\x12\x0e\n\nBITWISE_OR\x10\x04\x12\x0f\n\x0b\x42ITWISE_XOR\x10\x05\x32\xc2\x03\n\tFederated\x12k\n\tAllgather\x12..xgboost.collective.federated.AllgatherRequest\x1a,.xgboost.collective.federated.AllgatherReply\"\x00\x12n\n\nAllgatherV\x12/.xgboost.collective.federated.AllgatherVRequest\x1a-.xgboost.collective.federated.AllgatherVReply\"\x00\x12k\n\tAllreduce\x12..xgboost.collective.federated.AllreduceRequest\x1a,.xgboost.collective.federated.AllreduceReply\"\x00\x12k\n\tBroadcast\x12..xgboost.collective.federated.BroadcastRequest\x1a,.xgboost.collective.federated.BroadcastReply\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x66\x65\x64\x65rated.proto\x12\x11xgboost.federated\"N\n\x10\x41llgatherRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\"(\n\x0e\x41llgatherReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"O\n\x11\x41llgatherVRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\")\n\x0f\x41llgatherVReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\xbc\x01\n\x10\x41llreduceRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12.\n\tdata_type\x18\x04 \x01(\x0e\x32\x1b.xgboost.federated.DataType\x12<\n\x10reduce_operation\x18\x05 \x01(\x0e\x32\".xgboost.federated.ReduceOperation\"(\n\x0e\x41llreduceReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\\\n\x10\x42roadcastRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12\x0c\n\x04root\x18\x04 \x01(\x05\"(\n\x0e\x42roadcastReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c*d\n\x08\x44\x61taType\x12\x08\n\x04INT8\x10\x00\x12\t\n\x05UINT8\x10\x01\x12\t\n\x05INT32\x10\x02\x12\n\n\x06UINT32\x10\x03\x12\t\n\x05INT64\x10\x04\x12\n\n\x06UINT64\x10\x05\x12\t\n\x05\x46LOAT\x10\x06\x12\n\n\x06\x44OUBLE\x10\x07*^\n\x0fReduceOperation\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03MIN\x10\x01\x12\x07\n\x03SUM\x10\x02\x12\x0f\n\x0b\x42ITWISE_AND\x10\x03\x12\x0e\n\nBITWISE_OR\x10\x04\x12\x0f\n\x0b\x42ITWISE_XOR\x10\x05\x32\xea\x02\n\tFederated\x12U\n\tAllgather\x12#.xgboost.federated.AllgatherRequest\x1a!.xgboost.federated.AllgatherReply\"\x00\x12X\n\nAllgatherV\x12$.xgboost.federated.AllgatherVRequest\x1a\".xgboost.federated.AllgatherVReply\"\x00\x12U\n\tAllreduce\x12#.xgboost.federated.AllreduceRequest\x1a!.xgboost.federated.AllreduceReply\"\x00\x12U\n\tBroadcast\x12#.xgboost.federated.BroadcastRequest\x1a!.xgboost.federated.BroadcastReply\"\x00\x62\x06proto3') -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'federated_pb2', globals()) +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'federated_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _DATATYPE._serialized_start=687 - _DATATYPE._serialized_end=837 - _REDUCEOPERATION._serialized_start=839 - _REDUCEOPERATION._serialized_end=933 - _ALLGATHERREQUEST._serialized_start=49 - _ALLGATHERREQUEST._serialized_end=127 - _ALLGATHERREPLY._serialized_start=129 - _ALLGATHERREPLY._serialized_end=169 - _ALLGATHERVREQUEST._serialized_start=171 - _ALLGATHERVREQUEST._serialized_end=250 - _ALLGATHERVREPLY._serialized_start=252 - _ALLGATHERVREPLY._serialized_end=293 - _ALLREDUCEREQUEST._serialized_start=296 - _ALLREDUCEREQUEST._serialized_end=506 - _ALLREDUCEREPLY._serialized_start=508 - _ALLREDUCEREPLY._serialized_end=548 - _BROADCASTREQUEST._serialized_start=550 - _BROADCASTREQUEST._serialized_end=642 - _BROADCASTREPLY._serialized_start=644 - _BROADCASTREPLY._serialized_end=684 - _FEDERATED._serialized_start=936 - _FEDERATED._serialized_end=1386 + _globals['_DATATYPE']._serialized_start=653 + _globals['_DATATYPE']._serialized_end=753 + _globals['_REDUCEOPERATION']._serialized_start=755 + _globals['_REDUCEOPERATION']._serialized_end=849 + _globals['_ALLGATHERREQUEST']._serialized_start=38 + _globals['_ALLGATHERREQUEST']._serialized_end=116 + _globals['_ALLGATHERREPLY']._serialized_start=118 + _globals['_ALLGATHERREPLY']._serialized_end=158 + _globals['_ALLGATHERVREQUEST']._serialized_start=160 + _globals['_ALLGATHERVREQUEST']._serialized_end=239 + _globals['_ALLGATHERVREPLY']._serialized_start=241 + _globals['_ALLGATHERVREPLY']._serialized_end=282 + _globals['_ALLREDUCEREQUEST']._serialized_start=285 + _globals['_ALLREDUCEREQUEST']._serialized_end=473 + _globals['_ALLREDUCEREPLY']._serialized_start=475 + _globals['_ALLREDUCEREPLY']._serialized_end=515 + _globals['_BROADCASTREQUEST']._serialized_start=517 + _globals['_BROADCASTREQUEST']._serialized_end=609 + _globals['_BROADCASTREPLY']._serialized_start=611 + _globals['_BROADCASTREPLY']._serialized_end=651 + _globals['_FEDERATED']._serialized_start=852 + _globals['_FEDERATED']._serialized_end=1214 # @@protoc_insertion_point(module_scope) diff --git a/nvflare/app_common/xgb/proto/federated_pb2.pyi b/nvflare/app_common/xgb/proto/federated_pb2.pyi index 8e2a7e740e..750db95a25 100644 --- a/nvflare/app_common/xgb/proto/federated_pb2.pyi +++ b/nvflare/app_common/xgb/proto/federated_pb2.pyi @@ -3,98 +3,108 @@ from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union -BITWISE_AND: ReduceOperation -BITWISE_OR: ReduceOperation -BITWISE_XOR: ReduceOperation DESCRIPTOR: _descriptor.FileDescriptor -DOUBLE: DataType -FLOAT: DataType -HALF: DataType -INT16: DataType + +class DataType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + INT8: _ClassVar[DataType] + UINT8: _ClassVar[DataType] + INT32: _ClassVar[DataType] + UINT32: _ClassVar[DataType] + INT64: _ClassVar[DataType] + UINT64: _ClassVar[DataType] + FLOAT: _ClassVar[DataType] + DOUBLE: _ClassVar[DataType] + +class ReduceOperation(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + MAX: _ClassVar[ReduceOperation] + MIN: _ClassVar[ReduceOperation] + SUM: _ClassVar[ReduceOperation] + BITWISE_AND: _ClassVar[ReduceOperation] + BITWISE_OR: _ClassVar[ReduceOperation] + BITWISE_XOR: _ClassVar[ReduceOperation] +INT8: DataType +UINT8: DataType INT32: DataType +UINT32: DataType INT64: DataType -INT8: DataType -LONG_DOUBLE: DataType +UINT64: DataType +FLOAT: DataType +DOUBLE: DataType MAX: ReduceOperation MIN: ReduceOperation SUM: ReduceOperation -UINT16: DataType -UINT32: DataType -UINT64: DataType -UINT8: DataType - -class AllgatherReply(_message.Message): - __slots__ = ["receive_buffer"] - RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] - receive_buffer: bytes - def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... +BITWISE_AND: ReduceOperation +BITWISE_OR: ReduceOperation +BITWISE_XOR: ReduceOperation class AllgatherRequest(_message.Message): - __slots__ = ["rank", "send_buffer", "sequence_number"] + __slots__ = ("sequence_number", "rank", "send_buffer") + SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] - SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] + sequence_number: int rank: int send_buffer: bytes - sequence_number: int def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ...) -> None: ... -class AllgatherVReply(_message.Message): - __slots__ = ["receive_buffer"] +class AllgatherReply(_message.Message): + __slots__ = ("receive_buffer",) RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... class AllgatherVRequest(_message.Message): - __slots__ = ["rank", "send_buffer", "sequence_number"] + __slots__ = ("sequence_number", "rank", "send_buffer") + SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] - SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] + sequence_number: int rank: int send_buffer: bytes - sequence_number: int def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ...) -> None: ... -class AllreduceReply(_message.Message): - __slots__ = ["receive_buffer"] +class AllgatherVReply(_message.Message): + __slots__ = ("receive_buffer",) RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... class AllreduceRequest(_message.Message): - __slots__ = ["data_type", "rank", "reduce_operation", "send_buffer", "sequence_number"] - DATA_TYPE_FIELD_NUMBER: _ClassVar[int] + __slots__ = ("sequence_number", "rank", "send_buffer", "data_type", "reduce_operation") + SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] - REDUCE_OPERATION_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] - SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] - data_type: DataType + DATA_TYPE_FIELD_NUMBER: _ClassVar[int] + REDUCE_OPERATION_FIELD_NUMBER: _ClassVar[int] + sequence_number: int rank: int - reduce_operation: ReduceOperation send_buffer: bytes - sequence_number: int + data_type: DataType + reduce_operation: ReduceOperation def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ..., data_type: _Optional[_Union[DataType, str]] = ..., reduce_operation: _Optional[_Union[ReduceOperation, str]] = ...) -> None: ... -class BroadcastReply(_message.Message): - __slots__ = ["receive_buffer"] +class AllreduceReply(_message.Message): + __slots__ = ("receive_buffer",) RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] receive_buffer: bytes def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... class BroadcastRequest(_message.Message): - __slots__ = ["rank", "root", "send_buffer", "sequence_number"] + __slots__ = ("sequence_number", "rank", "send_buffer", "root") + SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] RANK_FIELD_NUMBER: _ClassVar[int] - ROOT_FIELD_NUMBER: _ClassVar[int] SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] - SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] + ROOT_FIELD_NUMBER: _ClassVar[int] + sequence_number: int rank: int - root: int send_buffer: bytes - sequence_number: int + root: int def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ..., root: _Optional[int] = ...) -> None: ... -class DataType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] - -class ReduceOperation(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] +class BroadcastReply(_message.Message): + __slots__ = ("receive_buffer",) + RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] + receive_buffer: bytes + def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... diff --git a/nvflare/app_common/xgb/proto/federated_pb2_grpc.py b/nvflare/app_common/xgb/proto/federated_pb2_grpc.py index 206d8474da..f906ff4def 100644 --- a/nvflare/app_common/xgb/proto/federated_pb2_grpc.py +++ b/nvflare/app_common/xgb/proto/federated_pb2_grpc.py @@ -1,17 +1,3 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc @@ -29,22 +15,22 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Allgather = channel.unary_unary( - '/xgboost.collective.federated.Federated/Allgather', + '/xgboost.federated.Federated/Allgather', request_serializer=federated__pb2.AllgatherRequest.SerializeToString, response_deserializer=federated__pb2.AllgatherReply.FromString, ) self.AllgatherV = channel.unary_unary( - '/xgboost.collective.federated.Federated/AllgatherV', + '/xgboost.federated.Federated/AllgatherV', request_serializer=federated__pb2.AllgatherVRequest.SerializeToString, response_deserializer=federated__pb2.AllgatherVReply.FromString, ) self.Allreduce = channel.unary_unary( - '/xgboost.collective.federated.Federated/Allreduce', + '/xgboost.federated.Federated/Allreduce', request_serializer=federated__pb2.AllreduceRequest.SerializeToString, response_deserializer=federated__pb2.AllreduceReply.FromString, ) self.Broadcast = channel.unary_unary( - '/xgboost.collective.federated.Federated/Broadcast', + '/xgboost.federated.Federated/Broadcast', request_serializer=federated__pb2.BroadcastRequest.SerializeToString, response_deserializer=federated__pb2.BroadcastReply.FromString, ) @@ -102,7 +88,7 @@ def add_FederatedServicer_to_server(servicer, server): ), } generic_handler = grpc.method_handlers_generic_handler( - 'xgboost.collective.federated.Federated', rpc_method_handlers) + 'xgboost.federated.Federated', rpc_method_handlers) server.add_generic_rpc_handlers((generic_handler,)) @@ -121,7 +107,7 @@ def Allgather(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/xgboost.collective.federated.Federated/Allgather', + return grpc.experimental.unary_unary(request, target, '/xgboost.federated.Federated/Allgather', federated__pb2.AllgatherRequest.SerializeToString, federated__pb2.AllgatherReply.FromString, options, channel_credentials, @@ -138,7 +124,7 @@ def AllgatherV(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/xgboost.collective.federated.Federated/AllgatherV', + return grpc.experimental.unary_unary(request, target, '/xgboost.federated.Federated/AllgatherV', federated__pb2.AllgatherVRequest.SerializeToString, federated__pb2.AllgatherVReply.FromString, options, channel_credentials, @@ -155,7 +141,7 @@ def Allreduce(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/xgboost.collective.federated.Federated/Allreduce', + return grpc.experimental.unary_unary(request, target, '/xgboost.federated.Federated/Allreduce', federated__pb2.AllreduceRequest.SerializeToString, federated__pb2.AllreduceReply.FromString, options, channel_credentials, @@ -172,7 +158,7 @@ def Broadcast(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/xgboost.collective.federated.Federated/Broadcast', + return grpc.experimental.unary_unary(request, target, '/xgboost.federated.Federated/Broadcast', federated__pb2.BroadcastRequest.SerializeToString, federated__pb2.BroadcastReply.FromString, options, channel_credentials, diff --git a/nvflare/app_common/xgb/runners/xgb_client_runner.py b/nvflare/app_common/xgb/runners/xgb_client_runner.py index 80955a27af..e1aea9af91 100644 --- a/nvflare/app_common/xgb/runners/xgb_client_runner.py +++ b/nvflare/app_common/xgb/runners/xgb_client_runner.py @@ -97,7 +97,7 @@ def run(self, ctx: dict): self._world_size = ctx.get(Constant.RUNNER_CTX_WORLD_SIZE) self._num_rounds = ctx.get(Constant.RUNNER_CTX_NUM_ROUNDS) self._server_addr = ctx.get(Constant.RUNNER_CTX_SERVER_ADDR) - self._data_loader = ctx.get(Constant.RUNNER_CTX_DATA_LOADER) + # self._data_loader = ctx.get(Constant.RUNNER_CTX_DATA_LOADER) self._tb_dir = ctx.get(Constant.RUNNER_CTX_TB_DIR) self._model_dir = ctx.get(Constant.RUNNER_CTX_MODEL_DIR) diff --git a/nvflare/app_common/xgb/sec/client_handler.py b/nvflare/app_common/xgb/sec/client_handler.py index fee7ccfb8f..dae48fdb9f 100644 --- a/nvflare/app_common/xgb/sec/client_handler.py +++ b/nvflare/app_common/xgb/sec/client_handler.py @@ -18,9 +18,8 @@ from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable -from nvflare.app_common.xgb.aggr import Aggregator +from nvflare.app_common.xgb.he.aggr import Aggregator from nvflare.app_common.xgb.defs import Constant -from nvflare.app_common.xgb.mock.mock_data_converter import MockDataConverter from nvflare.app_common.xgb.paillier.adder import Adder from nvflare.app_common.xgb.paillier.decrypter import Decrypter from nvflare.app_common.xgb.paillier.encryptor import Encryptor @@ -34,6 +33,7 @@ split, ) from nvflare.app_common.xgb.sec.data_converter import FeatureAggregationResult +from nvflare.app_common.xgb.sec.processor_data_converter import ProcessorDataConverter from nvflare.app_common.xgb.sec.sec_handler import SecurityHandler @@ -47,7 +47,7 @@ def __init__(self, key_length=1024, num_workers=10): self.encryptor = None self.adder = None self.decrypter = None - self.data_converter = MockDataConverter() + self.data_converter = ProcessorDataConverter() self.encrypted_ghs = None self.clear_ghs = None # for label client: list of tuples (g, h) self.original_gh_buffer = None @@ -168,7 +168,9 @@ def _process_before_all_gather_v(self, fl_ctx: FLContext): ) return - self.info(fl_ctx, f"_process_before_all_gather_v: non-label client - do encrypted aggr for grp {groups}") + self.info( + fl_ctx, f"_process_before_all_gather_v: non-label client - do encrypted aggr for {len(groups)} groups" + ) start = time.time() aggr_result = self.adder.add(self.encrypted_ghs, self.feature_masks, groups, encode_sum=True) self.info(fl_ctx, f"got aggr result for {len(aggr_result)} features in {time.time()-start} secs") diff --git a/nvflare/app_common/xgb/sec/dam.py b/nvflare/app_common/xgb/sec/dam.py new file mode 100644 index 0000000000..8367f9f86d --- /dev/null +++ b/nvflare/app_common/xgb/sec/dam.py @@ -0,0 +1,129 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import struct +from io import BytesIO +from typing import List + +SIGNATURE = "NVDADAM1" # DAM (Direct Accessible Marshalling) V1 +PREFIX_LEN = 24 + +DATA_TYPE_INT = 1 +DATA_TYPE_FLOAT = 2 +DATA_TYPE_STRING = 3 +DATA_TYPE_INT_ARRAY = 257 +DATA_TYPE_FLOAT_ARRAY = 258 + + +class DamEncoder: + def __init__(self, data_set_id: int): + self.data_set_id = data_set_id + self.entries = [] + self.buffer = BytesIO() + + def add_int_array(self, value: List[int]): + self.entries.append((DATA_TYPE_INT_ARRAY, value)) + + def add_float_array(self, value: List[float]): + self.entries.append((DATA_TYPE_FLOAT_ARRAY, value)) + + def finish(self) -> bytes: + size = PREFIX_LEN + for entry in self.entries: + size += 16 + size += len(entry[1]) * 8 + + self.write_str(SIGNATURE) + self.write_int64(size) + self.write_int64(self.data_set_id) + + for entry in self.entries: + data_type, value = entry + self.write_int64(data_type) + self.write_int64(len(value)) + + for x in value: + if data_type == DATA_TYPE_INT_ARRAY: + self.write_int64(x) + else: + self.write_float(x) + + return self.buffer.getvalue() + + def write_int64(self, value: int): + self.buffer.write(struct.pack("q", value)) + + def write_float(self, value: float): + self.buffer.write(struct.pack("d", value)) + + def write_str(self, value: str): + self.buffer.write(value.encode("utf-8")) + + +class DamDecoder: + def __init__(self, buffer: bytes): + self.buffer = buffer + self.pos = 0 + if len(buffer) >= PREFIX_LEN: + self.signature = self.read_string(8) + self.size = self.read_int64() + self.data_set_id = self.read_int64() + else: + self.signature = None + self.size = 0 + self.data_set_id = 0 + + def is_valid(self): + return self.signature == SIGNATURE + + def get_data_set_id(self): + return self.data_set_id + + def decode_int_array(self) -> List[int]: + data_type = self.read_int64() + if data_type != DATA_TYPE_INT_ARRAY: + raise RuntimeError("Invalid data type for int array") + + num = self.read_int64() + result = [0] * num + for i in range(num): + result[i] = self.read_int64() + + return result + + def decode_float_array(self): + data_type = self.read_int64() + if data_type != DATA_TYPE_FLOAT_ARRAY: + raise RuntimeError("Invalid data type for float array") + + num = self.read_int64() + result = [0.0] * num + for i in range(num): + result[i] = self.read_float() + + return result + + def read_string(self, length: int) -> str: + result = self.buffer[self.pos : self.pos + length].decode("latin1") + self.pos += length + return result + + def read_int64(self) -> int: + (result,) = struct.unpack_from("q", self.buffer, self.pos) + self.pos += 8 + return result + + def read_float(self) -> float: + (result,) = struct.unpack_from("d", self.buffer, self.pos) + self.pos += 8 + return result diff --git a/nvflare/app_common/xgb/sec/processor_data_converter.py b/nvflare/app_common/xgb/sec/processor_data_converter.py new file mode 100644 index 0000000000..f9d1287256 --- /dev/null +++ b/nvflare/app_common/xgb/sec/processor_data_converter.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Tuple + +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.xgb.sec.dam import DamDecoder, DamEncoder +from nvflare.app_common.xgb.sec.data_converter import ( + AggregationContext, + DataConverter, + FeatureAggregationResult, + FeatureContext, +) + +DATA_SET_GH_PAIRS = 1 +DATA_SET_AGGREGATION = 2 +DATA_SET_AGGREGATION_WITH_FEATURES = 3 +DATA_SET_AGGREGATION_RESULT = 4 + +SCALE_FACTOR = 1000000.0 # Preserve 6 decimal places + + +class ProcessorDataConverter(DataConverter): + def __init__(self): + super().__init__() + self.features = [] + self.feature_list = None + self.num_samples = 0 + + def decode_gh_pairs(self, buffer: bytes, fl_ctx: FLContext) -> List[Tuple[int, int]]: + decoder = DamDecoder(buffer) + if not decoder.is_valid(): + return None + + if decoder.get_data_set_id() != DATA_SET_GH_PAIRS: + raise RuntimeError(f"Data is not for GH Pairs: {decoder.get_data_set_id()}") + + float_array = decoder.decode_float_array() + result = [] + self.num_samples = int(len(float_array) / 2) + + for i in range(self.num_samples): + result.append((self.float_to_int(float_array[2 * i]), self.float_to_int(float_array[2 * i + 1]))) + + return result + + def decode_aggregation_context(self, buffer: bytes, fl_ctx: FLContext) -> AggregationContext: + decoder = DamDecoder(buffer) + if not decoder.is_valid(): + return None + data_set_id = decoder.get_data_set_id() + cuts = decoder.decode_int_array() + + if data_set_id == DATA_SET_AGGREGATION_WITH_FEATURES: + self.feature_list = decoder.decode_int_array() + num = len(self.feature_list) + slots = decoder.decode_int_array() + num_samples = int(len(slots) / num) + for i in range(num): + bin_assignment = [] + for row_id in range(num_samples): + _, bin_num = self.slot_to_bin(cuts, slots[row_id * num + i]) + bin_assignment.append(bin_num) + + bin_size = self.get_bin_size(cuts, self.feature_list[i]) + feature_ctx = FeatureContext(self.feature_list[i], bin_assignment, bin_size) + self.features.append(feature_ctx) + elif data_set_id != DATA_SET_AGGREGATION: + raise RuntimeError(f"Invalid DataSet: {data_set_id}") + + node_list = decoder.decode_int_array() + sample_groups = {} + for node in node_list: + row_ids = decoder.decode_int_array() + sample_groups[node] = row_ids + + return AggregationContext(self.features, sample_groups) + + def encode_aggregation_result( + self, aggr_results: Dict[int, List[FeatureAggregationResult]], fl_ctx: FLContext + ) -> bytes: + encoder = DamEncoder(DATA_SET_AGGREGATION_RESULT) + node_list = sorted(aggr_results.keys()) + encoder.add_int_array(node_list) + for node in node_list: + result_list = aggr_results.get(node) + feature_list = [result.feature_id for result in result_list] + encoder.add_int_array(feature_list) + for result in result_list: + encoder.add_float_array(self.to_float_array(result)) + + return encoder.finish() + + @staticmethod + def get_bin_size(cuts: [int], feature_id: int) -> int: + return cuts[feature_id + 1] - cuts[feature_id] + + @staticmethod + def slot_to_bin(cuts: [int], slot: int) -> Tuple[int, int]: + if slot < 0 or slot >= cuts[-1]: + raise RuntimeError(f"Invalid slot {slot}, out of range [0-{cuts[-1]-1}]") + + for i in range(len(cuts) - 1): + if cuts[i] <= slot < cuts[i + 1]: + bin_num = slot - cuts[i] + return i, bin_num + + raise RuntimeError(f"Logic error. Slot {slot}, out of range [0-{cuts[-1] - 1}]") + + @staticmethod + def float_to_int(value: float) -> int: + return int(value * SCALE_FACTOR) + + @staticmethod + def int_to_float(value: int) -> float: + return value / SCALE_FACTOR + + @staticmethod + def to_float_array(result: FeatureAggregationResult) -> List[float]: + float_array = [] + for (g, h) in result.aggregated_hist: + float_array.append(ProcessorDataConverter.int_to_float(g)) + float_array.append(ProcessorDataConverter.int_to_float(h)) + + return float_array diff --git a/tests/unit_test/app_common/xgb/sec/dam_test.py b/tests/unit_test/app_common/xgb/sec/dam_test.py new file mode 100644 index 0000000000..2ba44e1468 --- /dev/null +++ b/tests/unit_test/app_common/xgb/sec/dam_test.py @@ -0,0 +1,36 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.app_common.xgb.sec.dam import DamDecoder, DamEncoder + +DATA_SET = 123456 +INT_ARRAY = [123, 456, 789] +FLOAT_ARRAY = [1.2, 2.3, 3.4, 4.5] + + +class TestDam: + def test_encode_decode(self): + encoder = DamEncoder(DATA_SET) + encoder.add_int_array(INT_ARRAY) + encoder.add_float_array(FLOAT_ARRAY) + buffer = encoder.finish() + + decoder = DamDecoder(buffer) + assert decoder.is_valid() + assert decoder.get_data_set_id() == DATA_SET + + int_array = decoder.decode_int_array() + assert int_array == INT_ARRAY + + float_array = decoder.decode_float_array() + assert float_array == FLOAT_ARRAY diff --git a/tests/unit_test/app_common/xgb/sec/data_converter_test.py b/tests/unit_test/app_common/xgb/sec/data_converter_test.py new file mode 100644 index 0000000000..ccf2165529 --- /dev/null +++ b/tests/unit_test/app_common/xgb/sec/data_converter_test.py @@ -0,0 +1,126 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List + +import pytest + +from nvflare.app_common.xgb.sec.dam import DamDecoder, DamEncoder +from nvflare.app_common.xgb.sec.data_converter import FeatureAggregationResult +from nvflare.app_common.xgb.sec.processor_data_converter import ( + DATA_SET_AGGREGATION_WITH_FEATURES, + DATA_SET_GH_PAIRS, + ProcessorDataConverter, +) + + +class TestDataConverter: + @pytest.fixture() + def data_converter(self): + yield ProcessorDataConverter() + + @pytest.fixture() + def gh_buffer(self): + + gh = [0.1, 0.2, 1.2, 1.2, 2.1, 2.2, 3.1, 3.2, 4.1, 4.2, 5.1, 5.2, 6.1, 6.2, 7.1, 7.2, 8.1, 8.2, 9.1, 9.2] + + encoder = DamEncoder(DATA_SET_GH_PAIRS) + encoder.add_float_array(gh) + return encoder.finish() + + @pytest.fixture() + def aggr_buffer(self): + + encoder = DamEncoder(DATA_SET_AGGREGATION_WITH_FEATURES) + + cuts = [0, 2, 5, 10] + encoder.add_int_array(cuts) + + features = [0, 2] + encoder.add_int_array(features) + + slots = [ + 0, + 5, + 1, + 9, + 1, + 6, + 0, + 7, + 0, + 9, + 0, + 8, + 1, + 5, + 0, + 6, + 0, + 8, + 1, + 5, + ] + encoder.add_int_array(slots) + + nodes_to_build = [0, 1] + encoder.add_int_array(nodes_to_build) + + row_id_1 = [0, 3, 6, 8] + row_id_2 = [1, 2, 4, 5, 7, 9] + encoder.add_int_array(row_id_1) + encoder.add_int_array(row_id_2) + + return encoder.finish() + + @pytest.fixture() + def aggr_results(self) -> Dict[int, List[FeatureAggregationResult]]: + feature0 = [(1100000, 1200000), (1200000, 1300000)] + feature2 = [(1100000, 1200000), (2100000, 2200000), (3100000, 3200000), (4100000, 4200000), (5100000, 5200000)] + + aggr_result0 = FeatureAggregationResult(0, feature0) + aggr_result2 = FeatureAggregationResult(2, feature2) + result_list = [aggr_result0, aggr_result2] + return {0: result_list, 1: result_list} + + def test_decode(self, data_converter, gh_buffer, aggr_buffer): + gh_pair = data_converter.decode_gh_pairs(gh_buffer, None) + assert len(gh_pair) == data_converter.num_samples + + context = data_converter.decode_aggregation_context(aggr_buffer, None) + assert len(context.features) == 2 + f1 = context.features[0] + assert f1.feature_id == 0 + assert f1.num_bins == 2 + assert f1.sample_bin_assignment == [0, 1, 1, 0, 0, 0, 1, 0, 0, 1] + + f2 = context.features[1] + assert f2.feature_id == 2 + assert f2.num_bins == 5 + assert f2.sample_bin_assignment == [0, 4, 1, 2, 4, 3, 0, 1, 3, 0] + + def test_encode(self, data_converter, aggr_results): + + # Simulate the state of converter after decode call + data_converter.feature_list = [0, 2] + buffer = data_converter.encode_aggregation_result(aggr_results, None) + + decoder = DamDecoder(buffer) + node_list = decoder.decode_int_array() + assert node_list == [0, 1] + + histo0 = decoder.decode_float_array() + assert histo0 == [1.1, 1.2, 1.2, 1.3] + + histo2 = decoder.decode_float_array() + assert histo2 == [1.1, 1.2, 2.1, 2.2, 3.1, 3.2, 4.1, 4.2, 5.1, 5.2]