From e8a63dd002c857bfa085ae3a34c70046e29c2638 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Tue, 13 Dec 2022 17:21:33 +0200 Subject: [PATCH] Before processings (#59) * channel throws if write on closed * manager sets up a threadpool * manager sets up random vector each epoch * Squashed commit of the following: commit 6284db7ec336a7b6a629ba64e0fa06ab462092f0 Merge: 06d979b b75a7d2 Author: Jonathan Weiss Date: Sun Dec 11 14:20:26 2022 +0200 Merge branch 'dev' of github.com:elkanatovey/distribicom into dev commit 06d979b41fd0666fd80dbb8aa020c0e868abb621 Author: Jonathan Weiss Date: Sun Dec 11 14:20:17 2022 +0200 fix: query_expander construction commit b75a7d210503386149fb00a30f5c39b22535340b Author: Elkana Tovey <40407298+elkanatovey@users.noreply.github.com> Date: Sun Dec 11 14:03:55 2022 +0200 Second stage pir (#57) * manager: fix location of commas in test funcs * evaluator: add ptx decomposition mult * evaluator:add test for ptx embedding * evaluator:add ntt transform for EmbeddedCiphertext * server.cpp: minor change for debug mode * code-style fix * new epoch async prepares for frievald. * server.cpp: minor change for debug mode * added test to async query expansion * verify the query_x_rand_vec in debug mod * refactor locking in distribute-work * client_context.hpp: start refactor of clientDB * fx: safelatch introduced * refactor safelatch in its own file * refactor safelatch in its own file * client_context.hpp: refactor client setting manager.hpp: start adding partial work to dbs * client_context.hpp: minor fix for cmake * convert client db mutex to pointer * move ClientDB into manager * matrix_operations.tpp: fix multiplication bug for regular plaintext case evaluator_wrapper.cpp: create method for calculating expansion ratio worker_test.cpp: update according to evaluator wrapper * query_expander.cpp: add version of expansion that returns query as col vector * server: factor db into manager matrix_operations: add sync version of scalar dot product manager: add sync versions of stage two, freivalds still has bugs * removed promise-hell * improved async usage * server.cpp: add partial answer struct to client manager.cpp: add method to store partial work, start calculate final answer worker_test: update per mods to client struct * refactor: queries_dim2 are now not promised * fx: ntt form db multiply with query_vec * ensuring worker-manager ntt forms match * ntt-forms match * matrix_operations.hpp: fix for scalar dot product dims * worker_test.cpp update test for more general query nums. manager.hpp: add comment * ntt query-dim-2 * ledger: added promise for verify_worker * async-verification * integrating with current server * fx: removed comments * duct-tape without sending responses * todos Co-authored-by: Jonathan Weiss Co-authored-by: elkana --- src/concurrency/CMakeLists.txt | 3 +- src/concurrency/channel.hpp | 3 + src/concurrency/concurrency.h | 3 +- src/concurrency/promise.hpp | 25 +-- src/concurrency/safelatch.cpp | 16 ++ src/concurrency/safelatch.h | 19 +++ src/concurrency/threadpool.hpp | 3 +- src/math_utils/evaluator_wrapper.cpp | 15 ++ src/math_utils/evaluator_wrapper.hpp | 4 + src/math_utils/matrix_operations.hpp | 26 ++- src/math_utils/matrix_operations.tpp | 63 ++++++- src/math_utils/query_expander.cpp | 23 +++ src/math_utils/query_expander.hpp | 11 ++ src/server.cpp | 36 ++-- src/services/CMakeLists.txt | 2 +- src/services/client_context.cpp | 22 +++ src/services/client_context.hpp | 64 +++++++- src/services/manager.cpp | 210 ++++++++++++++++-------- src/services/manager.hpp | 196 ++++++++++++++++++---- src/services/server.cpp | 89 ++++------ src/services/server.hpp | 6 +- src/services/worker_strategy.cpp | 63 +++---- test/math_utils/query_expander_test.cpp | 71 +++++--- test/services/worker_test.cpp | 47 +++--- 24 files changed, 737 insertions(+), 283 deletions(-) create mode 100644 src/concurrency/safelatch.cpp create mode 100644 src/concurrency/safelatch.h create mode 100644 src/services/client_context.cpp diff --git a/src/concurrency/CMakeLists.txt b/src/concurrency/CMakeLists.txt index b42cdd44..4856ca58 100644 --- a/src/concurrency/CMakeLists.txt +++ b/src/concurrency/CMakeLists.txt @@ -1,3 +1,4 @@ -add_library(concurrency_utils concurrency.h channel.hpp counter.hpp counter.cpp promise.hpp threadpool.hpp threadpool.cpp) +add_library(concurrency_utils concurrency.h channel.hpp counter.hpp counter.cpp promise.hpp threadpool.hpp threadpool.cpp + safelatch.h safelatch.cpp) cmake_path(GET CMAKE_CURRENT_SOURCE_DIR PARENT_PATH MY_PARENT_DIR) target_include_directories(concurrency_utils PUBLIC ${MY_PARENT_DIR}) \ No newline at end of file diff --git a/src/concurrency/channel.hpp b/src/concurrency/channel.hpp index 505d4004..371106ec 100644 --- a/src/concurrency/channel.hpp +++ b/src/concurrency/channel.hpp @@ -59,6 +59,9 @@ namespace concurrency { */ void write(T t) { std::lock_guard lock(m); + if (closed) { + throw std::runtime_error("Channel::write() - channel closed."); + } q.push(t); c.notify_one(); } diff --git a/src/concurrency/concurrency.h b/src/concurrency/concurrency.h index 407729b8..16397ace 100644 --- a/src/concurrency/concurrency.h +++ b/src/concurrency/concurrency.h @@ -3,4 +3,5 @@ #include "channel.hpp" #include "counter.hpp" #include "promise.hpp" -#include "threadpool.hpp" \ No newline at end of file +#include "threadpool.hpp" +#include "safelatch.h" \ No newline at end of file diff --git a/src/concurrency/promise.hpp b/src/concurrency/promise.hpp index 4e902d4c..923b6413 100644 --- a/src/concurrency/promise.hpp +++ b/src/concurrency/promise.hpp @@ -1,28 +1,24 @@ #pragma once - -#include #include -#include - +#include "safelatch.h" namespace concurrency { template class promise { private: - std::atomic safety; std::atomic done; - std::shared_ptr wg; + std::shared_ptr wg; std::shared_ptr value; public: - promise(int n, std::shared_ptr &result_store) : safety(n), value(result_store) { - wg = std::make_shared(n); + promise(int n, std::shared_ptr &result_store) : value(result_store) { + wg = std::make_shared(n); } - promise(int n, std::shared_ptr &&result_store) : safety(n), value(std::move(result_store)) { - wg = std::make_shared(n); + promise(int n, std::shared_ptr &&result_store) : value(std::move(result_store)) { + wg = std::make_shared(n); } std::shared_ptr get() { @@ -42,23 +38,18 @@ namespace concurrency { if (done.load()) { return; } - if (safety.load() == 0) { + if (wg->done_waiting()) { throw std::runtime_error("promise set after it was done"); } value = std::move(val); } - std::shared_ptr &get_latch() { + std::shared_ptr get_latch() { return wg; } inline void count_down() { wg->count_down(); - auto prev = safety.fetch_add(-1); - if (prev <= 0) { - throw std::runtime_error( - "promise::count_down:: latch's value is less than 0, this is a bug that can lead to deadlock!"); - } } }; diff --git a/src/concurrency/safelatch.cpp b/src/concurrency/safelatch.cpp new file mode 100644 index 00000000..bfff463a --- /dev/null +++ b/src/concurrency/safelatch.cpp @@ -0,0 +1,16 @@ +#include "safelatch.h" + +namespace concurrency { + void safelatch::count_down() { + latch::count_down(); + auto prev = safety.fetch_add(-1); + if (prev <= 0) { + throw std::runtime_error( + "count_down:: latch's value is less than 0, this is a bug that can lead to deadlock!"); + } + } + + bool safelatch::done_waiting() { + return safety.load() == 0; + } +} \ No newline at end of file diff --git a/src/concurrency/safelatch.h b/src/concurrency/safelatch.h new file mode 100644 index 00000000..881e4402 --- /dev/null +++ b/src/concurrency/safelatch.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + +namespace concurrency { + + class safelatch : public std::latch { + std::atomic safety; + public: + explicit safelatch(int count) : std::latch(count), safety(count) {}; + + bool done_waiting(); + + void count_down(); + + }; + +} diff --git a/src/concurrency/threadpool.hpp b/src/concurrency/threadpool.hpp index 69d63ed6..8f6f5346 100644 --- a/src/concurrency/threadpool.hpp +++ b/src/concurrency/threadpool.hpp @@ -6,11 +6,12 @@ #include #include #include "channel.hpp" +#include "safelatch.h" namespace concurrency { struct Task { std::function f; - std::shared_ptr wg; + std::shared_ptr wg; }; diff --git a/src/math_utils/evaluator_wrapper.cpp b/src/math_utils/evaluator_wrapper.cpp index 17978a5e..b45530e5 100644 --- a/src/math_utils/evaluator_wrapper.cpp +++ b/src/math_utils/evaluator_wrapper.cpp @@ -1,5 +1,20 @@ #include "evaluator_wrapper.hpp" +namespace math_utils{ + + /* + * for the correct expansion ratio take last_parms_id() times 2 + */ + uint32_t compute_expansion_ratio(const seal::EncryptionParameters ¶ms) { + uint32_t expansion_ratio = 0; + uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value()); + for (size_t i = 0; i < params.coeff_modulus().size(); ++i) { + double coeff_bit_size = log2(params.coeff_modulus()[i].value()); + expansion_ratio += ceil(coeff_bit_size / pt_bits_per_coeff); + } + return expansion_ratio; + } +} namespace { uint32_t compute_expansion_ratio(const seal::EncryptionParameters& params) { uint32_t expansion_ratio = 0; diff --git a/src/math_utils/evaluator_wrapper.hpp b/src/math_utils/evaluator_wrapper.hpp index 9998dd74..4ef4b4a8 100644 --- a/src/math_utils/evaluator_wrapper.hpp +++ b/src/math_utils/evaluator_wrapper.hpp @@ -127,4 +127,8 @@ namespace math_utils { void transform_to_ntt_inplace(EmbeddedCiphertext &encoded) const; }; + /* + * for the correct expansion ratio take last_parms_id() times 2 + */ + uint32_t compute_expansion_ratio(const seal::EncryptionParameters& params); } \ No newline at end of file diff --git a/src/math_utils/matrix_operations.hpp b/src/math_utils/matrix_operations.hpp index ac4c414f..923433db 100644 --- a/src/math_utils/matrix_operations.hpp +++ b/src/math_utils/matrix_operations.hpp @@ -81,6 +81,7 @@ namespace math_utils { void to_ntt(std::vector &m) const; + void to_ntt(std::vector &m) const; void from_ntt(std::vector &m) const; @@ -184,6 +185,30 @@ namespace math_utils { std::unique_ptr>> async_scalar_dot_product(const std::shared_ptr> &mat, const std::shared_ptr> &vec) const; + /*** + * + * @tparam + * @param mat matrix where num of rows equals vec length + * @param vec vec of ints + * @return mat*vec where vec is of size mat.row_len() + */ + template + std::shared_ptr> + scalar_dot_product(const std::shared_ptr> &mat, + const std::shared_ptr> &vec) const; + + + /*** + * + * @tparam + * @param mat matrix where num of cols equals vec length + * @param vec vec of ints + * @return mat*vec where vec is of size mat.col_len() + */ + template + std::shared_ptr> + scalar_dot_product_col_major(const std::shared_ptr> &mat, + const std::shared_ptr> &vec) const; private: void @@ -192,7 +217,6 @@ namespace math_utils { const matrix &b, matrix &result) const; - void to_ntt(std::vector &m) const; }; } diff --git a/src/math_utils/matrix_operations.tpp b/src/math_utils/matrix_operations.tpp index 9d97da83..f7a13965 100644 --- a/src/math_utils/matrix_operations.tpp +++ b/src/math_utils/matrix_operations.tpp @@ -42,8 +42,20 @@ namespace math_utils { w_evaluator->evaluator->transform_to_ntt_inplace(tmp_result); } + // assume that in seal::Plaintext case we don't want to turn into splitPlaintexts for (uint64_t k = 0; k < left.cols; ++k) { - w_evaluator->mult(left(i, k), right(k, j), tmp); + if constexpr ((std::is_same_v)) + { + w_evaluator->mult_reg(left(i, k), right(k, j), tmp); + } + else if constexpr (std::is_same_v) + { + w_evaluator->mult_reg(right(k, j), left(i, k), tmp); + } + else + { + w_evaluator->mult(left(i, k), right(k, j), tmp); + } w_evaluator->add(tmp, tmp_result, tmp_result); } return tmp_result; @@ -55,7 +67,7 @@ namespace math_utils { matrix &result) const { verify_correct_dimension(left, right); verify_not_empty_matrices(left, right); - auto wg = std::make_shared(int(left.rows * right.cols)); + auto wg = std::make_shared(int(left.rows * right.cols)); for (uint64_t i = 0; i < left.rows; ++i) { for (uint64_t j = 0; j < right.cols; ++j) { @@ -174,5 +186,52 @@ namespace math_utils { return p; } + template + std::shared_ptr> + MatrixOperations::scalar_dot_product( + const std::shared_ptr> &mat, + const std::shared_ptr> &vec) const { +#ifdef DISTRIBICOM_DEBUG + assert(mat->rows==vec->size()); +#endif + auto result_vec = std::make_shared>(mat->cols, 1); + for (uint64_t k = 0; k < mat->cols; k++) { + seal::Ciphertext tmp; + seal::Ciphertext rslt(w_evaluator->context); + for (uint64_t j = 0; j < mat->rows; j++) { + w_evaluator->scalar_multiply((*vec)[j], (*mat)(j, k), tmp); + w_evaluator->add(tmp, rslt, rslt); + } + (*result_vec)(k, 0) = rslt; + + } + + return result_vec; + } + + + template + std::shared_ptr> + MatrixOperations::scalar_dot_product_col_major( + const std::shared_ptr> &mat, + const std::shared_ptr> &vec) const { +#ifdef DISTRIBICOM_DEBUG + assert(mat->cols==vec->size()); +#endif + auto result_vec = std::make_shared>(mat->rows, 1); + for (uint64_t k = 0; k < mat->rows; k++) { + seal::Ciphertext tmp; + seal::Ciphertext rslt(w_evaluator->context); + for (uint64_t j = 0; j < mat->cols; j++) { + w_evaluator->scalar_multiply((*vec)[j], (*mat)(k, j), tmp); + w_evaluator->add(tmp, rslt, rslt); + } + (*result_vec)(k, 0) = rslt; + + } + + return result_vec; + } + } \ No newline at end of file diff --git a/src/math_utils/query_expander.cpp b/src/math_utils/query_expander.cpp index 9d797ab3..49664389 100644 --- a/src/math_utils/query_expander.cpp +++ b/src/math_utils/query_expander.cpp @@ -219,4 +219,27 @@ namespace math_utils { return promise; } + + std::shared_ptr>> + QueryExpander::async_expand_to_matrix(std::vector query_i, uint64_t n_i, seal::GaloisKeys &galkey) { + auto query_i_cpy = std::make_shared>(query_i); + auto galkey_cpy = std::make_shared(galkey); + math_utils::matrix s; + auto promise = std::make_shared>>(1, nullptr); + + pool->submit( + { + .f = + [&, promise, query_i_cpy, galkey_cpy, n_i]() { + promise->set( + std::make_shared>(1,n_i,expand_query(*query_i_cpy, n_i, *galkey_cpy)) + ); + }, + .wg = promise->get_latch(), + + } + ); + return promise; + } + } \ No newline at end of file diff --git a/src/math_utils/query_expander.hpp b/src/math_utils/query_expander.hpp index ad982884..b254b711 100644 --- a/src/math_utils/query_expander.hpp +++ b/src/math_utils/query_expander.hpp @@ -3,6 +3,7 @@ #include #include #include "concurrency/concurrency.h" +#include "matrix.h" namespace math_utils { /*** @@ -29,6 +30,16 @@ namespace math_utils { std::shared_ptr>> async_expand(std::vector query_i, uint64_t n_i, seal::GaloisKeys &galkey); + /** + * retiurns a promise of a query that is expanded into a col vector + * @param query_i query to expand + * @param n_i number of cols + * @param galkey + * @return expanded query promise + */ + std::shared_ptr>> + async_expand_to_matrix(std::vector query_i, uint64_t n_i, seal::GaloisKeys &galkey); + std::vector __expand_query(const seal::Ciphertext &encrypted, uint32_t m, seal::GaloisKeys &galkey) const; diff --git a/src/server.cpp b/src/server.cpp index 930896f2..3f282e9f 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -22,17 +22,16 @@ int main(int, char *[]) { threads.emplace_back(run_server(wg, server, cnfgs)); server->wait_for_workers(int(cnfgs.number_of_workers())); - server->publish_galois_keys(); // 1 epoch for (int i = 0; i < 1; ++i) { + server->start_epoch(); // 5 rounds in an epoch. - for (int j = 0; j < 5; ++j) { - server->tell_new_round(); - + for (int j = 0; j < 10; ++j) { + // todo: start timing a full round here. auto ledger = server->distribute_work(); - ledger->done.read_for(std::chrono::milliseconds(5000)); // todo: get wait time from the configs. + ledger->done.read_for(std::chrono::milliseconds(5000)); // todo: how much time should we wait? // server should now inspect missing items and run the calculations for them. // server should also be notified by ledger about the rouge workers. @@ -41,9 +40,8 @@ int main(int, char *[]) { // perform step 2. server->run_step_2(ledger); - server->publish_answers(); +// server->publish_answers(); } - } server->send_stop_signal(); @@ -91,13 +89,21 @@ shared_ptr full_server_instance(const distribicom::AppConf auto cols = configs.configs().db_cols(); auto rows = configs.configs().db_rows(); - // todo: fill with random client's data, queries, and galois keys. math_utils::matrix db(rows, cols); - // a single row of ctxs and their respective gal_key. - math_utils::matrix queries(2, cols); - math_utils::matrix gal_keys(1, cols); - - // Copies the matrices. (not efficient, but it's fast enough to ignore). - return nullptr; -// return std::make_shared(db, queries, gal_keys, configs); + std::uint64_t i = 0; + for (auto &ptx: db.data) { + ptx = i++; + } + std::map> client_db; + std::uint64_t max_clients = 1 << 16; + for (std::uint64_t j = 0; j < max_clients; ++j) { + // TODO: create queries. + client_db.insert( + { + j, + std::make_unique(), + } + ); + } + return std::make_shared(db, client_db, configs); } diff --git a/src/services/CMakeLists.txt b/src/services/CMakeLists.txt index 44ad209a..a06a3b61 100644 --- a/src/services/CMakeLists.txt +++ b/src/services/CMakeLists.txt @@ -1,6 +1,6 @@ add_library(services worker_strategy.hpp worker_strategy.cpp worker.hpp factory.hpp factory.cpp worker.cpp constants.hpp manager.hpp manager.cpp - server.hpp server.cpp db.hpp db.cpp utils.hpp utils.cpp client_service.cpp client_service.hpp client_context.hpp + server.hpp server.cpp db.hpp db.cpp utils.hpp utils.cpp client_service.cpp client_service.hpp client_context.hpp client_context.cpp manager_workstream.cpp manager_workstream.hpp) target_include_directories(services PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/src/services/client_context.cpp b/src/services/client_context.cpp new file mode 100644 index 00000000..2451a6a4 --- /dev/null +++ b/src/services/client_context.cpp @@ -0,0 +1,22 @@ +#include "client_context.hpp" + +void +services::set_client(std::uint32_t expansion_ratio, std::uint32_t db_rows, int client_id, const seal::GaloisKeys &gkey, + string &gkey_serialised, vector> &query, + const distribicom::ClientQueryRequest &query_marshaled, + unique_ptr &client_info) { + client_info->galois_keys = gkey; + client_info->query = std::move(query); + + client_info->galois_keys_marshaled.set_keys(gkey_serialised); + client_info->galois_keys_marshaled.set_key_pos(client_id); + client_info->query_info_marshaled.CopyFrom(query_marshaled); + client_info->query_info_marshaled.set_mailbox_id(client_id); + client_info->answer_count=0; + client_info->partial_answer = + make_unique>(math_utils::matrix + (db_rows, expansion_ratio)); + client_info->final_answer = + make_unique>(math_utils::matrix + (1, expansion_ratio)); +} diff --git a/src/services/client_context.hpp b/src/services/client_context.hpp index f3736d25..44a7f849 100644 --- a/src/services/client_context.hpp +++ b/src/services/client_context.hpp @@ -3,7 +3,7 @@ #include #include "pir_client.hpp" #include "db.hpp" -#include "manager.hpp" +#include "utils.hpp" #include "distribicom.grpc.pb.h" #include "distribicom.pb.h" @@ -18,11 +18,71 @@ namespace services { PirQuery query; seal::GaloisKeys galois_keys; std::unique_ptr client_stub; + + std::unique_ptr> partial_answer; + std::unique_ptr> final_answer; + std::uint64_t answer_count; }; + + void + set_client(std::uint32_t expansion_ratio, std::uint32_t db_rows, int client_id, + const seal::GaloisKeys &gkey, string &gkey_serialised, vector> &query, + const distribicom::ClientQueryRequest &query_marshaled, + std::unique_ptr &client_info); + + + struct ClientDB { - mutable std::shared_mutex mutex; + std::unique_ptr mutex=std::make_unique(); std::map> id_to_info; std::uint64_t client_counter = 0; + + + struct shared_cdb { + explicit shared_cdb(const ClientDB &db, std::shared_mutex &mtx) : db(db), lck(mtx) {} + + const ClientDB &db; + private: + std::shared_lock lck; + }; + + // gives access to a locked reference of the cdb. + // as long as the shared_mat instance is not destroyed and we can't add a client. + // useful for distributing the DB, or going over multiple rows. + shared_cdb many_reads() { + return shared_cdb(*this, *mutex); + } + + + std::uint64_t add_client(grpc::ServerContext *context, const distribicom::ClientRegistryRequest *request, const distribicom::Configs& pir_configs, std::uint32_t expansion_ratio){ + auto requesting_client = utils::extract_ip(context); + std::string subscribing_client_address = requesting_client + ":" + std::to_string(request->client_port()); + + // creating stub to the client: + auto client_conn = std::make_unique(distribicom::Client::Stub( + grpc::CreateChannel( + subscribing_client_address, + grpc::InsecureChannelCredentials() + ) + )); + + auto client_info = std::make_unique(ClientInfo()); + client_info->galois_keys_marshaled.set_keys(request->galois_keys()); + client_info->client_stub = std::move(client_conn); + + std::unique_lock lock(*mutex); + client_info->galois_keys_marshaled.set_key_pos(client_counter); + client_info->answer_count=0; + client_info->partial_answer = std::make_unique> + (math_utils::matrix(pir_configs.db_rows(), expansion_ratio)); + id_to_info.insert( + {client_counter, std::move(client_info)}); + client_counter += 1; + return client_counter-1; + }; + + + }; } diff --git a/src/services/manager.cpp b/src/services/manager.cpp index c4eb99b5..c1e979aa 100644 --- a/src/services/manager.cpp +++ b/src/services/manager.cpp @@ -11,8 +11,8 @@ namespace services { ::distribicom::Ack *resp) { UNUSED(resp); std::string worker_creds = utils::extract_string_from_metadata( - context->client_metadata(), - constants::credentials_md + context->client_metadata(), + constants::credentials_md ); mtx.lock_shared(); @@ -23,34 +23,26 @@ namespace services { return {grpc::StatusCode::INVALID_ARGUMENT, "worker not registered"}; } - auto round = utils::extract_size_from_metadata(context->client_metadata(), constants::round_md); - auto epoch = utils::extract_size_from_metadata(context->client_metadata(), constants::epoch_md); - - mtx.lock_shared(); - exists = ledgers.find({round, epoch}) != ledgers.end(); - auto ledger = ledgers[{round, epoch}]; - mtx.unlock_shared(); + distribicom::MatrixPart tmp; + auto parts = std::make_shared>(); - if (!exists) { - return { - grpc::StatusCode::INVALID_ARGUMENT, - "ledger not found in {round,epoch}" + std::to_string(round) + "," + std::to_string(epoch) - }; + auto &parts_vec = *parts; + while (reader->Read(&tmp)) { + // TODO: should verify the incoming data - corresponding to the expected {ctx, row, col} from each worker. + auto current_ctx = marshal->unmarshal_seal_object(tmp.ctx().data()); + parts_vec.push_back( + { + std::move(current_ctx), + tmp.row(), + tmp.col() + } + ); } - // TODO: should verify the incoming data - corresponding to the expected {ctx, row, col} from each worker. - distribicom::MatrixPart tmp; - while (reader->Read(&tmp)) { - std::uint32_t row = tmp.row(); - std::uint32_t col = tmp.col(); + async_verify_worker(parts, worker_creds); + put_in_result_matrix(parts_vec, this->client_query_manager); -#ifdef DISTRIBICOM_DEBUG - matops->w_evaluator->evaluator->sub_inplace( - ledger->result_mat(row, col), - marshal->unmarshal_seal_object(tmp.ctx().data())); - assert(ledger->result_mat(row, col).is_transparent()); -#endif - } + auto ledger = epoch_data.ledger; ledger->mtx.lock(); ledger->contributed.insert(worker_creds); @@ -62,19 +54,6 @@ namespace services { ledger->done.close(); } - // mat = matrix; - // while (reader->Read(&tmp)) { - // tmp == [ctx, 0, j] - // mat[0,j] = ctx;; - // } - - // A - // B - // C - - // Frievalds(A, B, C) - // Frievalds: DB[:, :] - return {}; } @@ -82,29 +61,16 @@ namespace services { Manager::distribute_work(const math_utils::matrix &db, const ClientDB &all_clients, int rnd, int epoch -#ifdef DISTRIBICOM_DEBUG - ,const seal::GaloisKeys &expansion_key -#endif + #ifdef DISTRIBICOM_DEBUG + , const seal::GaloisKeys &expansion_key + #endif ) { + epoch_data.ledger = new_ledger(db, all_clients); + + #ifdef DISTRIBICOM_DEBUG + create_res_matrix(db, all_clients, expansion_key); + #endif - auto ledger = std::make_shared(); - { - // currently, there is no work distribution. everyone receives the entire DB and the entire queries. - std::shared_lock lock(mtx); - ledgers.insert({{rnd, epoch}, ledger}); - ledger->worker_list = std::vector(); - ledger->worker_list.reserve(work_streams.size()); - all_clients.mutex.lock_shared(); - ledger->result_mat = math_utils::matrix( - db.cols, all_clients.client_counter); - all_clients.mutex.unlock_shared(); - for (auto &worker: work_streams) { - ledger->worker_list.push_back(worker.first); - } - } -#ifdef DISTRIBICOM_DEBUG - create_res_matrix(db, all_clients, expansion_key, ledger); -#endif if (rnd == 1) { grpc::ClientContext context; utils::add_metadata_size(context, services::constants::round_md, rnd); @@ -115,19 +81,44 @@ namespace services { send_db(db, rnd, epoch); + return epoch_data.ledger; + } + + std::shared_ptr + Manager::new_ledger(const math_utils::matrix &db, const ClientDB &all_clients) { + auto ledger = std::make_shared(); + + ledger->worker_list = vector(); + ledger->worker_list.reserve(work_streams.size()); + ledger->result_mat = math_utils::matrix(db.cols, all_clients.client_counter); + + // need to compute DB X epoch_data.query_matrix. + matops->multiply(db, *epoch_data.query_mat_times_randvec, ledger->db_x_queries_x_randvec); + matops->from_ntt(ledger->db_x_queries_x_randvec.data); + + shared_lock lock(mtx); + for (auto &worker: work_streams) { + ledger->worker_list.push_back(worker.first); + ledger->worker_verification_results.insert( + { + worker.first, + std::make_unique>(1, nullptr) + } + ); + } + return ledger; } void Manager::create_res_matrix(const math_utils::matrix &db, const ClientDB &all_clients, - const seal::GaloisKeys &expansion_key, - std::shared_ptr &ledger) const { + const seal::GaloisKeys &expansion_key) const { #ifdef DISTRIBICOM_DEBUG + auto ledger = epoch_data.ledger; auto exp = expansion_key; auto expand_sz = db.cols; - all_clients.mutex.lock_shared(); math_utils::matrix query_mat(expand_sz, all_clients.client_counter); auto col = -1; for (const auto &client: all_clients.id_to_info) { @@ -141,7 +132,6 @@ namespace services { query_mat(i, col) = expanded[i]; } } - all_clients.mutex.unlock_shared(); matops->to_ntt(query_mat.data); @@ -194,7 +184,6 @@ namespace services { void Manager::send_queries(const ClientDB &all_clients) { - std::shared_lock client_db_lock(all_clients.mutex); std::shared_lock lock(mtx); for (auto &worker: work_streams) { @@ -328,26 +317,101 @@ namespace services { return stream; } - // TODO: add epoch number so we can throw out old work and not be confused by it. + void randomise_scalar_vec(std::vector &vec) { + seal::Blake2xbPRNGFactory factory; + auto prng = factory.create({(random_device()) ()}); + uniform_int_distribution dist( + numeric_limits::min(), + numeric_limits::max() + ); + + for (auto &i: vec) { i = prng->generate(); } + } + void Manager::new_epoch(const ClientDB &db) { EpochData ed{ - .worker_to_responsibilities = map_workers_to_responsibilities(db.client_counter), - .queries = {}, + .worker_to_responsibilities = map_workers_to_responsibilities(db.client_counter), + .queries = {}, // todo: consider removing this (not sure we need to store the queries after this func. + .queries_dim2 = {}, + .random_scalar_vector = std::make_shared>(db.client_counter), + .query_mat_times_randvec = {}, }; auto expand_size = app_configs.configs().db_cols(); + auto expand_size_dim2 = app_configs.configs().db_rows(); + + std::vector>>> qs(db.id_to_info.size()); + std::vector>>> qs2( + db.id_to_info.size()); + for (const auto &info: db.id_to_info) { - // expanding the first dimension asynchrounously. - ed.queries[info.first] = expander->async_expand( - info.second->query[0], - expand_size, - info.second->galois_keys + qs[info.first] = expander->async_expand( + info.second->query[0], + expand_size, + info.second->galois_keys + ); + + // setting dim2 in matrix and ntt form. + auto p = std::make_shared>>(1, nullptr); + pool->submit( + { + .f = [&, p]() { + auto mat = std::make_shared>( + 1, expand_size_dim2, // row vec. + expander->expand_query( + info.second->query[1], + expand_size_dim2, + info.second->galois_keys + )); + matops->to_ntt(mat->data); + p->set(mat); + }, + .wg = p->get_latch(), + } ); + qs2[info.first] = p; + + } + + randomise_scalar_vec(*ed.random_scalar_vector); + + auto rows = expand_size; + auto query_mat = std::make_shared>(rows, db.client_counter); + + for (std::uint64_t column = 0; column < qs.size(); column++) { + auto v = qs[column]->get(); + ed.queries[column] = v; + + for (std::uint64_t i = 0; i < rows; i++) { + (*query_mat)(i, column) = (*v)[i]; + } } + qs.clear(); + + auto promise = matops->async_scalar_dot_product( + query_mat, + ed.random_scalar_vector + ); + + for (std::uint64_t column = 0; column < qs2.size(); column++) { + ed.queries_dim2[column] = qs2[column]->get(); + } + qs2.clear(); + + ed.query_mat_times_randvec = promise->get(); + matops->to_ntt(ed.query_mat_times_randvec->data); mtx.lock(); epoch_data = std::move(ed); mtx.unlock(); } + void Manager::wait_on_verification() { + for (const auto &v: epoch_data.ledger->worker_verification_results) { + auto is_valid = *(v.second->get()); + if (!is_valid) { + throw std::runtime_error("wait_on_verification:: invalid verification"); + } + } + } } \ No newline at end of file diff --git a/src/services/manager.hpp b/src/services/manager.hpp index 69aaad3e..c111cc44 100644 --- a/src/services/manager.hpp +++ b/src/services/manager.hpp @@ -15,8 +15,20 @@ #include "manager_workstream.hpp" +namespace { + template + using promise = concurrency::promise; +} + + namespace services { + struct ResultMatPart { + seal::Ciphertext ctx; + std::uint64_t row; + std::uint64_t col; + }; + struct WorkerInfo { std::uint64_t worker_number; std::uint64_t query_range_start; @@ -24,15 +36,10 @@ namespace services { std::vector db_rows; }; - struct EpochData { - std::map worker_to_responsibilities; - // following the same key as the client's db. - std::map>>> queries; - }; /** - * WorkDistributionLedger keeps track on a distributed task. - * it should keep hold on a working task. - */ + * WorkDistributionLedger keeps track on a distributed task for a single round. + * it should keep hold on a working task. + */ struct WorkDistributionLedger { std::shared_mutex mtx; // states the workers that have already contributed their share of the work. @@ -44,43 +51,72 @@ namespace services { // result_mat is the work result of all workers. math_utils::matrix result_mat; + // stored in ntt form. + math_utils::matrix db_x_queries_x_randvec; + + std::map>> worker_verification_results; // open completion will be closed to indicate to anyone waiting. concurrency::Channel done; }; + struct EpochData { + std::shared_ptr ledger; + std::map worker_to_responsibilities; + // following the same key as the client's db. + std::map>> queries; + + // following the same key as the client's db. [NTT FORM] + std::map>> queries_dim2; + + // the following vector will be used to be multiplied against incoming work. + std::shared_ptr> random_scalar_vector; + + // contains promised computation for expanded_queries X random_scalar_vector [NTT FORM] + std::shared_ptr> query_mat_times_randvec; + }; + class Manager : distribicom::Manager::WithCallbackMethod_RegisterAsWorker { private: distribicom::AppConfigs app_configs; std::shared_mutex mtx; + std::shared_ptr pool; concurrency::Counter worker_counter; - std::map, std::shared_ptr> ledgers; std::shared_ptr marshal; - -#ifdef DISTRIBICOM_DEBUG std::shared_ptr matops; std::shared_ptr expander; -#endif - std::map work_streams; EpochData epoch_data; - public: - explicit Manager() {}; - - explicit Manager(const distribicom::AppConfigs &app_configs) : - app_configs(app_configs), - - marshal(marshal::Marshaller::Create(utils::setup_enc_params(app_configs))) -#ifdef DISTRIBICOM_DEBUG - , matops(math_utils::MatrixOperations::Create( - math_utils::EvaluatorWrapper::Create(utils::setup_enc_params(app_configs)))), - expander(math_utils::QueryExpander::Create(utils::setup_enc_params(app_configs))) -#endif - {}; + public: + ClientDB client_query_manager; + services::DB db; + + explicit Manager() : pool(std::make_shared()), db(1, 1) {}; + + explicit Manager(const distribicom::AppConfigs &app_configs, std::map> &client_db, math_utils::matrix &db) : + app_configs(app_configs), + pool(std::make_shared()), + marshal(marshal::Marshaller::Create(utils::setup_enc_params(app_configs))), + matops(math_utils::MatrixOperations::Create( + math_utils::EvaluatorWrapper::Create( + utils::setup_enc_params(app_configs) + ), pool + ) + ), + expander(math_utils::QueryExpander::Create( + utils::setup_enc_params(app_configs), + pool + ) + ), + db(db) { + this->client_query_manager.client_counter = client_db.size(); + this->client_query_manager.id_to_info = std::move(client_db); + }; // a worker should send its work, along with credentials of what it sent. @@ -89,26 +125,106 @@ namespace services { ::distribicom::Ack *response) override; + bool verify_row(std::shared_ptr> &workers_db_row_x_query, + std::uint64_t row_id) { + try { + auto challenge_vec = epoch_data.random_scalar_vector; + + auto db_row_x_query_x_challenge_vec = matops->scalar_dot_product(workers_db_row_x_query, challenge_vec); + auto expected_result = epoch_data.ledger->db_x_queries_x_randvec.data[row_id]; + + matops->w_evaluator->evaluator->sub_inplace(db_row_x_query_x_challenge_vec->data[0], expected_result); + return db_row_x_query_x_challenge_vec->data[0].is_transparent(); + } catch (std::exception &e) { + std::cout << e.what() << std::endl; + return false; + } + } + + void + async_verify_worker(const std::shared_ptr> parts_ptr, const std::string worker_creds) { + pool->submit( + { + .f=[&, parts_ptr, worker_creds]() { + auto &parts = *parts_ptr; + + auto work_responsibility = epoch_data.worker_to_responsibilities[worker_creds]; + auto rows = work_responsibility.db_rows; + auto query_row_len = + work_responsibility.query_range_end - work_responsibility.query_range_start; + + if (query_row_len != epoch_data.queries.size()) { throw std::runtime_error("unimplemented"); } + + for (size_t i = 0; i < rows.size(); i++) { + std::vector temp; + temp.reserve(query_row_len); + for (size_t j = 0; j < query_row_len; j++) { + temp.push_back(parts[j + i * query_row_len].ctx); + } + + auto workers_db_row_x_query = std::make_shared>( + query_row_len, 1, + temp); + auto is_valid = verify_row(workers_db_row_x_query, rows[i]); + if (!is_valid) { + epoch_data.ledger->worker_verification_results[worker_creds]->set( + std::make_unique(false) + ); + return; + } + } + + epoch_data.ledger->worker_verification_results[worker_creds]->set(std::make_unique(true)); + }, + .wg = epoch_data.ledger->worker_verification_results[worker_creds]->get_latch() + } + ); + + }; + + void put_in_result_matrix(const std::vector &parts, ClientDB &all_clients) { + + all_clients.mutex->lock_shared(); + for (const auto &partial_answer: parts) { + math_utils::EmbeddedCiphertext ptx_embedding; + this->matops->w_evaluator->get_ptx_embedding(partial_answer.ctx, ptx_embedding); + this->matops->w_evaluator->transform_to_ntt_inplace(ptx_embedding); + for (size_t i = 0; i < ptx_embedding.size(); i++) { + (*all_clients.id_to_info[partial_answer.col]->partial_answer)(partial_answer.row, i) = std::move( + ptx_embedding[i]); + } + all_clients.id_to_info[partial_answer.col]->answer_count += 1; + } + all_clients.mutex->unlock_shared(); + }; + + void calculate_final_answer() { + for (const auto &client: client_query_manager.id_to_info) { + auto current_query = *epoch_data.queries_dim2[client.first]; + matops->mat_mult(current_query, (*client.second->partial_answer), (*client.second->final_answer)); + } + }; + + // todo: break up query distribution, create unified structure for id lookups, modify ledger accoringly std::shared_ptr distribute_work( - const math_utils::matrix &db, - const ClientDB &all_clients, - int rnd, - int epoch + const math_utils::matrix &db, + const ClientDB &all_clients, + int rnd, + int epoch #ifdef DISTRIBICOM_DEBUG - , const seal::GaloisKeys &expansion_key + , const seal::GaloisKeys &expansion_key #endif ); void wait_for_workers(int i); - void create_res_matrix(const math_utils::matrix &db, const ClientDB &all_clients, - const seal::GaloisKeys &expansion_key, - std::shared_ptr &ledger) const; + const seal::GaloisKeys &expansion_key + ) const; /** * assumes num workers map well to db and queries @@ -122,8 +238,8 @@ namespace services { void send_queries(const ClientDB &all_clients); ::grpc::ServerWriteReactor<::distribicom::WorkerTaskPart> *RegisterAsWorker( - ::grpc::CallbackServerContext *ctx/*context*/, - const ::distribicom::WorkerRegistryRequest *rqst/*request*/) override; + ::grpc::CallbackServerContext *ctx/*context*/, + const ::distribicom::WorkerRegistryRequest *rqst/*request*/) override; void close() { mtx.lock(); @@ -137,5 +253,13 @@ namespace services { * assumes the given db is thread-safe. */ void new_epoch(const ClientDB &db); + + shared_ptr + new_ledger(const math_utils::matrix &db, const ClientDB &all_clients); + + /** + * Waits on freivalds verify, returns (if any) parts that need to be re-evaluated. + */ + void wait_on_verification(); }; } diff --git a/src/services/server.cpp b/src/services/server.cpp index f8de8295..9236681f 100644 --- a/src/services/server.cpp +++ b/src/services/server.cpp @@ -7,22 +7,20 @@ services::FullServer::FullServer(math_utils::matrix &db, std::map> &client_db, + std::unique_ptr> &client_db, const distribicom::AppConfigs &app_configs) : - db(db), manager(app_configs), pir_configs(app_configs.configs()), - enc_params(utils::setup_enc_params(app_configs)) { - this->client_query_manager.client_counter = client_db.size(); - this->client_query_manager.id_to_info = std::move(client_db); + manager(app_configs, client_db, db), pir_configs(app_configs.configs()), + enc_params(utils::setup_enc_params(app_configs)) { init_pir_data(app_configs); } -services::FullServer::FullServer(const distribicom::AppConfigs &app_configs) : - db(app_configs.configs().db_rows(), app_configs.configs().db_cols()), manager(app_configs), - pir_configs(app_configs.configs()), enc_params(utils::setup_enc_params(app_configs)) { - init_pir_data(app_configs); - -} +//services::FullServer::FullServer(const distribicom::AppConfigs &app_configs) : +// app_configs.configs().db_cols()), manager(app_configs), +// pir_configs(app_configs.configs()), enc_params(utils::setup_enc_params(app_configs)) { +// init_pir_data(app_configs); +// +//} void services::FullServer::init_pir_data(const distribicom::AppConfigs &app_configs) { const auto &configs = app_configs.configs(); @@ -35,39 +33,20 @@ grpc::Status services::FullServer::RegisterAsClient(grpc::ServerContext *context, const distribicom::ClientRegistryRequest *request, distribicom::ClientRegistryReply *response) { +//@todo fix expansion ration calc - try { - - auto requesting_client = utils::extract_ip(context); - std::string subscribing_client_address = requesting_client + ":" + std::to_string(request->client_port()); - - // creating stub to the client: - auto client_conn = std::make_unique(distribicom::Client::Stub( - grpc::CreateChannel( - subscribing_client_address, - grpc::InsecureChannelCredentials() - ) - )); - - auto client_info = std::make_unique(ClientInfo()); - client_info->galois_keys_marshaled.set_keys(request->galois_keys()); - client_info->client_stub = std::move(client_conn); - - std::unique_lock lock(client_query_manager.mutex); - client_info->galois_keys_marshaled.set_key_pos(client_query_manager.client_counter); - client_query_manager.id_to_info.insert( - {client_query_manager.client_counter, std::move(client_info)}); - response->set_mailbox_id(client_query_manager.client_counter); - client_query_manager.client_counter += 1; - - } catch (std::exception &e) { - std::cout << "Error: " << e.what() << std::endl; - return {grpc::StatusCode::INTERNAL, e.what()}; - } - - response->set_num_mailboxes(pir_configs.number_of_elements()); +// try { +// response->set_mailbox_id(manager.client_query_manager.add_client(context, request, pir_configs, math_utils::compute_expansion_ratio(enc_params))); +// } catch (std::exception &e) { +// std::cout << "Error: " << e.what() << std::endl; +// return {grpc::StatusCode::INTERNAL, e.what()}; +// } +// +// response->set_num_mailboxes(pir_configs.number_of_elements()); +// +// return grpc::Status::OK; + throw std::runtime_error("unimplemented"); - return grpc::Status::OK; } //@todo this assumes that no one is registering, very dangerous @@ -77,12 +56,12 @@ services::FullServer::StoreQuery(grpc::ServerContext *context, const distribicom auto id = request->mailbox_id(); - std::unique_lock lock(client_query_manager.mutex); - if (client_query_manager.id_to_info.find(id) == client_query_manager.id_to_info.end()) { + std::unique_lock lock(*manager.client_query_manager.mutex); + if (manager.client_query_manager.id_to_info.find(id) == manager.client_query_manager.id_to_info.end()) { return {grpc::StatusCode::NOT_FOUND, "Client not found"}; } - client_query_manager.id_to_info[id]->query_info_marshaled.CopyFrom(*request); + manager.client_query_manager.id_to_info[id]->query_info_marshaled.CopyFrom(*request); response->set_success(true); return grpc::Status::OK; } @@ -109,13 +88,13 @@ std::shared_ptr services::FullServer::distribu // block is to destroy the db handle. { - auto db_handle = db.many_reads(); + auto db_handle = manager.db.many_reads(); -// // todo: set specific round and handle. + std::shared_lock client_db_lock(*(manager.client_query_manager.mutex)); - ledger = manager.distribute_work(db_handle.mat, client_query_manager, 1, 1 + ledger = manager.distribute_work(db_handle.mat, manager.client_query_manager, 1, 1 #ifdef DISTRIBICOM_DEBUG - ,client_query_manager.id_to_info.begin()->second->galois_keys + , manager.client_query_manager.id_to_info.begin()->second->galois_keys #endif ); } @@ -124,10 +103,10 @@ std::shared_ptr services::FullServer::distribu } void services::FullServer::start_epoch() { - std::shared_lock client_db_lock(client_query_manager.mutex); + std::shared_lock client_db_lock(*manager.client_query_manager.mutex); - manager.new_epoch(client_query_manager); - manager.send_galois_keys(client_query_manager); + manager.new_epoch(manager.client_query_manager); + manager.send_galois_keys(manager.client_query_manager); } void services::FullServer::wait_for_workers(int i) { @@ -147,15 +126,15 @@ void services::FullServer::publish_answers() { } void services::FullServer::send_stop_signal() { - throw std::logic_error("not implemented"); + manager.close(); } void services::FullServer::learn_about_rouge_workers(std::shared_ptr) { - throw std::logic_error("not implemented"); + manager.wait_on_verification(); } void services::FullServer::run_step_2(std::shared_ptr) { - throw std::logic_error("not implemented"); + manager.calculate_final_answer(); } void services::FullServer::tell_new_round() { diff --git a/src/services/server.hpp b/src/services/server.hpp index c4492899..0af0d260 100644 --- a/src/services/server.hpp +++ b/src/services/server.hpp @@ -14,8 +14,6 @@ namespace services { // uses both the Manager and the Server services to complete a full distribicom server. class FullServer final : public distribicom::Server::Service { - // used for tests - services::DB db; // using composition to implement the interface of the manager. services::Manager manager; @@ -24,8 +22,6 @@ namespace services { PirParams pir_params; seal::EncryptionParameters enc_params; - // concurrency stuff - ClientDB client_query_manager; std::vector> db_write_requests; @@ -36,7 +32,7 @@ namespace services { std::map> &client_db, const distribicom::AppConfigs &app_configs); - explicit FullServer(const distribicom::AppConfigs &app_configs); +// explicit FullServer(const distribicom::AppConfigs &app_configs); grpc::Status diff --git a/src/services/worker_strategy.cpp b/src/services/worker_strategy.cpp index ab836093..e5ea4b4c 100644 --- a/src/services/worker_strategy.cpp +++ b/src/services/worker_strategy.cpp @@ -4,11 +4,11 @@ namespace services::work_strategy { WorkerStrategy::WorkerStrategy(const seal::EncryptionParameters &enc_params, std::unique_ptr &&manager_conn) noexcept - : - query_expander(), matops(), gkeys(), manager_conn(std::move(manager_conn)) { + : + query_expander(), matops(), gkeys(), manager_conn(std::move(manager_conn)) { query_expander = math_utils::QueryExpander::Create(enc_params); matops = math_utils::MatrixOperations::Create( - math_utils::EvaluatorWrapper::Create(enc_params) + math_utils::EvaluatorWrapper::Create(enc_params) ); } @@ -40,9 +40,9 @@ namespace services::work_strategy { } std::for_each( - futures.begin(), - futures.end(), - [](std::future &future) { future.get(); } + futures.begin(), + futures.end(), + [](std::future &future) { future.get(); } ); queries_to_mat(task); @@ -59,28 +59,28 @@ namespace services::work_strategy { const std::vector &&qry) { // TODO: ensure this utilises a threadpool. otherwise this isn't great. return std::async( - [&](int query_pos, int expanded_size, const std::vector &&qry) { - mu.lock_shared(); - auto not_exist = gkeys.find(query_pos) == gkeys.end(); - mu.unlock_shared(); - - if (not_exist) { - std::cout << "WorkerStrategy: galois keys not found for query position:" + - std::to_string(query_pos) - << std::endl; - return 0; - } - - // todo: use async_expander. - auto expanded = query_expander->expand_query(qry, expanded_size, gkeys.find(query_pos)->second); - - - mu.lock(); - queries.insert({query_pos, math_utils::matrix(expanded.size(), 1, expanded)}); - mu.unlock(); - return 1; - }, - query_pos, expanded_size, qry + [&](int query_pos, int expanded_size, const std::vector &&qry) { + mu.lock_shared(); + auto not_exist = gkeys.find(query_pos) == gkeys.end(); + mu.unlock_shared(); + + if (not_exist) { + std::cout << "WorkerStrategy: galois keys not found for query position:" + + std::to_string(query_pos) + << std::endl; + return 0; + } + + // todo: use async_expander. + auto expanded = query_expander->expand_query(qry, expanded_size, gkeys.find(query_pos)->second); + + + mu.lock(); + queries.insert({query_pos, math_utils::matrix(expanded.size(), 1, expanded)}); + mu.unlock(); + return 1; + }, + query_pos, expanded_size, qry ); } @@ -108,8 +108,8 @@ namespace services::work_strategy { void RowMultiplicationStrategy::send_response( - const WorkerServiceTask &task, - math_utils::matrix &computed + const WorkerServiceTask &task, + math_utils::matrix &computed ) { std::map row_to_index; auto row = -1; @@ -122,6 +122,9 @@ namespace services::work_strategy { utils::add_metadata_size(context, constants::round_md, task.round); utils::add_metadata_size(context, constants::epoch_md, task.epoch); + // before sending the response - ensure we send it in reg-form to compress the data a bit more. + matops->from_ntt(computed.data); + distribicom::Ack resp; auto stream = manager_conn->ReturnLocalWork(&context, &resp); diff --git a/test/math_utils/query_expander_test.cpp b/test/math_utils/query_expander_test.cpp index 2a4ac24d..cf6268e9 100644 --- a/test/math_utils/query_expander_test.cpp +++ b/test/math_utils/query_expander_test.cpp @@ -10,26 +10,30 @@ void correct_expansion_test(TestUtils::SetupConfigs); void expanding_full_dimension_query(TestUtils::SetupConfigs); +void async_expansion(TestUtils::SetupConfigs); + int query_expander_test(int, char *[]) { // encryption parameters ensure that we have enough multiplication depth. // coefficient modulus is set automatically. auto cnfgs = TestUtils::SetupConfigs{ - .encryption_params_configs = { - .scheme_type = seal::scheme_type::bfv, - .polynomial_degree = 4096 * 2, - .log_coefficient_modulus = 20, - }, - .pir_params_configs = { - .number_of_items = 2048, - .size_per_item = 288, - .dimensions= 2, - .use_symmetric = false, - .use_batching = true, - .use_recursive_mod_switching = true, - }, + .encryption_params_configs = { + .scheme_type = seal::scheme_type::bfv, + .polynomial_degree = 4096 * 2, + .log_coefficient_modulus = 20, + }, + .pir_params_configs = { + .number_of_items = 2048, + .size_per_item = 288, + .dimensions= 2, + .use_symmetric = false, + .use_batching = true, + .use_recursive_mod_switching = true, + }, }; expanding_full_dimension_query(cnfgs); + async_expansion(cnfgs); + cnfgs.encryption_params_configs.polynomial_degree = 4096; cnfgs.pir_params_configs.number_of_items = 512; cnfgs.pir_params_configs.dimensions = 1; @@ -67,17 +71,14 @@ void correct_expansion_test(TestUtils::SetupConfigs cnfgs) { if (decryption.is_zero() && index != i) { continue; - } - else if (decryption.is_zero()) { + } else if (decryption.is_zero()) { o << "Found zero where index should be"; throw std::runtime_error(o.str()); - } - else if (std::stoi(decryption.to_string()) != 1) { + } else if (std::stoi(decryption.to_string()) != 1) { o << "Query vector at index " << index << " should be 1 but is instead " << decryption.to_string(); throw std::runtime_error(o.str()); - } - else { + } else { if (decryption.to_string() != "1") { o << "Query vector at index " << index << " is not 1"; throw std::runtime_error(o.str()); @@ -177,3 +178,35 @@ std::shared_ptr gen_db(const shared_ptr &all server.set_database(move(db), number_of_items, size_per_item); return server.get_db(); } + +void async_expansion(TestUtils::SetupConfigs cnfgs) { + + assert(cnfgs.pir_params_configs.dimensions == 2); + auto all = TestUtils::setup(cnfgs); + + + auto db_ptr = gen_db(all); + + + PIRClient client(all->encryption_params, all->pir_params); + seal::GaloisKeys galois_keys = client.generate_galois_keys(); + std::uint64_t dim0_size = all->pir_params.nvec[0]; + std::uint64_t dim1_size = all->pir_params.nvec[1]; + + std::uint64_t ele_index = 71; // Choosing the SECOND plaintext that is stored in the DB. that is cell number 71. + std::uint64_t index = client.get_fv_index(ele_index); // index of FV plaintext + std::cout << "Main: element index = " << ele_index << " from [0, " + << all->pir_params.ele_num - 1 << "]" << std::endl; + PirQuery query = client.generate_query(index); + + auto expander = math_utils::QueryExpander::Create(all->encryption_params); + auto promise = expander->async_expand(query[0], dim0_size, galois_keys); + auto expanded_query_dim_0 = expander->expand_query(query[0], dim0_size, galois_keys); + + auto expanded2 = promise->get(); + for (int i = 0; i < expanded_query_dim_0.size(); ++i) { + all->w_evaluator->evaluator->sub_inplace(expanded_query_dim_0[i], (*expanded2)[i]); + assert(expanded_query_dim_0[i].is_transparent()); + } + +} diff --git a/test/services/worker_test.cpp b/test/services/worker_test.cpp index a9291d70..af6dea3b 100644 --- a/test/services/worker_test.cpp +++ b/test/services/worker_test.cpp @@ -21,12 +21,12 @@ int worker_test(int, char *[]) { auto all = TestUtils::setup(TestUtils::DEFAULT_SETUP_CONFIGS); auto cfgs = services::configurations::create_app_configs( - "localhost:" + std::string(server_port), - int(all->encryption_params.poly_modulus_degree()), - 20, - 5, - 5, - 256 + "localhost:" + std::string(server_port), + int(all->encryption_params.poly_modulus_degree()), + 20, + 6, + 6, + 256 ); services::FullServer fs = full_server_instance(all, cfgs); @@ -43,7 +43,7 @@ int worker_test(int, char *[]) { fs.wait_for_workers(1); fs.start_epoch(); - fs.distribute_work(); + fs.learn_about_rouge_workers(fs.distribute_work()); sleep(5); std::cout << "\nshutting down.\n" << std::endl; @@ -55,41 +55,40 @@ int worker_test(int, char *[]) { return 0; } + std::map> -create_client_db(int size, std::shared_ptr &all) { +create_client_db(int size, std::shared_ptr &all, const distribicom::AppConfigs &app_configs) { auto m = marshal::Marshaller::Create(all->encryption_params); std::map> cdb; for (int i = 0; i < size; i++) { - auto client_info = std::make_unique(services::ClientInfo()); auto gkey = all->gal_keys; - client_info->galois_keys = gkey; auto gkey_serialised = m->marshal_seal_object(gkey); - client_info->galois_keys_marshaled.set_keys(gkey_serialised); - client_info->galois_keys_marshaled.set_key_pos(i); std::vector> query = {{all->random_ciphertext()}, {all->random_ciphertext()}}; distribicom::ClientQueryRequest query_marshaled; m->marshal_query_vector(query, query_marshaled); - client_info->query_info_marshaled.CopyFrom(query_marshaled); - client_info->query_info_marshaled.set_mailbox_id(i); - client_info->query = std::move(query); + auto client_info = std::make_unique(services::ClientInfo()); + + services::set_client(math_utils::compute_expansion_ratio(all->seal_context.last_context_data()->parms()) * 2, + app_configs.configs().db_rows(), i, gkey, gkey_serialised, query, query_marshaled, + client_info); cdb.insert( - {i, std::move(client_info)}); + {i, std::move(client_info)}); } return cdb; } services::FullServer full_server_instance(std::shared_ptr &all, const distribicom::AppConfigs &configs) { - auto n = 5; - math_utils::matrix db(n, n); + auto num_clients = configs.configs().db_rows() * 2; + math_utils::matrix db(configs.configs().db_rows(), configs.configs().db_rows()); for (auto &p: db.data) { p = all->random_plaintext(); } - auto cdb = create_client_db(n, all); + auto cdb = create_client_db(num_clients, all, configs); return services::FullServer(db, cdb, configs); } @@ -100,11 +99,11 @@ std::thread setupWorker(std::latch &wg, distribicom::AppConfigs &configs) { return std::thread([&] { try { services::Worker worker( - services::configurations::create_worker_configs( - configs, - std::stoi(std::string(worker_port)), - "0.0.0.0" - ) + services::configurations::create_worker_configs( + configs, + std::stoi(std::string(worker_port)), + "0.0.0.0" + ) ); wg.wait();