Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sealpir multithread #70

Merged
merged 3 commits into from
Jan 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ target_link_libraries(main1 distribicom_cpp)
# eval executables:
add_executable(main_server server.cpp)
add_executable(worker worker.cpp)
add_executable(run_sealpir run_sealpir.cpp)

target_link_libraries(main_server distribicom_cpp)
target_link_libraries(worker distribicom_cpp)
target_link_libraries(run_sealpir distribicom_cpp)
193 changes: 193 additions & 0 deletions src/run_sealpir.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
#include "pir.hpp"
#include "pir_client.hpp"
#include "pir_server.hpp"
#include <seal/seal.h>
#include <chrono>
#include <memory>
#include <random>
#include <cstdint>
#include <cstddef>
#include "concurrency/concurrency.h"
#include <fstream>

using namespace std::chrono;
using namespace std;
using namespace seal;

bool verify_params(int argc, char *const *argv, uint32_t& num_threads, uint32_t& num_queries, std::string&
log_address) {
if(argc < 3 || std::stoi(argv[1]) < 1 || std::stoi(argv[2]) <= 0)
{
cout << "Usage: " << argv[0] << " <num_threads>" << " <num_queries> " << "<logfile_for_timings>" << endl;
return false;
}
num_threads = std::stoi(argv[1]);
num_queries = std::stoi(argv[2]);
if(argc > 3){
log_address = argv[3];
}
return true;
}

struct indice{
uint64_t ele_index;
uint64_t offset;
};

int main(int argc, char *argv[]) {
uint32_t num_threads;
uint32_t num_queries;
std::string log_address = "run_sealpir_log.txt";
std::string cout_file = "run_sealpir.txt";
auto all_good = verify_params(argc, argv, num_threads, num_queries, log_address);
if (!all_good)
{
return -1;
}
cout << "Main: timing logs saved to: " << log_address<< endl;


uint64_t number_of_items = 1 << 16;
uint64_t size_per_item = 256; // in bytes
uint32_t N = 4096;

// Recommended values: (logt, d) = (20, 2).
uint32_t logt = 20;
uint32_t d = 2;
bool use_symmetric = true; // use symmetric encryption instead of public key (recommended for smaller query)
bool use_batching = true; // pack as many elements as possible into a BFV plaintext (recommended)
bool use_recursive_mod_switching = false;

EncryptionParameters enc_params(scheme_type::bgv);
PirParams pir_params;

// Generates all parameters

cout << "Main: Generating SEAL parameters" << endl;
gen_encryption_params(N, logt, enc_params);

cout << "Main: Verifying SEAL parameters" << endl;
verify_encryption_params(enc_params);
cout << "Main: SEAL parameters are good" << endl;

cout << "Main: Generating PIR parameters" << endl;
gen_pir_params(number_of_items, size_per_item, d, enc_params, pir_params, use_symmetric, use_batching, use_recursive_mod_switching);


print_seal_params(enc_params);
print_pir_params(pir_params);

PIRServer server(enc_params, pir_params);

seal::Blake2xbPRNGFactory factory;
auto gen = factory.create();

// Create test database
auto db(make_unique<uint8_t[]>(number_of_items * size_per_item));
auto db_copy(make_unique<uint8_t[]>(number_of_items * size_per_item));
for (uint64_t i = 0; i < number_of_items; i++) {
for (uint64_t j = 0; j < size_per_item; j++) {
auto val = gen->generate() % 256;
db.get()[(i * size_per_item) + j] = val;
db_copy.get()[(i * size_per_item) + j] = val;
}
}
server.set_database(move(db), number_of_items, size_per_item);
server.preprocess_database();

//create clients and queries
std::vector<PIRClient> clients;
std::vector<PirQuery> queries;
std::vector<PirReply> answers(num_queries);
std::vector<indice> indices(num_queries);

for (int i=0; i<num_queries; i++) {
clients.push_back(PIRClient(enc_params, pir_params));
GaloisKeys galois_keys = clients[i].generate_galois_keys();
server.set_galois_key(i, galois_keys);
uint64_t ele_index = gen->generate() % number_of_items; // element in DB at random position
uint64_t index = clients[i].get_fv_index(ele_index); // index of FV plaintext
uint64_t offset = clients[i].get_fv_offset(ele_index); // offset in FV plaintext
queries.push_back(clients[i].generate_query(index));
indices[i].ele_index = ele_index;
indices[i].offset = offset;

}

std::streambuf *filebuf, *coutbackup;
std::ofstream coutfilestr;
coutfilestr.open (cout_file);
coutbackup = std::cout.rdbuf(); // back up cout's streambuf
filebuf = coutfilestr.rdbuf(); // get file's streambuf
std::cout.rdbuf(filebuf); // assign streambuf to cout

concurrency::threadpool* pool = new concurrency::threadpool(num_threads);
concurrency::safelatch latch(num_queries);
auto time_pool_s = high_resolution_clock::now();
for (int j=0; j<num_queries; j++) {

pool->submit(
std::move(concurrency::Task{
.f = [&, j]() {
answers[j] = server.generate_reply(queries[j], j);
latch.count_down();
},
.wg = nullptr,
.name = "server:generate_reply",
})
);

}
latch.wait();
auto time_pool_e = high_resolution_clock::now();
auto time_pool_us =
duration_cast<microseconds>(time_pool_e - time_pool_s).count();
std::cout << "This is written to the file";
std::cout.rdbuf(coutbackup); // restore cout's original streambuf

coutfilestr.close();


std::streambuf *psbuf, *backup;
std::ofstream filestr;
filestr.open (log_address, std::ios::app);
backup = std::clog.rdbuf(); // back up cout's streambuf
psbuf = filestr.rdbuf(); // get file's streambuf
std::clog.rdbuf(psbuf); // assign streambuf to cout


clog << "Main: pool query processing time: " << time_pool_us / 1000
<< " ms on "<<num_queries << " queries and "<< num_threads << " threads" << endl;
cout << "Main: pool query processing time: " << time_pool_us / 1000
<< " ms on "<<num_queries << " queries and "<< num_threads << " threads" << endl;

std::clog.rdbuf(backup); // restore cout's original streambuf

filestr.close();

delete pool;


for (int i=0; i<num_queries; i++) {
vector<uint8_t> elems = clients[i].decode_reply(answers[i], indices[i].offset);
assert(elems.size() == size_per_item);
bool failed = false;

// Check that we retrieved the correct element
for (uint32_t k = 0; k < size_per_item; k++) {
if (elems[k] != db_copy.get()[(indices[i].ele_index * size_per_item) + k]) {
cout << "Main: elems " << (int)elems[k] << ", db "
<< (int) db_copy.get()[(indices[i].ele_index * size_per_item) + k] << endl;
cout << "Main: PIR result wrong at " << k << endl;
failed = true;
}
}
if(failed){
return -1;
}


}

return 0;
}