diff --git a/Cargo.lock b/Cargo.lock index 7a991dabeab..36a6961de25 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3803,6 +3803,7 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" name = "sdk-test-module" version = "0.1.0" dependencies = [ + "anyhow", "log", "spacetimedb", ] @@ -4444,6 +4445,7 @@ dependencies = [ "lazy_static", "log", "prost", + "rand 0.8.5", "spacetimedb-client-api-messages", "spacetimedb-lib", "spacetimedb-sats", diff --git a/crates/bindings-csharp/Runtime/Module.cs b/crates/bindings-csharp/Runtime/Module.cs index a6b357359a4..2de25ac4003 100644 --- a/crates/bindings-csharp/Runtime/Module.cs +++ b/crates/bindings-csharp/Runtime/Module.cs @@ -189,6 +189,7 @@ private static byte[] DescribeModule() private static string? CallReducer( uint id, byte[] sender_identity, + byte[] sender_address, ulong timestamp, byte[] args ) @@ -197,7 +198,7 @@ byte[] args { using var stream = new MemoryStream(args); using var reader = new BinaryReader(stream); - reducers[(int)id].Invoke(reader, new(sender_identity, timestamp)); + reducers[(int)id].Invoke(reader, new(sender_identity, sender_address, timestamp)); if (stream.Position != stream.Length) { throw new Exception("Unrecognised extra bytes in the reducer arguments"); diff --git a/crates/bindings-csharp/Runtime/Runtime.cs b/crates/bindings-csharp/Runtime/Runtime.cs index 1927efa052c..d96a91ce168 100644 --- a/crates/bindings-csharp/Runtime/Runtime.cs +++ b/crates/bindings-csharp/Runtime/Runtime.cs @@ -216,14 +216,61 @@ public override int GetHashCode() => public static SpacetimeDB.SATS.TypeInfo GetSatsTypeInfo() => satsTypeInfo; } + public struct Address : IEquatable
+ { + private readonly byte[] bytes; + + public Address(byte[] bytes) => this.bytes = bytes; + + public static readonly Address Zero = new(new byte[16]); + + public bool Equals(Address other) => + StructuralComparisons.StructuralEqualityComparer.Equals(bytes, other.bytes); + + public override bool Equals(object? obj) => obj is Address other && Equals(other); + + public static bool operator ==(Address left, Address right) => left.Equals(right); + + public static bool operator !=(Address left, Address right) => !left.Equals(right); + + public override int GetHashCode() => + StructuralComparisons.StructuralEqualityComparer.GetHashCode(bytes); + + public override string ToString() => BitConverter.ToString(bytes); + + private static SpacetimeDB.SATS.TypeInfo
satsTypeInfo = + new( + // We need to set type info to inlined address type as `generate` CLI currently can't recognise type references for built-ins. + new SpacetimeDB.SATS.ProductType + { + { "__address_bytes", SpacetimeDB.SATS.BuiltinType.BytesTypeInfo.AlgebraicType } + }, + // Concern: We use this "packed" representation (as Bytes) + // in the caller_id field of reducer arguments, + // but in table rows, + // we send the "unpacked" representation as a product value. + // It's possible that these happen to be identical + // because BSATN is minimally self-describing, + // but that doesn't seem like something we should count on. + reader => new(SpacetimeDB.SATS.BuiltinType.BytesTypeInfo.Read(reader)), + (writer, value) => + SpacetimeDB.SATS.BuiltinType.BytesTypeInfo.Write(writer, value.bytes) + ); + + public static SpacetimeDB.SATS.TypeInfo
GetSatsTypeInfo() => satsTypeInfo; + } + public class DbEventArgs : EventArgs { public readonly Identity Sender; public readonly DateTimeOffset Time; + public readonly Address? Address; - public DbEventArgs(byte[] senderIdentity, ulong timestamp_us) + public DbEventArgs(byte[] senderIdentity, byte[] senderAddress, ulong timestamp_us) { Sender = new Identity(senderIdentity); + var addr = new Address(senderAddress); + Address = addr == Runtime.Address.Zero ? null : addr; // timestamp is in microseconds; the easiest way to convert those w/o losing precision is to get Unix origin and add ticks which are 0.1ms each. Time = DateTimeOffset.UnixEpoch.AddTicks(10 * (long)timestamp_us); } @@ -233,11 +280,11 @@ public DbEventArgs(byte[] senderIdentity, ulong timestamp_us) public static event Action? OnDisconnect; // Note: this is accessed by C bindings. - private static string? IdentityConnected(byte[] sender_identity, ulong timestamp) + private static string? IdentityConnected(byte[] sender_identity, byte[] sender_address, ulong timestamp) { try { - OnConnect?.Invoke(new(sender_identity, timestamp)); + OnConnect?.Invoke(new(sender_identity, sender_address, timestamp)); return null; } catch (Exception e) @@ -247,11 +294,11 @@ public DbEventArgs(byte[] senderIdentity, ulong timestamp_us) } // Note: this is accessed by C bindings. - private static string? IdentityDisconnected(byte[] sender_identity, ulong timestamp) + private static string? IdentityDisconnected(byte[] sender_identity, byte[] sender_address, ulong timestamp) { try { - OnDisconnect?.Invoke(new(sender_identity, timestamp)); + OnDisconnect?.Invoke(new(sender_identity, sender_address, timestamp)); return null; } catch (Exception e) diff --git a/crates/bindings-csharp/Runtime/bindings.c b/crates/bindings-csharp/Runtime/bindings.c index 0486b53dadd..f3aeae6075d 100644 --- a/crates/bindings-csharp/Runtime/bindings.c +++ b/crates/bindings-csharp/Runtime/bindings.c @@ -434,33 +434,37 @@ static Buffer return_result_buf(MonoObject* str) { __attribute__((export_name("__call_reducer__"))) Buffer __call_reducer__( uint32_t id, - Buffer sender_, + Buffer sender_id_, + Buffer sender_address_, uint64_t timestamp, Buffer args_) { - MonoArray* sender = stdb_buffer_consume(sender_); + MonoArray* sender_id = stdb_buffer_consume(sender_id_); + MonoArray* sender_address = stdb_buffer_consume(sender_address_); MonoArray* args = stdb_buffer_consume(args_); return return_result_buf(INVOKE_DOTNET_METHOD( "SpacetimeDB.Runtime.dll", "SpacetimeDB.Module", "FFI", "CallReducer", - NULL, &id, sender, ×tamp, args)); + NULL, &id, sender_id, sender_address, ×tamp, args)); } __attribute__((export_name("__identity_connected__"))) Buffer -__identity_connected__(Buffer sender_, uint64_t timestamp) { - MonoArray* sender = stdb_buffer_consume(sender_); +__identity_connected__(Buffer sender_id_, Buffer sender_address_, uint64_t timestamp) { + MonoArray* sender_id = stdb_buffer_consume(sender_id_); + MonoArray* sender_address = stdb_buffer_consume(sender_address_); return return_result_buf( INVOKE_DOTNET_METHOD("SpacetimeDB.Runtime.dll", "SpacetimeDB", "Runtime", - "IdentityConnected", NULL, sender, ×tamp)); + "IdentityConnected", NULL, sender_id, sender_address, ×tamp)); } __attribute__((export_name("__identity_disconnected__"))) Buffer -__identity_disconnected__(Buffer sender_, uint64_t timestamp) { - MonoArray* sender = stdb_buffer_consume(sender_); +__identity_disconnected__(Buffer sender_id_, Buffer sender_address_, uint64_t timestamp) { + MonoArray* sender_id = stdb_buffer_consume(sender_id_); + MonoArray* sender_address = stdb_buffer_consume(sender_address_); return return_result_buf( INVOKE_DOTNET_METHOD("SpacetimeDB.Runtime.dll", "SpacetimeDB", "Runtime", - "IdentityDisconnected", NULL, sender, ×tamp)); + "IdentityDisconnected", NULL, sender_id, sender_address, ×tamp)); } // Shims to avoid dependency on WASI in the generated Wasm file. diff --git a/crates/bindings-macro/src/lib.rs b/crates/bindings-macro/src/lib.rs index 5713410530c..80b418707fa 100644 --- a/crates/bindings-macro/src/lib.rs +++ b/crates/bindings-macro/src/lib.rs @@ -368,10 +368,23 @@ fn gen_reducer(original_function: ItemFn, reducer_name: &str, extra: ReducerExtr }; let generated_function = quote! { - fn __reducer(__sender: spacetimedb::sys::Buffer, __timestamp: u64, __args: &[u8]) -> spacetimedb::sys::Buffer { + // NOTE: double-underscoring names here is unnecessary, as Rust macros are hygienic. + fn __reducer( + __sender: spacetimedb::sys::Buffer, + __caller_address: spacetimedb::sys::Buffer, + __timestamp: u64, + __args: &[u8] + ) -> spacetimedb::sys::Buffer { #(spacetimedb::rt::assert_reducerarg::<#arg_tys>();)* #(spacetimedb::rt::assert_reducerret::<#ret_ty>();)* - spacetimedb::rt::invoke_reducer(#func_name, __sender, __timestamp, __args, |_res| { #epilogue }) + spacetimedb::rt::invoke_reducer( + #func_name, + __sender, + __caller_address, + __timestamp, + __args, + |_res| { #epilogue }, + ) } }; @@ -889,8 +902,12 @@ fn spacetimedb_connect_disconnect(item: TokenStream, connect: bool) -> syn::Resu let emission = quote! { const _: () = { #[export_name = #connect_disconnect_symbol] - extern "C" fn __connect_disconnect(__sender: spacetimedb::sys::Buffer, __timestamp: u64) -> spacetimedb::sys::Buffer { - spacetimedb::rt::invoke_connection_func(#func_name, __sender, __timestamp) + extern "C" fn __connect_disconnect( + __sender: spacetimedb::sys::Buffer, + __caller_address: spacetimedb::sys::Buffer, + __timestamp: u64, + ) -> spacetimedb::sys::Buffer { + spacetimedb::rt::invoke_connection_func(#func_name, __sender, __caller_address, __timestamp) } }; diff --git a/crates/bindings/src/lib.rs b/crates/bindings/src/lib.rs index 072fd3a430d..0a6c5b00883 100644 --- a/crates/bindings/src/lib.rs +++ b/crates/bindings/src/lib.rs @@ -23,6 +23,7 @@ pub use spacetimedb_bindings_macro::{duration, query, spacetimedb, TableType}; pub use sats::SpacetimeType; pub use spacetimedb_lib; pub use spacetimedb_lib::sats; +pub use spacetimedb_lib::Address; pub use spacetimedb_lib::AlgebraicValue; pub use spacetimedb_lib::Identity; pub use timestamp::Timestamp; @@ -51,6 +52,14 @@ pub struct ReducerContext { pub sender: Identity, /// The time at which the reducer was started. pub timestamp: Timestamp, + /// The `Address` of the client that invoked the reducer. + /// + /// `None` if no `Address` was supplied to the `/database/call` HTTP endpoint, + /// or via the CLI's `spacetime call` subcommand. + /// + /// For automatic reducers, i.e. `init`, `update` and scheduled reducers, + /// this will be the module's `Address`. + pub address: Option
, } impl ReducerContext { @@ -59,6 +68,7 @@ impl ReducerContext { Self { sender: Identity::__dummy(), timestamp: Timestamp::UNIX_EPOCH, + address: None, } } } diff --git a/crates/bindings/src/rt.rs b/crates/bindings/src/rt.rs index 55060def87f..26d9e9ed486 100644 --- a/crates/bindings/src/rt.rs +++ b/crates/bindings/src/rt.rs @@ -14,7 +14,7 @@ use spacetimedb_lib::de::{self, Deserialize, SeqProductAccess}; use spacetimedb_lib::sats::typespace::TypespaceBuilder; use spacetimedb_lib::sats::{impl_deserialize, impl_serialize, AlgebraicType, AlgebraicTypeRef, ProductTypeElement}; use spacetimedb_lib::ser::{Serialize, SerializeSeqProduct}; -use spacetimedb_lib::{bsatn, Identity, MiscModuleExport, ModuleDef, ReducerDef, TableDef, TypeAlias}; +use spacetimedb_lib::{bsatn, Address, Identity, MiscModuleExport, ModuleDef, ReducerDef, TableDef, TypeAlias}; use sys::Buffer; pub use once_cell::sync::{Lazy, OnceCell}; @@ -28,11 +28,12 @@ pub use once_cell::sync::{Lazy, OnceCell}; pub fn invoke_reducer<'a, A: Args<'a>, T>( reducer: impl Reducer<'a, A, T>, sender: Buffer, + client_address: Buffer, timestamp: u64, args: &'a [u8], epilogue: impl FnOnce(Result<(), &str>), ) -> Buffer { - let ctx = assemble_context(sender, timestamp); + let ctx = assemble_context(sender, timestamp, client_address); // Deserialize the arguments from a bsatn encoding. let SerDeArgs(args) = bsatn::from_slice(args).expect("unable to decode args"); @@ -71,21 +72,42 @@ pub fn create_index(index_name: &str, table_id: u32, index_type: sys::raw::Index pub fn invoke_connection_func( f: impl Fn(ReducerContext) -> R, sender: Buffer, + client_address: Buffer, timestamp: u64, ) -> Buffer { - let ctx = assemble_context(sender, timestamp); + let ctx = assemble_context(sender, timestamp, client_address); let res = with_timestamp_set(ctx.timestamp, || f(ctx).into_result()); cvt_result(res) } -/// Creates a reducer context from the given `sender` and `timestamp`. -fn assemble_context(sender: Buffer, timestamp: u64) -> ReducerContext { +/// Creates a reducer context from the given `sender`, `timestamp` and `client_address`. +/// +/// `sender` must contain 32 bytes, from which we will read an `Identity`. +/// +/// `timestamp` is a count of microseconds since the Unix epoch. +/// +/// `client_address` must contain 16 bytes, from which we will read an `Address`. +/// The all-zeros `client_address` (constructed by [`Address::__dummy`]) is used as a sentinel, +/// and translated to `None`. +fn assemble_context(sender: Buffer, timestamp: u64, client_address: Buffer) -> ReducerContext { let sender = Identity::from_byte_array(sender.read_array::<32>()); let timestamp = Timestamp::UNIX_EPOCH + Duration::from_micros(timestamp); - ReducerContext { sender, timestamp } + let address = Address::from_arr(&client_address.read_array::<16>()); + + let address = if address == Address::__dummy() { + None + } else { + Some(address) + }; + + ReducerContext { + sender, + timestamp, + address, + } } /// Converts `errno` into a string message. @@ -471,7 +493,7 @@ impl TypespaceBuilder for ModuleBuilder { static DESCRIBERS: Mutex> = Mutex::new(Vec::new()); /// A reducer function takes in `(Sender, Timestamp, Args)` and writes to a new `Buffer`. -pub type ReducerFn = fn(Buffer, u64, &[u8]) -> Buffer; +pub type ReducerFn = fn(Buffer, Buffer, u64, &[u8]) -> Buffer; static REDUCERS: OnceCell> = OnceCell::new(); /// Describes the module into a serialized form that is returned and writes the set of `REDUCERS`. @@ -497,8 +519,14 @@ extern "C" fn __describe_module__() -> Buffer { /// /// The result of the reducer is written into a fresh buffer. #[no_mangle] -extern "C" fn __call_reducer__(id: usize, sender: Buffer, timestamp: u64, args: Buffer) -> Buffer { +extern "C" fn __call_reducer__( + id: usize, + sender: Buffer, + caller_address: Buffer, + timestamp: u64, + args: Buffer, +) -> Buffer { let reducers = REDUCERS.get().unwrap(); let args = args.read(); - reducers[id](sender, timestamp, &args) + reducers[id](sender, caller_address, timestamp, &args) } diff --git a/crates/cli/src/subcommands/generate/csharp.rs b/crates/cli/src/subcommands/generate/csharp.rs index d8972b0e402..2f0a9b6022e 100644 --- a/crates/cli/src/subcommands/generate/csharp.rs +++ b/crates/cli/src/subcommands/generate/csharp.rs @@ -79,6 +79,8 @@ fn ty_fmt<'a>(ctx: &'a GenCtx, ty: &'a AlgebraicType, namespace: &'a str) -> imp // The only type that is allowed here is the identity type. All other types should fail. if prod.is_identity() { write!(f, "SpacetimeDB.Identity") + } else if prod.is_address() { + write!(f, "SpacetimeDB.Address") } else { unimplemented!() } @@ -183,6 +185,12 @@ fn convert_type<'a>( "SpacetimeDB.Identity.From({}.AsProductValue().elements[0].AsBytes())", value ) + } else if product.is_address() { + write!( + f, + "(SpacetimeDB.Address)SpacetimeDB.Address.From({}.AsProductValue().elements[0].AsBytes())", + value + ) } else { unimplemented!() } @@ -959,6 +967,8 @@ fn autogen_csharp_access_funcs_for_struct( AlgebraicType::Product(product) => { if product.is_identity() { ("Identity".into(), "SpacetimeDB.Identity") + } else if product.is_address() { + ("Address".into(), "SpacetimeDB.Address") } else { // TODO: We don't allow filtering on tuples right now, // it's possible we may consider it for the future. @@ -1031,6 +1041,13 @@ fn autogen_csharp_access_funcs_for_struct( col_i ) .unwrap(); + } else if field_type == "Address" { + writeln!( + output, + "var compareValue = (Address)Address.From(productValue.elements[{}].AsProductValue().elements[0].AsBytes());", + col_i + ) + .unwrap(); } else { writeln!( output, @@ -1395,7 +1412,7 @@ pub fn autogen_csharp_reducer(ctx: &GenCtx, reducer: &ReducerDef, namespace: &st writeln!(output, "args.{arg_name} = {convert};").unwrap(); } - writeln!(output, "dbEvent.FunctionCall.CallInfo = new ReducerEvent(ReducerType.{func_name_pascal_case}, \"{func_name}\", dbEvent.Timestamp, Identity.From(dbEvent.CallerIdentity.ToByteArray()), dbEvent.Message, dbEvent.Status, args);").unwrap(); + writeln!(output, "dbEvent.FunctionCall.CallInfo = new ReducerEvent(ReducerType.{func_name_pascal_case}, \"{func_name}\", dbEvent.Timestamp, Identity.From(dbEvent.CallerIdentity.ToByteArray()), Address.From(dbEvent.CallerAddress.ToByteArray()), dbEvent.Message, dbEvent.Status, args);").unwrap(); } // Closing brace for Event parsing function @@ -1496,12 +1513,12 @@ pub fn autogen_csharp_globals(items: &[GenItem], namespace: &str) -> Vec { - if !product.is_identity() { + if !product.is_special() { continue; } } @@ -491,6 +491,8 @@ pub fn encode_type<'a>( AlgebraicType::Product(product) => { if product.is_identity() { write!(f, "Identity.from_string({value})") + } else if product.is_address() { + write!(f, "Address.from_string({value})") } else { unimplemented!() } @@ -528,7 +530,7 @@ pub fn autogen_python_reducer(ctx: &GenCtx, reducer: &ReducerDef) -> String { writeln!(output, "# WILL NOT BE SAVED. MODIFY TABLES IN RUST INSTEAD.").unwrap(); writeln!(output).unwrap(); - writeln!(output, "from typing import List, Callable").unwrap(); + writeln!(output, "from typing import List, Callable, Optional").unwrap(); writeln!(output).unwrap(); writeln!( @@ -537,6 +539,7 @@ pub fn autogen_python_reducer(ctx: &GenCtx, reducer: &ReducerDef) -> String { ) .unwrap(); writeln!(output, "from spacetimedb_sdk.spacetimedb_client import Identity").unwrap(); + writeln!(output, "from spacetimedb_sdk.spacetimedb_client import Address").unwrap(); writeln!(output).unwrap(); @@ -621,7 +624,7 @@ pub fn autogen_python_reducer(ctx: &GenCtx, reducer: &ReducerDef) -> String { writeln!( output, - "def register_on_{}(callback: Callable[[Identity, str, str{}], None]):", + "def register_on_{}(callback: Callable[[Identity, Optional[Address], str, str{}], None]):", reducer.name.to_case(Case::Snake), callback_sig_str ) diff --git a/crates/cli/src/subcommands/generate/rust.rs b/crates/cli/src/subcommands/generate/rust.rs index 1906b24e653..179ad5b64d0 100644 --- a/crates/cli/src/subcommands/generate/rust.rs +++ b/crates/cli/src/subcommands/generate/rust.rs @@ -65,6 +65,9 @@ pub fn write_type(ctx: &impl Fn(AlgebraicTypeRef) -> String, out: &mut AlgebraicType::Product(p) if p.is_identity() => { write!(out, "Identity").unwrap(); } + AlgebraicType::Product(p) if p.is_address() => { + write!(out, "Address").unwrap(); + } AlgebraicType::Product(ProductType { elements }) => { print_comma_sep_braced(out, elements, |out: &mut W, elem: &ProductTypeElement| { if let Some(name) = &elem.name { @@ -86,6 +89,9 @@ pub fn write_type(ctx: &impl Fn(AlgebraicTypeRef) -> String, out: &mut // on generated types, and notably, `HashMap` is not itself `Hash`, // so any type that holds a `Map` cannot derive `Hash` and cannot // key a `Map`. + // UPDATE: No, `AlgebraicType::Map` is supposed to be `BTreeMap`. Fix this. + // This will require deriving `Ord` for generated types, + // and is likely to be a big headache. write!(out, "HashMap::<").unwrap(); write_type(ctx, out, &ty.key_ty); write!(out, ", ").unwrap(); @@ -155,6 +161,7 @@ const ALLOW_UNUSED: &str = "#[allow(unused)]"; const SPACETIMEDB_IMPORTS: &[&str] = &[ ALLOW_UNUSED, "use spacetimedb_sdk::{", + "\tAddress,", "\tsats::{ser::Serialize, de::Deserialize},", "\ttable::{TableType, TableIter, TableWithPrimaryKey},", "\treducer::{Reducer, ReducerCallbackId, Status},", @@ -632,7 +639,7 @@ pub fn autogen_rust_reducer(ctx: &GenCtx, reducer: &ReducerDef) -> String { writeln!(out, "{}", ALLOW_UNUSED).unwrap(); write!( out, - "pub fn on_{}(mut __callback: impl FnMut(&Identity, &Status", + "pub fn on_{}(mut __callback: impl FnMut(&Identity, Option
, &Status", func_name ) .unwrap(); @@ -646,7 +653,7 @@ pub fn autogen_rust_reducer(ctx: &GenCtx, reducer: &ReducerDef) -> String { |out| { write!(out, "{}", type_name).unwrap(); out.delimited_block( - "::on_reducer(move |__identity, __status, __args| {", + "::on_reducer(move |__identity, __addr, __status, __args| {", |out| { write!(out, "let ").unwrap(); print_reducer_struct_literal(out, reducer); @@ -655,6 +662,7 @@ pub fn autogen_rust_reducer(ctx: &GenCtx, reducer: &ReducerDef) -> String { "__callback(", |out| { writeln!(out, "__identity,").unwrap(); + writeln!(out, "__addr,").unwrap(); writeln!(out, "__status,").unwrap(); for arg_name in iter_reducer_arg_names(reducer) { writeln!(out, "{},", arg_name.unwrap()).unwrap(); @@ -675,7 +683,7 @@ pub fn autogen_rust_reducer(ctx: &GenCtx, reducer: &ReducerDef) -> String { writeln!(out, "{}", ALLOW_UNUSED).unwrap(); write!( out, - "pub fn once_on_{}(__callback: impl FnOnce(&Identity, &Status", + "pub fn once_on_{}(__callback: impl FnOnce(&Identity, Option
, &Status", func_name ) .unwrap(); @@ -689,7 +697,7 @@ pub fn autogen_rust_reducer(ctx: &GenCtx, reducer: &ReducerDef) -> String { |out| { write!(out, "{}", type_name).unwrap(); out.delimited_block( - "::once_on_reducer(move |__identity, __status, __args| {", + "::once_on_reducer(move |__identity, __addr, __status, __args| {", |out| { write!(out, "let ").unwrap(); print_reducer_struct_literal(out, reducer); @@ -698,6 +706,7 @@ pub fn autogen_rust_reducer(ctx: &GenCtx, reducer: &ReducerDef) -> String { "__callback(", |out| { writeln!(out, "__identity,").unwrap(); + writeln!(out, "__addr,").unwrap(); writeln!(out, "__status,").unwrap(); for arg_name in iter_reducer_arg_names(reducer) { writeln!(out, "{},", arg_name.unwrap()).unwrap(); diff --git a/crates/cli/src/subcommands/generate/typescript.rs b/crates/cli/src/subcommands/generate/typescript.rs index 289154e84d1..95844c98322 100644 --- a/crates/cli/src/subcommands/generate/typescript.rs +++ b/crates/cli/src/subcommands/generate/typescript.rs @@ -51,6 +51,8 @@ fn ty_fmt<'a>(ctx: &'a GenCtx, ty: &'a AlgebraicType, ref_prefix: &'a str) -> im // The only type that is allowed here is the identity type. All other types should fail. if prod.is_identity() { write!(f, "Identity") + } else if prod.is_address() { + write!(f, "Address") } else { unimplemented!() } @@ -131,6 +133,8 @@ fn convert_type<'a>( AlgebraicType::Product(product) => { if product.is_identity() { write!(f, "new Identity({}.asProductValue().elements[0].asBytes())", value) + } else if product.is_address() { + write!(f, "new Address({}.asProductValue().elements[0].asBytes())", value) } else { unimplemented!() } @@ -641,7 +645,7 @@ fn autogen_typescript_product_table_common( writeln!(output).unwrap(); writeln!(output, "// @ts-ignore").unwrap(); - writeln!(output, "import {{ __SPACETIMEDB__, AlgebraicType, ProductType, BuiltinType, ProductTypeElement, SumType, SumTypeVariant, IDatabaseTable, AlgebraicValue, ReducerEvent, Identity }} from \"@clockworklabs/spacetimedb-sdk\";").unwrap(); + writeln!(output, "import {{ __SPACETIMEDB__, AlgebraicType, ProductType, BuiltinType, ProductTypeElement, SumType, SumTypeVariant, IDatabaseTable, AlgebraicValue, ReducerEvent, Identity, Address }} from \"@clockworklabs/spacetimedb-sdk\";").unwrap(); let mut imports = Vec::new(); generate_imports(ctx, &product_type.elements, &mut imports, None); @@ -991,6 +995,8 @@ fn autogen_typescript_access_funcs_for_struct( AlgebraicType::Product(product) => { if product.is_identity() { "Identity" + } else if product.is_address() { + "Address" } else { // TODO: We don't allow filtering on tuples right now, its possible we may consider it for the future. continue; @@ -1046,7 +1052,7 @@ fn autogen_typescript_access_funcs_for_struct( writeln!(output, "{{").unwrap(); { indent_scope!(output); - if typescript_field_type == "Identity" { + if typescript_field_type == "Identity" || typescript_field_type == "Address" { writeln!(output, "if (instance.{typescript_field_name_camel}.isEqual(value)) {{",).unwrap(); { indent_scope!(output); @@ -1144,7 +1150,7 @@ pub fn autogen_typescript_reducer(ctx: &GenCtx, reducer: &ReducerDef) -> String writeln!(output).unwrap(); writeln!(output, "// @ts-ignore").unwrap(); - writeln!(output, "import {{ __SPACETIMEDB__, AlgebraicType, ProductType, BuiltinType, ProductTypeElement, IDatabaseTable, AlgebraicValue, ReducerArgsAdapter, SumTypeVariant, Serializer, Identity, ReducerEvent }} from \"@clockworklabs/spacetimedb-sdk\";").unwrap(); + writeln!(output, "import {{ __SPACETIMEDB__, AlgebraicType, ProductType, BuiltinType, ProductTypeElement, IDatabaseTable, AlgebraicValue, ReducerArgsAdapter, SumTypeVariant, Serializer, Identity, Address, ReducerEvent }} from \"@clockworklabs/spacetimedb-sdk\";").unwrap(); let mut imports = Vec::new(); generate_imports( diff --git a/crates/client-api-messages/protobuf/client_api.proto b/crates/client-api-messages/protobuf/client_api.proto index 5014a10575d..de76c13ee6a 100644 --- a/crates/client-api-messages/protobuf/client_api.proto +++ b/crates/client-api-messages/protobuf/client_api.proto @@ -28,7 +28,7 @@ message Message { } } -/// Received by database from client to inform of user's identity and token. +/// Received by database from client to inform of user's identity, token and client address. /// /// The database will always send an `IdentityToken` message /// as the first message for a new WebSocket connection. @@ -39,6 +39,7 @@ message Message { message IdentityToken { bytes identity = 1; string token = 2; + bytes address = 3; } // TODO: Evaluate if it makes sense for this to also include the @@ -110,6 +111,15 @@ message Subscribe { /// - `energy_quanta_used` and `host_execution_duration_micros` seem self-explanatory; /// they describe the amount of energy credits consumed by running the reducer, /// and how long it took to run. +/// +/// - `callerAddress` is the 16-byte address of the user who requested the reducer run. +/// The all-zeros address is a sentinel which denotes no address. +/// `init` and `update` reducers will have a `callerAddress` +/// if and only if one was provided to the `publish` HTTP endpoint. +/// Scheduled reducers will never have a `callerAddress`. +/// Reducers invoked by HTTP will have a `callerAddress` +/// if and only if one was provided to the `call` HTTP endpoint. +/// Reducers invoked by WebSocket will always have a `callerAddress`. message Event { enum Status { committed = 0; @@ -128,6 +138,8 @@ message Event { int64 energy_quanta_used = 6; uint64 host_execution_duration_micros = 7; + + bytes callerAddress = 8; } // TODO: Maybe call this StateUpdate if it's implied to be a subscription update diff --git a/crates/client-api/src/lib.rs b/crates/client-api/src/lib.rs index 8afd1dae11f..8e41ab9537b 100644 --- a/crates/client-api/src/lib.rs +++ b/crates/client-api/src/lib.rs @@ -138,6 +138,7 @@ pub trait ControlStateWriteAccess: Send + Sync { async fn publish_database( &self, identity: &Identity, + publisher_address: Option
, spec: DatabaseDef, ) -> spacetimedb::control_db::Result>; @@ -240,9 +241,10 @@ impl ControlStateWriteAccess for ArcEnv async fn publish_database( &self, identity: &Identity, + publisher_address: Option
, spec: DatabaseDef, ) -> spacetimedb::control_db::Result> { - self.0.publish_database(identity, spec).await + self.0.publish_database(identity, publisher_address, spec).await } async fn delete_database(&self, identity: &Identity, address: &Address) -> spacetimedb::control_db::Result<()> { @@ -392,9 +394,10 @@ impl ControlStateWriteAccess for Arc { async fn publish_database( &self, identity: &Identity, + publisher_address: Option
, spec: DatabaseDef, ) -> spacetimedb::control_db::Result> { - (**self).publish_database(identity, spec).await + (**self).publish_database(identity, publisher_address, spec).await } async fn delete_database(&self, identity: &Identity, address: &Address) -> spacetimedb::control_db::Result<()> { diff --git a/crates/client-api/src/routes/database.rs b/crates/client-api/src/routes/database.rs index dc788a047ff..0bc7a7f0e9b 100644 --- a/crates/client-api/src/routes/database.rs +++ b/crates/client-api/src/routes/database.rs @@ -21,6 +21,7 @@ use spacetimedb::identity::Identity; use spacetimedb::json::client_api::StmtResultJson; use spacetimedb::messages::control_db::{Database, DatabaseInstance, HostType}; use spacetimedb::sql::execute::execute; +use spacetimedb_lib::address::AddressForUrl; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::name::{self, DnsLookupResponse, DomainName, DomainParsingError, PublishOp, PublishResult}; use spacetimedb_lib::recovery::{RecoveryCode, RecoveryCodeResponse}; @@ -33,6 +34,7 @@ use crate::auth::{ SpacetimeAuth, SpacetimeAuthHeader, SpacetimeEnergyUsed, SpacetimeExecutionDurationMicros, SpacetimeIdentity, SpacetimeIdentityToken, }; +use crate::routes::subscribe::generate_random_address; use crate::util::{ByteStringBody, NameOrAddress}; use crate::{log_and_500, ControlStateDelegate, DatabaseDef, NodeDelegate}; @@ -51,6 +53,11 @@ pub struct CallParams { reducer: String, } +#[derive(Deserialize)] +pub struct CallQueryParams { + client_address: Option, +} + pub async fn call( State(worker_ctx): State, auth: SpacetimeAuthHeader, @@ -58,6 +65,7 @@ pub async fn call( name_or_address, reducer, }): Path, + Query(CallQueryParams { client_address }): Query, ByteStringBody(body): ByteStringBody, ) -> axum::response::Result { let SpacetimeAuth { @@ -93,10 +101,22 @@ pub async fn call( } }; - if let Err(e) = module.call_identity_connected_disconnected(caller_identity, true).await { + // HTTP callers always need an address to provide to connect/disconnect, + // so generate one if none was provided. + let client_address = client_address + .map(Address::from) + .unwrap_or_else(generate_random_address); + + if let Err(e) = module + .call_identity_connected_disconnected(caller_identity, client_address, true) + .await + { return Err((StatusCode::NOT_FOUND, format!("{:#}", anyhow::anyhow!(e))).into()); } - let result = match module.call_reducer(caller_identity, None, &reducer, args).await { + let result = match module + .call_reducer(caller_identity, Some(client_address), None, &reducer, args) + .await + { Ok(rcr) => Ok(rcr), Err(e) => { let status_code = match e { @@ -117,7 +137,7 @@ pub async fn call( }; if let Err(e) = module - .call_identity_connected_disconnected(caller_identity, false) + .call_identity_connected_disconnected(caller_identity, client_address, false) .await { return Err((StatusCode::NOT_FOUND, format!("{:#}", anyhow::anyhow!(e))).into()); @@ -579,7 +599,7 @@ pub struct DNSParams { #[derive(Deserialize)] pub struct ReverseDNSParams { - database_address: Address, + database_address: AddressForUrl, } #[derive(Deserialize)] @@ -608,6 +628,8 @@ pub async fn reverse_dns( State(ctx): State, Path(ReverseDNSParams { database_address }): Path, ) -> axum::response::Result { + let database_address = Address::from(database_address); + let names = ctx.reverse_lookup(&database_address).map_err(log_and_500)?; let response = name::ReverseDNSResponse { names }; @@ -748,6 +770,7 @@ pub struct PublishDatabaseQueryParams { #[serde(default)] clear: bool, name_or_address: Option, + client_address: Option, } pub async fn publish( @@ -757,7 +780,13 @@ pub async fn publish( auth: SpacetimeAuthHeader, body: Bytes, ) -> axum::response::Result> { - let PublishDatabaseQueryParams { name_or_address, clear } = query_params; + let PublishDatabaseQueryParams { + name_or_address, + clear, + client_address, + } = query_params; + + let client_address = client_address.map(Address::from); // You should not be able to publish to a database that you do not own // so, unless you are the owner, this will fail. @@ -803,6 +832,7 @@ pub async fn publish( let maybe_updated = ctx .publish_database( &auth.identity, + client_address, DatabaseDef { address: db_addr, program_bytes: body.into(), @@ -841,7 +871,7 @@ pub async fn publish( #[derive(Deserialize)] pub struct DeleteDatabaseParams { - address: Address, + address: AddressForUrl, } pub async fn delete_database( @@ -851,6 +881,8 @@ pub async fn delete_database( ) -> axum::response::Result { let auth = auth_or_unauth(auth)?; + let address = Address::from(address); + ctx.delete_database(&auth.identity, &address) .await .map_err(log_and_500)?; @@ -861,7 +893,7 @@ pub async fn delete_database( #[derive(Deserialize)] pub struct SetNameQueryParams { domain: String, - address: Address, + address: AddressForUrl, } pub async fn set_name( @@ -871,6 +903,8 @@ pub async fn set_name( ) -> axum::response::Result { let auth = auth_or_unauth(auth)?; + let address = Address::from(address); + let database = ctx .get_database_by_address(&address) .map_err(log_and_500)? diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index 9cabbce3670..cd76847fbd2 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -2,7 +2,7 @@ use std::mem; use std::pin::pin; use std::time::Duration; -use axum::extract::{Path, State}; +use axum::extract::{Path, Query, State}; use axum::response::IntoResponse; use axum::TypedHeader; use futures::{SinkExt, StreamExt}; @@ -12,6 +12,8 @@ use spacetimedb::client::messages::{IdentityTokenMessage, ServerMessage}; use spacetimedb::client::{ClientActorId, ClientClosed, ClientConnection, DataMessage, MessageHandleError, Protocol}; use spacetimedb::host::NoSuchModule; use spacetimedb::util::future_queue; +use spacetimedb_lib::address::AddressForUrl; +use spacetimedb_lib::Address; use tokio::sync::mpsc; use crate::auth::{SpacetimeAuthHeader, SpacetimeIdentity, SpacetimeIdentityToken}; @@ -31,9 +33,23 @@ pub struct SubscribeParams { pub name_or_address: NameOrAddress, } +#[derive(Deserialize)] +pub struct SubscribeQueryParams { + pub client_address: Option, +} + +// TODO: is this a reasonable way to generate client addresses? +// For DB addresses, [`ControlDb::alloc_spacetime_address`] +// maintains a global counter, and hashes the next value from that counter +// with some constant salt. +pub fn generate_random_address() -> Address { + Address::from_arr(&rand::random()) +} + pub async fn handle_websocket( State(ctx): State, Path(SubscribeParams { name_or_address }): Path, + Query(SubscribeQueryParams { client_address }): Query, forwarded_for: Option>, auth: SpacetimeAuthHeader, ws: WebSocketUpgrade, @@ -43,7 +59,18 @@ where { let auth = auth.get_or_create(&ctx).await?; - let address = name_or_address.resolve(&ctx).await?.into(); + let client_address = client_address + .map(Address::from) + .unwrap_or_else(generate_random_address); + + if client_address == Address::__dummy() { + Err(( + StatusCode::BAD_REQUEST, + "Invalid client address: the all-zeros Address is reserved.", + ))?; + } + + let db_address = name_or_address.resolve(&ctx).await?.into(); let (res, ws_upgrade, protocol) = ws.select_protocol([(BIN_PROTOCOL, Protocol::Binary), (TEXT_PROTOCOL, Protocol::Text)]); @@ -52,8 +79,9 @@ where // TODO: Should also maybe refactor the code and the protocol to allow a single websocket // to connect to multiple modules + let database = ctx - .get_database_by_address(&address) + .get_database_by_address(&db_address) .unwrap() .ok_or(StatusCode::BAD_REQUEST)?; let database_instance = ctx @@ -79,6 +107,7 @@ where let client_id = ClientActorId { identity: auth.identity, + address: client_address, name: ctx.client_actor_index().next_client_name(), }; @@ -123,6 +152,7 @@ where let message = IdentityTokenMessage { identity: auth.identity, identity_token, + address: client_address, }; if let Err(ClientClosed) = client.send_message(message).await { log::warn!("client closed before identity token was sent") @@ -250,7 +280,7 @@ async fn ws_client_actor(client: ClientConnection, mut ws: WebSocketStream, mut let _ = client.module.subscription().remove_subscriber(client.id); let _ = client .module - .call_identity_connected_disconnected(client.id.identity, false) + .call_identity_connected_disconnected(client.id.identity, client.id.address, false) .await; } diff --git a/crates/client-api/src/util.rs b/crates/client-api/src/util.rs index a11d1e37b27..57570687ed1 100644 --- a/crates/client-api/src/util.rs +++ b/crates/client-api/src/util.rs @@ -12,6 +12,7 @@ use bytestring::ByteString; use http::{HeaderName, HeaderValue, Request, StatusCode}; use spacetimedb::address::Address; +use spacetimedb_lib::address::AddressForUrl; use spacetimedb_lib::name::DomainName; use crate::routes::database::DomainParsingRejection; @@ -65,14 +66,14 @@ impl headers::Header for XForwardedFor { #[derive(Clone, Debug)] pub enum NameOrAddress { - Address(Address), + Address(AddressForUrl), Name(String), } impl NameOrAddress { pub fn into_string(self) -> String { match self { - NameOrAddress::Address(addr) => addr.to_hex(), + NameOrAddress::Address(addr) => Address::from(addr).to_hex(), NameOrAddress::Name(name) => name, } } @@ -98,7 +99,7 @@ impl NameOrAddress { ) -> axum::response::Result> { Ok(match self { Self::Address(addr) => Ok(ResolvedAddress { - address: *addr, + address: Address::from(*addr), domain: None, }), Self::Name(name) => { @@ -133,7 +134,7 @@ impl<'de> serde::Deserialize<'de> for NameOrAddress { { String::deserialize(deserializer).map(|s| { if let Ok(addr) = Address::from_hex(&s) { - NameOrAddress::Address(addr) + NameOrAddress::Address(AddressForUrl::from(addr)) } else { NameOrAddress::Name(s) } @@ -144,7 +145,7 @@ impl<'de> serde::Deserialize<'de> for NameOrAddress { impl fmt::Display for NameOrAddress { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Address(addr) => f.write_str(&addr.to_hex()), + Self::Address(addr) => f.write_str(&Address::from(*addr).to_hex()), Self::Name(name) => f.write_str(name), } } diff --git a/crates/core/src/client.rs b/crates/core/src/client.rs index da5b62ba122..252f296d4b9 100644 --- a/crates/core/src/client.rs +++ b/crates/core/src/client.rs @@ -9,10 +9,12 @@ pub mod messages; pub use client_connection::{ClientClosed, ClientConnection, ClientConnectionSender, DataMessage, Protocol}; pub use client_connection_index::ClientActorIndex; pub use message_handlers::MessageHandleError; +use spacetimedb_lib::Address; #[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)] pub struct ClientActorId { pub identity: Identity, + pub address: Address, pub name: ClientName, } @@ -21,6 +23,12 @@ pub struct ClientName(pub u64); impl fmt::Display for ClientActorId { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "ClientActorId({}/{})", self.identity.to_hex(), self.name.0) + write!( + f, + "ClientActorId({}@{}/{})", + self.identity.to_hex(), + self.address.to_hex(), + self.name.0 + ) } } diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index cfc4795799b..82744e5e62d 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -106,7 +106,9 @@ impl ClientConnection { // TODO: Right now this is connecting clients directly to an instance, but their requests should be // logically subscribed to the database, not any particular instance. We should handle failover for // them and stuff. Not right now though. - module.call_identity_connected_disconnected(id.identity, true).await?; + module + .call_identity_connected_disconnected(id.identity, id.address, true) + .await?; // Buffer up to 64 client messages let (sendtx, sendrx) = mpsc::channel::(64); @@ -150,7 +152,13 @@ impl ClientConnection { pub async fn call_reducer(&self, reducer: &str, args: ReducerArgs) -> Result { self.module - .call_reducer(self.id.identity, Some(self.sender()), reducer, args) + .call_reducer( + self.id.identity, + Some(self.id.address), + Some(self.sender()), + reducer, + args, + ) .await } diff --git a/crates/core/src/client/message_handlers.rs b/crates/core/src/client/message_handlers.rs index 84ee825c769..edc394f7999 100644 --- a/crates/core/src/client/message_handlers.rs +++ b/crates/core/src/client/message_handlers.rs @@ -9,6 +9,7 @@ use base64::Engine; use bytes::Bytes; use bytestring::ByteString; use prost::Message as _; +use spacetimedb_lib::Address; use super::messages::{ServerMessage, TransactionUpdateMessage}; use super::{ClientConnection, DataMessage}; @@ -147,6 +148,7 @@ impl DecodedMessage<'_> { res.map_err(|(reducer, err)| MessageExecutionError { reducer: reducer.map(str::to_owned), caller_identity: client.id.identity, + caller_address: Some(client.id.address), err, }) } @@ -158,6 +160,7 @@ impl DecodedMessage<'_> { pub struct MessageExecutionError { pub reducer: Option, pub caller_identity: Identity, + pub caller_address: Option
, #[source] pub err: anyhow::Error, } @@ -167,6 +170,7 @@ impl MessageExecutionError { ModuleEvent { timestamp: Timestamp::now(), caller_identity: self.caller_identity, + caller_address: self.caller_address, function_call: ModuleFunctionCall { reducer: self.reducer.unwrap_or_else(|| "".to_owned()), args: Default::default(), diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index f4ba25e4d7f..f09bfdd019d 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -1,7 +1,7 @@ use base64::Engine; use prost::Message as _; use spacetimedb_client_api_messages::client_api::{OneOffQueryResponse, OneOffTable}; -use spacetimedb_lib::relation::MemTable; +use spacetimedb_lib::{relation::MemTable, Address}; use crate::host::module_host::{DatabaseUpdate, EventStatus, ModuleEvent}; use crate::identity::Identity; @@ -29,6 +29,7 @@ pub trait ServerMessage: Sized { pub struct IdentityTokenMessage { pub identity: Identity, pub identity_token: String, + pub address: Address, } impl ServerMessage for IdentityTokenMessage { @@ -36,6 +37,7 @@ impl ServerMessage for IdentityTokenMessage { MessageJson::IdentityToken(IdentityTokenJson { identity: self.identity.to_hex(), token: self.identity_token, + address: self.address.to_hex(), }) } fn serialize_binary(self) -> Message { @@ -43,6 +45,7 @@ impl ServerMessage for IdentityTokenMessage { r#type: Some(message::Type::IdentityToken(IdentityToken { identity: self.identity.as_bytes().to_vec(), token: self.identity_token, + address: self.address.as_slice().to_vec(), })), } } @@ -72,6 +75,7 @@ impl ServerMessage for TransactionUpdateMessage<'_> { }, energy_quanta_used: event.energy_quanta_used.0, message: errmsg, + caller_address: event.caller_address.unwrap_or(Address::ZERO).to_hex(), }; let subscription_update = database_update.into_json(); @@ -100,6 +104,7 @@ impl ServerMessage for TransactionUpdateMessage<'_> { message: errmsg, energy_quanta_used: event.energy_quanta_used.0 as i64, host_execution_duration_micros: event.host_execution_duration.as_micros() as u64, + caller_address: event.caller_address.unwrap_or(Address::zero()).as_slice().to_vec(), }; let subscription_update = database_update.into_protobuf(); diff --git a/crates/core/src/database_instance_context.rs b/crates/core/src/database_instance_context.rs index e424a9c74a4..4e296016c56 100644 --- a/crates/core/src/database_instance_context.rs +++ b/crates/core/src/database_instance_context.rs @@ -19,6 +19,7 @@ pub struct DatabaseInstanceContext { pub address: Address, pub logger: Arc>, pub relational_db: Arc, + pub publisher_address: Option
, } impl DatabaseInstanceContext { @@ -37,6 +38,7 @@ impl DatabaseInstanceContext { database.address, db_path, &log_path, + database.publisher_address, ) } @@ -56,6 +58,7 @@ impl DatabaseInstanceContext { address: Address, db_path: PathBuf, log_path: &Path, + publisher_address: Option
, ) -> Arc { let message_log = match config.storage { Storage::Memory => None, @@ -83,6 +86,7 @@ impl DatabaseInstanceContext { relational_db: Arc::new( RelationalDB::open(db_path, message_log, odb, address, config.fsync != FsyncPolicy::Never).unwrap(), ), + publisher_address, }) } diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index f910556d385..33e1318bd25 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -16,7 +16,7 @@ use base64::{engine::general_purpose::STANDARD as BASE_64_STD, Engine as _}; use futures::{Future, FutureExt}; use indexmap::IndexMap; use spacetimedb_lib::relation::MemTable; -use spacetimedb_lib::{ReducerDef, TableDef}; +use spacetimedb_lib::{Address, ReducerDef, TableDef}; use spacetimedb_sats::{ProductValue, Typespace, WithTypespace}; use std::collections::HashMap; use std::fmt; @@ -186,6 +186,7 @@ pub struct ModuleFunctionCall { pub struct ModuleEvent { pub timestamp: Timestamp, pub caller_identity: Identity, + pub caller_address: Option
, pub function_call: ModuleFunctionCall, pub status: EventStatus, pub energy_quanta_used: EnergyDiff, @@ -195,6 +196,7 @@ pub struct ModuleEvent { #[derive(Debug)] pub struct ModuleInfo { pub identity: Identity, + pub address: Address, pub module_hash: Hash, pub typespace: Typespace, pub reducers: IndexMap, @@ -239,12 +241,13 @@ pub trait ModuleInstance: Send + 'static { fn call_reducer( &mut self, caller_identity: Identity, + caller_address: Option
, client: Option, reducer_id: usize, args: ArgsTuple, ) -> ReducerCallResult; - fn call_connect_disconnect(&mut self, identity: Identity, connected: bool); + fn call_connect_disconnect(&mut self, identity: Identity, caller_address: Address, connected: bool); } // TODO: figure out how we want to handle traps. maybe it should just not return to the LendingPool and @@ -279,16 +282,19 @@ impl ModuleInstance for AutoReplacingModuleInstance { fn call_reducer( &mut self, caller_identity: Identity, + caller_address: Option
, client: Option, reducer_id: usize, args: ArgsTuple, ) -> ReducerCallResult { - let ret = self.inst.call_reducer(caller_identity, client, reducer_id, args); + let ret = self + .inst + .call_reducer(caller_identity, caller_address, client, reducer_id, args); self.check_trap(); ret } - fn call_connect_disconnect(&mut self, identity: Identity, connected: bool) { - self.inst.call_connect_disconnect(identity, connected); + fn call_connect_disconnect(&mut self, identity: Identity, caller_address: Address, connected: bool) { + self.inst.call_connect_disconnect(identity, caller_address, connected); self.check_trap(); } } @@ -503,15 +509,17 @@ impl ModuleHost { pub async fn call_identity_connected_disconnected( &self, caller_identity: Identity, + caller_address: Address, connected: bool, ) -> Result<(), NoSuchModule> { - self.call(move |inst| inst.call_connect_disconnect(caller_identity, connected)) + self.call(move |inst| inst.call_connect_disconnect(caller_identity, caller_address, connected)) .await } async fn call_reducer_inner( &self, caller_identity: Identity, + caller_address: Option
, client: Option, reducer_name: &str, args: ReducerArgs, @@ -524,7 +532,7 @@ impl ModuleHost { let args = args.into_tuple(self.info.typespace.with_type(schema))?; - self.call(move |inst| inst.call_reducer(caller_identity, client, reducer_id, args)) + self.call(move |inst| inst.call_reducer(caller_identity, caller_address, client, reducer_id, args)) .await .map_err(Into::into) } @@ -532,12 +540,13 @@ impl ModuleHost { pub async fn call_reducer( &self, caller_identity: Identity, + caller_address: Option
, client: Option, reducer_name: &str, args: ReducerArgs, ) -> Result { let res = self - .call_reducer_inner(caller_identity, client, reducer_name, args) + .call_reducer_inner(caller_identity, caller_address, client, reducer_name, args) .await; let log_message = match &res { diff --git a/crates/core/src/host/scheduler.rs b/crates/core/src/host/scheduler.rs index bd40200a4f6..cb08a48b021 100644 --- a/crates/core/src/host/scheduler.rs +++ b/crates/core/src/host/scheduler.rs @@ -268,12 +268,15 @@ impl SchedulerActor { let scheduled: ScheduledReducer = bsatn::from_slice(&scheduled).unwrap(); let db = self.db.clone(); tokio::spawn(async move { - let identity = module_host.info().identity; + let info = module_host.info(); + let identity = info.identity; // TODO: pass a logical "now" timestamp to this reducer call, but there's some // intricacies to get right (how much drift to tolerate? what kind of tokio::time::MissedTickBehavior do we want?) let res = module_host .call_reducer( identity, + // Scheduled reducers take `None` as the caller address. + None, None, &scheduled.reducer, ReducerArgs::Bsatn(scheduled.bsatn_args.into()), diff --git a/crates/core/src/host/wasm_common.rs b/crates/core/src/host/wasm_common.rs index f6c71327fa7..ac2df193403 100644 --- a/crates/core/src/host/wasm_common.rs +++ b/crates/core/src/host/wasm_common.rs @@ -111,10 +111,25 @@ const PREINIT_SIG: StaticFuncSig = FuncSig::new(&[], &[]); const INIT_SIG: StaticFuncSig = FuncSig::new(&[], &[WasmType::I32]); const DESCRIBE_MODULE_SIG: StaticFuncSig = FuncSig::new(&[], &[WasmType::I32]); const CALL_REDUCER_SIG: StaticFuncSig = FuncSig::new( - &[WasmType::I32, WasmType::I32, WasmType::I64, WasmType::I32], + &[ + WasmType::I32, // Reducer ID + WasmType::I32, // Sender `Identity` buffer + WasmType::I32, // Sender `Address` buffer + WasmType::I64, // Timestamp + WasmType::I32, // Args buffer + ], + &[ + WasmType::I32, // Result buffer + ], +); +const CONN_DISCONN_SIG: StaticFuncSig = FuncSig::new( + &[ + WasmType::I32, // Sender `Identity` buffer + WasmType::I32, // Sender `Address` buffer + WasmType::I64, // Timestamp + ], &[WasmType::I32], ); -const CONN_DISCONN_SIG: StaticFuncSig = FuncSig::new(&[WasmType::I32, WasmType::I64], &[WasmType::I32]); #[derive(thiserror::Error, Debug)] pub enum ValidationError { diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index 446969c51de..7d6963c8520 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -11,7 +11,7 @@ use bytes::Bytes; use nonempty::NonEmpty; use spacetimedb_lib::buffer::DecodeError; use spacetimedb_lib::identity::AuthCtx; -use spacetimedb_lib::{bsatn, IndexType, ModuleDef}; +use spacetimedb_lib::{bsatn, Address, IndexType, ModuleDef}; use spacetimedb_vm::expr::CrudExpr; use crate::client::ClientConnectionSender; @@ -60,7 +60,8 @@ pub trait WasmInstance: Send + Sync + 'static { &mut self, reducer_id: usize, budget: EnergyQuanta, - sender: &[u8; 32], + sender_identity: &[u8; 32], + sender_address: &[u8; 16], timestamp: Timestamp, arg_bytes: Bytes, ) -> ExecuteResult; @@ -69,7 +70,8 @@ pub trait WasmInstance: Send + Sync + 'static { &mut self, connect: bool, budget: EnergyQuanta, - sender: &[u8; 32], + sender_identity: &[u8; 32], + sender_address: &[u8; 16], timestamp: Timestamp, ) -> ExecuteResult; @@ -174,6 +176,7 @@ impl WasmModuleHostActor { let info = Arc::new(ModuleInfo { identity: database_instance_context.identity, + address: database_instance_context.address, module_hash, typespace, reducers, @@ -392,8 +395,12 @@ impl ModuleInstance for WasmModuleInstance { Some(reducer_id) => { self.system_logger().info("Invoking `init` reducer"); let caller_identity = self.database_instance_context().identity; + // If a caller address was passed to the `/database/publish` HTTP endpoint, + // the init/update reducer will receive it as the caller address. + // This is useful for bootstrapping the control DB in SpacetimeDB-cloud. + let caller_address = self.database_instance_context().publisher_address; let client = None; - self.call_reducer_internal(Some(tx), caller_identity, client, reducer_id, args) + self.call_reducer_internal(Some(tx), caller_identity, caller_address, client, reducer_id, args) } }; @@ -457,9 +464,19 @@ impl ModuleInstance for WasmModuleInstance { Some(reducer_id) => { self.system_logger().info("Invoking `update` reducer"); let caller_identity = self.database_instance_context().identity; + // If a caller address was passed to the `/database/publish` HTTP endpoint, + // the init/update reducer will receive it as the caller address. + // This is useful for bootstrapping the control DB in SpacetimeDB-cloud. + let caller_address = self.database_instance_context().publisher_address; let client = None; - let res = - self.call_reducer_internal(Some(tx), caller_identity, client, reducer_id, ArgsTuple::default()); + let res = self.call_reducer_internal( + Some(tx), + caller_identity, + caller_address, + client, + reducer_id, + ArgsTuple::default(), + ); Some(res) } }; @@ -476,15 +493,16 @@ impl ModuleInstance for WasmModuleInstance { fn call_reducer( &mut self, caller_identity: Identity, + caller_address: Option
, client: Option, reducer_id: usize, args: ArgsTuple, ) -> ReducerCallResult { - self.call_reducer_internal(None, caller_identity, client, reducer_id, args) + self.call_reducer_internal(None, caller_identity, caller_address, client, reducer_id, args) } #[tracing::instrument(skip_all)] - fn call_connect_disconnect(&mut self, identity: Identity, connected: bool) { + fn call_connect_disconnect(&mut self, caller_identity: Identity, caller_address: Address, connected: bool) { let has_function = if connected { self.func_names.conn } else { @@ -502,7 +520,8 @@ impl ModuleInstance for WasmModuleInstance { None, InstanceOp::ConnDisconn { conn: connected, - sender: &identity, + sender_identity: &caller_identity, + sender_address: &caller_address, timestamp, }, ); @@ -524,7 +543,10 @@ impl ModuleInstance for WasmModuleInstance { args: ArgsTuple::default(), }, status, - caller_identity: identity, + caller_identity, + // Conn/disconn always get a caller address, + // as WebSockets always have an address. + caller_address: Some(caller_address), energy_quanta_used: energy.used, host_execution_duration: start_instant.elapsed(), }; @@ -552,6 +574,7 @@ impl WasmModuleInstance { &mut self, tx: Option, caller_identity: Identity, + caller_address: Option
, client: Option, reducer_id: usize, mut args: ArgsTuple, @@ -568,7 +591,8 @@ impl WasmModuleInstance { tx, InstanceOp::Reducer { id: reducer_id, - sender: &caller_identity, + sender_identity: &caller_identity, + sender_address: &caller_address.unwrap_or(Address::__dummy()), timestamp, arg_bytes: args.get_bsatn().clone(), }, @@ -582,6 +606,7 @@ impl WasmModuleInstance { let event = ModuleEvent { timestamp, caller_identity, + caller_address, function_call: ModuleFunctionCall { reducer: reducerdef.name.clone(), args, @@ -632,7 +657,9 @@ impl WasmModuleInstance { module_hash: self.info.module_hash, module_identity: self.info.identity, caller_identity: match op { - InstanceOp::Reducer { sender, .. } | InstanceOp::ConnDisconn { sender, .. } => *sender, + InstanceOp::Reducer { sender_identity, .. } | InstanceOp::ConnDisconn { sender_identity, .. } => { + *sender_identity + } }, reducer_name: func_ident, }; @@ -645,19 +672,30 @@ impl WasmModuleInstance { let (tx, result) = tx_slot.set(tx, || match op { InstanceOp::Reducer { id, - sender, + sender_identity, + sender_address, timestamp, arg_bytes, - } => self - .instance - .call_reducer(id, budget, sender.as_bytes(), timestamp, arg_bytes), + } => self.instance.call_reducer( + id, + budget, + sender_identity.as_bytes(), + &sender_address.as_slice(), + timestamp, + arg_bytes, + ), InstanceOp::ConnDisconn { conn, - sender, + sender_identity, + sender_address, + timestamp, + } => self.instance.call_connect_disconnect( + conn, + budget, + sender_identity.as_bytes(), + &sender_address.as_slice(), timestamp, - } => self - .instance - .call_connect_disconnect(conn, budget, sender.as_bytes(), timestamp), + ), }); let ExecuteResult { @@ -938,13 +976,15 @@ struct SchemaUpdates { enum InstanceOp<'a> { Reducer { id: usize, - sender: &'a Identity, + sender_identity: &'a Identity, + sender_address: &'a Address, timestamp: Timestamp, arg_bytes: Bytes, }, ConnDisconn { conn: bool, - sender: &'a Identity, + sender_identity: &'a Identity, + sender_address: &'a Address, timestamp: Timestamp, }, } diff --git a/crates/core/src/host/wasmer/wasmer_module.rs b/crates/core/src/host/wasmer/wasmer_module.rs index 34b453bdd6b..898ddb79a3f 100644 --- a/crates/core/src/host/wasmer/wasmer_module.rs +++ b/crates/core/src/host/wasmer/wasmer_module.rs @@ -265,15 +265,29 @@ impl module_host_actor::WasmInstance for WasmerInstance { &mut self, reducer_id: usize, budget: EnergyQuanta, - sender: &[u8; 32], + sender_identity: &[u8; 32], + sender_address: &[u8; 16], timestamp: Timestamp, arg_bytes: Bytes, ) -> module_host_actor::ExecuteResult { - self.call_tx_function::<(u32, u32, u64, u32), 2>( + self.call_tx_function::<(u32, u32, u32, u64, u32), 3>( CALL_REDUCER_DUNDER, budget, - [sender.to_vec().into(), arg_bytes], - |func, store, [sender, args]| func.call(store, reducer_id as u32, sender.0, timestamp.0, args.0), + [ + sender_identity.to_vec().into(), + sender_address.to_vec().into(), + arg_bytes, + ], + |func, store, [sender_identity, sender_address, args]| { + func.call( + store, + reducer_id as u32, + sender_identity.0, + sender_address.0, + timestamp.0, + args.0, + ) + }, ) } @@ -281,18 +295,21 @@ impl module_host_actor::WasmInstance for WasmerInstance { &mut self, connect: bool, budget: EnergyQuanta, - sender: &[u8; 32], + sender_identity: &[u8; 32], + sender_address: &[u8; 16], timestamp: Timestamp, ) -> module_host_actor::ExecuteResult { - self.call_tx_function::<(u32, u64), 1>( + self.call_tx_function::<(u32, u32, u64), 2>( if connect { IDENTITY_CONNECTED_DUNDER } else { IDENTITY_DISCONNECTED_DUNDER }, budget, - [sender.to_vec().into()], - |func, store, [sender]| func.call(store, sender.0, timestamp.0), + [sender_identity.to_vec().into(), sender_address.to_vec().into()], + |func, store, [sender_identity, sender_address]| { + func.call(store, sender_identity.0, sender_address.0, timestamp.0) + }, ) } diff --git a/crates/core/src/json/client_api.rs b/crates/core/src/json/client_api.rs index 9371ec0334b..4533d94bb3d 100644 --- a/crates/core/src/json/client_api.rs +++ b/crates/core/src/json/client_api.rs @@ -45,8 +45,9 @@ impl MessageJson { #[derive(Debug, Clone, Serialize)] pub struct IdentityTokenJson { - pub identity: String, + pub identity: String, // in hex pub token: String, + pub address: String, // in hex } #[derive(Debug, Clone, Serialize)] @@ -84,6 +85,7 @@ pub struct EventJson { pub function_call: FunctionCallJson, pub energy_quanta_used: i128, pub message: String, + pub caller_address: String, // hex address } #[derive(Debug, Clone, Serialize)] diff --git a/crates/core/src/messages/control_db.rs b/crates/core/src/messages/control_db.rs index 9821ea763bd..592a40ec448 100644 --- a/crates/core/src/messages/control_db.rs +++ b/crates/core/src/messages/control_db.rs @@ -28,6 +28,7 @@ pub struct Database { pub host_type: HostType, pub num_replicas: u32, pub program_bytes_address: Hash, + pub publisher_address: Option
, } #[derive(Clone, PartialEq, Serialize, Deserialize)] diff --git a/crates/lib/src/address.rs b/crates/lib/src/address.rs index a67aa97f06b..b0007448466 100644 --- a/crates/lib/src/address.rs +++ b/crates/lib/src/address.rs @@ -1,44 +1,66 @@ -use std::{fmt::Display, net::Ipv6Addr}; - use anyhow::Context as _; use hex::FromHex as _; -use sats::{impl_deserialize, impl_serialize, impl_st}; +use sats::{impl_deserialize, impl_serialize, impl_st, AlgebraicType, ProductTypeElement}; +use spacetimedb_bindings_macro::{Deserialize, Serialize}; +use std::{fmt, net::Ipv6Addr}; use crate::sats; -/// This is the address for a SpacetimeDB database. It is a unique identifier -/// for a particular database and once set for a database, does not change. +/// This is the address for a SpacetimeDB database or client connection. +/// +/// It is a unique identifier for a particular database and once set for a database, +/// does not change. /// -/// TODO: Evaluate other possible names: `DatabaseAddress`, `SPAddress` -/// TODO: Evaluate replacing this with a literal Ipv6Address which is assigned -/// permanently to a database. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct Address(u128); - -impl Display for Address { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +// TODO: Evaluate other possible names: `DatabaseAddress`, `SPAddress` +// TODO: Evaluate replacing this with a literal Ipv6Address +// which is assigned permanently to a database. +// This is likely +#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +pub struct Address { + __address_bytes: [u8; 16], +} + +impl_st!([] Address, _ts => AlgebraicType::product(vec![ + ProductTypeElement::new_named(AlgebraicType::bytes(), "__address_bytes") +])); + +impl fmt::Display for Address { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.to_hex()) } } +impl fmt::Debug for Address { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Address").field(&format_args!("{self}")).finish() + } +} + impl Address { const ABBREVIATION_LEN: usize = 16; - pub const ZERO: Self = Self(0); + pub const ZERO: Self = Self { + __address_bytes: [0; 16], + }; pub fn from_arr(arr: &[u8; 16]) -> Self { - Self(u128::from_be_bytes(*arr)) + Self { __address_bytes: *arr } } pub fn zero() -> Self { - Self(0) + Self { + __address_bytes: [0; 16], + } + } + + pub fn from_u128(u: u128) -> Self { + Self::from_arr(&u.to_be_bytes()) } pub fn from_hex(hex: &str) -> Result { <[u8; 16]>::from_hex(hex) .context("Addresses must be 32 hex characters (16 bytes) in length.") - .map(u128::from_be_bytes) - .map(Self) + .map(|arr| Self::from_arr(&arr)) } pub fn to_hex(self) -> String { @@ -53,15 +75,15 @@ impl Address { let slice = slice.as_ref(); let mut dst = [0u8; 16]; dst.copy_from_slice(slice); - Self(u128::from_be_bytes(dst)) + Self::from_arr(&dst) } pub fn as_slice(&self) -> [u8; 16] { - self.0.to_be_bytes() + self.__address_bytes } pub fn to_ipv6(self) -> Ipv6Addr { - Ipv6Addr::from(self.0) + Ipv6Addr::from(self.__address_bytes) } #[allow(dead_code)] @@ -69,50 +91,69 @@ impl Address { self.to_ipv6().to_string() } + #[doc(hidden)] + pub fn __dummy() -> Self { + Self::zero() + } + pub fn to_u128(&self) -> u128 { - self.0 + u128::from_be_bytes(self.__address_bytes) } } impl From for Address { fn from(value: u128) -> Self { - Self(value) + Self::from_u128(value) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct AddressForUrl(u128); + +impl From
for AddressForUrl { + fn from(addr: Address) -> Self { + AddressForUrl(u128::from_be_bytes(addr.__address_bytes)) } } -impl_serialize!([] Address, (self, ser) => self.0.to_be_bytes().serialize(ser)); -impl_deserialize!([] Address, de => <[u8; 16]>::deserialize(de).map(|v| Self(u128::from_be_bytes(v)))); +impl From for Address { + fn from(addr: AddressForUrl) -> Self { + Address::from_u128(addr.0) + } +} + +impl_serialize!([] AddressForUrl, (self, ser) => self.0.to_be_bytes().serialize(ser)); +impl_deserialize!([] AddressForUrl, de => <[u8; 16]>::deserialize(de).map(|v| Self(u128::from_be_bytes(v)))); +impl_st!([] AddressForUrl, _ts => AlgebraicType::bytes()); #[cfg(feature = "serde")] -impl serde::Serialize for Address { +impl serde::Serialize for AddressForUrl { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { - self.to_hex().serialize(serializer) + Address::from(*self).to_hex().serialize(serializer) } } #[cfg(feature = "serde")] -impl<'de> serde::Deserialize<'de> for Address { +impl<'de> serde::Deserialize<'de> for AddressForUrl { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let s = String::deserialize(deserializer)?; - Address::from_hex(&s).map_err(serde::de::Error::custom) + Address::from_hex(&s).map_err(serde::de::Error::custom).map(Self::from) } } -impl_st!([] Address, _ts => sats::AlgebraicType::bytes()); - #[cfg(test)] mod tests { use super::*; #[test] fn test_bsatn_roundtrip() { - let addr = Address(rand::random()); + let addr = Address::from_u128(rand::random()); let ser = sats::bsatn::to_vec(&addr).unwrap(); let de = sats::bsatn::from_slice(&ser).unwrap(); assert_eq!(addr, de); @@ -124,10 +165,12 @@ mod tests { #[test] fn test_serde_roundtrip() { - let addr = Address(rand::random()); - let ser = serde_json::to_vec(&addr).unwrap(); - let de = serde_json::from_slice(&ser).unwrap(); - assert_eq!(addr, de); + let addr = Address::from_u128(rand::random()); + let to_url = AddressForUrl::from(addr); + let ser = serde_json::to_vec(&to_url).unwrap(); + let de = serde_json::from_slice::(&ser).unwrap(); + let from_url = Address::from(de); + assert_eq!(addr, from_url); } } } diff --git a/crates/sats/src/product_type.rs b/crates/sats/src/product_type.rs index a63589098bc..0c7698daef5 100644 --- a/crates/sats/src/product_type.rs +++ b/crates/sats/src/product_type.rs @@ -53,6 +53,22 @@ impl ProductType { _ => false, } } + + /// Returns whether this is the special case of `spacetimedb_lib::Address`. + pub fn is_address(&self) -> bool { + match &*self.elements { + [ProductTypeElement { + name: Some(name), + algebraic_type, + }] => name == "__address_bytes" && algebraic_type.is_bytes(), + _ => false, + } + } + + /// Returns whether this is a special known type, currently `Address` or `Identity`. + pub fn is_special(&self) -> bool { + self.is_identity() || self.is_address() + } } impl> FromIterator for ProductType { diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index 86e4d6a0923..335d784ad26 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -23,6 +23,7 @@ im.workspace = true lazy_static.workspace = true log.workspace = true prost.workspace = true +rand.workspace = true tokio.workspace = true tokio-tungstenite.workspace = true diff --git a/crates/sdk/examples/cursive-chat/main.rs b/crates/sdk/examples/cursive-chat/main.rs index 1d9378b0089..d5c2b78d887 100644 --- a/crates/sdk/examples/cursive-chat/main.rs +++ b/crates/sdk/examples/cursive-chat/main.rs @@ -8,6 +8,7 @@ use spacetimedb_sdk::{ reducer::Status, subscribe, table::{TableType, TableWithPrimaryKey}, + Address, }; use cursive::{ @@ -140,7 +141,7 @@ fn register_callbacks(send: UiSend) { // ## Save credentials to a file /// Our `on_connect` callback: save our credentials to a file. -fn on_connected(creds: &Credentials) { +fn on_connected(creds: &Credentials, _address: Address) { if let Err(e) = save_credentials(CREDS_DIR, creds) { eprintln!("Failed to save credentials: {:?}", e); } @@ -257,8 +258,8 @@ fn on_sub_applied(send: UiSend) -> impl FnMut() + Send + 'static { // ## Warn if set_name failed /// Our `on_set_name` callback: print a warning if the reducer failed. -fn on_name_set(send: UiSend) -> impl FnMut(&Identity, &Status, &String) { - move |_sender, status, name| { +fn on_name_set(send: UiSend) -> impl FnMut(&Identity, Option
, &Status, &String) { + move |_sender_id, _sender_addr, status, name| { if let Status::Failed(err) = status { send.unbounded_send(UiMessage::NameRejected { current_name: user_name_or_identity(&User::filter_by_identity(identity().unwrap()).unwrap()), @@ -273,8 +274,8 @@ fn on_name_set(send: UiSend) -> impl FnMut(&Identity, &Status, &String) { // ## Warn if a message was rejected /// Our `on_send_message` callback: print a warning if the reducer failed. -fn on_message_sent(send: UiSend) -> impl FnMut(&Identity, &Status, &String) { - move |_sender, status, text| { +fn on_message_sent(send: UiSend) -> impl FnMut(&Identity, Option
, &Status, &String) { + move |_sender_id, _sender_addr, status, text| { if let Status::Failed(err) = status { send.unbounded_send(UiMessage::MessageRejected { rejected_message: text.clone(), diff --git a/crates/sdk/examples/cursive-chat/module_bindings/message.rs b/crates/sdk/examples/cursive-chat/module_bindings/message.rs index e44493675b7..761877354f3 100644 --- a/crates/sdk/examples/cursive-chat/module_bindings/message.rs +++ b/crates/sdk/examples/cursive-chat/module_bindings/message.rs @@ -9,6 +9,7 @@ use spacetimedb_sdk::{ sats::{de::Deserialize, ser::Serialize}, spacetimedb_lib, table::{TableIter, TableType, TableWithPrimaryKey}, + Address, }; #[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] diff --git a/crates/sdk/examples/cursive-chat/module_bindings/mod.rs b/crates/sdk/examples/cursive-chat/module_bindings/mod.rs index a9780510ccf..1d35ab65282 100644 --- a/crates/sdk/examples/cursive-chat/module_bindings/mod.rs +++ b/crates/sdk/examples/cursive-chat/module_bindings/mod.rs @@ -16,6 +16,7 @@ use spacetimedb_sdk::{ sats::{de::Deserialize, ser::Serialize}, spacetimedb_lib, table::{TableIter, TableType, TableWithPrimaryKey}, + Address, }; use std::sync::Arc; diff --git a/crates/sdk/examples/cursive-chat/module_bindings/send_message_reducer.rs b/crates/sdk/examples/cursive-chat/module_bindings/send_message_reducer.rs index 5cef95354c8..1c879dd39e1 100644 --- a/crates/sdk/examples/cursive-chat/module_bindings/send_message_reducer.rs +++ b/crates/sdk/examples/cursive-chat/module_bindings/send_message_reducer.rs @@ -9,6 +9,7 @@ use spacetimedb_sdk::{ sats::{de::Deserialize, ser::Serialize}, spacetimedb_lib, table::{TableIter, TableType, TableWithPrimaryKey}, + Address, }; #[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] @@ -27,21 +28,21 @@ pub fn send_message(text: String) { #[allow(unused)] pub fn on_send_message( - mut __callback: impl FnMut(&Identity, &Status, &String) + Send + 'static, + mut __callback: impl FnMut(&Identity, Option
, &Status, &String) + Send + 'static, ) -> ReducerCallbackId { - SendMessageArgs::on_reducer(move |__identity, __status, __args| { + SendMessageArgs::on_reducer(move |__identity, __addr, __status, __args| { let SendMessageArgs { text } = __args; - __callback(__identity, __status, text); + __callback(__identity, __addr, __status, text); }) } #[allow(unused)] pub fn once_on_send_message( - __callback: impl FnOnce(&Identity, &Status, &String) + Send + 'static, + __callback: impl FnOnce(&Identity, Option
, &Status, &String) + Send + 'static, ) -> ReducerCallbackId { - SendMessageArgs::once_on_reducer(move |__identity, __status, __args| { + SendMessageArgs::once_on_reducer(move |__identity, __addr, __status, __args| { let SendMessageArgs { text } = __args; - __callback(__identity, __status, text); + __callback(__identity, __addr, __status, text); }) } diff --git a/crates/sdk/examples/cursive-chat/module_bindings/set_name_reducer.rs b/crates/sdk/examples/cursive-chat/module_bindings/set_name_reducer.rs index 785defcbec9..b809a8b71c6 100644 --- a/crates/sdk/examples/cursive-chat/module_bindings/set_name_reducer.rs +++ b/crates/sdk/examples/cursive-chat/module_bindings/set_name_reducer.rs @@ -9,6 +9,7 @@ use spacetimedb_sdk::{ sats::{de::Deserialize, ser::Serialize}, spacetimedb_lib, table::{TableIter, TableType, TableWithPrimaryKey}, + Address, }; #[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] @@ -27,21 +28,21 @@ pub fn set_name(name: String) { #[allow(unused)] pub fn on_set_name( - mut __callback: impl FnMut(&Identity, &Status, &String) + Send + 'static, + mut __callback: impl FnMut(&Identity, Option
, &Status, &String) + Send + 'static, ) -> ReducerCallbackId { - SetNameArgs::on_reducer(move |__identity, __status, __args| { + SetNameArgs::on_reducer(move |__identity, __addr, __status, __args| { let SetNameArgs { name } = __args; - __callback(__identity, __status, name); + __callback(__identity, __addr, __status, name); }) } #[allow(unused)] pub fn once_on_set_name( - __callback: impl FnOnce(&Identity, &Status, &String) + Send + 'static, + __callback: impl FnOnce(&Identity, Option
, &Status, &String) + Send + 'static, ) -> ReducerCallbackId { - SetNameArgs::once_on_reducer(move |__identity, __status, __args| { + SetNameArgs::once_on_reducer(move |__identity, __addr, __status, __args| { let SetNameArgs { name } = __args; - __callback(__identity, __status, name); + __callback(__identity, __addr, __status, name); }) } diff --git a/crates/sdk/examples/cursive-chat/module_bindings/user.rs b/crates/sdk/examples/cursive-chat/module_bindings/user.rs index 7d472ca97eb..9e22256bf19 100644 --- a/crates/sdk/examples/cursive-chat/module_bindings/user.rs +++ b/crates/sdk/examples/cursive-chat/module_bindings/user.rs @@ -9,6 +9,7 @@ use spacetimedb_sdk::{ sats::{de::Deserialize, ser::Serialize}, spacetimedb_lib, table::{TableIter, TableType, TableWithPrimaryKey}, + Address, }; #[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] diff --git a/crates/sdk/examples/quickstart-chat/main.rs b/crates/sdk/examples/quickstart-chat/main.rs index 68d5dca70f0..7559aa504f7 100644 --- a/crates/sdk/examples/quickstart-chat/main.rs +++ b/crates/sdk/examples/quickstart-chat/main.rs @@ -9,6 +9,7 @@ use spacetimedb_sdk::{ reducer::Status, subscribe, table::{TableType, TableWithPrimaryKey}, + Address, }; // # Our main function @@ -53,7 +54,7 @@ fn register_callbacks() { // ## Save credentials to a file /// Our `on_connect` callback: save our credentials to a file. -fn on_connected(creds: &Credentials) { +fn on_connected(creds: &Credentials, _address: Address) { if let Err(e) = save_credentials(CREDS_DIR, creds) { eprintln!("Failed to save credentials: {:?}", e); } @@ -131,7 +132,7 @@ fn on_sub_applied() { // ## Warn if set_name failed /// Our `on_set_name` callback: print a warning if the reducer failed. -fn on_name_set(_sender: &Identity, status: &Status, name: &String) { +fn on_name_set(_sender_id: &Identity, _sender_addr: Option
, status: &Status, name: &String) { if let Status::Failed(err) = status { eprintln!("Failed to change name to {:?}: {}", name, err); } @@ -140,7 +141,7 @@ fn on_name_set(_sender: &Identity, status: &Status, name: &String) { // ## Warn if a message was rejected /// Our `on_send_message` callback: print a warning if the reducer failed. -fn on_message_sent(_sender: &Identity, status: &Status, text: &String) { +fn on_message_sent(_sender: &Identity, _sender_addr: Option
, status: &Status, text: &String) { if let Status::Failed(err) = status { eprintln!("Failed to send message {:?}: {}", text, err); } @@ -160,7 +161,7 @@ fn on_disconnected() { const HOST: &str = "http://localhost:3000"; /// The module name we chose when we published our module. -const DB_NAME: &str = "chat"; +const DB_NAME: &str = "quickstart-chat"; /// Load credentials from a file and connect to the database. fn connect_to_db() { diff --git a/crates/sdk/examples/quickstart-chat/module_bindings/message.rs b/crates/sdk/examples/quickstart-chat/module_bindings/message.rs index e44493675b7..761877354f3 100644 --- a/crates/sdk/examples/quickstart-chat/module_bindings/message.rs +++ b/crates/sdk/examples/quickstart-chat/module_bindings/message.rs @@ -9,6 +9,7 @@ use spacetimedb_sdk::{ sats::{de::Deserialize, ser::Serialize}, spacetimedb_lib, table::{TableIter, TableType, TableWithPrimaryKey}, + Address, }; #[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] diff --git a/crates/sdk/examples/quickstart-chat/module_bindings/mod.rs b/crates/sdk/examples/quickstart-chat/module_bindings/mod.rs index a9780510ccf..1d35ab65282 100644 --- a/crates/sdk/examples/quickstart-chat/module_bindings/mod.rs +++ b/crates/sdk/examples/quickstart-chat/module_bindings/mod.rs @@ -16,6 +16,7 @@ use spacetimedb_sdk::{ sats::{de::Deserialize, ser::Serialize}, spacetimedb_lib, table::{TableIter, TableType, TableWithPrimaryKey}, + Address, }; use std::sync::Arc; diff --git a/crates/sdk/examples/quickstart-chat/module_bindings/send_message_reducer.rs b/crates/sdk/examples/quickstart-chat/module_bindings/send_message_reducer.rs index 5cef95354c8..1c879dd39e1 100644 --- a/crates/sdk/examples/quickstart-chat/module_bindings/send_message_reducer.rs +++ b/crates/sdk/examples/quickstart-chat/module_bindings/send_message_reducer.rs @@ -9,6 +9,7 @@ use spacetimedb_sdk::{ sats::{de::Deserialize, ser::Serialize}, spacetimedb_lib, table::{TableIter, TableType, TableWithPrimaryKey}, + Address, }; #[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] @@ -27,21 +28,21 @@ pub fn send_message(text: String) { #[allow(unused)] pub fn on_send_message( - mut __callback: impl FnMut(&Identity, &Status, &String) + Send + 'static, + mut __callback: impl FnMut(&Identity, Option
, &Status, &String) + Send + 'static, ) -> ReducerCallbackId { - SendMessageArgs::on_reducer(move |__identity, __status, __args| { + SendMessageArgs::on_reducer(move |__identity, __addr, __status, __args| { let SendMessageArgs { text } = __args; - __callback(__identity, __status, text); + __callback(__identity, __addr, __status, text); }) } #[allow(unused)] pub fn once_on_send_message( - __callback: impl FnOnce(&Identity, &Status, &String) + Send + 'static, + __callback: impl FnOnce(&Identity, Option
, &Status, &String) + Send + 'static, ) -> ReducerCallbackId { - SendMessageArgs::once_on_reducer(move |__identity, __status, __args| { + SendMessageArgs::once_on_reducer(move |__identity, __addr, __status, __args| { let SendMessageArgs { text } = __args; - __callback(__identity, __status, text); + __callback(__identity, __addr, __status, text); }) } diff --git a/crates/sdk/examples/quickstart-chat/module_bindings/set_name_reducer.rs b/crates/sdk/examples/quickstart-chat/module_bindings/set_name_reducer.rs index 785defcbec9..b809a8b71c6 100644 --- a/crates/sdk/examples/quickstart-chat/module_bindings/set_name_reducer.rs +++ b/crates/sdk/examples/quickstart-chat/module_bindings/set_name_reducer.rs @@ -9,6 +9,7 @@ use spacetimedb_sdk::{ sats::{de::Deserialize, ser::Serialize}, spacetimedb_lib, table::{TableIter, TableType, TableWithPrimaryKey}, + Address, }; #[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] @@ -27,21 +28,21 @@ pub fn set_name(name: String) { #[allow(unused)] pub fn on_set_name( - mut __callback: impl FnMut(&Identity, &Status, &String) + Send + 'static, + mut __callback: impl FnMut(&Identity, Option
, &Status, &String) + Send + 'static, ) -> ReducerCallbackId { - SetNameArgs::on_reducer(move |__identity, __status, __args| { + SetNameArgs::on_reducer(move |__identity, __addr, __status, __args| { let SetNameArgs { name } = __args; - __callback(__identity, __status, name); + __callback(__identity, __addr, __status, name); }) } #[allow(unused)] pub fn once_on_set_name( - __callback: impl FnOnce(&Identity, &Status, &String) + Send + 'static, + __callback: impl FnOnce(&Identity, Option
, &Status, &String) + Send + 'static, ) -> ReducerCallbackId { - SetNameArgs::once_on_reducer(move |__identity, __status, __args| { + SetNameArgs::once_on_reducer(move |__identity, __addr, __status, __args| { let SetNameArgs { name } = __args; - __callback(__identity, __status, name); + __callback(__identity, __addr, __status, name); }) } diff --git a/crates/sdk/examples/quickstart-chat/module_bindings/user.rs b/crates/sdk/examples/quickstart-chat/module_bindings/user.rs index 7d472ca97eb..9e22256bf19 100644 --- a/crates/sdk/examples/quickstart-chat/module_bindings/user.rs +++ b/crates/sdk/examples/quickstart-chat/module_bindings/user.rs @@ -9,6 +9,7 @@ use spacetimedb_sdk::{ sats::{de::Deserialize, ser::Serialize}, spacetimedb_lib, table::{TableIter, TableType, TableWithPrimaryKey}, + Address, }; #[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] diff --git a/crates/sdk/src/background_connection.rs b/crates/sdk/src/background_connection.rs index a82097bcb60..e8dba369f25 100644 --- a/crates/sdk/src/background_connection.rs +++ b/crates/sdk/src/background_connection.rs @@ -313,11 +313,19 @@ impl BackgroundDbConnection { IntoUri: TryInto, >::Error: std::error::Error + Send + Sync + 'static, { + let client_address = { + let mut lock = self.credentials.lock().expect("CredentialStore Mutex is poisoned"); + lock.get_or_init_address() + }; // `block_in_place` is required here, as tokio won't allow us to call // `block_on` if it would block the current thread of an outer runtime let connection = tokio::task::block_in_place(|| { - self.handle - .block_on(DbConnection::connect(spacetimedb_uri, db_name, credentials.as_ref())) + self.handle.block_on(DbConnection::connect( + spacetimedb_uri, + db_name, + credentials.as_ref(), + client_address, + )) })?; let client_cache = Arc::new(ClientCache::new(module.clone())); diff --git a/crates/sdk/src/callbacks.rs b/crates/sdk/src/callbacks.rs index b35a404b9c4..a4420765dad 100644 --- a/crates/sdk/src/callbacks.rs +++ b/crates/sdk/src/callbacks.rs @@ -23,7 +23,9 @@ use crate::{ reducer::{AnyReducerEvent, Reducer, Status}, spacetime_module::SpacetimeModule, table::TableType, + Address, }; +use anyhow::Context; use anymap::{any::Any, Map}; use futures::stream::StreamExt; use futures_channel::mpsc; @@ -137,12 +139,12 @@ impl std::fmt::Debug for CallbackId { // - `Credentials` -> `&Credentials`, for `on_connect`. // - `()` -> `()`, for `on_subscription_applied`. // - `(T, T, U)` -> `(&T, &T, U)`, for `TableWithPrimaryKey::on_update`. -// - `(Identity, Status, R)` -> `(&Identity, &Status, &R)`, for `Reducer::on_reducer`. +// - `(Identity, Option
, Status, R)` -> `(&Identity, Option
, &Status, &R)`, for `Reducer::on_reducer`. -impl OwnedArgs for Credentials { - type Borrowed<'a> = &'a Credentials; - fn borrow(&self) -> &Credentials { - self +impl OwnedArgs for (Credentials, Address) { + type Borrowed<'a> = (&'a Credentials, Address); + fn borrow(&self) -> (&Credentials, Address) { + (&self.0, self.1) } } @@ -178,13 +180,13 @@ impl OwnedArgs for (T, T, Option>) { } } -impl OwnedArgs for (Identity, Status, R) +impl OwnedArgs for (Identity, Option
, Status, R) where R: Send + 'static, { - type Borrowed<'a> = (&'a Identity, &'a Status, &'a R); - fn borrow(&self) -> (&Identity, &Status, &R) { - (&self.0, &self.1, &self.2) + type Borrowed<'a> = (&'a Identity, Option
, &'a Status, &'a R); + fn borrow(&self) -> (&Identity, Option
, &Status, &R) { + (&self.0, self.1, &self.2, &self.3) } } @@ -450,9 +452,9 @@ fn uncurry_update_callback( /// /// This function is intended specifically for `Reducer::on_reducer` callbacks. fn uncurry_reducer_callback( - mut f: impl for<'a> FnMut(&'a Identity, &'a Status, &'a R) + Send + 'static, -) -> impl for<'a> FnMut((&'a Identity, &'a Status, &'a R)) + Send + 'static { - move |(identity, status, reducer)| f(identity, status, reducer) + mut f: impl for<'a> FnMut(&'a Identity, Option
, &'a Status, &'a R) + Send + 'static, +) -> impl for<'a> FnMut((&'a Identity, Option
, &'a Status, &'a R)) + Send + 'static { + move |(identity, address, status, reducer)| f(identity, address, status, reducer) } /// A collection of registered callbacks for `on_insert`, `on_delete` and `on_update` events @@ -684,9 +686,9 @@ impl ReducerCallbacks { self.module = Some(module); } - pub(crate) fn find_callbacks(&mut self) -> &mut CallbackMap<(Identity, Status, R)> { + pub(crate) fn find_callbacks(&mut self) -> &mut CallbackMap<(Identity, Option
, Status, R)> { self.callbacks - .entry::>() + .entry::, Status, R)>>() .or_insert_with(|| CallbackMap::spawn(&self.runtime)) } @@ -705,6 +707,7 @@ impl ReducerCallbacks { ) -> Option> { let client_api_messages::Event { caller_identity, + caller_address, function_call: Some(function_call), status, message, @@ -715,6 +718,12 @@ impl ReducerCallbacks { return None; }; let identity = Identity::from_bytes(caller_identity); + let address = Address::from_slice(caller_address); + let address = if address == Address::zero() { + None + } else { + Some(address) + }; let Some(status) = parse_status(status, message) else { log::warn!("Received Event with unknown status {:?}", status); return None; @@ -727,7 +736,7 @@ impl ReducerCallbacks { Ok(instance) => { // TODO: should reducer callbacks' `OwnedArgs` impl take an `Arc` rather than an `R`? self.find_callbacks::() - .invoke((identity, status, instance.clone()), state); + .invoke((identity, address, status, instance.clone()), state); Some(Arc::new(wrap(instance))) } } @@ -738,8 +747,8 @@ impl ReducerCallbacks { // TODO: reduce monomorphization by accepting `Box` instead of `impl Callback` pub(crate) fn register_on_reducer( &mut self, - callback: impl FnMut(&Identity, &Status, &R) + Send + 'static, - ) -> CallbackId<(Identity, Status, R)> { + callback: impl FnMut(&Identity, Option
, &Status, &R) + Send + 'static, + ) -> CallbackId<(Identity, Option
, Status, R)> { self.find_callbacks::() .insert(Box::new(uncurry_reducer_callback(callback))) } @@ -752,14 +761,14 @@ impl ReducerCallbacks { // since [`CallbackMap::insert_oneshot`] boxes its wrapper callback. pub(crate) fn register_on_reducer_oneshot( &mut self, - callback: impl FnOnce(&Identity, &Status, &R) + Send + 'static, - ) -> CallbackId<(Identity, Status, R)> { + callback: impl FnOnce(&Identity, Option
, &Status, &R) + Send + 'static, + ) -> CallbackId<(Identity, Option
, Status, R)> { self.find_callbacks::() - .insert_oneshot(move |(identity, status, args)| callback(identity, status, args)) + .insert_oneshot(move |(identity, address, status, args)| callback(identity, address, status, args)) } /// Unregister a previously-registered on-reducer callback identified by `id`. - pub(crate) fn unregister_on_reducer(&mut self, id: CallbackId<(Identity, Status, R)>) { + pub(crate) fn unregister_on_reducer(&mut self, id: CallbackId<(Identity, Option
, Status, R)>) { self.find_callbacks::().remove(id); } @@ -792,11 +801,46 @@ pub(crate) struct CredentialStore { /// from the database. credentials: Option, + address: Option
, + /// Any `on_connect` callbacks to run when credentials become available. - callbacks: CallbackMap, + callbacks: CallbackMap<(Credentials, Address)>, } impl CredentialStore { + pub(crate) fn use_saved_address(&mut self, file: &str) -> anyhow::Result
{ + if let Some(address) = self.address { + panic!( + "Cannot use_saved_address when an address has already been generated. Generated address: {:?}", + address + ); + } + match std::fs::read(file) { + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + let file = AsRef::::as_ref(file); + let addr = Address::from_arr(&rand::random()); + let addr_bytes = bsatn::to_vec(&addr).context("Error serializing Address")?; + + if let Some(parent) = file.parent() { + if parent != AsRef::::as_ref("") { + std::fs::create_dir_all(parent).context("Error creating parent directory for address file")?; + } + } + + std::fs::write(file, addr_bytes).context("Error writing Address to file")?; + + self.address = Some(addr); + Ok(addr) + } + Err(e) => Err(e).context("Error reading BSATN-encoded Address from file"), + Ok(file_contents) => { + let addr = bsatn::from_slice::
(&file_contents) + .context("Error decoding BSATN-encoded Address from file")?; + self.address = Some(addr); + Ok(addr) + } + } + } /// Construct a `CredentialStore` for a not-yet-connected `BackgroundDbConnection` /// containing no credentials. /// @@ -804,6 +848,7 @@ impl CredentialStore { pub(crate) fn without_credentials(runtime: &runtime::Handle) -> Self { CredentialStore { credentials: None, + address: None, callbacks: CallbackMap::spawn(runtime), } } @@ -815,26 +860,38 @@ impl CredentialStore { self.credentials = credentials; } + pub(crate) fn get_or_init_address(&mut self) -> Address { + if let Some(addr) = self.address { + addr + } else { + let addr = Address::from_arr(&rand::random()); + self.address = Some(addr); + addr + } + } + /// Register an on-connect callback to run when the client's `Credentials` become available. // TODO: reduce monomorphization by accepting `Box` instead of `impl Callback` pub(crate) fn register_on_connect( &mut self, - callback: impl FnMut(&Credentials) + Send + 'static, - ) -> CallbackId { - self.callbacks.insert(Box::new(callback)) + mut callback: impl FnMut(&Credentials, Address) + Send + 'static, + ) -> CallbackId<(Credentials, Address)> { + self.callbacks + .insert(Box::new(move |(creds, addr)| callback(creds, addr))) } /// Register an on-connect callback which will run at most once, /// then unregister itself. pub(crate) fn register_on_connect_oneshot( &mut self, - callback: impl FnOnce(&Credentials) + Send + 'static, - ) -> CallbackId { - self.callbacks.insert_oneshot(callback) + callback: impl FnOnce(&Credentials, Address) + Send + 'static, + ) -> CallbackId<(Credentials, Address)> { + self.callbacks + .insert_oneshot(move |(creds, addr)| callback(creds, addr)) } /// Unregister a previously-registered on-connect callback identified by `id`. - pub(crate) fn unregister_on_connect(&mut self, id: CallbackId) { + pub(crate) fn unregister_on_connect(&mut self, id: CallbackId<(Credentials, Address)>) { self.callbacks.remove(id); } @@ -847,9 +904,14 @@ impl CredentialStore { /// /// Either way, invoke any on-connect callbacks with the received credentials. pub(crate) fn handle_identity_token(&mut self, msg: client_api_messages::IdentityToken, state: ClientCacheView) { - let client_api_messages::IdentityToken { identity, token } = msg; - if identity.is_empty() || token.is_empty() { - log::warn!("Received IdentityToken message with emtpy identity and/or empty token"); + let client_api_messages::IdentityToken { + identity, + token, + address, + } = msg; + if identity.is_empty() || token.is_empty() || address.is_empty() { + // TODO: panic? + log::warn!("Received IdentityToken message with emtpy identity, token and/or address"); return; } @@ -858,7 +920,17 @@ impl CredentialStore { token: Token { string: token }, }; - self.callbacks.invoke(creds.clone(), state); + let address = Address::from_slice(&address); + + if Some(address) != self.address { + log::error!( + "Address provided by the server does not match local record. Server: {:?} Local: {:?}", + address, + self.address, + ); + } + + self.callbacks.invoke((creds.clone(), address), state); if let Some(existing_creds) = &self.credentials { // If we already have credentials, make sure that they match. Log an error if @@ -892,6 +964,11 @@ impl CredentialStore { pub(crate) fn credentials(&self) -> Option { self.credentials.clone() } + + /// Return the current connection's `Address`, if it is stored. + pub(crate) fn address(&self) -> Option
{ + self.address + } } /// Manages running `on_subscription_applied` callbacks after `subscribe` calls. diff --git a/crates/sdk/src/identity.rs b/crates/sdk/src/identity.rs index f9eebf80ecf..e1286bc704c 100644 --- a/crates/sdk/src/identity.rs +++ b/crates/sdk/src/identity.rs @@ -3,6 +3,7 @@ use crate::global_connection::with_credential_store; use anyhow::{anyhow, Context, Result}; use spacetimedb_lib::de::Deserialize; use spacetimedb_lib::ser::Serialize; +use spacetimedb_lib::Address; use spacetimedb_sats::bsatn; // TODO: impl ser/de for `Identity`, `Token`, `Credentials` so that clients can stash them // to disk and use them to re-connect. @@ -75,7 +76,7 @@ pub struct Credentials { #[derive(Copy, Clone)] pub struct ConnectCallbackId { - id: CallbackId, + id: CallbackId<(Credentials, Address)>, } /// Register a callback to be invoked upon authentication with the database. @@ -102,7 +103,7 @@ pub struct ConnectCallbackId { /// /// The returned `ConnectCallbackId` can be passed to `remove_on_connect` to unregister /// the callback. -pub fn on_connect(callback: impl FnMut(&Credentials) + Send + 'static) -> ConnectCallbackId { +pub fn on_connect(callback: impl FnMut(&Credentials, Address) + Send + 'static) -> ConnectCallbackId { let id = with_credential_store(|cred_store| cred_store.register_on_connect(callback)); ConnectCallbackId { id } } @@ -125,7 +126,7 @@ pub fn on_connect(callback: impl FnMut(&Credentials) + Send + 'static) -> Connec /// /// The returned `ConnectCallbackId` can be passed to `remove_on_connect` to unregister /// the callback. -pub fn once_on_connect(callback: impl FnOnce(&Credentials) + Send + 'static) -> ConnectCallbackId { +pub fn once_on_connect(callback: impl FnOnce(&Credentials, Address) + Send + 'static) -> ConnectCallbackId { let id = with_credential_store(|cred_store| cred_store.register_on_connect_oneshot(callback)); ConnectCallbackId { id } } @@ -166,6 +167,13 @@ pub fn credentials() -> Result { with_credential_store(|cred_store| cred_store.credentials().ok_or(anyhow!("Credentials not yet received"))) } +/// Read the current connection's `Address`. +/// +/// Returns an error if `connect` has not yet been called. +pub fn address() -> Result
{ + with_credential_store(|cred_store| cred_store.address().ok_or(anyhow!("Address not yet generated"))) +} + const CREDS_FILE: &str = "credentials"; /// Load a saved `Credentials` from a file within `~/dirname`, if one exists. @@ -209,3 +217,13 @@ pub fn save_credentials(dirname: &str, creds: &Credentials) -> Result<()> { path.push(CREDS_FILE); std::fs::write(&path, creds_bytes).with_context(|| "Writing credentials to file") } + +#[doc(hidden)] +/// Designate a file to store this client's `Address`. +/// +/// If called, `use_save_address` must be called before the first call to `connect` in a process. +/// +/// If the file at `path` exists, it will be treated as a BSATN-encoded `Address` +pub fn use_saved_address(path: &str) -> Result
{ + with_credential_store(|cred_store| cred_store.use_saved_address(path)) +} diff --git a/crates/sdk/src/lib.rs b/crates/sdk/src/lib.rs index b37dd2eb85d..74a3047b790 100644 --- a/crates/sdk/src/lib.rs +++ b/crates/sdk/src/lib.rs @@ -30,7 +30,7 @@ pub mod background_connection; // We re-export `spacetimedb_lib` so the cli codegen can reference it through us, rather // than requiring downstream users to depend on it explicitly. // TODO: determine if this should be `#[doc(hidden)]` -pub use spacetimedb_lib; +pub use spacetimedb_lib::{self, Address}; // Ditto re-exporing `log`. // TODO: determine if this should be `#[doc(hidden)]`. diff --git a/crates/sdk/src/reducer.rs b/crates/sdk/src/reducer.rs index 4457cc34d32..930228b6309 100644 --- a/crates/sdk/src/reducer.rs +++ b/crates/sdk/src/reducer.rs @@ -1,6 +1,7 @@ use crate::callbacks::CallbackId; use crate::global_connection::{with_connection, with_reducer_callbacks}; use crate::identity::Identity; +use crate::Address; use anyhow::Result; use spacetimedb_sats::{de::DeserializeOwned, ser::Serialize}; use std::any::Any; @@ -14,7 +15,7 @@ pub enum Status { #[derive(Copy, Clone)] pub struct ReducerCallbackId { - id: CallbackId<(Identity, Status, R)>, + id: CallbackId<(Identity, Option
, Status, R)>, } // Any bound so these can be keys in an `AnyMap` to store callbacks. @@ -39,7 +40,9 @@ pub trait Reducer: DeserializeOwned + Serialize + Any + Send + Sync + Clone { // /// The returned `ReducerCallbackId` can be passed to `remove_on_reducer` to /// unregister the callback. - fn on_reducer(callback: impl FnMut(&Identity, &Status, &Self) + Send + 'static) -> ReducerCallbackId { + fn on_reducer( + callback: impl FnMut(&Identity, Option
, &Status, &Self) + Send + 'static, + ) -> ReducerCallbackId { let id = with_reducer_callbacks(|callbacks| callbacks.register_on_reducer::(callback)); ReducerCallbackId { id } } @@ -49,7 +52,9 @@ pub trait Reducer: DeserializeOwned + Serialize + Any + Send + Sync + Clone { /// The `callback` will run at most once, then unregister itself. /// It can also be unregistered by passing the returned `ReducerCallbackId` /// to `remove_on_reducer`. - fn once_on_reducer(callback: impl FnOnce(&Identity, &Status, &Self) + Send + 'static) -> ReducerCallbackId { + fn once_on_reducer( + callback: impl FnOnce(&Identity, Option
, &Status, &Self) + Send + 'static, + ) -> ReducerCallbackId { let id = with_reducer_callbacks(|callbacks| callbacks.register_on_reducer_oneshot::(callback)); ReducerCallbackId { id } } diff --git a/crates/sdk/src/websocket.rs b/crates/sdk/src/websocket.rs index d2ae20186b2..14aa3977980 100644 --- a/crates/sdk/src/websocket.rs +++ b/crates/sdk/src/websocket.rs @@ -5,6 +5,7 @@ use futures_channel::mpsc; use http::uri::{Parts, Scheme, Uri}; use prost::Message as ProtobufMessage; use spacetimedb_client_api_messages::client_api::Message; +use spacetimedb_lib::Address; use tokio::task::JoinHandle; use tokio::{net::TcpStream, runtime}; use tokio_tungstenite::{ @@ -30,7 +31,7 @@ fn parse_scheme(scheme: Option) -> Result { }) } -fn make_uri(host: Host, db_name: &str) -> Result +fn make_uri(host: Host, db_name: &str, client_address: Address) -> Result where Host: TryInto, >::Error: std::error::Error + Send + Sync + 'static, @@ -53,6 +54,8 @@ where } path.push_str("database/subscribe/"); path.push_str(db_name); + path.push_str("?client_address="); + path.push_str(&client_address.to_hex()); parts.path_and_query = Some(path.parse()?); Ok(Uri::try_from(parts)?) } @@ -66,12 +69,17 @@ where // rather than having Tungstenite manage its own connections. Should this library do // the same? -fn make_request(host: Host, db_name: &str, credentials: Option<&Credentials>) -> Result> +fn make_request( + host: Host, + db_name: &str, + credentials: Option<&Credentials>, + client_address: Address, +) -> Result> where Host: TryInto, >::Error: std::error::Error + Send + Sync + 'static, { - let uri = make_uri(host, db_name)?; + let uri = make_uri(host, db_name, client_address)?; let mut req = IntoClientRequest::into_client_request(uri)?; request_insert_protocol_header(&mut req); request_insert_auth_header(&mut req, credentials); @@ -115,12 +123,17 @@ fn request_insert_auth_header(req: &mut http::Request<()>, credentials: Option<& } impl DbConnection { - pub(crate) async fn connect(host: Host, db_name: &str, credentials: Option<&Credentials>) -> Result + pub(crate) async fn connect( + host: Host, + db_name: &str, + credentials: Option<&Credentials>, + client_address: Address, + ) -> Result where Host: TryInto, >::Error: std::error::Error + Send + Sync + 'static, { - let req = make_request(host, db_name, credentials)?; + let req = make_request(host, db_name, credentials, client_address)?; let (sock, _): (WebSocketStream>, _) = connect_async_with_config( req, // TODO(kim): In order to be able to replicate module WASM blobs, diff --git a/crates/sdk/tests/connect_disconnect_client/src/main.rs b/crates/sdk/tests/connect_disconnect_client/src/main.rs index ddbc53e5222..196f7b0b20a 100644 --- a/crates/sdk/tests/connect_disconnect_client/src/main.rs +++ b/crates/sdk/tests/connect_disconnect_client/src/main.rs @@ -35,7 +35,7 @@ fn main() { }; sub_applied_one_row_result(check()); }); - once_on_connect(move |_| { + once_on_connect(move |_, _| { subscribe_result(subscribe(&["SELECT * FROM Connected;"])); }); @@ -68,7 +68,7 @@ fn main() { }; sub_applied_one_row_result(check()); }); - once_on_connect(move |_| { + once_on_connect(move |_, _| { subscribe_result(subscribe(&["SELECT * FROM Disconnected;"])); }); diff --git a/crates/sdk/tests/test-client/src/main.rs b/crates/sdk/tests/test-client/src/main.rs index 13c7a219218..443487a45d1 100644 --- a/crates/sdk/tests/test-client/src/main.rs +++ b/crates/sdk/tests/test-client/src/main.rs @@ -1,6 +1,7 @@ use spacetimedb_sdk::{ - identity::{identity, load_credentials, once_on_connect, save_credentials}, - once_on_subscription_applied, + disconnect, + identity::{address, identity, load_credentials, once_on_connect, save_credentials}, + once_on_disconnect, once_on_subscription_applied, reducer::Status, subscribe, table::TableType, @@ -66,6 +67,10 @@ fn main() { "delete_identity" => exec_delete_identity(), "update_identity" => exec_update_identity(), + "insert_address" => exec_insert_address(), + "delete_address" => exec_delete_address(), + "update_address" => exec_update_address(), + "on_reducer" => exec_on_reducer(), "fail_reducer" => exec_fail_reducer(), @@ -86,6 +91,8 @@ fn main() { "should_fail" => exec_should_fail(), + "reconnect_same_address" => exec_reconnect_same_address(), + _ => panic!("Unknown test: {}", test), } } @@ -125,6 +132,7 @@ fn assert_all_tables_empty() -> anyhow::Result<()> { assert_table_empty::()?; assert_table_empty::()?; + assert_table_empty::()?; assert_table_empty::()?; assert_table_empty::()?; @@ -153,6 +161,7 @@ fn assert_all_tables_empty() -> anyhow::Result<()> { assert_table_empty::()?; assert_table_empty::()?; + assert_table_empty::()?; assert_table_empty::()?; assert_table_empty::()?; @@ -178,6 +187,7 @@ fn assert_all_tables_empty() -> anyhow::Result<()> { assert_table_empty::()?; assert_table_empty::()?; + assert_table_empty::()?; assert_table_empty::()?; assert_table_empty::()?; @@ -195,6 +205,7 @@ fn assert_all_tables_empty() -> anyhow::Result<()> { assert_table_empty::()?; assert_table_empty::()?; + assert_table_empty::()?; assert_table_empty::()?; @@ -220,6 +231,7 @@ const SUBSCRIBE_ALL: &[&str] = &[ "SELECT * FROM OneF64;", "SELECT * FROM OneString;", "SELECT * FROM OneIdentity;", + "SELECT * FROM OneAddress;", "SELECT * FROM OneSimpleEnum;", "SELECT * FROM OneEnumWithPayload;", "SELECT * FROM OneUnitStruct;", @@ -241,6 +253,7 @@ const SUBSCRIBE_ALL: &[&str] = &[ "SELECT * FROM VecF64;", "SELECT * FROM VecString;", "SELECT * FROM VecIdentity;", + "SELECT * FROM VecAddress;", "SELECT * FROM VecSimpleEnum;", "SELECT * FROM VecEnumWithPayload;", "SELECT * FROM VecUnitStruct;", @@ -260,6 +273,7 @@ const SUBSCRIBE_ALL: &[&str] = &[ "SELECT * FROM UniqueBool;", "SELECT * FROM UniqueString;", "SELECT * FROM UniqueIdentity;", + "SELECT * FROM UniqueAddress;", "SELECT * FROM PkU8;", "SELECT * FROM PkU16;", "SELECT * FROM PkU32;", @@ -273,6 +287,7 @@ const SUBSCRIBE_ALL: &[&str] = &[ "SELECT * FROM PkBool;", "SELECT * FROM PkString;", "SELECT * FROM PkIdentity;", + "SELECT * FROM PkAddress;", "SELECT * FROM LargeTable;", "SELECT * FROM TableHoldsTable;", ]; @@ -317,7 +332,7 @@ fn exec_insert_primitive() { }); } - once_on_connect(move |_| sub_result(subscribe(SUBSCRIBE_ALL))); + once_on_connect(move |_, _| sub_result(subscribe(SUBSCRIBE_ALL))); conn_result(connect(LOCALHOST, &name, None)); @@ -358,7 +373,7 @@ fn exec_delete_primitive() { }); } - once_on_connect(move |_| sub_result(subscribe(SUBSCRIBE_ALL))); + once_on_connect(move |_, _| sub_result(subscribe(SUBSCRIBE_ALL))); conn_result(connect(LOCALHOST, &name, None)); @@ -401,7 +416,7 @@ fn exec_update_primitive() { }); } - once_on_connect(move |_| sub_result(subscribe(SUBSCRIBE_ALL))); + once_on_connect(move |_, _| sub_result(subscribe(SUBSCRIBE_ALL))); conn_result(connect(LOCALHOST, &name, None)); @@ -430,7 +445,7 @@ fn exec_insert_identity() { }); } - once_on_connect(move |_| sub_result(subscribe(SUBSCRIBE_ALL))); + once_on_connect(move |_, _| sub_result(subscribe(SUBSCRIBE_ALL))); conn_result(connect(LOCALHOST, &name, None)); @@ -458,7 +473,7 @@ fn exec_delete_identity() { }); } - once_on_connect(move |_| sub_result(subscribe(SUBSCRIBE_ALL))); + once_on_connect(move |_, _| sub_result(subscribe(SUBSCRIBE_ALL))); conn_result(connect(LOCALHOST, &name, None)); @@ -488,7 +503,94 @@ fn exec_update_identity() { }); } - once_on_connect(move |_| sub_result(subscribe(SUBSCRIBE_ALL))); + once_on_connect(move |_, _| sub_result(subscribe(SUBSCRIBE_ALL))); + + conn_result(connect(LOCALHOST, &name, None)); + + test_counter.wait_for_all(); + + assert_all_tables_empty().unwrap(); +} + +/// This tests that we can serialize and deserialize `Address` in various contexts. +fn exec_insert_address() { + let test_counter = TestCounter::new(); + let name = db_name_or_panic(); + + let conn_result = test_counter.add_test("connect"); + + let sub_result = test_counter.add_test("subscribe"); + + let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); + + { + let test_counter = test_counter.clone(); + once_on_subscription_applied(move || { + insert_one::(&test_counter, address().unwrap()); + + sub_applied_nothing_result(assert_all_tables_empty()); + }); + } + + once_on_connect(move |_, _| sub_result(subscribe(SUBSCRIBE_ALL))); + + conn_result(connect(LOCALHOST, &name, None)); + + test_counter.wait_for_all(); +} + +/// This test doesn't add much alongside `exec_insert_address` and `exec_delete_primitive`, +/// but it's here for symmetry. +fn exec_delete_address() { + let test_counter = TestCounter::new(); + let name = db_name_or_panic(); + + let conn_result = test_counter.add_test("connect"); + + let sub_result = test_counter.add_test("subscribe"); + + let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); + + { + let test_counter = test_counter.clone(); + once_on_subscription_applied(move || { + insert_then_delete_one::(&test_counter, address().unwrap(), 0xbeef); + + sub_applied_nothing_result(assert_all_tables_empty()); + }); + } + + once_on_connect(move |_, _| sub_result(subscribe(SUBSCRIBE_ALL))); + + conn_result(connect(LOCALHOST, &name, None)); + + test_counter.wait_for_all(); + + assert_all_tables_empty().unwrap(); +} + +/// This tests that we can distinguish between `on_delete` and `on_update` events +/// for tables with `Address` primary keys. +fn exec_update_address() { + let test_counter = TestCounter::new(); + let name = db_name_or_panic(); + + let conn_result = test_counter.add_test("connect"); + + let sub_result = test_counter.add_test("subscribe"); + + let sub_applied_nothing_result = test_counter.add_test("on_subscription_applied_nothing"); + + { + let test_counter = test_counter.clone(); + once_on_subscription_applied(move || { + insert_update_delete_one::(&test_counter, address().unwrap(), 0xbeef, 0xbabe); + + sub_applied_nothing_result(assert_all_tables_empty()); + }); + } + + once_on_connect(move |_, _| sub_result(subscribe(SUBSCRIBE_ALL))); conn_result(connect(LOCALHOST, &name, None)); @@ -512,16 +614,23 @@ fn exec_on_reducer() { let value = 128; - once_on_insert_one_u_8(move |caller, status, arg| { + once_on_insert_one_u_8(move |caller_id, caller_addr, status, arg| { let run_checks = || { if *arg != value { anyhow::bail!("Unexpected reducer argument. Expected {} but found {}", value, *arg); } - if *caller != identity().unwrap() { + if *caller_id != identity().unwrap() { anyhow::bail!( - "Unexpected caller. Expected:\n{:?}\nFound:\n{:?}", + "Unexpected caller_id. Expected:\n{:?}\nFound:\n{:?}", identity().unwrap(), - caller + caller_id + ); + } + if caller_addr != Some(address().unwrap()) { + anyhow::bail!( + "Unexpected caller_addr. Expected:\n{:?}\nFound:\n{:?}", + address().unwrap(), + caller_addr ); } if !matches!(status, Status::Committed) { @@ -546,7 +655,7 @@ fn exec_on_reducer() { sub_applied_nothing_result(assert_all_tables_empty()); }); - once_on_connect(move |_| sub_result(subscribe(SUBSCRIBE_ALL))); + once_on_connect(move |_, _| sub_result(subscribe(SUBSCRIBE_ALL))); conn_result(connect(LOCALHOST, &name, None)); @@ -571,7 +680,7 @@ fn exec_fail_reducer() { let initial_data = 0xbeef; let fail_data = 0xbabe; - once_on_insert_pk_u_8(move |caller, status, arg_key, arg_val| { + once_on_insert_pk_u_8(move |caller_id, caller_addr, status, arg_key, arg_val| { let run_checks = || { if *arg_key != key { anyhow::bail!("Unexpected reducer argument. Expected {} but found {}", key, *arg_key); @@ -580,14 +689,21 @@ fn exec_fail_reducer() { anyhow::bail!( "Unexpected reducer argument. Expected {} but found {}", initial_data, - *arg_val + *arg_val, ); } - if *caller != identity().unwrap() { + if *caller_id != identity().unwrap() { anyhow::bail!( - "Unexpected caller. Expected:\n{:?}\nFound:\n{:?}", + "Unexpected caller_id. Expected:\n{:?}\nFound:\n{:?}", identity().unwrap(), - caller + caller_id, + ); + } + if caller_addr != Some(address().unwrap()) { + anyhow::bail!( + "Unexpected caller_addr. Expected:\n{:?}\nFound:\n{:?}", + address().unwrap(), + caller_addr, ); } if !matches!(status, Status::Committed) { @@ -610,7 +726,7 @@ fn exec_fail_reducer() { reducer_success_result(run_checks()); - once_on_insert_pk_u_8(move |caller, status, arg_key, arg_val| { + once_on_insert_pk_u_8(move |caller_id, caller_addr, status, arg_key, arg_val| { let run_checks = || { if *arg_key != key { anyhow::bail!("Unexpected reducer argument. Expected {} but found {}", key, *arg_key); @@ -622,13 +738,20 @@ fn exec_fail_reducer() { *arg_val ); } - if *caller != identity().unwrap() { + if *caller_id != identity().unwrap() { anyhow::bail!( - "Unexpected caller. Expected:\n{:?}\nFound:\n{:?}", + "Unexpected caller_id. Expected:\n{:?}\nFound:\n{:?}", identity().unwrap(), - caller + caller_id, ); } + if caller_addr != Some(address().unwrap()) { + anyhow::bail!( + "Unexpected caller_addr. Expected:\n{:?}\nFound:\n{:?}", + address().unwrap(), + caller_addr, + ) + } if !matches!(status, Status::Failed(_)) { anyhow::bail!("Unexpected status. Expected Failed but found {:?}", status); } @@ -659,7 +782,7 @@ fn exec_fail_reducer() { sub_applied_nothing_result(assert_all_tables_empty()); }); - once_on_connect(move |_| sub_result(subscribe(SUBSCRIBE_ALL))); + once_on_connect(move |_, _| sub_result(subscribe(SUBSCRIBE_ALL))); conn_result(connect(LOCALHOST, &name, None)); @@ -705,7 +828,7 @@ fn exec_insert_vec() { }); } - once_on_connect(move |_| sub_result(subscribe(SUBSCRIBE_ALL))); + once_on_connect(move |_, _| sub_result(subscribe(SUBSCRIBE_ALL))); conn_result(connect(LOCALHOST, &name, None)); @@ -746,6 +869,7 @@ fn exec_insert_struct() { m: -1.0, n: "string".to_string(), o: identity().unwrap(), + p: address().unwrap(), }, ); insert_one::( @@ -766,6 +890,7 @@ fn exec_insert_struct() { m: vec![0.0, -0.5, 0.5, -1.5, 1.5], n: ["vec", "of", "strings"].into_iter().map(str::to_string).collect(), o: vec![identity().unwrap()], + p: vec![address().unwrap()], }, ); @@ -789,6 +914,7 @@ fn exec_insert_struct() { m: -1.0, n: "string".to_string(), o: identity().unwrap(), + p: address().unwrap(), }], ); insert_one::( @@ -809,6 +935,7 @@ fn exec_insert_struct() { m: vec![0.0, -0.5, 0.5, -1.5, 1.5], n: ["vec", "of", "strings"].into_iter().map(str::to_string).collect(), o: vec![identity().unwrap()], + p: vec![address().unwrap()], }], ); @@ -816,7 +943,7 @@ fn exec_insert_struct() { }); } - once_on_connect(move |_| sub_result(subscribe(SUBSCRIBE_ALL))); + once_on_connect(move |_, _| sub_result(subscribe(SUBSCRIBE_ALL))); conn_result(connect(LOCALHOST, &name, None)); @@ -844,7 +971,7 @@ fn exec_insert_simple_enum() { }); } - once_on_connect(move |_| sub_result(subscribe(SUBSCRIBE_ALL))); + once_on_connect(move |_, _| sub_result(subscribe(SUBSCRIBE_ALL))); conn_result(connect(LOCALHOST, &name, None)); @@ -899,7 +1026,7 @@ fn exec_insert_enum_with_payload() { }); } - once_on_connect(move |_| sub_result(subscribe(SUBSCRIBE_ALL))); + once_on_connect(move |_, _| sub_result(subscribe(SUBSCRIBE_ALL))); conn_result(connect(LOCALHOST, &name, None)); @@ -962,6 +1089,7 @@ fn exec_insert_long_table() { m: 1.0, n: "string".to_string(), o: identity().unwrap(), + p: address().unwrap(), }; let every_vec_struct = EveryVecStruct { a: vec![0], @@ -979,6 +1107,7 @@ fn exec_insert_long_table() { m: vec![1.0], n: vec!["string".to_string()], o: vec![identity().unwrap()], + p: vec![address().unwrap()], }; let every_primitive_dup = every_primitive_struct.clone(); @@ -1044,7 +1173,7 @@ fn exec_insert_long_table() { }); } - once_on_connect(move |_| sub_result(subscribe(SUBSCRIBE_ALL))); + once_on_connect(move |_, _| sub_result(subscribe(SUBSCRIBE_ALL))); conn_result(connect(LOCALHOST, &name, None)); @@ -1067,7 +1196,7 @@ fn exec_resubscribe() { sub_applied_result(assert_all_tables_empty()); }); - once_on_connect(|_| { + once_on_connect(|_, _| { subscribe_result(subscribe(SUBSCRIBE_ALL)); }); @@ -1170,7 +1299,7 @@ fn exec_reauth_part_1() { let connect_result = test_counter.add_test("connect"); let save_result = test_counter.add_test("save-credentials"); - once_on_connect(|creds| { + once_on_connect(|creds, _| { save_result(save_credentials(".spacetime_rust_sdk_test", creds)); }); @@ -1196,7 +1325,7 @@ fn exec_reauth_part_2() { let creds_dup = creds.clone(); - once_on_connect(move |received_creds| { + once_on_connect(move |received_creds, _| { let run_checks = || { assert_eq_or_bail!(creds_dup, *received_creds); Ok(()) @@ -1209,3 +1338,51 @@ fn exec_reauth_part_2() { test_counter.wait_for_all(); } + +fn exec_reconnect_same_address() { + let test_counter = TestCounter::new(); + let name = db_name_or_panic(); + + let connect_result = test_counter.add_test("connect"); + let read_addr_result = test_counter.add_test("read_addr"); + + let name_dup = name.clone(); + once_on_connect(move |_, received_address| { + let my_address = address().unwrap(); + let run_checks = || { + assert_eq_or_bail!(my_address, received_address); + Ok(()) + }; + + read_addr_result(run_checks()); + }); + + connect_result(connect(LOCALHOST, &name, None)); + + test_counter.wait_for_all(); + + let my_address = address().unwrap(); + + let test_counter = TestCounter::new(); + let reconnect_result = test_counter.add_test("reconnect"); + let addr_after_reconnect_result = test_counter.add_test("addr_after_reconnect"); + + once_on_disconnect(move || { + once_on_connect(move |_, received_address| { + let my_address_2 = address().unwrap(); + let run_checks = || { + assert_eq_or_bail!(my_address, received_address); + assert_eq_or_bail!(my_address, my_address_2); + Ok(()) + }; + + addr_after_reconnect_result(run_checks()); + }); + + reconnect_result(connect(LOCALHOST, &name_dup, None)); + }); + + disconnect(); + + test_counter.wait_for_all(); +} diff --git a/crates/sdk/tests/test-client/src/pk_test_table.rs b/crates/sdk/tests/test-client/src/pk_test_table.rs index 023f456a670..c9781e5cdee 100644 --- a/crates/sdk/tests/test-client/src/pk_test_table.rs +++ b/crates/sdk/tests/test-client/src/pk_test_table.rs @@ -308,4 +308,16 @@ impl_pk_test_table! { update_reducer = update_pk_identity; update_reducer_event = UpdatePkIdentity; } + + PkAddress { + Key = Address; + key_field_name = a; + insert_reducer = insert_pk_address; + insert_reducer_event = InsertPkAddress; + delete_reducer = delete_pk_address; + delete_reducer_event = DeletePkAddress; + update_reducer = update_pk_address; + update_reducer_event = UpdatePkAddress; + } + } diff --git a/crates/sdk/tests/test-client/src/simple_test_table.rs b/crates/sdk/tests/test-client/src/simple_test_table.rs index 5d12c998985..7b32553770b 100644 --- a/crates/sdk/tests/test-client/src/simple_test_table.rs +++ b/crates/sdk/tests/test-client/src/simple_test_table.rs @@ -1,6 +1,6 @@ use crate::module_bindings::*; use anyhow::anyhow; -use spacetimedb_sdk::{identity::Identity, table::TableType}; +use spacetimedb_sdk::{identity::Identity, table::TableType, Address}; use std::sync::Arc; use test_counter::TestCounter; @@ -146,6 +146,13 @@ impl_simple_test_table! { insert_reducer_event = InsertOneIdentity; } + OneAddress { + Contents = Address; + field_name = a; + insert_reducer = insert_one_address; + insert_reducer_event = InsertOneAddress; + } + OneSimpleEnum { Contents = SimpleEnum; field_name = e; @@ -280,6 +287,13 @@ impl_simple_test_table! { insert_reducer_event = InsertVecIdentity; } + VecAddress { + Contents = Vec
; + field_name = a; + insert_reducer = insert_vec_address; + insert_reducer_event = InsertVecAddress; + } + VecSimpleEnum { Contents = Vec; field_name = e; diff --git a/crates/sdk/tests/test-client/src/unique_test_table.rs b/crates/sdk/tests/test-client/src/unique_test_table.rs index 6da5f697ded..31e44201601 100644 --- a/crates/sdk/tests/test-client/src/unique_test_table.rs +++ b/crates/sdk/tests/test-client/src/unique_test_table.rs @@ -1,6 +1,6 @@ use crate::module_bindings::*; use anyhow::anyhow; -use spacetimedb_sdk::{identity::Identity, table::TableType}; +use spacetimedb_sdk::{identity::Identity, table::TableType, Address}; use std::sync::Arc; use test_counter::TestCounter; @@ -235,4 +235,13 @@ impl_unique_test_table! { delete_reducer = delete_unique_identity; delete_reducer_event = DeleteUniqueIdentity; } + + UniqueAddress { + Key = Address; + key_field_name = a; + insert_reducer = insert_unique_address; + insert_reducer_event = InsertUniqueAddress; + delete_reducer = delete_unique_address; + delete_reducer_event = DeleteUniqueAddress; + } } diff --git a/crates/sdk/tests/test.rs b/crates/sdk/tests/test.rs index ec713bc2802..e0b47a928f6 100644 --- a/crates/sdk/tests/test.rs +++ b/crates/sdk/tests/test.rs @@ -45,6 +45,21 @@ fn update_identity() { make_test("delete_identity").run(); } +#[test] +fn insert_address() { + make_test("insert_address").run(); +} + +#[test] +fn delete_address() { + make_test("delete_address").run(); +} + +#[test] +fn update_address() { + make_test("delete_address").run(); +} + #[test] fn on_reducer() { make_test("on_reducer").run(); @@ -92,6 +107,11 @@ fn reauth() { make_test("reauth_part_2").run(); } +#[test] +fn reconnect_same_address() { + make_test("reconnect_same_address").run(); +} + #[test] fn connect_disconnect_callbacks() { Test::builder() diff --git a/crates/standalone/src/lib.rs b/crates/standalone/src/lib.rs index 77811c4f949..c7491bf72a7 100644 --- a/crates/standalone/src/lib.rs +++ b/crates/standalone/src/lib.rs @@ -274,6 +274,7 @@ impl spacetimedb_client_api::ControlStateWriteAccess for StandaloneEnv { async fn publish_database( &self, identity: &Identity, + publisher_address: Option
, spec: spacetimedb_client_api::DatabaseDef, ) -> spacetimedb::control_db::Result> { let existing_db = self.control_db.get_database_by_address(&spec.address)?; @@ -283,6 +284,7 @@ impl spacetimedb_client_api::ControlStateWriteAccess for StandaloneEnv { address: spec.address, num_replicas: spec.num_replicas, program_bytes_address, + publisher_address, ..existing.clone() }, None => Database { @@ -292,6 +294,7 @@ impl spacetimedb_client_api::ControlStateWriteAccess for StandaloneEnv { host_type: HostType::Wasmer, num_replicas: spec.num_replicas, program_bytes_address, + publisher_address, }, }; diff --git a/crates/testing/src/modules.rs b/crates/testing/src/modules.rs index 03c5bed7957..a48203c01d8 100644 --- a/crates/testing/src/modules.rs +++ b/crates/testing/src/modules.rs @@ -125,7 +125,8 @@ impl CompiledModule { crate::set_key_env_vars(&paths); let env = spacetimedb_standalone::StandaloneEnv::init(config).await.unwrap(); let identity = env.create_identity().await.unwrap(); - let address = env.create_address().await.unwrap(); + let db_address = env.create_address().await.unwrap(); + let client_address = env.create_address().await.unwrap(); let program_bytes = self .program_bytes @@ -134,8 +135,9 @@ impl CompiledModule { env.publish_database( &identity, + Some(client_address), DatabaseDef { - address, + address: db_address, program_bytes, num_replicas: 1, }, @@ -143,11 +145,12 @@ impl CompiledModule { .await .unwrap(); - let database = env.get_database_by_address(&address).unwrap().unwrap(); + let database = env.get_database_by_address(&db_address).unwrap().unwrap(); let instance = env.get_leader_database_instance_by_database(database.id).unwrap(); let client_id = ClientActorId { identity, + address: client_address, name: env.client_actor_index().next_client_name(), }; @@ -160,7 +163,7 @@ impl CompiledModule { ModuleHandle { _env: env, client: ClientConnection::dummy(client_id, Protocol::Text, instance.id, module), - db_address: address, + db_address, } } } diff --git a/modules/sdk-test/Cargo.toml b/modules/sdk-test/Cargo.toml index d578d5d1dde..068e81f09ca 100644 --- a/modules/sdk-test/Cargo.toml +++ b/modules/sdk-test/Cargo.toml @@ -10,4 +10,5 @@ crate-type = ["cdylib"] [dependencies] spacetimedb = { path = "../../crates/bindings", version = "0.6.1" } -log = "0.4" +log.workspace = true +anyhow.workspace = true diff --git a/modules/sdk-test/src/lib.rs b/modules/sdk-test/src/lib.rs index 6b9e35a1c8c..ce2f69b85c3 100644 --- a/modules/sdk-test/src/lib.rs +++ b/modules/sdk-test/src/lib.rs @@ -6,7 +6,8 @@ // and clippy misunderstands `#[allow]` attributes in macro-expansions. #![allow(clippy::too_many_arguments)] -use spacetimedb::{spacetimedb, Identity, SpacetimeType}; +use anyhow::{Context, Result}; +use spacetimedb::{spacetimedb, Address, Identity, ReducerContext, SpacetimeType}; #[derive(SpacetimeType)] pub enum SimpleEnum { @@ -32,6 +33,7 @@ pub enum EnumWithPayload { F64(f64), Str(String), Identity(Identity), + Address(Address), Bytes(Vec), Ints(Vec), Strings(Vec), @@ -65,6 +67,7 @@ pub struct EveryPrimitiveStruct { m: f64, n: String, o: Identity, + p: Address, } #[derive(SpacetimeType)] @@ -84,6 +87,7 @@ pub struct EveryVecStruct { m: Vec, n: Vec, o: Vec, + p: Vec
, } /// Defines one or more tables, and optionally reducers alongside them. @@ -230,6 +234,7 @@ define_tables! { OneString { insert insert_one_string } s String; OneIdentity { insert insert_one_identity } i Identity; + OneAddress { insert insert_one_address } a Address; OneSimpleEnum { insert insert_one_simple_enum } e SimpleEnum; OneEnumWithPayload { insert insert_one_enum_with_payload } e EnumWithPayload; @@ -262,6 +267,7 @@ define_tables! { VecString { insert insert_vec_string } s Vec; VecIdentity { insert insert_vec_identity } i Vec; + VecAddress { insert insert_vec_address } a Vec
; VecSimpleEnum { insert insert_vec_simple_enum } e Vec; VecEnumWithPayload { insert insert_vec_enum_with_payload } e Vec; @@ -355,6 +361,12 @@ define_tables! { update_by update_unique_identity = update_by_i(i), delete_by delete_unique_identity = delete_by_i(i: Identity), } #[unique] i Identity, data i32; + + UniqueAddress { + insert_or_panic insert_unique_address, + update_by update_unique_address = update_by_a(a), + delete_by delete_unique_address = delete_by_a(a: Address), + } #[unique] a Address, data i32; } // Tables mapping a primary key to a boring i32 payload. @@ -390,14 +402,12 @@ define_tables! { delete_by delete_pk_u128 = delete_by_n(n: u128), } #[primarykey] n u128, data i32; - PkI8 { insert_or_panic insert_pk_i8, update_by update_pk_i8 = update_by_n(n), delete_by delete_pk_i8 = delete_by_n(n: i8), } #[primarykey] n i8, data i32; - PkI16 { insert_or_panic insert_pk_i16, update_by update_pk_i16 = update_by_n(n), @@ -422,7 +432,6 @@ define_tables! { delete_by delete_pk_i128 = delete_by_n(n: i128), } #[primarykey] n i128, data i32; - PkBool { insert_or_panic insert_pk_bool, update_by update_pk_bool = update_by_b(b), @@ -440,6 +449,70 @@ define_tables! { update_by update_pk_identity = update_by_i(i), delete_by delete_pk_identity = delete_by_i(i: Identity), } #[primarykey] i Identity, data i32; + + PkAddress { + insert_or_panic insert_pk_address, + update_by update_pk_address = update_by_a(a), + delete_by delete_pk_address = delete_by_a(a: Address), + } #[primarykey] a Address, data i32; +} + +#[spacetimedb(reducer)] +fn insert_caller_one_identity(ctx: ReducerContext) -> anyhow::Result<()> { + OneIdentity::insert(OneIdentity { i: ctx.sender }); + Ok(()) +} + +#[spacetimedb(reducer)] +fn insert_caller_vec_identity(ctx: ReducerContext) -> anyhow::Result<()> { + VecIdentity::insert(VecIdentity { i: vec![ctx.sender] }); + Ok(()) +} + +#[spacetimedb(reducer)] +fn insert_caller_unique_identity(ctx: ReducerContext, data: i32) -> anyhow::Result<()> { + UniqueIdentity::insert(UniqueIdentity { i: ctx.sender, data })?; + Ok(()) +} + +#[spacetimedb(reducer)] +fn insert_caller_pk_identity(ctx: ReducerContext, data: i32) -> anyhow::Result<()> { + PkIdentity::insert(PkIdentity { i: ctx.sender, data })?; + Ok(()) +} + +#[spacetimedb(reducer)] +fn insert_caller_one_address(ctx: ReducerContext) -> anyhow::Result<()> { + OneAddress::insert(OneAddress { + a: ctx.address.context("No address in reducer context")?, + }); + Ok(()) +} + +#[spacetimedb(reducer)] +fn insert_caller_vec_address(ctx: ReducerContext) -> anyhow::Result<()> { + VecAddress::insert(VecAddress { + a: vec![ctx.address.context("No address in reducer context")?], + }); + Ok(()) +} + +#[spacetimedb(reducer)] +fn insert_caller_unique_address(ctx: ReducerContext, data: i32) -> anyhow::Result<()> { + UniqueAddress::insert(UniqueAddress { + a: ctx.address.context("No address in reducer context")?, + data, + })?; + Ok(()) +} + +#[spacetimedb(reducer)] +fn insert_caller_pk_address(ctx: ReducerContext, data: i32) -> anyhow::Result<()> { + PkAddress::insert(PkAddress { + a: ctx.address.context("No address in reducer context")?, + data, + })?; + Ok(()) } // Some weird-looking tables.