Skip to content

Commit

Permalink
Before processings (#59)
Browse files Browse the repository at this point in the history
* channel throws if write on closed

* manager sets up a threadpool

* manager sets up random vector each epoch

* Squashed commit of the following:

commit 6284db7
Merge: 06d979b b75a7d2
Author: Jonathan Weiss <[email protected]>
Date:   Sun Dec 11 14:20:26 2022 +0200

    Merge branch 'dev' of github.com:elkanatovey/distribicom into dev

commit 06d979b
Author: Jonathan Weiss <[email protected]>
Date:   Sun Dec 11 14:20:17 2022 +0200

    fix: query_expander construction

commit b75a7d2
Author: Elkana Tovey <[email protected]>
Date:   Sun Dec 11 14:03:55 2022 +0200

    Second stage pir (#57)

    * manager: fix location of commas in test funcs

    * evaluator: add ptx decomposition mult

    * evaluator:add test for ptx embedding

    * evaluator:add ntt transform for EmbeddedCiphertext

    * server.cpp: minor change for debug mode

* code-style fix

* new epoch async prepares for frievald.

* server.cpp: minor change for debug mode

* added test to async query expansion

* verify the query_x_rand_vec in debug mod

* refactor locking in distribute-work

* client_context.hpp: start refactor of clientDB

* fx: safelatch introduced

* refactor safelatch in its own file

* refactor safelatch in its own file

* client_context.hpp: refactor client setting
manager.hpp: start adding partial work to dbs

* client_context.hpp: minor fix for cmake

* convert client db mutex to pointer

* move ClientDB into manager

* matrix_operations.tpp: fix multiplication bug for regular plaintext case
evaluator_wrapper.cpp: create method for calculating expansion ratio
worker_test.cpp: update according to evaluator wrapper

* query_expander.cpp: add version of expansion that returns query as col vector

* server: factor db into manager
matrix_operations: add sync version of scalar dot product
manager: add sync versions of stage two, freivalds still has bugs

* removed promise-hell

* improved async usage

* server.cpp: add partial answer struct to client
manager.cpp: add method to store partial work, start calculate final answer
worker_test: update per mods to client struct

* refactor: queries_dim2 are now not promised

* fx: ntt form db multiply with query_vec

* ensuring worker-manager ntt forms match

* ntt-forms match

* matrix_operations.hpp: fix for scalar  dot product dims

* worker_test.cpp update test for more general query nums.
manager.hpp: add comment

* ntt query-dim-2

* ledger: added promise for verify_worker

* async-verification

* integrating with current server

* fx: removed comments

* duct-tape without sending responses

* todos

Co-authored-by: Jonathan Weiss <[email protected]>
Co-authored-by: elkana <[email protected]>
  • Loading branch information
3 people authored Dec 13, 2022
1 parent 416d434 commit e8a63dd
Show file tree
Hide file tree
Showing 24 changed files with 737 additions and 283 deletions.
3 changes: 2 additions & 1 deletion src/concurrency/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_library(concurrency_utils concurrency.h channel.hpp counter.hpp counter.cpp promise.hpp threadpool.hpp threadpool.cpp)
add_library(concurrency_utils concurrency.h channel.hpp counter.hpp counter.cpp promise.hpp threadpool.hpp threadpool.cpp
safelatch.h safelatch.cpp)
cmake_path(GET CMAKE_CURRENT_SOURCE_DIR PARENT_PATH MY_PARENT_DIR)
target_include_directories(concurrency_utils PUBLIC ${MY_PARENT_DIR})
3 changes: 3 additions & 0 deletions src/concurrency/channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ namespace concurrency {
*/
void write(T t) {
std::lock_guard<std::mutex> lock(m);
if (closed) {
throw std::runtime_error("Channel::write() - channel closed.");
}
q.push(t);
c.notify_one();
}
Expand Down
3 changes: 2 additions & 1 deletion src/concurrency/concurrency.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
#include "channel.hpp"
#include "counter.hpp"
#include "promise.hpp"
#include "threadpool.hpp"
#include "threadpool.hpp"
#include "safelatch.h"
25 changes: 8 additions & 17 deletions src/concurrency/promise.hpp
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@

#pragma once

#include <latch>
#include <memory>
#include <atomic>

#include "safelatch.h"
namespace concurrency {

template<typename T>
class promise {

private:
std::atomic<int> safety;
std::atomic<bool> done;
std::shared_ptr<std::latch> wg;
std::shared_ptr<safelatch> wg;
std::shared_ptr<T> value;

public:
promise(int n, std::shared_ptr<T> &result_store) : safety(n), value(result_store) {
wg = std::make_shared<std::latch>(n);
promise(int n, std::shared_ptr<T> &result_store) : value(result_store) {
wg = std::make_shared<safelatch>(n);
}

promise(int n, std::shared_ptr<T> &&result_store) : safety(n), value(std::move(result_store)) {
wg = std::make_shared<std::latch>(n);
promise(int n, std::shared_ptr<T> &&result_store) : value(std::move(result_store)) {
wg = std::make_shared<safelatch>(n);
}

std::shared_ptr<T> get() {
Expand All @@ -42,23 +38,18 @@ namespace concurrency {
if (done.load()) {
return;
}
if (safety.load() == 0) {
if (wg->done_waiting()) {
throw std::runtime_error("promise set after it was done");
}
value = std::move(val);
}

std::shared_ptr<std::latch> &get_latch() {
std::shared_ptr<safelatch> get_latch() {
return wg;
}

inline void count_down() {
wg->count_down();
auto prev = safety.fetch_add(-1);
if (prev <= 0) {
throw std::runtime_error(
"promise::count_down:: latch's value is less than 0, this is a bug that can lead to deadlock!");
}
}
};

Expand Down
16 changes: 16 additions & 0 deletions src/concurrency/safelatch.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include "safelatch.h"

namespace concurrency {
void safelatch::count_down() {
latch::count_down();
auto prev = safety.fetch_add(-1);
if (prev <= 0) {
throw std::runtime_error(
"count_down:: latch's value is less than 0, this is a bug that can lead to deadlock!");
}
}

bool safelatch::done_waiting() {
return safety.load() == 0;
}
}
19 changes: 19 additions & 0 deletions src/concurrency/safelatch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include <latch>
#include <atomic>

namespace concurrency {

class safelatch : public std::latch {
std::atomic<int> safety;
public:
explicit safelatch(int count) : std::latch(count), safety(count) {};

bool done_waiting();

void count_down();

};

}
3 changes: 2 additions & 1 deletion src/concurrency/threadpool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
#include <functional>
#include <memory>
#include "channel.hpp"
#include "safelatch.h"

namespace concurrency {
struct Task {
std::function<void()> f;
std::shared_ptr<std::latch> wg;
std::shared_ptr<safelatch> wg;
};


Expand Down
15 changes: 15 additions & 0 deletions src/math_utils/evaluator_wrapper.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
#include "evaluator_wrapper.hpp"

namespace math_utils{

/*
* for the correct expansion ratio take last_parms_id() times 2
*/
uint32_t compute_expansion_ratio(const seal::EncryptionParameters &params) {
uint32_t expansion_ratio = 0;
uint32_t pt_bits_per_coeff = log2(params.plain_modulus().value());
for (size_t i = 0; i < params.coeff_modulus().size(); ++i) {
double coeff_bit_size = log2(params.coeff_modulus()[i].value());
expansion_ratio += ceil(coeff_bit_size / pt_bits_per_coeff);
}
return expansion_ratio;
}
}
namespace {
uint32_t compute_expansion_ratio(const seal::EncryptionParameters& params) {
uint32_t expansion_ratio = 0;
Expand Down
4 changes: 4 additions & 0 deletions src/math_utils/evaluator_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,8 @@ namespace math_utils {

void transform_to_ntt_inplace(EmbeddedCiphertext &encoded) const;
};
/*
* for the correct expansion ratio take last_parms_id() times 2
*/
uint32_t compute_expansion_ratio(const seal::EncryptionParameters& params);
}
26 changes: 25 additions & 1 deletion src/math_utils/matrix_operations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ namespace math_utils {

void to_ntt(std::vector<seal::Ciphertext> &m) const;

void to_ntt(std::vector<seal::Plaintext> &m) const;

void from_ntt(std::vector<seal::Ciphertext> &m) const;

Expand Down Expand Up @@ -184,6 +185,30 @@ namespace math_utils {
std::unique_ptr<concurrency::promise<matrix<seal::Ciphertext>>>
async_scalar_dot_product(const std::shared_ptr<matrix<U>> &mat,
const std::shared_ptr<std::vector<std::uint64_t>> &vec) const;
/***
*
* @tparam
* @param mat matrix where num of rows equals vec length
* @param vec vec of ints
* @return mat*vec where vec is of size mat.row_len()
*/
template<typename U>
std::shared_ptr<matrix<seal::Ciphertext>>
scalar_dot_product(const std::shared_ptr<matrix<U>> &mat,
const std::shared_ptr<std::vector<std::uint64_t>> &vec) const;


/***
*
* @tparam
* @param mat matrix where num of cols equals vec length
* @param vec vec of ints
* @return mat*vec where vec is of size mat.col_len()
*/
template<typename U>
std::shared_ptr<matrix<seal::Ciphertext>>
scalar_dot_product_col_major(const std::shared_ptr<matrix<U>> &mat,
const std::shared_ptr<std::vector<std::uint64_t>> &vec) const;

private:
void
Expand All @@ -192,7 +217,6 @@ namespace math_utils {
const matrix<seal::Ciphertext> &b,
matrix<seal::Ciphertext> &result) const;

void to_ntt(std::vector<seal::Plaintext> &m) const;
};

}
Expand Down
63 changes: 61 additions & 2 deletions src/math_utils/matrix_operations.tpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,20 @@ namespace math_utils {
w_evaluator->evaluator->transform_to_ntt_inplace(tmp_result);
}

// assume that in seal::Plaintext case we don't want to turn into splitPlaintexts
for (uint64_t k = 0; k < left.cols; ++k) {
w_evaluator->mult(left(i, k), right(k, j), tmp);
if constexpr ((std::is_same_v<U, seal::Plaintext>))
{
w_evaluator->mult_reg(left(i, k), right(k, j), tmp);
}
else if constexpr (std::is_same_v<V, seal::Plaintext>)
{
w_evaluator->mult_reg(right(k, j), left(i, k), tmp);
}
else
{
w_evaluator->mult(left(i, k), right(k, j), tmp);
}
w_evaluator->add(tmp, tmp_result, tmp_result);
}
return tmp_result;
Expand All @@ -55,7 +67,7 @@ namespace math_utils {
matrix<seal::Ciphertext> &result) const {
verify_correct_dimension(left, right);
verify_not_empty_matrices(left, right);
auto wg = std::make_shared<std::latch>(int(left.rows * right.cols));
auto wg = std::make_shared<concurrency::safelatch>(int(left.rows * right.cols));
for (uint64_t i = 0; i < left.rows; ++i) {
for (uint64_t j = 0; j < right.cols; ++j) {

Expand Down Expand Up @@ -174,5 +186,52 @@ namespace math_utils {
return p;
}

template<typename U>
std::shared_ptr<matrix<seal::Ciphertext>>
MatrixOperations::scalar_dot_product(
const std::shared_ptr<matrix<U>> &mat,
const std::shared_ptr<std::vector<std::uint64_t>> &vec) const {
#ifdef DISTRIBICOM_DEBUG
assert(mat->rows==vec->size());
#endif
auto result_vec = std::make_shared<matrix<seal::Ciphertext>>(mat->cols, 1);
for (uint64_t k = 0; k < mat->cols; k++) {
seal::Ciphertext tmp;
seal::Ciphertext rslt(w_evaluator->context);
for (uint64_t j = 0; j < mat->rows; j++) {
w_evaluator->scalar_multiply((*vec)[j], (*mat)(j, k), tmp);
w_evaluator->add(tmp, rslt, rslt);
}
(*result_vec)(k, 0) = rslt;

}

return result_vec;
}


template<typename U>
std::shared_ptr<matrix<seal::Ciphertext>>
MatrixOperations::scalar_dot_product_col_major(
const std::shared_ptr<matrix<U>> &mat,
const std::shared_ptr<std::vector<std::uint64_t>> &vec) const {
#ifdef DISTRIBICOM_DEBUG
assert(mat->cols==vec->size());
#endif
auto result_vec = std::make_shared<matrix<seal::Ciphertext>>(mat->rows, 1);
for (uint64_t k = 0; k < mat->rows; k++) {
seal::Ciphertext tmp;
seal::Ciphertext rslt(w_evaluator->context);
for (uint64_t j = 0; j < mat->cols; j++) {
w_evaluator->scalar_multiply((*vec)[j], (*mat)(k, j), tmp);
w_evaluator->add(tmp, rslt, rslt);
}
(*result_vec)(k, 0) = rslt;

}

return result_vec;
}


}
23 changes: 23 additions & 0 deletions src/math_utils/query_expander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,27 @@ namespace math_utils {
return promise;
}


std::shared_ptr<concurrency::promise<math_utils::matrix<seal::Ciphertext>>>
QueryExpander::async_expand_to_matrix(std::vector<seal::Ciphertext> query_i, uint64_t n_i, seal::GaloisKeys &galkey) {
auto query_i_cpy = std::make_shared<std::vector<seal::Ciphertext>>(query_i);
auto galkey_cpy = std::make_shared<seal::GaloisKeys>(galkey);
math_utils::matrix<seal::Ciphertext> s;
auto promise = std::make_shared<concurrency::promise<math_utils::matrix<seal::Ciphertext>>>(1, nullptr);

pool->submit(
{
.f =
[&, promise, query_i_cpy, galkey_cpy, n_i]() {
promise->set(
std::make_shared<math_utils::matrix<seal::Ciphertext>>(1,n_i,expand_query(*query_i_cpy, n_i, *galkey_cpy))
);
},
.wg = promise->get_latch(),

}
);
return promise;
}

}
11 changes: 11 additions & 0 deletions src/math_utils/query_expander.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <seal/seal.h>
#include <vector>
#include "concurrency/concurrency.h"
#include "matrix.h"

namespace math_utils {
/***
Expand All @@ -29,6 +30,16 @@ namespace math_utils {
std::shared_ptr<concurrency::promise<std::vector<seal::Ciphertext>>>
async_expand(std::vector<seal::Ciphertext> query_i, uint64_t n_i, seal::GaloisKeys &galkey);

/**
* retiurns a promise of a query that is expanded into a col vector
* @param query_i query to expand
* @param n_i number of cols
* @param galkey
* @return expanded query promise
*/
std::shared_ptr<concurrency::promise<math_utils::matrix<seal::Ciphertext>>>
async_expand_to_matrix(std::vector<seal::Ciphertext> query_i, uint64_t n_i, seal::GaloisKeys &galkey);

std::vector<seal::Ciphertext> __expand_query(const seal::Ciphertext &encrypted,
uint32_t m, seal::GaloisKeys &galkey) const;

Expand Down
Loading

0 comments on commit e8a63dd

Please sign in to comment.