Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concurrency refactor #35

Merged
merged 3 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@


add_subdirectory(concurrency)
add_subdirectory(math_utils)
add_subdirectory(marshal)
add_subdirectory(services)
add_subdirectory(internals)

add_library(distribicom_cpp distributed_pir.hpp distributed_pir.cpp master.hpp master.cpp worker.hpp worker.cpp
FreivaldsVector.hpp FreivaldsVector.cpp)
add_library(distribicom_cpp old_src/distributed_pir.hpp old_src/distributed_pir.cpp old_src/master.hpp old_src/master.cpp old_src/worker.hpp old_src/worker.cpp
old_src/FreivaldsVector.hpp old_src/FreivaldsVector.cpp)


target_include_directories(distribicom_cpp PUBLIC marshal math_utils services ${com_sealpir_SOURCE_DIR}/src ${CMAKE_CURRENT_SOURCE_DIR})
Expand Down
3 changes: 3 additions & 0 deletions src/concurrency/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
add_library(concurrency_utils channel.hpp waitgroup.hpp waitgroup.cpp)
cmake_path(GET CMAKE_CURRENT_SOURCE_DIR PARENT_PATH MY_PARENT_DIR)
target_include_directories(concurrency_utils PUBLIC ${MY_PARENT_DIR})
1 change: 1 addition & 0 deletions src/concurrency/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The files here are related to local concurrency
95 changes: 95 additions & 0 deletions src/concurrency/channel.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#pragma once


#include <queue>
#include <mutex>
#include <condition_variable>
namespace concurrency {

template<class T>
struct Result {
T answer;
bool ok;
};


template<class T>
class Channel {
public:
Channel() : q(), m(), c() {}

~Channel() {
close();
}

// not allowing copy or moving of a channel.
Channel(const Channel &) = delete;

Channel(Channel &&) = delete;


/**
* Adds an element to the queue.
*/
void write(T t) {
std::lock_guard<std::mutex> lock(m);
q.push(t);
c.notify_one();
}

/**
* Get the "front"-element.
* If there is nothing to read from the channel, wait till an element was written on another thread.
*/
Result<T> read() {
std::unique_lock<std::mutex> lock(m);
// returns if q is not empty, and if channel is not closed. // todo verify this expression.
c.wait(lock, [&] { return (!q.empty() || closed); });
if (closed) { return Result<T>{T(), false}; } // result is not OK.
T val = q.front();
q.pop();
return Result<T>{val, true};
}

/**
* read_for behaves like read(), but ensures the caller does not block forever on the channel.
* After the given duration the channel returns a result of read failure.
*/ // TODO: should return something that indicates timeout.
template<typename _Rep, typename _Period>
Result<T> read_for(const std::chrono::duration<_Rep, _Period> &dur) {
auto max_timeout = std::chrono::steady_clock::now() + dur;

std::unique_lock<std::mutex> lock(m);
// returns if q is not empty, and if channel is not closed.
c.wait_until(lock, max_timeout, [&] { return (!q.empty() || closed); });

if (std::chrono::steady_clock::now() > max_timeout) {
return Result<T>{T(), false}; // timeout passed..
}

if (closed) {
return Result<T>{T(), false};
} // result is not OK.

T val = q.front();
q.pop();
return Result<T>{val, true};
}

/**
* Closes the channel, anyone attempting to read from a closed channel should quickly receive read failure.
*/
void close() {
std::lock_guard<std::mutex> lock(m);
closed = true;
c.notify_all();
}


private:
bool closed = false;
std::queue<T> q;
mutable std::mutex m;
std::condition_variable c;
};
}
30 changes: 30 additions & 0 deletions src/concurrency/waitgroup.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "waitgroup.hpp"

namespace concurrency {
void WaitGroup::add(int delta) {
{
count.fetch_add(delta); // TODO: understand order of operation parameter in these funcs.
}
}


void WaitGroup::done() {
{
auto prev_val = count.fetch_add(-1);
if (prev_val == 1) {
cv.notify_all();
return;
}
if (prev_val <= 0) {
throw std::runtime_error("negative counter in wait group!");
}
}
}

void WaitGroup::wait() {
{
std::unique_lock<std::mutex> lock(m);
cv.wait(lock, [this] { return count.load() == 0; });
}
}
}
26 changes: 26 additions & 0 deletions src/concurrency/waitgroup.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include <queue>
#include <condition_variable>

namespace concurrency {
class WaitGroup {
private:
std::mutex m;
std::condition_variable cv;
std::atomic_int32_t count = 0;
public:
WaitGroup() : m(), cv() {}

// not allowing move or copy.
WaitGroup(const WaitGroup &) = delete;

WaitGroup(WaitGroup &&) = delete;

void add(int delta);

void done();

void wait();
};
}
4 changes: 3 additions & 1 deletion src/marshal/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
add_library(marshal_package marshal.hpp marshal.cpp)

target_include_directories(marshal_package PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
cmake_path(GET CMAKE_CURRENT_SOURCE_DIR PARENT_PATH MY_PARENT_DIR)

target_include_directories(marshal_package PUBLIC ${MY_PARENT_DIR})
target_link_libraries(marshal_package SEAL::seal)
8 changes: 5 additions & 3 deletions src/math_utils/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
add_library(multiplication_utils matrix_multiplier.hpp matrix_multiplier.cpp evaluator_wrapper.hpp evaluator_wrapper.cpp query_expander.hpp query_expander.cpp matrix.h channel.h)
add_library(multiplication_utils matrix_operations.hpp matrix_operations.cpp evaluator_wrapper.hpp evaluator_wrapper.cpp query_expander.hpp query_expander.cpp matrix.h)

target_include_directories(multiplication_utils PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
target_link_libraries(multiplication_utils SEAL::seal)
cmake_path(GET CMAKE_CURRENT_SOURCE_DIR PARENT_PATH MY_PARENT_DIR)

target_include_directories(multiplication_utils PUBLIC ${MY_PARENT_DIR})
target_link_libraries(multiplication_utils concurrency_utils SEAL::seal)
126 changes: 0 additions & 126 deletions src/math_utils/channel.h

This file was deleted.

2 changes: 1 addition & 1 deletion src/math_utils/evaluator_wrapper.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "evaluator_wrapper.hpp"

namespace multiplication_utils {
namespace math_utils {

void EvaluatorWrapper::multiply_add(const std::uint64_t left, const seal::Plaintext &right,
seal::Plaintext &sum) const {
Expand Down
4 changes: 2 additions & 2 deletions src/math_utils/evaluator_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <cassert>
#include "defines.h"

namespace multiplication_utils {
namespace math_utils {
/***
* SplitPlaintextNTTForm represents a plaintext that was intentionally split into two plaintexts such that their
* sum equals the original.
Expand All @@ -28,7 +28,7 @@ namespace multiplication_utils {
std::unique_ptr<seal::Evaluator> evaluator;

/***
* Creates and returns a an initialized matrix_multiplier
* Creates and returns a an initialized MatrixOperations
* @param evaluator
* @param enc_params
* @return a matrix multiplier
Expand Down
4 changes: 2 additions & 2 deletions src/math_utils/matrix.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

#pragma once

#include "vector"
#include <vector>
#include <cstdint>
#include <cassert>

Expand All @@ -13,7 +13,7 @@
#endif // DISTRIBICOM_DEBUG


namespace multiplication_utils {
namespace math_utils {
template<typename T>
class matrix {
public:
Expand Down
Loading