From ccbd96edbecc72ac3fb11c9b2a5d2f213f0988dc Mon Sep 17 00:00:00 2001 From: Teo Stocco Date: Mon, 10 Jan 2022 09:33:53 +0100 Subject: [PATCH] feat: support struct and pointer return types (#41) Co-authored-by: Divy Srivastava --- Makefile | 2 +- codegen.ts | 63 +++++--- deno_bindgen_macro/src/derive_fn.rs | 8 +- deno_bindgen_macro/src/lib.rs | 24 +++- deno_bindgen_macro/src/meta.rs | 1 + example/bindings/bindings.ts | 216 ++++++++++++++++------------ example/bindings_test.ts | 67 ++++++++- example/src/lib.rs | 47 ++++++ 8 files changed, 312 insertions(+), 116 deletions(-) diff --git a/Makefile b/Makefile index d48a531..564391a 100644 --- a/Makefile +++ b/Makefile @@ -3,4 +3,4 @@ fmt: deno fmt --ignore=target/,example/target/,example/bindings/ test: - cd example && deno_bindgen && deno test -A --unstable + cd example && deno run -A ../cli.ts && deno test -A --unstable diff --git a/codegen.ts b/codegen.ts index 45f0bec..adb2417 100644 --- a/codegen.ts +++ b/codegen.ts @@ -42,6 +42,7 @@ const BufferTypes: Record = { str: "string", buffer: "Uint8Array", buffermut: "Uint8Array", + ptr: "Uint8Array", }; enum Encoder { @@ -53,6 +54,7 @@ const BufferTypeEncoders: Record = { str: Encoder.None, buffer: Encoder.None, buffermut: Encoder.None, + ptr: Encoder.None, }; type TypeDef = Record>; @@ -90,8 +92,12 @@ type Options = { release?: boolean; }; +function isTypeDef(p: any) { + return typeof p !== "string"; +} + function isBufferType(p: any) { - return typeof p !== "string" || BufferTypes[p] !== undefined; + return isTypeDef(p) || BufferTypes[p] !== undefined; } // @littledivy is a dumb kid! @@ -125,7 +131,10 @@ function encode(v: string | Uint8Array): Uint8Array { if (typeof v !== "string") return v; return new TextEncoder().encode(v); } -function decode(v: any): Uint8Array { +function decode(v: Uint8Array): string { + return new TextDecoder().decode(v); +} +function read_pointer(v: any): Uint8Array { const ptr = new Deno.UnsafePointerView(v as Deno.UnsafePointer) const lengthBe = new Uint8Array(4); const view = new DataView(lengthBe.buffer); @@ -156,14 +165,14 @@ const _lib = await prepare(opts, { } }); ${Object.keys(decl).map((def) => typescript[def]).join("\n")} ${ - Object.keys(signature).map((sig) => - `export function ${sig}(${ - signature[sig].parameters.map((p, i) => - `a${i}: ${resolveType(decl, p)}` - ).join(",") + Object.keys(signature).map((sig) => { + const { parameters, result, nonBlocking } = signature[sig]; + + return `export function ${sig}(${ + parameters.map((p, i) => `a${i}: ${resolveType(decl, p)}`).join(",") }) { ${ - signature[sig].parameters.map((p, i) => + parameters.map((p, i) => isBufferType(p) ? `const a${i}_buf = encode(${ BufferTypeEncoders[p] ?? Encoder.JsonStringify @@ -172,18 +181,40 @@ ${ ).filter((c) => c !== null).join("\n") } let result = _lib.symbols.${sig}(${ - signature[sig].parameters.map((p, i) => + parameters.map((p, i) => isBufferType(p) ? `a${i}_buf, a${i}_buf.byteLength` : `a${i}` ).join(", ") }) as ${ - signature[sig].nonBlocking - ? `Promise<${resolveType(decl, signature[sig].result)}>` - : resolveType(decl, signature[sig].result) + nonBlocking + ? `Promise<${ + isTypeDef(result) + ? "Uint8Array" + : resolveType(decl, result) + }>` + : isTypeDef(result) + ? "Uint8Array" + : resolveType(decl, result) } - ${isBufferType(signature[sig].result) ? `result = decode(result);` : ""} - return result; -}` - ).join("\n") + ${ + isBufferType(result) + ? nonBlocking + ? `result = result.then(read_pointer)` + : `result = read_pointer(result)` + : "" + }; + ${ + isTypeDef(result) + ? nonBlocking + ? `return result.then(r => JSON.parse(decode(r))) as Promise<${ + resolveType(decl, result) + }>` + : `return JSON.parse(decode(result)) as ${ + resolveType(decl, result) + }` + : "return result" + }; +}`; + }).join("\n") } `, ); diff --git a/deno_bindgen_macro/src/derive_fn.rs b/deno_bindgen_macro/src/derive_fn.rs index db1208a..c3863e7 100644 --- a/deno_bindgen_macro/src/derive_fn.rs +++ b/deno_bindgen_macro/src/derive_fn.rs @@ -88,6 +88,7 @@ pub fn process_function( let result = match &function.sig.output { ReturnType::Default => Type::Void, ReturnType::Type(_, ref ty) => match ty.as_ref() { + syn::Type::Ptr(_) => Type::Ptr, syn::Type::Path(ref ty) => { let segment = ty.path.segments.first().unwrap(); let ident = segment.ident.to_string(); @@ -105,7 +106,12 @@ pub fn process_function( "isize" => Type::Isize, "f32" => Type::F32, "f64" => Type::F64, - _ => panic!("{} return type not supported by Deno FFI", ident), + _ => { + match metadata.type_defs.get(&ident) { + Some(_) => Type::StructEnum { ident }, + None => panic!("{} return type not supported by Deno FFI", ident) + } + } } } syn::Type::Reference(ref ty) => match *ty.elem { diff --git a/deno_bindgen_macro/src/lib.rs b/deno_bindgen_macro/src/lib.rs index 2e29e57..50cf7c7 100644 --- a/deno_bindgen_macro/src/lib.rs +++ b/deno_bindgen_macro/src/lib.rs @@ -146,11 +146,30 @@ pub fn deno_bindgen(attr: TokenStream, input: TokenStream) -> TokenStream { let result = v.as_ptr(); // Leak the result to JS land. ::std::mem::forget(v); + result }; (ty, transformer) } - _ => (syn::Type::from(symbol.result), quote! {}), + Type::StructEnum { .. } => { + let ty = parse_quote! { *const u8 }; + let transformer = quote! { + let json = deno_bindgen::serde_json::to_string(&result).expect("Failed to serialize as JSON"); + let encoded = json.into_bytes(); + let length = (encoded.len() as u32).to_be_bytes(); + let mut v = length.to_vec(); + v.extend(encoded.clone()); + + let ret = v.as_ptr(); + // Leak the result to JS land. + ::std::mem::forget(v); + ret + }; + + (ty, transformer) + }, + Type::Ptr => (parse_quote! { *const u8 }, quote! { result }), + _ => (syn::Type::from(symbol.result), quote! { result }), }; let name = &func.sig.ident; @@ -174,7 +193,6 @@ pub fn deno_bindgen(attr: TokenStream, input: TokenStream) -> TokenStream { #overrides let result = __inner_impl(#(#input_idents, ) *); #transformer - result } }) } @@ -187,7 +205,7 @@ pub fn deno_bindgen(attr: TokenStream, input: TokenStream) -> TokenStream { .unwrap(); TokenStream::from(quote! { - #[derive(::serde::Deserialize)] + #[derive(::serde::Deserialize,::serde::Serialize)] #input }) } diff --git a/deno_bindgen_macro/src/meta.rs b/deno_bindgen_macro/src/meta.rs index 1d7cd3e..dd404fe 100644 --- a/deno_bindgen_macro/src/meta.rs +++ b/deno_bindgen_macro/src/meta.rs @@ -27,6 +27,7 @@ pub enum Type { Buffer, BufferMut, Str, + Ptr, /// Not-so straightforward types that /// `deno_bingen` maps to. diff --git a/example/bindings/bindings.ts b/example/bindings/bindings.ts index 9472910..403e5df 100644 --- a/example/bindings/bindings.ts +++ b/example/bindings/bindings.ts @@ -1,23 +1,26 @@ // Auto-generated with deno_bindgen -import { CachePolicy, prepare } from "https://deno.land/x/plug@0.4.1/plug.ts" +import { CachePolicy, prepare } from "https://deno.land/x/plug@0.4.1/plug.ts"; function encode(v: string | Uint8Array): Uint8Array { - if (typeof v !== "string") return v - return new TextEncoder().encode(v) -} -function decode(v: any): Uint8Array { - const ptr = new Deno.UnsafePointerView(v as Deno.UnsafePointer) - const lengthBe = new Uint8Array(4) - const view = new DataView(lengthBe.buffer) - ptr.copyInto(lengthBe, 0) - const buf = new Uint8Array(view.getUint32(0)) - ptr.copyInto(buf, 4) - return buf + if (typeof v !== "string") return v; + return new TextEncoder().encode(v); +} +function decode(v: Uint8Array): string { + return new TextDecoder().decode(v); +} +function read_pointer(v: any): Uint8Array { + const ptr = new Deno.UnsafePointerView(v as Deno.UnsafePointer); + const lengthBe = new Uint8Array(4); + const view = new DataView(lengthBe.buffer); + ptr.copyInto(lengthBe, 0); + const buf = new Uint8Array(view.getUint32(0)); + ptr.copyInto(buf, 4); + return buf; } const opts = { name: "deno_bindgen_test", - url: (new URL("../target/release", import.meta.url)).toString(), - policy: undefined, -} + url: (new URL("../target/debug", import.meta.url)).toString(), + policy: CachePolicy.NONE, +}; const _lib = await prepare(opts, { add: { parameters: ["i32", "i32"], result: "i32", nonblocking: false }, add2: { parameters: ["pointer", "usize"], result: "i32", nonblocking: false }, @@ -32,11 +35,22 @@ const _lib = await prepare(opts, { result: "pointer", nonblocking: false, }, + test_buffer_return_async: { + parameters: ["pointer", "usize"], + result: "pointer", + nonblocking: true, + }, test_lifetime: { parameters: ["pointer", "usize"], result: "usize", nonblocking: false, }, + test_manual_ptr: { parameters: [], result: "pointer", nonblocking: false }, + test_manual_ptr_async: { + parameters: [], + result: "pointer", + nonblocking: true, + }, test_mixed: { parameters: ["isize", "pointer", "usize"], result: "i32", @@ -52,6 +66,8 @@ const _lib = await prepare(opts, { result: "void", nonblocking: false, }, + test_output: { parameters: [], result: "pointer", nonblocking: false }, + test_output_async: { parameters: [], result: "pointer", nonblocking: true }, test_serde: { parameters: ["pointer", "usize"], result: "u8", @@ -67,32 +83,7 @@ const _lib = await prepare(opts, { result: "i32", nonblocking: false, }, -}) -export type TestLifetimes = { - text: string -} -export type TestLifetimeEnums = { - Text: { - _text: string - } -} -export type OptionStruct = { - maybe: string | undefined | null -} -export type MyStruct = { - arr: Array -} -export type PlainEnum = - | { - a: { - _a: string - } - } - | "b" - | "c" -export type TestLifetimeWrap = { - _a: TestLifetimeEnums -} +}); /** * Doc comment for `Input` struct. * ...testing multiline @@ -103,90 +94,133 @@ export type Input = { * transformed to JS doc * comments. */ - a: number - b: number -} + a: number; + b: number; +}; +export type TestLifetimeEnums = { + Text: { + _text: string; + }; +}; +export type PlainEnum = + | { + a: { + _a: string; + }; + } + | "b" + | "c"; +export type TestLifetimes = { + text: string; +}; +export type MyStruct = { + arr: Array; +}; +export type TestLifetimeWrap = { + _a: TestLifetimeEnums; +}; export type TagAndContent = | { key: "A"; value: { b: number } } - | { key: "C"; value: { d: number } } + | { key: "C"; value: { d: number } }; +export type OptionStruct = { + maybe: string | undefined | null; +}; export function add(a0: number, a1: number) { - let result = _lib.symbols.add(a0, a1) as number - - return result + let result = _lib.symbols.add(a0, a1) as number; + return result; } export function add2(a0: Input) { - const a0_buf = encode(JSON.stringify(a0)) - let result = _lib.symbols.add2(a0_buf, a0_buf.byteLength) as number - - return result + const a0_buf = encode(JSON.stringify(a0)); + let result = _lib.symbols.add2(a0_buf, a0_buf.byteLength) as number; + return result; } export function sleep(a0: number) { - let result = _lib.symbols.sleep(a0) as Promise - - return result + let result = _lib.symbols.sleep(a0) as Promise; + return result; } export function test_buf(a0: Uint8Array) { - const a0_buf = encode(a0) - let result = _lib.symbols.test_buf(a0_buf, a0_buf.byteLength) as number - - return result + const a0_buf = encode(a0); + let result = _lib.symbols.test_buf(a0_buf, a0_buf.byteLength) as number; + return result; } export function test_buffer_return(a0: Uint8Array) { - const a0_buf = encode(a0) + const a0_buf = encode(a0); let result = _lib.symbols.test_buffer_return( a0_buf, a0_buf.byteLength, - ) as Uint8Array - result = decode(result) - return result + ) as Uint8Array; + result = read_pointer(result); + return result; +} +export function test_buffer_return_async(a0: Uint8Array) { + const a0_buf = encode(a0); + let result = _lib.symbols.test_buffer_return_async( + a0_buf, + a0_buf.byteLength, + ) as Promise; + result = result.then(read_pointer); + return result; } export function test_lifetime(a0: TestLifetimes) { - const a0_buf = encode(JSON.stringify(a0)) - let result = _lib.symbols.test_lifetime(a0_buf, a0_buf.byteLength) as number - - return result + const a0_buf = encode(JSON.stringify(a0)); + let result = _lib.symbols.test_lifetime(a0_buf, a0_buf.byteLength) as number; + return result; +} +export function test_manual_ptr() { + let result = _lib.symbols.test_manual_ptr() as Uint8Array; + result = read_pointer(result); + return result; +} +export function test_manual_ptr_async() { + let result = _lib.symbols.test_manual_ptr_async() as Promise; + result = result.then(read_pointer); + return result; } export function test_mixed(a0: number, a1: Input) { - const a1_buf = encode(JSON.stringify(a1)) - let result = _lib.symbols.test_mixed(a0, a1_buf, a1_buf.byteLength) as number - - return result + const a1_buf = encode(JSON.stringify(a1)); + let result = _lib.symbols.test_mixed(a0, a1_buf, a1_buf.byteLength) as number; + return result; } export function test_mixed_order(a0: number, a1: Input, a2: number) { - const a1_buf = encode(JSON.stringify(a1)) + const a1_buf = encode(JSON.stringify(a1)); let result = _lib.symbols.test_mixed_order( a0, a1_buf, a1_buf.byteLength, a2, - ) as number - - return result + ) as number; + return result; } export function test_mut_buf(a0: Uint8Array) { - const a0_buf = encode(a0) - let result = _lib.symbols.test_mut_buf(a0_buf, a0_buf.byteLength) as null - - return result + const a0_buf = encode(a0); + let result = _lib.symbols.test_mut_buf(a0_buf, a0_buf.byteLength) as null; + return result; +} +export function test_output() { + let result = _lib.symbols.test_output() as Uint8Array; + result = read_pointer(result); + return JSON.parse(decode(result)) as Input; +} +export function test_output_async() { + let result = _lib.symbols.test_output_async() as Promise; + result = result.then(read_pointer); + return result.then((r) => JSON.parse(decode(r))) as Promise; } export function test_serde(a0: MyStruct) { - const a0_buf = encode(JSON.stringify(a0)) - let result = _lib.symbols.test_serde(a0_buf, a0_buf.byteLength) as number - - return result + const a0_buf = encode(JSON.stringify(a0)); + let result = _lib.symbols.test_serde(a0_buf, a0_buf.byteLength) as number; + return result; } export function test_str(a0: string) { - const a0_buf = encode(a0) - let result = _lib.symbols.test_str(a0_buf, a0_buf.byteLength) as null - - return result + const a0_buf = encode(a0); + let result = _lib.symbols.test_str(a0_buf, a0_buf.byteLength) as null; + return result; } export function test_tag_and_content(a0: TagAndContent) { - const a0_buf = encode(JSON.stringify(a0)) + const a0_buf = encode(JSON.stringify(a0)); let result = _lib.symbols.test_tag_and_content( a0_buf, a0_buf.byteLength, - ) as number - - return result + ) as number; + return result; } diff --git a/example/bindings_test.ts b/example/bindings_test.ts index eefe2ca..46f6d90 100644 --- a/example/bindings_test.ts +++ b/example/bindings_test.ts @@ -4,14 +4,19 @@ import { OptionStruct, sleep, test_buf, + test_buffer_return, + test_buffer_return_async, test_lifetime, + test_manual_ptr, + test_manual_ptr_async, test_mixed, test_mixed_order, test_mut_buf, + test_output, + test_output_async, test_serde, test_str, test_tag_and_content, - test_buffer_return, } from "./bindings/bindings.ts"; import { assert, assertEquals } from "https://deno.land/std/testing/asserts.ts"; @@ -111,10 +116,64 @@ Deno.test({ const buf = test_buffer_return( new Uint8Array([1, 2, 3]), ); - + assertEquals(buf.byteLength, 3); assertEquals(buf[0], 1); assertEquals(buf[1], 2); assertEquals(buf[2], 3); - } -}) \ No newline at end of file + }, +}); + +Deno.test({ + name: "test_buffer_return_async#test", + fn: async () => { + const buf = await test_buffer_return_async( + new Uint8Array([1, 2, 3]), + ); + + assertEquals(buf.byteLength, 3); + assertEquals(buf[0], 1); + assertEquals(buf[1], 2); + assertEquals(buf[2], 3); + }, +}); + +Deno.test({ + name: "test_manual_ptr#test", + fn: () => { + const buf = test_manual_ptr(); + const val = new TextDecoder().decode(buf); + + assertEquals(val, "test"); + }, +}); + +Deno.test({ + name: "test_manual_ptr_async#test", + fn: async () => { + const buf = await test_manual_ptr_async(); + const val = new TextDecoder().decode(buf); + + assertEquals(val, "test"); + }, +}); + +Deno.test({ + name: "test_output#test", + fn: () => { + const obj = test_output(); + + assertEquals(obj.a, 1); + assertEquals(obj.b, 2); + }, +}); + +Deno.test({ + name: "test_output_async#test", + fn: async () => { + const obj = await test_output_async(); + + assertEquals(obj.a, 3); + assertEquals(obj.b, 4); + }, +}); diff --git a/example/src/lib.rs b/example/src/lib.rs index a1936c4..e728edd 100644 --- a/example/src/lib.rs +++ b/example/src/lib.rs @@ -134,3 +134,50 @@ fn test_tag_and_content(arg: TagAndContent) -> i32 { fn test_buffer_return(buf: &[u8]) -> &[u8] { buf } + +#[deno_bindgen(non_blocking)] +fn test_buffer_return_async(buf: &[u8]) -> &[u8] { + buf +} + +#[deno_bindgen] +fn test_manual_ptr() -> *const u8 { + let result = String::from("test").into_bytes(); + let length = (result.len() as u32).to_be_bytes(); + let mut v = length.to_vec(); + v.extend(result.clone()); + + let ret = v.as_ptr(); + // Leak the result to JS land. + ::std::mem::forget(v); + ret +} + +#[deno_bindgen(non_blocking)] +fn test_manual_ptr_async() -> *const u8 { + let result = String::from("test").into_bytes(); + let length = (result.len() as u32).to_be_bytes(); + let mut v = length.to_vec(); + v.extend(result.clone()); + + let ret = v.as_ptr(); + // Leak the result to JS land. + ::std::mem::forget(v); + ret +} + +#[deno_bindgen] +fn test_output() -> Input { + Input { + a: 1, + b: 2 + } +} + +#[deno_bindgen(non_blocking)] +fn test_output_async() -> Input { + Input { + a: 3, + b: 4 + } +}