diff --git a/README.md b/README.md index bca4f179..5c1b9e30 100644 --- a/README.md +++ b/README.md @@ -342,6 +342,46 @@ assert(res.ok); assert.strictEqual(await res.text(), "Hello, World!"); ``` +## D1 Databases + +### Enabling D1 databases +As D1 databases are in alpha, you'll need to enable the `d1` feature on the `worker` crate. + +```toml +worker = { version = "x.y.z", features = ["d1"] } +``` + +### Example usage +```rust +use worker::*; + +#[derive(Deserialize)] +struct Thing { + thing_id: String, + desc: String, + num: u32, +} + +#[event(fetch, respond_with_errors)] +pub async fn main(request: Request, env: Env, _ctx: Context) -> Result { + Router::new() + .get_async("/:id", |_, ctx| async move { + let id = ctx.param("id").unwrap()?; + let d1 = ctx.env.d1("things-db")?; + let statement = d1.prepare("SELECT * FROM things WHERE thing_id = ?1"); + let query = statement.bind(&[id])?; + let result = query.first::(None).await?; + match result { + Some(thing) => Response::from_json(&thing), + None => Response::error("Not found", 404), + } + }) + .run(request, env) + .await +} +``` + + # Notes and FAQ It is exciting to see how much is possible with a framework like this, by expanding the options diff --git a/tsconfig.json b/tsconfig.json index 3e080513..bed35c1d 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -11,7 +11,7 @@ // "disableReferencedProjectLoad": true, /* Reduce the number of projects loaded automatically by TypeScript. */ /* Language and Environment */ - "target": "es2016", /* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */ + "target": "es2022", /* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */ // "lib": [], /* Specify a set of bundled library declaration files that describe the target runtime environment. */ // "jsx": "preserve", /* Specify what JSX code is generated. */ // "experimentalDecorators": true, /* Enable experimental support for legacy experimental decorators. */ @@ -25,9 +25,9 @@ // "moduleDetection": "auto", /* Control what method is used to detect module-format JS files. */ /* Modules */ - "module": "commonjs", /* Specify what module code is generated. */ + "module": "ES2022", /* Specify what module code is generated. */ // "rootDir": "./", /* Specify the root folder within your source files. */ - // "moduleResolution": "node10", /* Specify how TypeScript looks up a file from a given module specifier. */ + "moduleResolution": "nodenext", /* Specify how TypeScript looks up a file from a given module specifier. */ // "baseUrl": "./", /* Specify the base directory to resolve non-relative module names. */ // "paths": {}, /* Specify a set of entries that re-map imports to additional lookup locations. */ // "rootDirs": [], /* Allow multiple folders to be treated as one when resolving modules. */ diff --git a/worker-sandbox/.gitignore b/worker-sandbox/.gitignore new file mode 100644 index 00000000..7310e736 --- /dev/null +++ b/worker-sandbox/.gitignore @@ -0,0 +1 @@ +.wrangler \ No newline at end of file diff --git a/worker-sandbox/Cargo.toml b/worker-sandbox/Cargo.toml index 29d01d8c..c5530b27 100644 --- a/worker-sandbox/Cargo.toml +++ b/worker-sandbox/Cargo.toml @@ -26,7 +26,7 @@ http = "0.2.9" regex = "1.8.4" serde = { version = "1.0.164", features = ["derive"] } serde_json = "1.0.96" -worker = { path = "../worker", version = "0.0.17", features= ["queue"] } +worker = { path = "../worker", version = "0.0.17", features= ["queue", "d1"] } futures-channel = "0.3.28" futures-util = { version = "0.3.28", default-features = false } rand = "0.8.5" diff --git a/worker-sandbox/src/d1.rs b/worker-sandbox/src/d1.rs new file mode 100644 index 00000000..64bd8443 --- /dev/null +++ b/worker-sandbox/src/d1.rs @@ -0,0 +1,106 @@ +use serde::Deserialize; +use worker::*; + +use crate::SomeSharedData; + +#[derive(Deserialize)] +struct Person { + id: u32, + name: String, + age: u32, +} + +pub async fn prepared_statement( + _req: Request, + ctx: RouteContext, +) -> Result { + let db = ctx.env.d1("DB")?; + let stmt = worker::query!(&db, "SELECT * FROM people WHERE name = ?", "Ryan Upton")?; + + // All rows + let results = stmt.all().await?; + let people = results.results::()?; + + assert!(results.success()); + assert_eq!(results.error(), None); + assert_eq!(people.len(), 1); + assert_eq!(people[0].name, "Ryan Upton"); + assert_eq!(people[0].age, 21); + assert_eq!(people[0].id, 6); + + // All columns of the first rows + let person = stmt.first::(None).await?.unwrap(); + assert_eq!(person.name, "Ryan Upton"); + assert_eq!(person.age, 21); + + // The name of the first row + let name = stmt.first::(Some("name")).await?.unwrap(); + assert_eq!(name, "Ryan Upton"); + + // All of the rows as column arrays of raw JSON values. + let rows = stmt.raw::().await?; + assert_eq!(rows.len(), 1); + let columns = &rows[0]; + + assert_eq!(columns[0].as_u64(), Some(6)); + assert_eq!(columns[1].as_str(), Some("Ryan Upton")); + assert_eq!(columns[2].as_u64(), Some(21)); + + Response::ok("ok") +} + +pub async fn batch(_req: Request, ctx: RouteContext) -> Result { + let db = ctx.env.d1("DB")?; + let mut results = db + .batch(vec![ + worker::query!(&db, "SELECT * FROM people WHERE id < 4"), + worker::query!(&db, "SELECT * FROM people WHERE id > 4"), + ]) + .await? + .into_iter(); + + let first_results = results.next().unwrap().results::()?; + assert_eq!(first_results.len(), 3); + assert_eq!(first_results[0].id, 1); + assert_eq!(first_results[1].id, 2); + assert_eq!(first_results[2].id, 3); + + let second_results = results.next().unwrap().results::()?; + assert_eq!(second_results.len(), 2); + assert_eq!(second_results[0].id, 5); + assert_eq!(second_results[1].id, 6); + + Response::ok("ok") +} + +pub async fn exec(mut req: Request, ctx: RouteContext) -> Result { + let db = ctx.env.d1("DB")?; + let result = db + .exec(req.text().await?.as_ref()) + .await + .expect("doesn't exist"); + + Response::ok(result.count().unwrap_or_default().to_string()) +} + +pub async fn dump(_req: Request, ctx: RouteContext) -> Result { + let db = ctx.env.d1("DB")?; + let bytes = db.dump().await?; + Response::from_bytes(bytes) +} + +pub async fn error(_req: Request, ctx: RouteContext) -> Result { + let db = ctx.env.d1("DB")?; + let error = db + .exec("THIS IS NOT VALID SQL") + .await + .expect_err("did not get error"); + + if let Error::D1(error) = error { + assert_eq!(error.cause(), "Error in line 1: THIS IS NOT VALID SQL: ERROR 9009: SQL prepare error: near \"THIS\": syntax error in THIS IS NOT VALID SQL at offset 0") + } else { + panic!("expected D1 error"); + } + + Response::ok("") +} diff --git a/worker-sandbox/src/lib.rs b/worker-sandbox/src/lib.rs index 3d5b79d4..a80ab59a 100644 --- a/worker-sandbox/src/lib.rs +++ b/worker-sandbox/src/lib.rs @@ -15,6 +15,7 @@ use worker::*; mod alarm; mod counter; +mod d1; mod r2; mod test; mod utils; @@ -497,7 +498,7 @@ pub async fn main(req: Request, env: Env, _ctx: worker::Context) -> Result { // Ensure that the cancelled future returns an AbortError. match cancelled_fut.await { - Err(e) if e.to_string().starts_with("AbortError") => { /* Yay! It worked, let's do nothing to celebrate */}, + Err(e) if e.to_string().contains("AbortError") => { /* Yay! It worked, let's do nothing to celebrate */}, Err(e) => panic!("Fetch errored with a different error than expected: {:#?}", e), Ok(text) => panic!("Fetch unexpectedly succeeded: {}", text) } @@ -698,6 +699,11 @@ pub async fn main(req: Request, env: Env, _ctx: worker::Context) -> Result = guard.clone(); Response::from_json(&messages) }) + .get_async("/d1/prepared", d1::prepared_statement) + .get_async("/d1/batch", d1::batch) + .get_async("/d1/dump", d1::dump) + .post_async("/d1/exec", d1::exec) + .get_async("/d1/error", d1::error) .get_async("/r2/list-empty", r2::list_empty) .get_async("/r2/list", r2::list) .get_async("/r2/get-empty", r2::get_empty) diff --git a/worker-sandbox/tests/d1.spec.ts b/worker-sandbox/tests/d1.spec.ts new file mode 100644 index 00000000..5e426a52 --- /dev/null +++ b/worker-sandbox/tests/d1.spec.ts @@ -0,0 +1,70 @@ +import { describe, test, expect, beforeAll } from "vitest"; + +const hasLocalDevServer = await fetch("http://localhost:8787/request") + .then((resp) => resp.ok) + .catch(() => false); + +async function exec(query: string): Promise { + const resp = await fetch("http://localhost:8787/d1/exec", { + method: "POST", + body: query.split("\n").join(""), + }); + + const body = await resp.text(); + expect(resp.status).toBe(200); + return Number(body); +} + +describe.skipIf(!hasLocalDevServer)("d1", () => { + test("create table", async () => { + const query = `CREATE TABLE IF NOT EXISTS uniqueTable ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + age INTEGER NOT NULL + );`; + + expect(await exec(query)).toBe(1); + }); + + test("insert data", async () => { + let query = `CREATE TABLE IF NOT EXISTS people ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + age INTEGER NOT NULL + );`; + + expect(await exec(query)).toBe(1); + + query = `INSERT OR IGNORE INTO people + (id, name, age) + VALUES + (1, 'Freddie Pearce', 26), + (2, 'Wynne Ogley', 67), + (3, 'Dorian Fischer', 19), + (4, 'John Smith', 92), + (5, 'Magaret Willamson', 54), + (6, 'Ryan Upton', 21);`; + + expect(await exec(query)).toBe(1); + }); + + test("prepared statement", async () => { + const resp = await fetch("http://localhost:8787/d1/prepared"); + expect(resp.status).toBe(200); + }); + + test("batch", async () => { + const resp = await fetch("http://localhost:8787/d1/batch"); + expect(resp.status).toBe(200); + }); + + test("dump", async () => { + const resp = await fetch("http://localhost:8787/d1/dump"); + expect(resp.status).toBe(200); + }); + + test("dump", async () => { + const resp = await fetch("http://localhost:8787/d1/error"); + expect(resp.status).toBe(200); + }); +}); diff --git a/worker-sandbox/wrangler.toml b/worker-sandbox/wrangler.toml index 3920ad36..60333d88 100644 --- a/worker-sandbox/wrangler.toml +++ b/worker-sandbox/wrangler.toml @@ -21,6 +21,12 @@ remote-service = "./remote-service" [durable_objects] bindings = [{ name = "COUNTER", class_name = "Counter" }, { name = "ALARM", class_name = "AlarmObject" }] +[[d1_databases]] +binding = 'DB' +database_name = 'my_db' +database_id = '.' +preview_database_id = '.' + [[queues.consumers]] queue = "my_queue" diff --git a/worker-sys/Cargo.toml b/worker-sys/Cargo.toml index feeabfbe..248ced76 100644 --- a/worker-sys/Cargo.toml +++ b/worker-sys/Cargo.toml @@ -41,4 +41,5 @@ features = [ ] [features] +d1 = [] queue = [] diff --git a/worker-sys/src/types.rs b/worker-sys/src/types.rs index 0e767b11..7356ff29 100644 --- a/worker-sys/src/types.rs +++ b/worker-sys/src/types.rs @@ -1,4 +1,6 @@ mod context; +#[cfg(feature = "d1")] +mod d1; mod durable_object; mod dynamic_dispatcher; mod fetcher; @@ -13,6 +15,8 @@ mod tls_client_auth; mod websocket_pair; pub use context::*; +#[cfg(feature = "d1")] +pub use d1::*; pub use durable_object::*; pub use dynamic_dispatcher::*; pub use fetcher::*; diff --git a/worker-sys/src/types/d1.rs b/worker-sys/src/types/d1.rs new file mode 100644 index 00000000..50a3d4d0 --- /dev/null +++ b/worker-sys/src/types/d1.rs @@ -0,0 +1,75 @@ +use ::js_sys::Object; +use wasm_bindgen::prelude::*; + +use js_sys::{Array, Promise}; + +#[wasm_bindgen] +extern "C" { + #[derive(Debug, Clone)] + pub type D1Result; + + #[wasm_bindgen(structural, method, getter, js_name=results)] + pub fn results(this: &D1Result) -> Option; + + #[wasm_bindgen(structural, method, getter, js_name=success)] + pub fn success(this: &D1Result) -> bool; + + #[wasm_bindgen(structural, method, getter, js_name=error)] + pub fn error(this: &D1Result) -> Option; + + #[wasm_bindgen(structural, method, getter, js_name=meta)] + pub fn meta(this: &D1Result) -> Object; +} + +#[wasm_bindgen] +extern "C" { + #[derive(Debug, Clone)] + pub type D1ExecResult; + + #[wasm_bindgen(structural, method, getter, js_name=count)] + pub fn count(this: &D1ExecResult) -> Option; + + #[wasm_bindgen(structural, method, getter, js_name=duration)] + pub fn duration(this: &D1ExecResult) -> Option; +} + +#[wasm_bindgen] +extern "C" { + #[wasm_bindgen(extends=::js_sys::Object, js_name=D1Database)] + #[derive(Debug, Clone, PartialEq, Eq)] + pub type D1Database; + + #[wasm_bindgen(structural, method, js_class=D1Database, js_name=prepare)] + pub fn prepare(this: &D1Database, query: &str) -> D1PreparedStatement; + + #[wasm_bindgen(structural, method, js_class=D1Database, js_name=dump)] + pub fn dump(this: &D1Database) -> Promise; + + #[wasm_bindgen(structural, method, js_class=D1Database, js_name=batch)] + pub fn batch(this: &D1Database, statements: Array) -> Promise; + + #[wasm_bindgen(structural, method, js_class=D1Database, js_name=exec)] + pub fn exec(this: &D1Database, query: &str) -> Promise; +} + +#[wasm_bindgen] +extern "C" { + #[wasm_bindgen(extends=::js_sys::Object, js_name=D1PreparedStatement)] + #[derive(Debug, Clone, PartialEq, Eq)] + pub type D1PreparedStatement; + + #[wasm_bindgen(structural, method, catch, variadic, js_class=D1PreparedStatement, js_name=bind)] + pub fn bind(this: &D1PreparedStatement, values: Array) -> Result; + + #[wasm_bindgen(structural, method, js_class=D1PreparedStatement, js_name=first)] + pub fn first(this: &D1PreparedStatement, col_name: Option<&str>) -> Promise; + + #[wasm_bindgen(structural, method, js_class=D1PreparedStatement, js_name=run)] + pub fn run(this: &D1PreparedStatement) -> Promise; + + #[wasm_bindgen(structural, method, js_class=D1PreparedStatement, js_name=all)] + pub fn all(this: &D1PreparedStatement) -> Promise; + + #[wasm_bindgen(structural, method, js_class=D1PreparedStatement, js_name=raw)] + pub fn raw(this: &D1PreparedStatement) -> Promise; +} diff --git a/worker/Cargo.toml b/worker/Cargo.toml index 2df98edf..2be5cc2f 100644 --- a/worker/Cargo.toml +++ b/worker/Cargo.toml @@ -44,3 +44,4 @@ features = [ [features] queue = ["worker-macros/queue", "worker-sys/queue"] +d1 = ["worker-sys/d1"] diff --git a/worker/src/d1/macros.rs b/worker/src/d1/macros.rs new file mode 100644 index 00000000..14ed269d --- /dev/null +++ b/worker/src/d1/macros.rs @@ -0,0 +1,38 @@ +/// Prepare a D1 query from the provided D1Database, query string, and optional query parameters. +/// +/// Any parameter provided is required to implement [`serde::Serialize`] to be used. +/// +/// Using [`query`] is equivalent to using db.prepare('').bind('') in Javascript. +/// +/// # Example +/// +/// ``` +/// let query = worker::query!( +/// &d1, +/// "SELECT * FROM things WHERE num > ?1 AND num < ?2", +/// &min, +/// &max, +/// )?; +/// ``` +#[macro_export] +macro_rules! query { + // rule for simple queries + ($db:expr, $query:expr) => { + $crate::d1::D1Database::prepare($db, $query) + }; + // rule for parameterized queries + ($db:expr, $query:expr, $($args:expr),* $(,)?) => {{ + || -> $crate::Result<$crate::d1::D1PreparedStatement> { + let prepared = $crate::d1::D1Database::prepare($db, $query); + + // D1 doesn't support taking in undefined values, so we translate these missing values to NULL. + let serializer = $crate::d1::serde_wasm_bindgen::Serializer::new().serialize_missing_as_null(true); + let bindings = &[$( + ::serde::ser::Serialize::serialize(&$args, &serializer) + .map_err(|e| $crate::Error::Internal(e.into()))? + ),*]; + + $crate::d1::D1PreparedStatement::bind(prepared, bindings) + }() + }}; +} diff --git a/worker/src/d1/mod.rs b/worker/src/d1/mod.rs new file mode 100644 index 00000000..a9f0ffa1 --- /dev/null +++ b/worker/src/d1/mod.rs @@ -0,0 +1,302 @@ +use std::fmt::Display; +use std::fmt::Formatter; +use std::result::Result as StdResult; + +use js_sys::Array; +use js_sys::ArrayBuffer; +use js_sys::JsString; +use js_sys::Uint8Array; +use serde::Deserialize; +use wasm_bindgen::{JsCast, JsValue}; +use wasm_bindgen_futures::JsFuture; +use worker_sys::types::D1Database as D1DatabaseSys; +use worker_sys::types::D1ExecResult; +use worker_sys::types::D1PreparedStatement as D1PreparedStatementSys; +use worker_sys::types::D1Result as D1ResultSys; + +use crate::env::EnvBinding; +use crate::Error; +use crate::Result; + +pub use serde_wasm_bindgen; + +pub mod macros; + +// A D1 Database. +pub struct D1Database(D1DatabaseSys); + +impl D1Database { + /// Prepare a query statement from a query string. + pub fn prepare>(&self, query: T) -> D1PreparedStatement { + self.0.prepare(&query.into()).into() + } + + /// Dump the data in the database to a `Vec`. + pub async fn dump(&self) -> Result> { + let result = JsFuture::from(self.0.dump()).await; + let array_buffer = cast_to_d1_error(result)?; + let array_buffer = array_buffer.dyn_into::()?; + let array = Uint8Array::new(&array_buffer); + Ok(array.to_vec()) + } + + /// Batch execute one or more statements against the database. + /// + /// Returns the results in the same order as the provided statements. + pub async fn batch(&self, statements: Vec) -> Result> { + let statements = statements.into_iter().map(|s| s.0).collect::(); + let results = JsFuture::from(self.0.batch(statements)).await; + let results = cast_to_d1_error(results)?; + let results = results.dyn_into::()?; + let mut vec = Vec::with_capacity(results.length() as usize); + for result in results.iter() { + let result = result.unchecked_into::(); + vec.push(D1Result(result)); + } + Ok(vec) + } + + /// Execute one or more queries directly against the database. + /// + /// The input can be one or multiple queries separated by `\n`. + /// + /// # Considerations + /// + /// This method can have poorer performance (prepared statements can be reused + /// in some cases) and, more importantly, is less safe. Only use this + /// method for maintenance and one-shot tasks (example: migration jobs). + /// + /// If an error occurs, an exception is thrown with the query and error + /// messages, execution stops and further statements are not executed. + pub async fn exec(&self, query: &str) -> Result { + let result = JsFuture::from(self.0.exec(query)).await; + let result = cast_to_d1_error(result)?; + Ok(result.into()) + } +} + +impl EnvBinding for D1Database { + const TYPE_NAME: &'static str = "D1Database"; + + // Workaround for Miniflare D1 Beta + fn get(val: JsValue) -> Result { + let obj = js_sys::Object::from(val); + if obj.constructor().name() == Self::TYPE_NAME || obj.constructor().name() == "BetaDatabase" + { + Ok(obj.unchecked_into()) + } else { + Err(format!( + "Binding cannot be cast to the type {} from {}", + Self::TYPE_NAME, + obj.constructor().name() + ) + .into()) + } + } +} + +impl JsCast for D1Database { + fn instanceof(val: &JsValue) -> bool { + val.is_instance_of::() + } + + fn unchecked_from_js(val: JsValue) -> Self { + Self(val.into()) + } + + fn unchecked_from_js_ref(val: &JsValue) -> &Self { + unsafe { &*(val as *const JsValue as *const Self) } + } +} + +impl From for JsValue { + fn from(database: D1Database) -> Self { + JsValue::from(database.0) + } +} + +impl AsRef for D1Database { + fn as_ref(&self) -> &JsValue { + &self.0 + } +} + +impl From for D1Database { + fn from(inner: D1DatabaseSys) -> Self { + Self(inner) + } +} + +// A D1 prepared query statement. +pub struct D1PreparedStatement(D1PreparedStatementSys); + +impl D1PreparedStatement { + /// Bind one or more parameters to the statement. + /// Consumes the old statement and returns a new statement with the bound parameters. + /// + /// D1 follows the SQLite convention for prepared statements parameter binding. + /// + /// # Considerations + /// + /// Supports Ordered (?NNNN) and Anonymous (?) parameters - named parameters are currently not supported. + /// + pub fn bind(self, values: &[JsValue]) -> Result { + let array: Array = values.iter().collect::(); + + match self.0.bind(array) { + Ok(stmt) => Ok(D1PreparedStatement(stmt)), + Err(err) => Err(Error::from(err)), + } + } + + /// Return the first row of results. + /// + /// If `col_name` is `Some`, returns that single value, otherwise returns the entire object. + /// + /// If the query returns no rows, then this will return `None`. + /// + /// If the query returns rows, but column does not exist, then this will return an `Err`. + pub async fn first(&self, col_name: Option<&str>) -> Result> + where + T: for<'a> Deserialize<'a>, + { + let result = JsFuture::from(self.0.first(col_name)).await; + let js_value = cast_to_d1_error(result)?; + let value = serde_wasm_bindgen::from_value(js_value)?; + Ok(value) + } + + /// Executes a query against the database but only return metadata. + pub async fn run(&self) -> Result { + let result = JsFuture::from(self.0.run()).await; + let result = cast_to_d1_error(result)?; + Ok(D1Result(result.into())) + } + + /// Executes a query against the database and returns all rows and metadata. + pub async fn all(&self) -> Result { + let result = JsFuture::from(self.0.all()).await?; + Ok(D1Result(result.into())) + } + + /// Executes a query against the database and returns a `Vec` of rows instead of objects. + pub async fn raw(&self) -> Result>> + where + T: for<'a> Deserialize<'a>, + { + let result = JsFuture::from(self.0.raw()).await; + let result = cast_to_d1_error(result)?; + let result = result.dyn_into::()?; + let mut vec = Vec::with_capacity(result.length() as usize); + for value in result.iter() { + let value = serde_wasm_bindgen::from_value(value)?; + vec.push(value); + } + Ok(vec) + } +} + +impl From for D1PreparedStatement { + fn from(inner: D1PreparedStatementSys) -> Self { + Self(inner) + } +} + +// The result of a D1 query execution. +pub struct D1Result(D1ResultSys); + +impl D1Result { + /// Returns `true` if the result indicates a success, otherwise `false`. + pub fn success(&self) -> bool { + self.0.success() + } + + /// Return the error contained in this result. + /// + /// Returns `None` if the result indicates a success. + pub fn error(&self) -> Option { + self.0.error() + } + + /// Retrieve the collection of result objects, or an `Err` if an error occurred. + pub fn results(&self) -> Result> + where + T: for<'a> Deserialize<'a>, + { + if let Some(results) = self.0.results() { + let mut vec = Vec::with_capacity(results.length() as usize); + for result in results.iter() { + let result = serde_wasm_bindgen::from_value(result).unwrap(); + vec.push(result); + } + Ok(vec) + } else { + Ok(Vec::new()) + } + } +} + +#[derive(Clone)] +pub struct D1Error { + inner: js_sys::Error, +} + +impl D1Error { + /// Gets the cause of the error specific to D1. + pub fn cause(&self) -> String { + if let Ok(cause) = self.inner.cause().dyn_into::() { + cause.message().into() + } else { + "unknown error".into() + } + } +} + +impl std::fmt::Debug for D1Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let cause = self.inner.cause(); + + f.debug_struct("D1Error").field("cause", &cause).finish() + } +} + +impl Display for D1Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let cause = self.inner.cause(); + let cause = JsString::from(cause); + write!(f, "{}", cause) + } +} + +impl AsRef for D1Error { + fn as_ref(&self) -> &js_sys::Error { + &self.inner + } +} + +impl AsRef for D1Error { + fn as_ref(&self) -> &JsValue { + &self.inner + } +} + +fn cast_to_d1_error(result: StdResult) -> StdResult { + let err = match result { + Ok(value) => return Ok(value), + Err(err) => err, + }; + + let err: JsValue = match err.dyn_into::() { + Ok(err) => { + let message: String = err.message().into(); + + if message.starts_with("D1") { + return Err(D1Error { inner: err }.into()); + }; + err.into() + } + Err(err) => err, + }; + + Err(err.into()) +} diff --git a/worker/src/env.rs b/worker/src/env.rs index 54ca52b1..64a66a62 100644 --- a/worker/src/env.rs +++ b/worker/src/env.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "d1")] +use crate::d1::D1Database; use crate::error::Error; #[cfg(feature = "queue")] use crate::Queue; @@ -71,6 +73,12 @@ impl Env { pub fn bucket(&self, binding: &str) -> Result { self.get_binding(binding) } + + /// Access a D1 Database by the binding name configured in your wrangler.toml file. + #[cfg(feature = "d1")] + pub fn d1(&self, binding: &str) -> Result { + self.get_binding(binding) + } } pub trait EnvBinding: Sized + JsCast { diff --git a/worker/src/error.rs b/worker/src/error.rs index 9ca519f9..c2c0056d 100644 --- a/worker/src/error.rs +++ b/worker/src/error.rs @@ -17,6 +17,8 @@ pub enum Error { SerdeJsonError(serde_json::Error), #[cfg(feature = "queue")] SerdeWasmBindgenError(serde_wasm_bindgen::Error), + #[cfg(feature = "d1")] + D1(crate::d1::D1Error), } impl From for Error { @@ -39,6 +41,13 @@ impl From for Error { } } +#[cfg(feature = "d1")] +impl From for Error { + fn from(e: crate::d1::D1Error) -> Self { + Self::D1(e) + } +} + impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -56,18 +65,28 @@ impl std::fmt::Display for Error { Error::SerdeJsonError(e) => write!(f, "Serde Error: {e}"), #[cfg(feature = "queue")] Error::SerdeWasmBindgenError(e) => write!(f, "Serde Error: {e}"), + Error::D1(e) => write!(f, "D1: {e:#?}"), } } } impl std::error::Error for Error {} +// Not sure if the changes I've made here are good or bad... impl From for Error { fn from(v: JsValue) -> Self { - match v - .as_string() - .or_else(|| v.dyn_ref::().map(|e| e.to_string().into())) - { + match v.as_string().or_else(|| { + v.dyn_ref::().map(|e| { + format!( + "Error: {} - Cause: {}", + e.to_string(), + e.cause() + .as_string() + .or_else(|| { Some(e.to_string().into()) }) + .unwrap_or(String::from("N/A")) + ) + }) + }) { Some(s) => Self::JsError(s), None => Self::Internal(v), } diff --git a/worker/src/lib.rs b/worker/src/lib.rs index 5fe5675e..9159a5c0 100644 --- a/worker/src/lib.rs +++ b/worker/src/lib.rs @@ -25,6 +25,8 @@ pub use crate::abort::*; pub use crate::cache::{Cache, CacheDeletionOutcome, CacheKey}; pub use crate::context::Context; pub use crate::cors::Cors; +#[cfg(feature = "d1")] +pub use crate::d1::*; pub use crate::date::{Date, DateInit}; pub use crate::delay::Delay; pub use crate::durable::*; @@ -53,6 +55,9 @@ mod cache; mod cf; mod context; mod cors; +// Require pub module for macro export +#[cfg(feature = "d1")] +pub mod d1; mod date; mod delay; pub mod durable; diff --git a/worker/src/router.rs b/worker/src/router.rs index 4c1b14a7..f0593d81 100644 --- a/worker/src/router.rs +++ b/worker/src/router.rs @@ -105,6 +105,12 @@ impl RouteContext { pub fn bucket(&self, binding: &str) -> Result { self.env.bucket(binding) } + + /// Access a D1 Database by the binding name configured in your wrangler.toml file. + #[cfg(feature = "d1")] + pub fn d1(&self, binding: &str) -> Result { + self.env.d1(binding) + } } impl<'a> Router<'a, ()> {