Skip to content

Commit

Permalink
Sealpir multithread (#70)
Browse files Browse the repository at this point in the history
* src/run_sealpir: add runner for sealpir timing multithread
  • Loading branch information
elkanatovey authored Jan 15, 2023
1 parent 2da6869 commit 305fa20
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ target_link_libraries(main1 distribicom_cpp)
# eval executables:
add_executable(main_server server.cpp)
add_executable(worker worker.cpp)
add_executable(run_sealpir run_sealpir.cpp)

target_link_libraries(main_server distribicom_cpp)
target_link_libraries(worker distribicom_cpp)
target_link_libraries(run_sealpir distribicom_cpp)
193 changes: 193 additions & 0 deletions src/run_sealpir.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
#include "pir.hpp"
#include "pir_client.hpp"
#include "pir_server.hpp"
#include <seal/seal.h>
#include <chrono>
#include <memory>
#include <random>
#include <cstdint>
#include <cstddef>
#include "concurrency/concurrency.h"
#include <fstream>

using namespace std::chrono;
using namespace std;
using namespace seal;

bool verify_params(int argc, char *const *argv, uint32_t& num_threads, uint32_t& num_queries, std::string&
log_address) {
if(argc < 3 || std::stoi(argv[1]) < 1 || std::stoi(argv[2]) <= 0)
{
cout << "Usage: " << argv[0] << " <num_threads>" << " <num_queries> " << "<logfile_for_timings>" << endl;
return false;
}
num_threads = std::stoi(argv[1]);
num_queries = std::stoi(argv[2]);
if(argc > 3){
log_address = argv[3];
}
return true;
}

struct indice{
uint64_t ele_index;
uint64_t offset;
};

int main(int argc, char *argv[]) {
uint32_t num_threads;
uint32_t num_queries;
std::string log_address = "run_sealpir_log.txt";
std::string cout_file = "run_sealpir.txt";
auto all_good = verify_params(argc, argv, num_threads, num_queries, log_address);
if (!all_good)
{
return -1;
}
cout << "Main: timing logs saved to: " << log_address<< endl;


uint64_t number_of_items = 1 << 16;
uint64_t size_per_item = 256; // in bytes
uint32_t N = 4096;

// Recommended values: (logt, d) = (20, 2).
uint32_t logt = 20;
uint32_t d = 2;
bool use_symmetric = true; // use symmetric encryption instead of public key (recommended for smaller query)
bool use_batching = true; // pack as many elements as possible into a BFV plaintext (recommended)
bool use_recursive_mod_switching = false;

EncryptionParameters enc_params(scheme_type::bgv);
PirParams pir_params;

// Generates all parameters

cout << "Main: Generating SEAL parameters" << endl;
gen_encryption_params(N, logt, enc_params);

cout << "Main: Verifying SEAL parameters" << endl;
verify_encryption_params(enc_params);
cout << "Main: SEAL parameters are good" << endl;

cout << "Main: Generating PIR parameters" << endl;
gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params, use_symmetric, use_batching, use_recursive_mod_switching);


print_seal_params(enc_params);
print_pir_params(pir_params);

PIRServer server(enc_params, pir_params);

seal::Blake2xbPRNGFactory factory;
auto gen = factory.create();

// Create test database
auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
for (uint64_t i = 0; i < number_of_items; i++) {
for (uint64_t j = 0; j < size_per_item; j++) {
auto val = gen->generate() % 256;
db.get()[(i * size_per_item) + j] = val;
db_copy.get()[(i * size_per_item) + j] = val;
}
}
server.set_database(move(db), number_of_items, size_per_item);
server.preprocess_database();

//create clients and queries
std::vector<PIRClient> clients;
std::vector<PirQuery> queries;
std::vector<PirReply> answers(num_queries);
std::vector<indice> indices(num_queries);

for (int i=0; i<num_queries; i++) {
clients.push_back(PIRClient(enc_params, pir_params));
GaloisKeys galois_keys = clients[i].generate_galois_keys();
server.set_galois_key(i, galois_keys);
uint64_t ele_index = gen->generate() % number_of_items; // element in DB at random position
uint64_t index = clients[i].get_fv_index(ele_index); // index of FV plaintext
uint64_t offset = clients[i].get_fv_offset(ele_index); // offset in FV plaintext
queries.push_back(clients[i].generate_query(index));
indices[i].ele_index = ele_index;
indices[i].offset = offset;

}

std::streambuf *filebuf, *coutbackup;
std::ofstream coutfilestr;
coutfilestr.open (cout_file);
coutbackup = std::cout.rdbuf(); // back up cout's streambuf
filebuf = coutfilestr.rdbuf(); // get file's streambuf
std::cout.rdbuf(filebuf); // assign streambuf to cout

concurrency::threadpool* pool = new concurrency::threadpool(num_threads);
concurrency::safelatch latch(num_queries);
auto time_pool_s = high_resolution_clock::now();
for (int j=0; j<num_queries; j++) {

pool->submit(
std::move(concurrency::Task{
.f = [&, j]() {
answers[j] = server.generate_reply(queries[j], j);
latch.count_down();
},
.wg = nullptr,
.name = "server:generate_reply",
})
);

}
latch.wait();
auto time_pool_e = high_resolution_clock::now();
auto time_pool_us =
duration_cast<microseconds>(time_pool_e - time_pool_s).count();
std::cout << "This is written to the file";
std::cout.rdbuf(coutbackup); // restore cout's original streambuf

coutfilestr.close();


std::streambuf *psbuf, *backup;
std::ofstream filestr;
filestr.open (log_address, std::ios::app);
backup = std::clog.rdbuf(); // back up cout's streambuf
psbuf = filestr.rdbuf(); // get file's streambuf
std::clog.rdbuf(psbuf); // assign streambuf to cout


clog << "Main: pool query processing time: " << time_pool_us / 1000
<< " ms on "<<num_queries << " queries and "<< num_threads << " threads" << endl;
cout << "Main: pool query processing time: " << time_pool_us / 1000
<< " ms on "<<num_queries << " queries and "<< num_threads << " threads" << endl;

std::clog.rdbuf(backup); // restore cout's original streambuf

filestr.close();

delete pool;


for (int i=0; i<num_queries; i++) {
vector<uint8_t> elems = clients[i].decode_reply(answers[i], indices[i].offset);
assert(elems.size() == size_per_item);
bool failed = false;

// Check that we retrieved the correct element
for (uint32_t k = 0; k < size_per_item; k++) {
if (elems[k] != db_copy.get()[(indices[i].ele_index * size_per_item) + k]) {
cout << "Main: elems " << (int)elems[k] << ", db "
<< (int) db_copy.get()[(indices[i].ele_index * size_per_item) + k] << endl;
cout << "Main: PIR result wrong at " << k << endl;
failed = true;
}
}
if(failed){
return -1;
}


}

return 0;
}

0 comments on commit 305fa20

Please sign in to comment.