Skip to content

Commit

Permalink
Implement a new algorithm for the CUDA plugin (#3085)
Browse files Browse the repository at this point in the history
* New cuda plugin implementation

Signed-off-by: YuanTingHsieh <[email protected]>

* Update docstring

* Rename CellTable to GHPairArray for clarity and allow max_num_of_gh_pair_per_launch to be customized

---------

Signed-off-by: YuanTingHsieh <[email protected]>
  • Loading branch information
YuanTingHsieh authored Dec 10, 2024
1 parent e594dcb commit f3e1e0d
Show file tree
Hide file tree
Showing 10 changed files with 1,023 additions and 272 deletions.
471 changes: 422 additions & 49 deletions integration/xgboost/encryption_plugins/cuda_plugin/src/cuda_plugin.h

Large diffs are not rendered by default.

125 changes: 1 addition & 124 deletions integration/xgboost/encryption_plugins/cuda_plugin/src/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,12 @@
const static unsigned int bits=2048;
const static unsigned int key_len=1024;

//const static unsigned int bits=4096;
//const static unsigned int key_len=2048;

//const static unsigned int bits=6144;
//const static unsigned int key_len=3072;


const static int TPB=512;
const static int TPI=32;
const static int window_bits=5;

/** Class **/
struct CgbnPair {
struct GHPair {
cgbn_mem_t<bits> g;
cgbn_mem_t<bits> h;
};
Expand Down Expand Up @@ -155,120 +148,4 @@ void store2Gmp(mpz_t z, cgbn_mem_t<BITS> *address ) {
mpz_import(z, (BITS+31)/32, -1, sizeof(uint32_t), 0, 0, (uint32_t *)address);
}

template<unsigned int BITS>
void initArr(cgbn_mem_t<BITS> *address, int count, int default_value = 0, bool randomize = false){

for(int i = 0; i < count; i++){
int value;
mpz_t n;
mpz_init(n);
if (randomize) {
value = i; // rand();
} else {
value = default_value;
}
mpz_set_si(n, value);

store2Cgbn(address + i, n);

gmp_printf("input%d:%Zd\n", i, n);
mpz_clear(n);
}
}
template<unsigned int BITS>
void printCgbn(cgbn_mem_t<BITS> *h_ptr, int print_count){
for(int i = 0; i < print_count; i++){
mpz_t n;
mpz_init(n);
store2Gmp(n, h_ptr + i);
gmp_printf("printCgbn [%d]:%Zd\n",i, n);
mpz_clear(n);
}
}

template<unsigned int BITS>
void printDevCgbn(cgbn_mem_t<BITS> *d_ptr, int print_count, std::string name="cipher"){

int mem_size=sizeof(cgbn_mem_t<BITS>)*print_count;
cgbn_mem_t<BITS>* h_plains_ptr=(cgbn_mem_t<BITS>* )malloc(mem_size);
cudaMemcpy(h_plains_ptr, d_ptr, mem_size, cudaMemcpyDeviceToHost);

for(int i = 0; i < print_count; i++){
mpz_t n;
mpz_init(n);
store2Gmp(n, h_plains_ptr + i);
gmp_printf("printDevCgbn %s[%d]:%Zd\n",name,i, n);
mpz_clear(n);
}


free(h_plains_ptr);
}

template<unsigned int BITS>
void printDevGH(CgbnPair *d_ptr, int print_count, std::string name="cipher"){

int mem_size=sizeof(CgbnPair) * print_count;
CgbnPair* h_plains_ptr=(CgbnPair *)malloc(mem_size);
cudaMemcpy(h_plains_ptr, d_ptr, mem_size, cudaMemcpyDeviceToHost);

for(int i = 0; i < print_count; i++){
mpz_t g, h;
mpz_init(g);
mpz_init(h);
CgbnPair p = *(h_plains_ptr +i);
store2Gmp(g, &p.g);
store2Gmp(h, &p.h);
gmp_printf("printDevCgbn %s[%d]:g %Zd, h %Zd\n",name, i, g, h);
mpz_clear(g);
mpz_clear(h);
}

free(h_plains_ptr);
}

template<unsigned int BITS>
void compArr(cgbn_mem_t<BITS> *a, cgbn_mem_t<BITS> *b,int count){
int mem_size=sizeof(cgbn_mem_t<BITS>)*count;
cgbn_mem_t<BITS>* ha=(cgbn_mem_t<BITS>* )malloc(mem_size);
cudaMemcpy(ha, a, mem_size, cudaMemcpyDeviceToHost);

cgbn_mem_t<BITS>* hb=(cgbn_mem_t<BITS>* )malloc(mem_size);
cudaMemcpy(hb, b, mem_size, cudaMemcpyDeviceToHost);


for(int i = 0; i < count; i++){
int res=0;
mpz_t na, nb;
mpz_init(na);
store2Gmp(na, ha + i);

mpz_init(nb);
store2Gmp(nb, hb + i);

res=mpz_cmp(na, nb);
if(res!=0){
std::cout<<"res= "<<res<<" Incorrect at i= "<<i<<std::endl;;
gmp_printf(" a=%Zd \n",na);
gmp_printf(" b=%Zd \n",nb);
exit(1);
}

mpz_clear(na);
mpz_clear(nb);
}
std::cout<<"Correct!"<<std::endl;
}

bool compare_result(const std::vector<double> &a, const std::vector<double> &b, double eps=1e-6) {
if (a.size() != b.size()) return false;
for (auto i = 0; i < a.size(); ++i) {
if (fabs(a[i] - b[i]) >= eps) {
std::cout << "Fatal Error at position " << i << " " << a[i] << " " << b[i] << std::endl;
return false;
}
}
return true;
}

#endif // CUDA_UTILS_H
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "delegated_plugin.h"
#include "old_cuda_plugin.h"
#include "cuda_plugin.h"

namespace nvflare {
Expand All @@ -24,6 +25,8 @@ DelegatedPlugin::DelegatedPlugin(std::vector<std::pair<std::string_view, std::st
auto name = get_string(args, "name");
if (name == "cuda_paillier") {
plugin_ = new CUDAPlugin(args);
} else if (name == "cuda_paillier_old") {
plugin_ = new OldCUDAPlugin(args);
} else {
throw std::invalid_argument{"Unknown plugin name: " + name};
}
Expand Down
Loading

0 comments on commit f3e1e0d

Please sign in to comment.