Skip to content

Commit

Permalink
Merge pull request #109 from sakridge/wip_gpu
Browse files Browse the repository at this point in the history
Change for cuda verify integration
  • Loading branch information
garious authored Apr 6, 2018
2 parents bc6d6b2 + f4466c8 commit a7f59ef
Show file tree
Hide file tree
Showing 13 changed files with 400 additions and 109 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ codecov = { repository = "solana-labs/solana", branch = "master", service = "git
[features]
unstable = []
ipv6 = []
cuda = []

[dependencies]
rayon = "1.0.0"
Expand All @@ -54,5 +55,7 @@ untrusted = "0.5.1"
bincode = "1.0.0"
chrono = { version = "0.4.0", features = ["serde"] }
log = "^0.4.1"
env_logger = "^0.4.1"
matches = "^0.1.6"
byteorder = "^1.2.1"
libc = "^0.2.1"
12 changes: 12 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use std::env;

fn main() {
if !env::var("CARGO_FEATURE_CUDA").is_err() {
println!("cargo:rustc-link-search=native=.");
println!("cargo:rustc-link-lib=static=cuda_verify_ed25519");
println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cuda");
println!("cargo:rustc-link-lib=dylib=cudadevrt");
}
}
10 changes: 6 additions & 4 deletions src/accountant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
//! on behalf of the caller, and a private low-level API for when they have
//! already been signed and verified.
extern crate libc;

use chrono::prelude::*;
use event::Event;
use hash::Hash;
Expand Down Expand Up @@ -104,19 +106,19 @@ impl Accountant {

/// Process a Transaction that has already been verified.
pub fn process_verified_transaction(&self, tr: &Transaction) -> Result<()> {
if self.get_balance(&tr.from).unwrap_or(0) < tr.tokens {
if self.get_balance(&tr.from).unwrap_or(0) < tr.data.tokens {
return Err(AccountingError::InsufficientFunds);
}

if !self.reserve_signature_with_last_id(&tr.sig, &tr.last_id) {
if !self.reserve_signature_with_last_id(&tr.sig, &tr.data.last_id) {
return Err(AccountingError::InvalidTransferSignature);
}

if let Some(x) = self.balances.read().unwrap().get(&tr.from) {
*x.write().unwrap() -= tr.tokens;
*x.write().unwrap() -= tr.data.tokens;
}

let mut plan = tr.plan.clone();
let mut plan = tr.data.plan.clone();
plan.apply_witness(&Witness::Timestamp(*self.last_time.read().unwrap()));

if let Some(ref payment) = plan.final_payment() {
Expand Down
155 changes: 114 additions & 41 deletions src/accountant_skel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,27 @@ use accountant::Accountant;
use bincode::{deserialize, serialize};
use entry::Entry;
use event::Event;
use ecdsa;
use hash::Hash;
use historian::Historian;
use packet;
use packet::SharedPackets;
use rayon::prelude::*;
use recorder::Signal;
use result::Result;
use serde_json;
use signature::PublicKey;
use std::cmp::max;
use std::collections::VecDeque;
use std::io::Write;
use std::net::{SocketAddr, UdpSocket};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{channel, SendError};
use std::sync::mpsc::{channel, Receiver, SendError, Sender};
use std::sync::{Arc, Mutex};
use std::thread::{spawn, JoinHandle};
use std::time::Duration;
use streamer;
use packet;
use std::sync::{Arc, Mutex};
use transaction::Transaction;
use std::collections::VecDeque;

pub struct AccountantSkel<W: Write + Send + 'static> {
acc: Accountant,
Expand All @@ -44,14 +47,14 @@ impl Request {
/// Verify the request is valid.
pub fn verify(&self) -> bool {
match *self {
Request::Transaction(ref tr) => tr.verify(),
Request::Transaction(ref tr) => tr.verify_plan(),
_ => true,
}
}
}

/// Parallel verfication of a batch of requests.
fn filter_valid_requests(reqs: Vec<(Request, SocketAddr)>) -> Vec<(Request, SocketAddr)> {
pub fn filter_valid_requests(reqs: Vec<(Request, SocketAddr)>) -> Vec<(Request, SocketAddr)> {
reqs.into_par_iter().filter({ |x| x.0.verify() }).collect()
}

Expand Down Expand Up @@ -84,16 +87,20 @@ impl<W: Write + Send + 'static> AccountantSkel<W> {
}

/// Process Request items sent by clients.
pub fn log_verified_request(&mut self, msg: Request) -> Option<Response> {
pub fn log_verified_request(&mut self, msg: Request, verify: u8) -> Option<Response> {
match msg {
Request::Transaction(_) if verify == 0 => {
trace!("Transaction failed sigverify");
None
}
Request::Transaction(tr) => {
if let Err(err) = self.acc.process_verified_transaction(&tr) {
eprintln!("Transaction error: {:?}", err);
trace!("Transaction error: {:?}", err);
} else if let Err(SendError(_)) = self.historian
.sender
.send(Signal::Event(Event::Transaction(tr)))
.send(Signal::Event(Event::Transaction(tr.clone())))
{
eprintln!("Channel send error");
error!("Channel send error");
}
None
}
Expand All @@ -105,46 +112,87 @@ impl<W: Write + Send + 'static> AccountantSkel<W> {
}
}

fn verifier(
recvr: &streamer::PacketReceiver,
sendr: &Sender<(Vec<SharedPackets>, Vec<Vec<u8>>)>,
) -> Result<()> {
let timer = Duration::new(1, 0);
let msgs = recvr.recv_timeout(timer)?;
trace!("got msgs");
let mut v = Vec::new();
v.push(msgs);
while let Ok(more) = recvr.try_recv() {
trace!("got more msgs");
v.push(more);
}
info!("batch {}", v.len());
let chunk = max(1, (v.len() + 3) / 4);
let chunks: Vec<_> = v.chunks(chunk).collect();
let rvs: Vec<_> = chunks
.into_par_iter()
.map(|x| ecdsa::ed25519_verify(&x.to_vec()))
.collect();
for (v, r) in v.chunks(chunk).zip(rvs) {
sendr.send((v.to_vec(), r))?;
}
Ok(())
}

pub fn deserialize_packets(p: &packet::Packets) -> Vec<Option<(Request, SocketAddr)>> {
// TODO: deserealize in parallel
let mut r = vec![];
for x in &p.packets {
let rsp_addr = x.meta.addr();
let sz = x.meta.size;
if let Ok(req) = deserialize(&x.data[0..sz]) {
r.push(Some((req, rsp_addr)));
} else {
r.push(None);
}
}
r
}

fn process(
obj: &Arc<Mutex<AccountantSkel<W>>>,
packet_receiver: &streamer::PacketReceiver,
verified_receiver: &Receiver<(Vec<SharedPackets>, Vec<Vec<u8>>)>,
blob_sender: &streamer::BlobSender,
packet_recycler: &packet::PacketRecycler,
blob_recycler: &packet::BlobRecycler,
) -> Result<()> {
let timer = Duration::new(1, 0);
let msgs = packet_receiver.recv_timeout(timer)?;
let msgs_ = msgs.clone();
let mut rsps = VecDeque::new();
{
let mut reqs = vec![];
for packet in &msgs.read().unwrap().packets {
let rsp_addr = packet.meta.addr();
let sz = packet.meta.size;
let req = deserialize(&packet.data[0..sz])?;
reqs.push((req, rsp_addr));
}
let reqs = filter_valid_requests(reqs);
for (req, rsp_addr) in reqs {
if let Some(resp) = obj.lock().unwrap().log_verified_request(req) {
let blob = blob_recycler.allocate();
{
let mut b = blob.write().unwrap();
let v = serialize(&resp)?;
let len = v.len();
b.data[..len].copy_from_slice(&v);
b.meta.size = len;
b.meta.set_addr(&rsp_addr);
let (mms, vvs) = verified_receiver.recv_timeout(timer)?;
for (msgs, vers) in mms.into_iter().zip(vvs.into_iter()) {
let msgs_ = msgs.clone();
let mut rsps = VecDeque::new();
{
let reqs = Self::deserialize_packets(&((*msgs).read().unwrap()));
for (data, v) in reqs.into_iter().zip(vers.into_iter()) {
if let Some((req, rsp_addr)) = data {
if !req.verify() {
continue;
}
if let Some(resp) = obj.lock().unwrap().log_verified_request(req, v) {
let blob = blob_recycler.allocate();
{
let mut b = blob.write().unwrap();
let v = serialize(&resp)?;
let len = v.len();
b.data[..len].copy_from_slice(&v);
b.meta.size = len;
b.meta.set_addr(&rsp_addr);
}
rsps.push_back(blob);
}
}
rsps.push_back(blob);
}
}
if !rsps.is_empty() {
//don't wake up the other side if there is nothing
blob_sender.send(rsps)?;
}
packet_recycler.recycle(msgs_);
}
if !rsps.is_empty() {
//don't wake up the other side if there is nothing
blob_sender.send(rsps)?;
}
packet_recycler.recycle(msgs_);
Ok(())
}

Expand All @@ -169,11 +217,21 @@ impl<W: Write + Send + 'static> AccountantSkel<W> {
let (blob_sender, blob_receiver) = channel();
let t_responder =
streamer::responder(write, exit.clone(), blob_recycler.clone(), blob_receiver);
let (verified_sender, verified_receiver) = channel();

let exit_ = exit.clone();
let t_verifier = spawn(move || loop {
let e = Self::verifier(&packet_receiver, &verified_sender);
if e.is_err() && exit_.load(Ordering::Relaxed) {
break;
}
});

let skel = obj.clone();
let t_server = spawn(move || loop {
let e = AccountantSkel::process(
&skel,
&packet_receiver,
&verified_receiver,
&blob_sender,
&packet_recycler,
&blob_recycler,
Expand All @@ -182,6 +240,21 @@ impl<W: Write + Send + 'static> AccountantSkel<W> {
break;
}
});
Ok(vec![t_receiver, t_responder, t_server])
Ok(vec![t_receiver, t_responder, t_server, t_verifier])
}
}

#[cfg(test)]
mod tests {
use accountant_skel::Request;
use bincode::serialize;
use ecdsa;
use transaction::{memfind, test_tx};
#[test]
fn test_layout() {
let tr = test_tx();
let tx = serialize(&tr).unwrap();
let packet = serialize(&Request::Transaction(tr)).unwrap();
assert_matches!(memfind(&packet, &tx), Some(ecdsa::TX_OFFSET));
}
}
4 changes: 3 additions & 1 deletion src/bin/testnode.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
extern crate env_logger;
extern crate serde_json;
extern crate solana;

Expand All @@ -11,6 +12,7 @@ use std::sync::atomic::AtomicBool;
use std::sync::{Arc, Mutex};

fn main() {
env_logger::init().unwrap();
let addr = "127.0.0.1:8000";
let stdin = io::stdin();
let mut entries = stdin
Expand All @@ -27,7 +29,7 @@ fn main() {
// transfer to oneself.
let entry1: Entry = entries.next().unwrap();
let deposit = if let Event::Transaction(ref tr) = entry1.events[0] {
tr.plan.final_payment()
tr.data.plan.final_payment()
} else {
None
};
Expand Down
Loading

0 comments on commit a7f59ef

Please sign in to comment.