Skip to content

Commit

Permalink
Deduplicate arbitrator RustBytes types
Browse files Browse the repository at this point in the history
  • Loading branch information
PlasmaPower committed Dec 12, 2024
1 parent 2b3b823 commit 0afaf50
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 90 deletions.
79 changes: 62 additions & 17 deletions arbitrator/prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use once_cell::sync::OnceCell;
use static_assertions::const_assert_eq;
use std::{
ffi::CStr,
marker::PhantomData,
num::NonZeroUsize,
os::raw::{c_char, c_int},
path::Path,
Expand All @@ -59,11 +60,67 @@ pub struct CByteArray {
}

#[repr(C)]
#[derive(Clone, Copy)]
pub struct RustByteArray {
pub struct RustSlice<'a> {
pub ptr: *const u8,
pub len: usize,
pub phantom: PhantomData<&'a [u8]>,
}

impl<'a> RustSlice<'a> {
pub fn new(slice: &'a [u8]) -> Self {
if slice.is_empty() {
return Self {
ptr: ptr::null(),
len: 0,
phantom: PhantomData,
};
}
Self {
ptr: slice.as_ptr(),
len: slice.len(),
phantom: PhantomData,
}
}
}

#[repr(C)]
pub struct RustBytes {
pub ptr: *mut u8,
pub len: usize,
pub capacity: usize,
pub cap: usize,
}

impl RustBytes {
pub unsafe fn into_vec(self) -> Vec<u8> {
Vec::from_raw_parts(self.ptr, self.len, self.cap)
}

pub unsafe fn write(&mut self, mut vec: Vec<u8>) {
if vec.capacity() == 0 {
*self = RustBytes {
ptr: ptr::null_mut(),
len: 0,
cap: 0,
};
return;
}
self.ptr = vec.as_mut_ptr();
self.len = vec.len();
self.cap = vec.capacity();
std::mem::forget(vec);
}
}

/// Frees the vector. Does nothing when the vector is null.
///
/// # Safety
///
/// Must only be called once per vec.
#[no_mangle]
pub unsafe extern "C" fn free_rust_bytes(vec: RustBytes) {
if !vec.ptr.is_null() {
drop(vec.into_vec())
}
}

#[no_mangle]
Expand Down Expand Up @@ -410,18 +467,6 @@ pub unsafe extern "C" fn arbitrator_module_root(mach: *mut Machine) -> Bytes32 {

#[no_mangle]
#[cfg(feature = "native")]
pub unsafe extern "C" fn arbitrator_gen_proof(mach: *mut Machine) -> RustByteArray {
let mut proof = (*mach).serialize_proof();
let ret = RustByteArray {
ptr: proof.as_mut_ptr(),
len: proof.len(),
capacity: proof.capacity(),
};
std::mem::forget(proof);
ret
}

#[no_mangle]
pub unsafe extern "C" fn arbitrator_free_proof(proof: RustByteArray) {
drop(Vec::from_raw_parts(proof.ptr, proof.len, proof.capacity))
pub unsafe extern "C" fn arbitrator_gen_proof(mach: *mut Machine, out: *mut RustBytes) {
(*out).write((*mach).serialize_proof());
}
3 changes: 2 additions & 1 deletion arbitrator/stylus/src/evm_api.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// Copyright 2022-2024, Offchain Labs, Inc.
// For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE

use crate::{GoSliceData, RustSlice};
use crate::GoSliceData;
use arbutil::evm::{
api::{EvmApiMethod, Gas, EVM_API_METHOD_REQ_OFFSET},
req::RequestHandler,
};
use prover::RustSlice;

#[repr(C)]
pub struct NativeRequestHandler {
Expand Down
90 changes: 22 additions & 68 deletions arbitrator/stylus/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ use cache::{deserialize_module, CacheMetrics, InitCache};
use evm_api::NativeRequestHandler;
use eyre::ErrReport;
use native::NativeInstance;
use prover::programs::{prelude::*, StylusData};
use prover::{
programs::{prelude::*, StylusData},
RustBytes,
};
use run::RunProgram;
use std::{marker::PhantomData, mem, ptr};
use std::ptr;
use target_cache::{target_cache_get, target_cache_set};

pub use brotli;
Expand Down Expand Up @@ -76,52 +79,15 @@ impl DataReader for GoSliceData {
}
}

#[repr(C)]
pub struct RustSlice<'a> {
ptr: *const u8,
len: usize,
phantom: PhantomData<&'a [u8]>,
}

impl<'a> RustSlice<'a> {
fn new(slice: &'a [u8]) -> Self {
Self {
ptr: slice.as_ptr(),
len: slice.len(),
phantom: PhantomData,
}
}
}

#[repr(C)]
pub struct RustBytes {
ptr: *mut u8,
len: usize,
cap: usize,
unsafe fn write_err(output: &mut RustBytes, err: ErrReport) -> UserOutcomeKind {
output.write(err.debug_bytes());
UserOutcomeKind::Failure
}

impl RustBytes {
unsafe fn into_vec(self) -> Vec<u8> {
Vec::from_raw_parts(self.ptr, self.len, self.cap)
}

unsafe fn write(&mut self, mut vec: Vec<u8>) {
self.ptr = vec.as_mut_ptr();
self.len = vec.len();
self.cap = vec.capacity();
mem::forget(vec);
}

unsafe fn write_err(&mut self, err: ErrReport) -> UserOutcomeKind {
self.write(err.debug_bytes());
UserOutcomeKind::Failure
}

unsafe fn write_outcome(&mut self, outcome: UserOutcome) -> UserOutcomeKind {
let (status, outs) = outcome.into_data();
self.write(outs);
status
}
unsafe fn write_outcome(output: &mut RustBytes, outcome: UserOutcome) -> UserOutcomeKind {
let (status, outs) = outcome.into_data();
output.write(outs);
status
}

/// "activates" a user wasm.
Expand Down Expand Up @@ -164,7 +130,7 @@ pub unsafe extern "C" fn stylus_activate(
gas,
) {
Ok(val) => val,
Err(err) => return output.write_err(err),
Err(err) => return write_err(output, err),
};

*module_hash = module.hash();
Expand Down Expand Up @@ -194,16 +160,16 @@ pub unsafe extern "C" fn stylus_compile(
let output = &mut *output;
let name = match String::from_utf8(name.slice().to_vec()) {
Ok(val) => val,
Err(err) => return output.write_err(err.into()),
Err(err) => return write_err(output, err.into()),
};
let target = match target_cache_get(&name) {
Ok(val) => val,
Err(err) => return output.write_err(err),
Err(err) => return write_err(output, err),
};

let asm = match native::compile(wasm, version, debug, target) {
Ok(val) => val,
Err(err) => return output.write_err(err),
Err(err) => return write_err(output, err),
};

output.write(asm);
Expand All @@ -218,7 +184,7 @@ pub unsafe extern "C" fn wat_to_wasm(wat: GoSliceData, output: *mut RustBytes) -
let output = &mut *output;
let wasm = match wasmer::wat2wasm(wat.slice()) {
Ok(val) => val,
Err(err) => return output.write_err(err.into()),
Err(err) => return write_err(output, err.into()),
};
output.write(wasm.into_owned());
UserOutcomeKind::Success
Expand All @@ -241,16 +207,16 @@ pub unsafe extern "C" fn stylus_target_set(
let output = &mut *output;
let name = match String::from_utf8(name.slice().to_vec()) {
Ok(val) => val,
Err(err) => return output.write_err(err.into()),
Err(err) => return write_err(output, err.into()),
};

let desc_str = match String::from_utf8(description.slice().to_vec()) {
Ok(val) => val,
Err(err) => return output.write_err(err.into()),
Err(err) => return write_err(output, err.into()),
};

if let Err(err) = target_cache_set(name, desc_str, native) {
return output.write_err(err);
return write_err(output, err);
};

UserOutcomeKind::Success
Expand Down Expand Up @@ -298,8 +264,8 @@ pub unsafe extern "C" fn stylus_call(
};

let status = match instance.run_main(&calldata, config, ink) {
Err(e) | Ok(UserOutcome::Failure(e)) => output.write_err(e.wrap_err("call failed")),
Ok(outcome) => output.write_outcome(outcome),
Err(e) | Ok(UserOutcome::Failure(e)) => write_err(output, e.wrap_err("call failed")),
Ok(outcome) => write_outcome(output, outcome),
};
let ink_left = match status {
UserOutcomeKind::OutOfStack => Ink(0), // take all gas when out of stack
Expand Down Expand Up @@ -352,18 +318,6 @@ pub extern "C" fn stylus_reorg_vm(_block: u64, arbos_tag: u32) {
InitCache::clear_long_term(arbos_tag);
}

/// Frees the vector. Does nothing when the vector is null.
///
/// # Safety
///
/// Must only be called once per vec.
#[no_mangle]
pub unsafe extern "C" fn stylus_drop_vec(vec: RustBytes) {
if !vec.ptr.is_null() {
mem::drop(vec.into_vec())
}
}

/// Gets cache metrics.
///
/// # Safety
Expand Down
2 changes: 1 addition & 1 deletion arbos/programs/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ func (vec *rustBytes) intoBytes() []byte {
}

func (vec *rustBytes) drop() {
C.stylus_drop_vec(*vec)
C.free_rust_bytes(*vec)
}

func goSlice(slice []byte) C.GoSliceData {
Expand Down
7 changes: 4 additions & 3 deletions validator/server_arb/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,10 @@ func (m *ArbitratorMachine) ProveNextStep() []byte {
m.mutex.Lock()
defer m.mutex.Unlock()

rustProof := C.arbitrator_gen_proof(m.ptr)
proofBytes := C.GoBytes(unsafe.Pointer(rustProof.ptr), C.int(rustProof.len))
C.arbitrator_free_proof(rustProof)
output := &C.RustBytes{}
C.arbitrator_gen_proof(m.ptr, output)
proofBytes := C.GoBytes(unsafe.Pointer(output.ptr), C.int(output.len))
C.free_rust_bytes(*output)

return proofBytes
}
Expand Down

0 comments on commit 0afaf50

Please sign in to comment.