diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e6e82d70..dd8ec360 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,7 +9,10 @@ on: jobs: test-macos: - runs-on: macos-11 + if: ${{ false }} # disable until macOS 12 (with concurrency) runners are available. + runs-on: macos-12 + env: + DEVELOPER_DIR: /Applications/Xcode_13.2.app/Contents/Developer steps: - uses: actions/checkout@v2 - name: Build @@ -17,13 +20,15 @@ jobs: - name: Run tests run: swift test -v test-linux: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 strategy: matrix: - swift: [5.4] + swift: [5.5] container: swift:${{ matrix.swift }} steps: - uses: actions/checkout@v2 + - name: Install sqlite + run: apt-get -q update && apt-get install -y libsqlite3-dev - name: Build run: swift build -v --enable-test-discovery - name: Run tests diff --git a/Docs/0_GettingStarted.md b/Docs/0_GettingStarted.md deleted file mode 100644 index fd912aa7..00000000 --- a/Docs/0_GettingStarted.md +++ /dev/null @@ -1,62 +0,0 @@ -# Getting Started - -- [Installation](#installation) - * [CLI](#cli) - * [Swift Package Manager](#swift-package-manager) -- [Start Coding](#start-coding) - -## Installation - -### CLI - -The Alchemy CLI is installable with [Mint](https://github.com/yonaskolb/Mint). - -```shell -mint install alchemy-swift/alchemy-cli -``` - -Creating an app with the CLI will let you pick between a backend or fullstack (`iOS` frontend, `Alchemy` backend, `Shared` library) project. - -1. `alchemy new MyNewProject` -2. `cd MyNewProject` (if you selected fullstack, `MyNewProject/Backend`) -3. `swift run` -4. view your brand new app at http://localhost:3000 - -### Swift Package Manager - -Alchemy is also installable through the [Swift Package Manager](https://github.com/apple/swift-package-manager). - -```swift -dependencies: [ - .package(url: "https://github.com/alchemy-swift/alchemy", .upToNextMinor(from: "0.2.0")) - ... -], -targets: [ - .target(name: "MyServer", dependencies: [ - .product(name: "Alchemy", package: "alchemy"), - ]), -] -``` - -From here, conform to `Application` somewhere in your target and add the `@main` attribute. - -```swift -@main -struct App: Application { - func boot() { - get("/") { _ in - return "Hello from alchemy!" - } - } -} -``` - -Run your app with `swift run` and visit `localhost:3000` in the browser to see your new server in action. - -## Start Coding! - -Congrats, you're off to the races! Check out the rest of the guides for what you can do with Alchemy. - -_Up next: [Architecture](1_Configuration.md)_ - -_[Table of Contents](/Docs#docs)_ diff --git a/Docs/10_DiggingDeeper.md b/Docs/10_DiggingDeeper.md deleted file mode 100644 index be62caee..00000000 --- a/Docs/10_DiggingDeeper.md +++ /dev/null @@ -1,268 +0,0 @@ -# Digging Deeper - -- [Scheduling Tasks](#scheduling-tasks) - * [Scheduling](#scheduling) - + [Scheduling Jobs](#scheduling-jobs) - * [Schedule frequencies](#schedule-frequencies) - * [Running the Scheduler](#running-the-scheduler) -- [Logging](#logging) -- [Thread](#thread) -- [Making HTTP Requests](#making-http-requests) -- [Plot: HTML DSL](#plot--html-dsl) - * [Control Flow](#control-flow) - * [HTMLView](#htmlview) - * [Plot Docs](#plot-docs) -- [Serving Static Files](#serving-static-files) - -## Scheduling Tasks - -You'll likely want to run various recurring tasks associated with your server. In the past, this may have been done utilizing `cron`, but it can be frustrating to have your scheduling logic disconnected from your code. -To make this easy, Alchemy provides a clean API for scheduling repeated tasks & jobs. - -### Scheduling - -You can schedule recurring work for your application using the `schedule()` function. You'll probably want to do this in your `boot()` function. This returns a builder with which you can customize the frequency of the task. - -```swift -struct ExampleApp: Application { - func boot() { - schedule { print("Good morning!") } - .daily() - } -} -``` - -#### Scheduling Jobs - -You can also schedule jobs to be dispatched. Don't forget to run a worker to run the dispatched jobs. - -```swift -app.schedule(job: BackupDatabase()) - .daily(hr: 23) -``` - -### Schedule frequencies - -A variety of builder functions are offered to customize your schedule frequency. If your desired frequency is complex, you can even schedule a task using a cron expression. - -```swift -// Every week on tuesday at 8:00 pm -app.schedule { ... } - .weekly(day: .tue, hr: 20) - -// Every second -app.schedule { ... } - .secondly() - -// Every minute at 30 seconds -app.schedule { ... } - .minutely(sec: 30) - -// At 22:00 on every day-of-week from Monday through Friday.” -app.schedule { ... } - .cron("0 22 * * 1-5") -``` - -### Running the Scheduler - -Note that by default, your app won't actually schedule tasks. You'll need to pass the `--schedule` flag to either the `serve` (default) or `queue` command. - -```bash -# Serves and schedules -swift run MyServer --schedule - -# Runs a queue worker and schedules -swift run MyServer queue --schedule -``` - -## Logging - -To aid with logging, Alchemy provides a lightweight wrapper on top of [SwiftLog](https://github.com/apple/swift-log). - -You can conveniently log to the various levels via static functions on the `Log` struct. - -```swift -Log.trace("Here") -Log.debug("Testing") -Log.info("Hello") -Log.notice("FYI") -Log.warning("Hmmm") -Log.error("Uh oh") -Log.critical("Houston, we have a problem") -``` - -These log to `Log.logger`, an instance of `SwiftLog.Logger`, which defaults to a basic stdout logger. This is a settable variable so you may overwrite it to be a more complex `Logger`. See [SwiftLog](https://github.com/apple/swift-log) for advanced usage. - -## Thread - -As mentioned in [Under the Hood](12_UnderTheHood.md), you'll want to avoid blocking the current `EventLoop` as much as possible to help your server have maximum request throughput. - -Should you need to do some blocking work, such as file IO or CPU intensive work, `Thread` provides a dead simple interface for running work on a separate (non-`EventLoop`) thread. - -Initiate work with `Thread.run` which takes a closure, runs it on a separate thread, and returns the value generated by the closure back on the initiating `EventLoop`. - -```swift -Thread - .run { - // Will be run on a separate thread. - blockingWork() - } - .whenSuccess { value in - // Back on the initiating `EventLoop`, with access to any value - // returned by `blockingWork()`. - } -``` - -## Making HTTP Requests - -HTTP requests should be made with [AsyncHTTPClient](https://github.com/swift-server/async-http-client). For convenience `HTTPClient` is a `Service` and a default one is registered to your application container. - -```swift -HTTPClient.default - .get(url: "https://swift.org") - .whenComplete { result in - switch result { - case .failure(let error): - ... - case .success(let response): - ... - } - } -``` - -## Plot: HTML DSL - -Out of the box, Alchemy supports [Plot](https://github.com/JohnSundell/Plot), a Swift DSL for writing type safe HTML. With Plot, returning HTML is dead simple and elegant. You can do so straight from a `Router` handler. - -```swift -app.get("/website") { _ in - return HTML { - .head( - .title("My website"), - .stylesheet("styles.css") - ), - .body( - .div( - .h1(.class("title"), "My website"), - .p("Writing HTML in Swift is pretty great!") - ) - ) - } -} -``` - -### Control Flow - -Plot also supports inline control flow with conditionals, loops, and even unwrapping. It's the perfect, type safe substitute for a templating language. - -```swift -let animals: [String] = ... -let showSubtitle: Bool = ... -let username: String? = ... -HTML { - .head( - .title("My website"), - .stylesheet("styles.css") - ), - .body( - .div( - .h1("My favorite animals are..."), - .if(showSubtitle, - .h2("You found the subtitle") - ), - .ul(.forEach(animals) { - .li(.class("name"), .text($0)) - }), - .unwrap(username) { - .p("Hello, \(username)") - } - ) - ) -} -``` - -### HTMLView - -You can use the `HTMLView` type to help organize your projects view and pages. It is a simple protocol with a single requirement, `var content: HTML`. Like `HTML`, `HTMLView`s can be returned directly from a `Router` handler. - -```swift -struct HomeView: HTMLView { - let showSubtitle: Bool - let animals: [String] - let username: String? - - var content: HTML { - HTML { - .head( - .title("My website"), - .stylesheet("styles.css") - ), - .body( - .div( - .h1("My favorite animals are..."), - .if(self.showSubtitle, - .h2("You found the subtitle") - ), - .ul(.forEach(self.animals) { - .li(.class("name"), .text($0)) - }), - .unwrap(self.username) { - .p("Hello, \(username)") - } - ) - ) - } - } -} - -app.get("/home") { _ in - HomeView(showSubtitle: true, animals: ["Orangutan", "Axolotl", "Echidna"], username: "Kendra") -} -``` - -### Plot Docs - -Check out the [Plot docs](https://github.com/JohnSundell/Plot) for everything you can do with it. - -## Serving Static Files - -If you'd like to serve files from a static directory, there's a `Middleware` for that. It will match incoming requests to files in the directory, streaming those back to the client if they exist. By default, it serves from `Public/` but you may pass a custom path in the initializer if you like. - -Consider a `Public` directory in your project with a few files. - -``` -│ -├── Public -│ ├── css -│ │ └── style.css -│ ├── js -│ │ └── app.js -│ ├── images -│ │ └── puppy.png -│ └── index.html -│ -├── Sources -├── Tests -└── Package.swift -``` - -You could use the following code to serve files from that directory. - -```swift -app.useAll(StaticFileMiddleware()) -``` - -Now, assets in the `Public/` directory can be requested. -``` -http://localhost:3000/index.html -http://localhost:3000/css/style.css -http://localhost:3000/js/app.js -http://localhost:3000/images/puppy.png -http://localhost:3000/ (by default, will return any `index.html` file) -``` - -**Note**: The given directory is relative to your server's working directory. If you are using Xcode, be sure to [set a custom working directory](1_Configuration.md#setting-a-custom-working-directory) for your project where the static file directory is. - -_Next page: [Deploying](11_Deploying.md)_ - -_[Table of Contents](/Docs#docs)_ diff --git a/Docs/11_Deploying.md b/Docs/11_Deploying.md deleted file mode 100644 index 3223563a..00000000 --- a/Docs/11_Deploying.md +++ /dev/null @@ -1,145 +0,0 @@ -# Deploying - -- [DigitalOcean](#digitalocean) - * [Install Swift](#install-swift) - * [Run Your App](#run-your-app) -- [Docker](#docker) - * [Create a Dockerfile](#create-a-dockerfile) - * [Build and deploy the image](#build-and-deploy-the-image) - -While there are many ways to deploy your Alchemy app, this guide focuses on deploying to a Linux machine with DigitalOcean and deploying with Docker. - -## DigitalOcean - -Deploying with DigitalOcean is simple and cheap. You'll just need to create a droplet, install Swift, and run your project. - -First, create a new droplet with the image of your choice, for this guide we'll use `Ubuntu 20.04 (LTS) x64`. You can see the supported Swift [platforms here](https://swift.org/download/#releases). - -### Install Swift - -Once your droplet is created, ssh into it and install Swift. Start by installing the required dependencies. - -```shell -sudo apt-get update -sudo apt-get install clang libicu-dev libatomic1 build-essential pkg-config zlib1g-dev -``` - -Next, install Swift. You can do this by right clicking the name of your droplet image on the [Swift Releases](https://swift.org/download/#releases) page and copying the link. - -Download and decompress the copied link... - -```shell -wget https://swift.org/builds/swift-5.4.2-release/ubuntu2004/swift-5.4.2-RELEASE/swift-5.4.2-RELEASE-ubuntu20.04.tar.gz -tar xzf swift-5.4.2-RELEASE-ubuntu20.04.tar.gz -``` - -Put Swift somewhere easy to link to, such as a folder `/swift/{version}`. -```swift -sudo mkdir /swift -sudo mv swift-5.4.2-RELEASE-ubuntu20.04 /swift/5.4.2 -``` - -Then create a link in `/usr/bin`. -```shell -sudo ln -s /swift/5.4.2/usr/bin/swift /usr/bin/swift -``` - -Verify that it was installed correctly. - -```shell -swift --version -``` - -### Run Your App - -Now that Swift is installed, you can just run your app. - -Start by cloning it - -```shell -git clone -``` - -Make sure to allow HTTP through your droplet's firewall -``` -sudo ufw allow http -``` - -Then run it. Note that since we're on Linux we'll need to pass `--enable-test-discovery`, the executable name of your server (`Backend` if you cloned a quickstart), and a custom host and port so that the server will listen on your droplet's IP at port 80. - -```shell -cd my-project -swift run --enable-test-discovery Backend --host --port 80 -``` - -Assuming you had something like this in your `Application.boot` -```swift -get("/hello") { - "Hello, World!" -} -``` - -Visit `/hello` in your browser and you should see - -``` -Hello, World! -``` - -Congrats, your project is live! - -**Note** When you're ready to run a production version of your app, add a couple flags to the `swift run` command to speed it up and enable debug symbols for crash traces. You might just want to run these flags every time so it's less to think about. - -```shell -swift run -c release -Xswiftc -g -``` - -## Docker - -You can use Docker to create an image that will be deployable anywhere Docker is usable. - -### Create a Dockerfile - -Start off by creating a `Dockerfile`. This is a file that tells Docker how to build & run an image with your server. - -Here's a sample one to copy and paste. Note that you may have to change `Backend` to the name of your executable product. - -This file tells docker to use a base image of `swift:latest`, build your project, and, when the image is run, run your executable on host 0.0.0.0 at port 3000 - -```dockerfile -FROM swift:latest -WORKDIR /app -COPY . . -RUN swift build -c release -Xswiftc -g -RUN mkdir /app/bin -RUN mv `swift build -c release --show-bin-path` /app/bin -EXPOSE 3000 -ENTRYPOINT ./bin/release/Backend --host 0.0.0.0 --port 3000 -``` - -### Build and deploy the image - -Now build your image. If you've been running your project from the CLI, there may be a hefty `.build` folder. You might want to nuke that before running `docker build` so that you don't need to wait to pass that unneeded directory to Docker. - -```shell -$ docker build . -... -Successfully built ab21d0f26ecd -``` - -Finally, run the built image. Pass in `-d` to tell Docker to run your image in the background and `-p 3000:3000` to tell it that your container's 3000 port should be exposed to your machine. - -```shell -docker run -d -p 3000:3000 ab21d0f26ecd -``` - -Visit `http://0.0.0.0:3000/hello` in the browser and you should see - -``` -Hello, World! -``` - -Awesome! You're ready to deploy with Docker. - -_Up next: [Under The Hood](12_UnderTheHood.md)_ - -_[Table of Contents](/Docs#docs)_ diff --git a/Docs/12_UnderTheHood.md b/Docs/12_UnderTheHood.md deleted file mode 100644 index 711e4025..00000000 --- a/Docs/12_UnderTheHood.md +++ /dev/null @@ -1,55 +0,0 @@ -# Under the Hood - -- [Event Loops and You](#event-loops-and-you) - * [Caveat 1: Don't block EventLoops!](#caveat-1-dont-block-eventloops) - * [Caveat 2: Use non-blocking APIs (EventLoopFuture)](#caveat-2-use-non-blocking-apis-eventloopfuture-when-doing-async-tasks) - + [Creating a new EventLoopFuture](#creating-a-new-eventloopfuture) - * [Accessing EventLoops or EventLoopGroups](#accessing-eventloops-or-eventloopgroups) - -Alchemy is built on top of [Swift NIO](https://github.com/apple/swift-nio) which provides an "event driven architecture". This means that each request your server handles is assigned/run on an "event loop", a thread designated for handling incoming requests (represented by the `NIO.EventLoop` type). - -## Event Loops and You - -There are as many unique `EventLoop`s as there are logical cores on your machine, and as requests come in, they are distributed between them. For the most part, logic around `EventLoop`s is abstracted away for you, but there are a few caveats to be aware of when building with Alchemy. - -### Caveat 1: **Don't block `EventLoop`s!** - -The faster you finish handling a request, the sooner the `EventLoop` it's running on will be able to handle additional requests. To keep your server fast, don't block the event loops on which your router handlers are run. If you need to do some CPU intensive work, spin up another thread with `Thread.run`. This will allow the `EventLoop` of the request to handle other work while your intesive task is being completed on another thread. When the task is done, it will hop back to it's original `EventLoop` where it's handling can be finished. - -### Caveat 2: **Use non-blocking APIs (`EventLoopFuture`) when doing async tasks** - -Often, handling a request involves waiting for other servers / services to do something. This could include making a database query or making an external HTTP request. So that EventLoop threads aren't blocked, Alchemy leverages `EventLoopFuture`. `EventLoopFuture` is the Swift server world's version of a `Future`. It represents an asynchronous operation that hasn't yet completed, but will complete on a specific `EventLoop` with either an `Error` or a value of `T`. - -If you've worked with other future types before, these should be straighforward; the API reference is [here](https://apple.github.io/swift-nio/docs/current/NIO/Classes/EventLoopFuture.html). If you haven't, think of them as functional sugar around a value that you'll get in the future (i.e. is being fetched asynchronously). You can chain functions that change the value (`.map { ... }`) or change the value asynchronously (`.flatMap { ... }`) and then respond to the value (or an error) when it's finally resolved. - -#### Creating a new `EventLoopFuture` - -If needed, you can easily create a new future associated with the current `EventLoop` via `EventLoopFuture.new(error:)` or `EventLoopFuture.new(_ value:)`. These will resolve immediately on the current `EventLoop` with the value or error passed to them. - -```swift -func someHandler() -> EventLoopFuture { - .new("Hello!") -} - -func unimplementedHandler() -> EventLoopFuture { - .new(error: HTTPError(.notImplemented, message: "This endpoint isn't implemented yet")) -} -``` - -### Accessing `EventLoop`s or `EventLoopGroup`s - -In general, you won't need to access or think about any `EventLoop`s, but if you do, you can get the current one with `Loop.current`. - -```swift -let thisLoop: EventLoop = Loop.current -``` - -Should you need an `EventLoopGroup` for other `NIO` based libraries, you can access the global `EventLoopGroup` (a `MultiThreadedEventLoopGroup`) via `Loop.group`. - -```swift -let appLoopGroup: EventLoopGroup = Loop.group -``` - -Finally, should you need to run an expensive operation, you may use `Thread.run` which uses an entirely separate thread pool instead of blocking any of your app's `EventLoop`s. - -_[Table of Contents](/Docs#docs)_ diff --git a/Docs/13_Commands.md b/Docs/13_Commands.md deleted file mode 100644 index ddbb58b0..00000000 --- a/Docs/13_Commands.md +++ /dev/null @@ -1,138 +0,0 @@ -# Commands - -- [Writing a custom Command](#writing-a-custom-command) - * [Adding Options, Flags, and help info](#adding-options-flags-and-help-info) - * [Printing help info](#printing-help-info) -- [`make` Commands](#make-commands) - -Often, you'll want to run specific tasks around maintenance, cleanup or productivity for your Alchemy app. - -The `Command` interface makes this a cinche, allowing you to create custom commands to run your application with. It's built on the powerful [Swift Argument Parser](https://github.com/apple/swift-argument-parser) making it easy to add arguments, options, flags and help functionality to your custom commands. All commands have access to services registered in `Application.boot` so it's easy to interact with whatever database, queues, & other functionality that your app already has. - -## Writing a custom Command - -To create a command, conform to the `Command` protocol, implement `func start()`, and register it with `app.registerCommand(...)`. Now, when you run your Alchemy app you may pass your custom command name as an argument to execute it. - -For example, let's say you wanted a command that prints all user emails in your default database. - -```swift -final class PrintUserEmails: Command { - // see Swift Argument Parser for other configuration options - static var configuration = CommandConfiguration(commandName: "print") - - func start() -> EventLoopFuture { - User.all() - .mapEach { user in - print(user.email) - } - .voided() - } -} -``` - -Now just register the command, likely in your `Application.boot` - -```swift -app.registerCommand(PrintUserEmails.self) -``` - -and you can run your app with the `print` argument to run your command. - -``` -$ swift run MyApp print -... -jack@twitter.com -elon@tesla.com -mark@facebook.com -``` - -### Adding Options, Flags, and help info - -Because `Command` inherits from Swift Argument Parser's `ParsableCommand` you can easily add flags, options, and configurations to your commands. There's also support for adding help & discussion strings that will show if your app is run with the `help` argument. - -```swift -final class SyncUserData: Command { - static var configuration = CommandConfiguration(commandName: "sync", discussion: "Sync all data for all users.") - - @Option var id: Int? - @Flag(help: "Loaded data but don't save it.") var dry: Bool = false - - func start() -> EventLoopFuture { - if let userId = id { - // sync only a specific user's data - } else { - // sync all users' data - } - } -} -``` - -You can now pass options and flags to this command like so `swift run MyApp sync --id 2 --dry` and it run with the given arguments. - -### Printing help info - -Out of the box, your server can be run with the `help` argument to show all commands available to it, including any custom ones your may have registered. - -```bash -$ swift run MyApp help -OVERVIEW: Run an Alchemy app. - -USAGE: launch [--env ] - -OPTIONS: - -e, --env (default: env) - -h, --help Show help information. - -SUBCOMMANDS: - serve (default) - migrate - queue - make:controller - make:middleware - make:migration - make:model - make:job - make:view - sync - - See 'launch help ' for detailed help. -``` - -You can also pass a command name after help to get detailed information on that command, based on the information your provide in your `configuration`, options, flags, etc. - -```bash -$ swift run MyApp help sync -OVERVIEW: -Sync all data for all users. - -USAGE: MyApp sync [--id ] [--dry] - -OPTIONS: - -e, --env (default: env) - --id Sync data for a specific user only. - --dry Should data be loaded but not saved. - -h, --help Show help information. -``` - -Note that you can always pass `-e, --env ` to any command to have it load your environment from a custom env file before running. - -## `make` Commands - -Out of the box, Alchemy includes a variety of commands to boost your productivity and generate commonly used interfaces. These commands are prefaced with `make:`, and you can see all available ones with `swift run MyApp help`. - -For example, the `make:model` command makes it easy to generate a model with the given fields. You can event generate a full populated Migration and Controller with CRUD routes by passing the `--migration` and `--controller` flags. - -```bash -$ swift run Server make:model Todo id:increments:primary name:string is_done:bool user_id:bigint:references.users.id --migration --controller -🧪 create Sources/App/Models/Todo.swift -🧪 create Sources/App/Migrations/2021_09_24_11_07_02CreateTodos.swift - └─ remember to add migration to a Database.migrations! -🧪 create Sources/App/Controllers/TodoController.swift -``` - -Like all commands, you may view the details & arguments of each make command with `swift run MyApp help `. - - -_Next page: [Digging Deeper](10_DiggingDeeper.md)_ - -_[Table of Contents](/Docs#docs)_ \ No newline at end of file diff --git a/Docs/1_Configuration.md b/Docs/1_Configuration.md deleted file mode 100644 index ea510c8c..00000000 --- a/Docs/1_Configuration.md +++ /dev/null @@ -1,164 +0,0 @@ -# Configuration - -- [Run Commands](#run-commands) - * [`serve`](#serve) - * [`migrate`](#migrate) - * [`queue`](#queue) -- [Environment](#environment) - * [Dynamic Member Lookup](#dynamic-member-lookup) - * [.env File](#env-file) - * [Custom Environments](#custom-environments) -- [Working with Xcode](#working-with-xcode) - * [Setting a Custom Working Directory](#setting-a-custom-working-directory) - -## Run Commands - -When Alchemy is run, it takes an argument that determines how it behaves on launch. When no argument is passed, the default command is `serve` which boots the app and serves it on the machine. - -There are also `migrate` and `queue` commands which help run migrations and queue workers/schedulers respectively. - -You can run these like so. - -```shell -swift run Server migrate -``` - -Each command has options for customizing how it runs. If you're running your app from Xcode, you can configure launch arguments by editing the current scheme and navigating to `Run` -> `Arguments`. - -If you're looking to extend your Alchemy app with your own custom commands, check out [Commands](13_Commands.md). - -### Serve - -> `swift run` or `swift run Server serve` - -|Option|Default|Description| -|-|-|-| -|--host|127.0.0.1|The host to listen on| -|--port|3000|The port to listen on| -|--unixSocket|nil|The unix socket to listen on. Mutually exclusive with `host` & `port`| -|--workers|0|The number of workers to run| -|--schedule|false|Whether scheduled tasks should be scheduled| -|--migrate|false|Whether any outstanding migrations should be run before serving| -|--env|env|The environment to load| - -### Migrate - -> `swift run Server migrate` - -|Option|Default|Description| -|-|-|-| -|--rollback|false|Should migrations be rolled back instead of applied| -|--env|env|The environment to load| - -### Queue - -> `swift run Server queue` - -|Option|Default|Description| -|-|-|-| -|--name|`nil`|The queue to monitor. Leave empty to monitor `Queue.default`| -|--channels|`default`|The channels to monitor, separated by comma| -|--workers|1|The number of workers to run| -|--schedule|false|Whether scheduled tasks should be scheduled| -|--env|env|The environment to load| - -## Environment - -Often you'll need to access environment variables of the running program. To do so, use the `Env` type. - -```swift -// The type is inferred -let envBool: Bool? = Env.current.get("SOME_BOOL") -let envInt: Int? = Env.current.get("SOME_INT") -let envString: String? = Env.current.get("SOME_STRING") -``` - -### Dynamic member lookup - -If you're feeling fancy, `Env` supports dynamic member lookup. - -```swift -let db: String? = Env.DB_DATABASE -let dbUsername: String? = Env.DB_USER -let dbPass: String? = Env.DB_PASS -``` - -### .env file - -By default, environment variables are loaded from the process as well as the file `.env` if it exists in the working directory of your project. - -Inside your `.env` file, keys & values are separated with an `=`. - -```bash -# A sample .env file (a file literally titled ".env" in the working directory) - -APP_NAME=Alchemy -APP_ENV=local -APP_KEY= -APP_DEBUG=true -APP_URL=http://localhost - -DB_CONNECTION=mysql -DB_HOST=127.0.0.1 -DB_PORT=5432 -DB_DATABASE=alchemy -DB_USER=josh -DB_PASS=password - -REDIS_HOST=127.0.0.1 -REDIS_PASSWORD=null -REDIS_PORT=6379 - -AWS_ACCESS_KEY_ID= -AWS_SECRET_ACCESS_KEY= -AWS_DEFAULT_REGION=us-east-1 -AWS_BUCKET= -``` - -### Custom Environments - -You can load your environment from another location by passing your app the `--env` option. - -If you have separate environment variables for different server configurations (i.e. local dev, staging, production), you can pass your program a separate `--env` for each configuration so the right environment is loaded. - -## Configuring Your Server - -There are a couple of options available for configuring how your server is running. By default, the server runs over `HTTP/1.1`. - -### Enable TLS - -You can enable running over TLS with `useHTTPS`. - -```swift -func boot() throws { - try useHTTPS(key: "/path/to/private-key.pem", cert: "/path/to/cert.pem") -} -``` - -### Enable HTTP/2 - -You may also configure your server with `HTTP/2` upgrades (will prefer `HTTP/2` but still accept `HTTP/1.1` over TLS). To do this use `useHTTP2`. - -```swift -func boot() throws { - try useHTTP2(key: "/path/to/private-key.pem", cert: "/path/to/cert.pem") -} -``` - -Note that the `HTTP/2` protocol is only supported over TLS, and so implies using it. Thus, there's no need to call both `useHTTPS` and `useHTTP2`; `useHTTP2` sets up both TLS and `HTTP/2` support. - -## Working with Xcode - -You can use Xcode to run your project to take advantage of all the great tools built into it; debugging, breakpoints, memory graphs, testing, etc. - -When working with Xcode be sure to set a custom working directory. - -### Setting a Custom Working Directory - -By default, Xcode builds and runs your project in a **DerivedData** folder, separate from the root directory of your project. Unfortunately this means that files your running server may need to access, such as a `.env` file or a `Public` directory, will not be available. - -To solve this, edit your server target's scheme & change the working directory to your package's root folder. `Edit Scheme` -> `Run` -> `Options` -> `WorkingDirectory`. - -_Up next: [Services & Dependency Injection](2_Fusion.md)_ - -_[Table of Contents](/Docs#docs)_ diff --git a/Docs/2_Fusion.md b/Docs/2_Fusion.md deleted file mode 100644 index c56fbfda..00000000 --- a/Docs/2_Fusion.md +++ /dev/null @@ -1,83 +0,0 @@ -# Services & Dependency Injection - -- [Registering and Injecting Services](#registering-and-injecting-services) - * [Registering Defaults](#registering-defaults) - * [Registering Additional Instances](#registering-additional-instances) -- [Mocking](#mocking) - -Alchemy handles dependency injection using [Fusion](https://github.com/alchemy-swift/fusion). In addition to Fusion APIs, it includes a `Service` protocol to make it easy to inject common Alchemy such as `Database`, `Redis` and `Queue`. - -## Registering and Injecting Services - -Most Alchemy services conform to the `Service` protocol, which you can use to configure and access various connections. - -For example, you likely want to use an SQL database in your app. You can use the `Service` methods to set up a default database driver. You'll probably want to do this in your `Application.boot`. - -### Registering Defaults - -Services typically have static driver functions to your configure defaults. In this case, the `.postgres()` function helps create a PostgreSQL database driver. - -```swift -Database.config( - default: .postgres( - host: "localhost", - database: "alchemy")) -``` - -Once registered, you can inject this database anywhere in your code via `@Inject`. The service container will resolve the registered configuration. - -```swift -@Inject var database: Database -``` - -You can also inject it with `Database.default`. Many Alchemy APIs default to using a service's `default` so that you don't have to pass an instance in every time. For example for loading models from Rune, Alchemy's built in ORM. - -```swift -struct User: Model { ... } - -// Fetchs all `User` models from `Database.default` -User.all() -``` - -### Registering Additional Instances - -If you have more than one instance of a service that you'd like to use, you can pass an identifier to `Service.config()` to associate it with the given configuration. - -```swift -Database.config( - "mysql", - .mysql( - host: "localhost", - database: "alchemy")) -``` - -This can now be injected by passing that identifier to `@Inject`. - -```swift -@Inject("mysql") var mysqlDB: Database -``` - -It can also be inject by using the `Service.named()` function. - -```swift -User.all(db: .named("mysql")) -``` - -## Mocking - -When it comes time to write tests for your app, you can leverage the service protocol to inject mock interfaces of various services. These mocks will now be resolved any time this service is accessed in your code. - -```swift -final class RouterTests: XCTestCase { - private var app = TestApp() - - override func setUp() { - super.setUp() - Cache.config(default: .mock()) - } -} -``` - -_Next page: [Routing: Basics](3a_RoutingBasics.md)_ - -_[Table of Contents](/Docs#docs)_ diff --git a/Docs/3a_RoutingBasics.md b/Docs/3a_RoutingBasics.md deleted file mode 100644 index 296015ed..00000000 --- a/Docs/3a_RoutingBasics.md +++ /dev/null @@ -1,206 +0,0 @@ -# Routing: Basics - -- [Handling Requests](#handling-requests) -- [ResponseEncodable](#responseencodable) - * [Anything `Codable`](#anything-codable) - * [a `Response`](#a-response) - * [`Void`](#void) - * [Futures that result in a `ResponseConvertible` value](#futures-that-result-in-a-responseconvertible-value) - * [Chaining Requests](#chaining-requests) -- [Controller](#controller) -- [Errors](#errors) -- [Path parameters](#path-parameters) -- [Accessing request data](#accessing-request-data) - -## Handling Requests - -When a request comes through the host & port on which your server is listening, it immediately gets routed to your application. - -You can set up handlers in the `boot()` function of your app. - -Handlers are defined with the `.on(method:at:handler:)` function, which takes an `HTTPMethod`, a path, and a handler. The handler is a closure that accepts a `Request` and returns a type that conforms to `ResponseConvertable`. There's sugar for registering handlers for specific methods via `get()`, `post()`, `put()`, `patch()`, etc. - -```swift -struct ExampleApp: Application { - func boot() { - // GET {host}:{port}/hello - get("/hello") { request in - "Hello, World!" - } - } -} -``` - -## ResponseEncodable - -Out of the box, Alchemy conforms most types you'd need to return from a handler to `ResponseConvertible`. - -### Anything `Codable` - -```swift -/// String -app.get("/string", handler: { _ in "Howdy!" }) - -/// Int -app.on(.GET, at: "/int", handler: { _ in 42 }) - -/// Custom type - -struct Todo: Codable { - var name: String - var isDone: Bool -} - -app.get("/todo", handler: { _ in - Todo(name: "Write backend in Swift", isDone: true) -}) -``` - -### a `Response` - -```swift -app.get("/response") { _ in - Response(status: .ok, body: HTTPBody(text: "Hello from /response")) -} -``` - -### `Void` - -```swift -app.get("/testing_query") { request in - print("Got params \(request.queryItems)") -} -``` - -### Futures that result in a `ResponseConvertible` value - -```swift -app.get("/todos") { _ in - loadTodosFromDatabase() -} - -func loadTodosFromDatabase() -> EventLoopFuture<[Todo]> { - ... -} -``` - -*Note* an `EventLoopFuture` is the Swift server world's version of a future. See [Under the Hood](12_UnderTheHood.md). - -### Chaining Requests - -To keep code clean, handlers are chainable. - -```swift -let controller = UserController() -app - .post("/user", handler: controller.create) - .get("/user", handler: controller.get) - .put("/user", handler: controller.update) - .delete("/user", handler: controller.delete) -``` - -## Controller - -For convenience, a protocol `Controller` is provided to help break up your route handlers. Implement the `route(_ app: Application)` function and register it in your `Application.boot`. - -```swift -struct UserController: Controller { - func route(_ app: Application) { - app.post("/create", handler: create) - .post("/reset", handler: reset) - .post("/login", handler: login) - } - - func create(req: Request) -> String { - "Greetings from user create!" - } - - func reset(req: Request) -> String { - "Howdy from user reset!" - } - - func login(req: Request) -> String { - "Yo from user login!" - } -} - -struct App: Application { - func boot() { - ... - controller(UserController()) - } -} -``` - -## Errors - -Routing in Alchemy is heavily integrated with Swift's built in error handling. [Middleware](3b_RoutingMiddleware.md) & handlers allow for synchronous or asynchronous code to `throw`. - -If an error is thrown or an `EventLoopFuture` results in an error, it will be caught & mapped to a `Response`. - -Generic errors will result in an `Response` with a status code of 500, but if any error that conforms to `ResponseConvertible` is thrown, it will be converted as such. - -Out of the box `HTTPError` conforms to `ResponseConvertible`. If it is thrown, the response will contain the status code & message of the `HTTPError`. - -```swift -struct SomeError: Error {} - -app - .get("/foo") { _ in - // Will result in a 500 response with a generic error message. - throw SomeError() - } - .get("/bar") { _ in - // Will result in a 404 response with the custom message. - throw HTTPError(status: .notFound, message: "This endpoint doesn't exist!") - } -``` - -## Path parameters - -Dynamic path parameters can be added with a variable name prefaced by a colon (`:`). The value will be parsed and accessible in the handler. - -```swift -app.on(.GET, at: "/users/:userID") { req in - let userID: String? = req.pathParameter(named: "userID") -} -``` - -As long as they have different names, a route can have as many path parameters as you'd like. - -## Accessing request data - -Data you might need to get off of an incoming request are in the `Request` type. - -```swift -app.post("/users/:userID") { req in - // Headers - let authHeader: String? = req.headers.first(name: "Authorization") - - // Query (URL) parameters - let countParameter: QueryParameter? = req.queryItems - .filter ({ $0.name == "count" }).first - - // Path - let thePath: String? = req.path - - // Path parameters - let userID: String? = req.pathParameter(named: "userID") - - // Method - let theMethod: HTTPMethod = req.method - - // Body - let body: SomeCodable = try req.body.decodeJSON() - - // Token auth, if there is any - let basicAuth: HTTPBasicAuth? = req.basicAuth() - - // Bearer auth, if there is any - let bearerAuth: HTTPBearerAuth? = req.bearerAuth() -} -``` - -_Next page: [Routing: Middleware](3b_RoutingMiddleware.md)_ - -_[Table of Contents](/Docs#docs)_ diff --git a/Docs/3b_RoutingMiddleware.md b/Docs/3b_RoutingMiddleware.md deleted file mode 100644 index 8a41efe5..00000000 --- a/Docs/3b_RoutingMiddleware.md +++ /dev/null @@ -1,151 +0,0 @@ -# Routing: Middleware - -- [Creating Middleware](#creating-middleware) - * [Accessing the `Request`](#accessing-the-request) - * [Setting Data on a Request](#setting-data-on-a-request) - * [Accessing the `Response`](#accessing-the--response-) -- [Adding Middleware to Your Application](#adding-middleware-to-your-application) - * [Global Intercepting](#global-intercepting) - * [Specific Intercepting](#specific-intercepting) - -## Creating Middleware - -A middleware is a piece of code that is run before or after a request is handled. It might modify the `Request` or `Response`. - -Create a middleware by conforming to the `Middleware` protocol. It has a single function `intercept` which takes a `Request` and `next` closure. It returns an `EventLoopFuture`. - -### Accessing the `Request` - -If you'd like to do something with the `Request` before it is handled, you can do so before calling `next`. Be sure to call and return `next` when you're finished! - -```swift -/// Logs all requests that come through this middleware. -struct LogRequestMiddleware: Middleware { - func intercept(_ request: Request, next: @escaping Next) -> EventLoopFuture { - Log.info("Got a request to \(request.path).") - return next(request) - } -} -``` - -You may also do something with the request asynchronously, just be sure to continue the chain with `next(req)` when you are finished. - -```swift -/// Runs a database query before passing a request to a handler. -struct QueryingMiddleware: Middleware { - func intercept(_ request: Request, next: @escaping Next) -> EventLoopFuture { - return User.all() - .flatMap { users in - // Do something with `users` then continue the chain - next(request) - } - } -} -``` - -### Setting Data on a Request - -Sometimes you may want a `Middleware` to add some data to a `Request`. For example, you may want to authenticate an incoming request with a `Middleware` and then add a `User` to it for handlers down the chain to access. - -You can set generic data on a `Request` using `Request.set` and then access it in subsequent `Middleware` or handlers via `Request.get`. - -For example, you might be doing some experiments with a homegrown `ExperimentConfig` type. You'd like to assign random configurations of that type on a per-request basis. You might do so with a `Middleware`: - -```swift -struct ExperimentMiddleware: Middleware { - func intercept(_ request: Request, next: @escaping Next) -> EventLoopFuture { - let config: ExperimentConfig = ... // load a random experiment config - return next(request.set(config)) - } -} -``` - -You would then intercept requests with that `Middleware` and utilize the set `ExperimentConfig` in your handlers. - -```swift -app - .use(ExperimentalMiddleware()) - .get("/experimental_endpoint") { request in - // .get() will throw an error if a value with that type hasn't been `set()` on the `Request`. - let config: ExperimentConfig = try request.get() - if config.shouldUseLoudCopy { - return "HELLO WORLD!!!!!" - } else { - return "hey, world." - } - } -``` - -### Accessing the `Response` - -If you'd like to do something with the `Response` of the handled request, you can plug into the future returned by `next`. - -```swift -/// Logs all responses that come through this middleware. -struct LogResponseMiddleware: Middleware { - func intercept(_ request: Request, next: @escaping Next) -> EventLoopFuture { - return next(request) - // Use `flatMap` if you want to do something asynchronously. - .map { response in - Log.info("Got a response \(response.status) from \(request.path).") - return response - } - } -} -``` - -## Adding Middleware to Your Application - -There are a few ways to have a `Middleware` intercept requests. - -### Global Intercepting - -If you'd like a middleware to intercept _all_ requests on an `Application`, you can add it via `Application.useAll`. - -```swift -struct ExampleApp: Application { - func boot() { - self - .useAll(LoggingMiddleware()) - // LoggingMiddleware will intercept all of these, as well as any unhandled requests. - .get("/foo") { request in "Howdy foo!" } - .post("/bar") { request in "Howdy bar!" } - .put("/baz") { request in "Howdy baz!" } - } -} -``` - -### Specific Intercepting - -A `Middleware` can be setup to only intercept requests to specific handlers via the `.use(_ middleware: Middleware)` function on an `Application`. The `Middleware` will intercept all requests to the subsequently defined handlers. - -```swift -app - .post("/password_reset", handler: ...) - // Because this middleware is provided after the /password_reset endpoint, - // it will only affect subsequent routes. In this case, only requests to - // `/user` and `/todos` would be intercepted by the LoggingMiddleware. - .use(LoggingMiddleware()) - .get("/user", handler: ...) - .get("/todos", handler: ...) -``` - -There is also a `.group` function that takes a `Middleware`. The `Middleware` will _only_ intercept requests handled by handlers defined in the closure. - -```swift -app - .post("/user", handle: ...) - .group(middleware: CustomAuthMiddleware()) { - // Each of these endpoints will be protected by the - // `CustomAuthMiddleWare`... - $0.get("/todo", handler: ...) - .put("/todo", handler: ...) - .delete("/todo", handler: ...) - } - // ...but this one will not. - .post("/reset", handler: ...) -``` - -_Next page: [Papyrus](4_Papyrus.md)_ - -_[Table of Contents](/Docs#docs)_ diff --git a/Docs/4_Papyrus.md b/Docs/4_Papyrus.md deleted file mode 100644 index 2bf1b92a..00000000 --- a/Docs/4_Papyrus.md +++ /dev/null @@ -1,347 +0,0 @@ -# Papyrus - -- [Installation](#installation) - * [Server](#server) - * [Shared Library](#shared-library) - * [iOS / macOS](#ios---macos) -- [Usage](#usage) - * [Defining APIs](#defining-apis) - + [Basics](#basics) - + [Supported Methods](#supported-methods) - + [Empty Request or Reponse](#empty-request-or-reponse) - + [Custom Request Data](#custom-request-data) - - [URLQuery](#urlquery) - - [Header](#header) - - [Path Parameters](#path-parameters) - - [Body](#body) - - [Combinations](#combinations) - * [Requesting APIs](#requesting-apis) - + [Client, via Alamofire](#client-via-alamofire) - + [Server, via AsyncHTTPClient](#server-via-asynchttpclient) - * [Providing APIs](#providing-apis) - * [Interceptors](#interceptors) - -Papyrus is a helper library for defining network APIs in Swift. - -It leverages `Codable` and Property Wrappers for creating network APIs that are easy to read, easy to consume (on Server or Client) and easy to provide (on Server). When shared between a Swift client and server, it enforces type safety when requesting and handling HTTP requests. - -## Installation - -### Server - -Papyrus is included when you `import Alchemy` on the server side. - -### Shared Library - -If you're sharing code between clients and servers with a Swift library, you can add `Papyrus` as a dependency to that library via SPM. - -```swift -// in your Package.swift - -dependencies: [ - .package(url: "https://github.com/alchemy-swift/alchemy", .upToNextMinor(from: "0.2.0")) - ... -], -targets: [ - .target(name: "MySharedLibrary", dependencies: [ - .product(name: "Papyrus", package: "alchemy"), - ]), -] -``` - -### iOS / macOS - -If you want to define or request `Papyrus` APIs on a Swift client (iOS, macOS, etc) you'll add [`PapyrusAlamofire`](https://github.com/alchemy-swift/papyrus-alamofire) as a dependency via SPM. This is a light wrapper around `Papyrus` with support for requesting endpoints with [Alamofire](https://github.com/Alamofire/Alamofire). - -Since Xcode manages the `Package.swift` for iOS and macOS targets, you can add `PapyrusAlamofire` as a dependency through `File` -> `Swift Packages` -> `Add Package Dependency` -> paste `https://github.com/alchemy-swift/papyrus-alamofire` -> check `PapyrusAlamofire` to import. - -## Usage - -Papyrus is used to define, request, and provide HTTP endpoints. - -### Defining APIs - -#### Basics - -A single endpoint is defined with the `Endpoint` type. - -`Endpoint.Request` represents the data needed to make this request, and `Endpoint.Response` represents the expected return data from this request. Note that `Request` must conform to `RequestComponents` and `Response` must conform to `Codable`. - -Define an `Endpoint` on an enclosing `EndpointGroup` subclass, and wrap it with a property wrapper representing it's HTTP method and path, relative to a base URL. - -```swift -class TodosAPI: EndpointGroup { - @GET("/todos") - var getAll: Endpoint - - struct GetTodosRequest: RequestComponents { - @URLQuery - var limit: Int - - @URLQuery - var incompleteOnly: Bool - } - - struct TodoDTO: Codable { - var name: String - var isComplete: Bool - } -} -``` - -Notice a few things about the `getAll` endpoint. - -1. The `@GET("/todos")` indicates that the endpoint is at `POST {some_base_url}/todos`. -2. The endpoint expects a request object of `GetUsersRequest` which conforms to `RequestComponents` and contains two properties, wrapped by `@URLQuery`. The `URLQuery` wrappers indicate data that's expected in the query url of the request. This lets requesters of this endpoint know that the endpoint needs two query values, `limit` and `incompleteOnly`. It also lets the providers of this endpoint know that incoming requests to `GET /todo` will contain two items in their query URLs; `limit` and `incompleteOnly`. -3. The endpoint has a response type of `[TodoDTO]`, defined below it. This lets clients know what response type to expect and lets providers know what response type to return. - -This gives anyone reading or using the API all the information they would need to interact with it. - -Requesting this endpoint might look like -``` -GET {some_base_url}/todos?limit=1&incompleteOnly=0 -``` -While a response would look like -```json -[ - { - "name": "Do laundry", - "isComplete": false - }, - { - "name": "Learn Alchemy", - "isComplete": true - }, - { - "name": "Be awesome", - "isComplete": true - }, -] -``` - -**Note**: The DTO suffix of `TodoDTO` stands for `Data Transfer Object`, indicating that this type represents some data moving across the wire. It is not necesssary, but helps differentiate from local `Todo` model types that may exist on either client or server. - -#### Supported Methods - -Out of the box, Papyrus provides `@GET`, `@POST`, `@PUT`, `@PATCH`, `@DELETE` as well as a `@CUSTOM("OPTIONS", "/some/path")` that can take any method string for defining your `Endpoint`s. - -#### Empty Request or Reponse - -If you're endpoint doesn't have any request or response data that needs to be parsed, you may define the `Request` or `Response` type to be `Empty`. - -```swift -class SomeAPI: EndpointGroup { - @GET("/foo") - var noRequest: Endpoint - - @POST("/bar") - var noResponse: Endpoint -} -``` - -#### Custom Request Data - -Like `@URLQuery`, there are other property wrappers to define where on an HTTP request data should be. - -Each wrapper denotes a value in the request at the proper location with a key of the name of the property. For example `@Header var someHeader: String` indicates requests to this endpoint should have a header named `someHeader`. - -**Note**: `@Body` ignore's its property name and instead encodes it's value into the entire request body. - -##### URLQuery - -`@URLQuery` can wrap a `Bool`, `String`, `String?`, `Int`, `Int?` or `[String]`. - -Optional properties with nil values will be omitted. - -```swift -class SomeAPI: EndpointGroup { - // There will be a query1, query3 and optional query2 in the request URL. - @GET("/foo") - var queryRequest: Endpoint -} - -struct QueryRequest: RequestComponents { - @URLQuery var query1: String - @URLQuery var query2: String? - @URLQuery var query3: Int -} -``` - -##### Header - -`@Header` can wrap a `String`. It indicates that there should be a header of name `{propertyName}` on the request. - -```swift -class SomeAPI: EndpointGroup { - @POST("/foo") - var foo: Endpoint -} - -/// Defines a header "someHeader" on the request. -struct HeaderRequest: RequestComponents { - @Header var someHeader: String -} -``` - -##### Path Parameters - -`@Path` can wrap a `String`. It indicates a dynamic path parameter at `:{propertyName}` in the request path. - -```swift -class SomeAPI: EndpointGroup { - @POST("/some/:someID/value") - var foo: Endpoint -} - -struct PathRequest: RequestComponents { - @Path var someID: String -} -``` - -##### Body - -`@Body` can wrap any `Codable` type which will be encoded to the request. By default, the body is encoded as JSON, but you may override `RequestComponents.contentType` to use another encoding type. - -```swift -class SomeAPI: EndpointGroup { - @POST("/json") - var json: Endpoint - - @GET("/url") - var json: Endpoint -} - -/// Will encode `BodyData` in the request body. -struct JSONBody: RequestComponents { - @Body var body: BodyData -} - -/// Will encode `BodyData` in the request URL. -struct URLEncodedBody: RequestComponents { - static let contentType = .urlEncoded - - @Body var body: BodyData -} - -struct BodyData: Codable { - var foo: String - var baz: Int -} -``` - -You may also use `RequestBody` if the only content of the request is in the body. This will encode whatever fields are on your type into the `Request`'s body, instead of having to add a separate type and use the `@Body` property wrapper. - -```swift -struct JSONBody: RequestBody { - var foo: String - var baz: Int -} -``` - -##### Combinations - -You can combine any number of these property wrappers, except for `@Body`. There can only be a single `@Body` per request. - -```swift -struct MyCustomRequest: RequestComponents { - struct SomeCodable: Codable { - ... - } - - @Body var bodyData: SomeCodable - - @Header var someHeader: String - - @Path var userID: String - - @URLQuery var query1: Int - @URLQuery var query2: String - @URLQuery var query3: String? - @URLQuery var query3: [String] -} -``` - -### Requesting APIs - -Papyrus can be used to request endpoints on client or server targets. - -To request an endpoint, create the `EndpointGroup` with a `baseURL` and call `request` on a specific endpoint, providing the needed `Request` type. - -Requesting the the `TodosAPI.getAll` endpoint from above looks similar on both client and server. - -```swift -// `import PapyrusAlamofire` on client -import Alchemy - -let todosAPI = TodosAPI(baseURL: "http://localhost:3000") -todosAPI.getAll - .request(.init(limit: 50, incompleteOnly: true)) { response, todoResult in - switch todoResult { - case .success(let todos): - for todo in todos { - print("Got todo: \(todo.name)") - } - case .failure(let error): - print("Got error: \(error).") - } - } -``` - -This would make a request that looks like: -``` -GET http://localhost:3000/todos?limit=50&incompleteOnly=false -``` - -While the APIs are built to look similar, the client and server implementations sit on top of different HTTP libraries and are customizable in separate ways. - -#### Client, via Alamofire - -Requesting an `Endpoint` client side is built on top of [Alamofire](https://github.com/Alamofire/Alamofire). By default, requests are run on `Session.default`, but you may provide a custom `Session` for any customization, interceptors, etc. - -#### Server, via AsyncHTTPClient - -Request an `Endpoint` in an `Alchemy` server is built on top of [AsyncHTTPClient](https://github.com/swift-server/async-http-client). By default, requests are run on the default `HTTPClient`, but you may provide a custom `HTTPClient`. - -### Providing APIs - -Alchemy contains convenient extensions for registering your `Endpoint`s on a `Router`. Use `.on` to register an `Endpoint` to a router. - -```swift -let todos = TodosAPI() -router.on(todos.getAll) { (request: Request, data: GetTodosRequest) in - // when a request to `GET /todos` is handled, the `GetTodosRequest` properties will be loaded from the `Alchemy.Request`. -} -``` - -This will automatically parse the relevant `GetTodosRequest` data from the right places (URL query, headers, body, path parameters) on the incoming request. In this case, "limit" & "incompleteOnly" from the request query `String`. - -If expected data is missing, a `400` is thrown describing the missing expected fields: - -```json -400 Bad Request -{ - "message": "expected query value `limit`" -} -``` - -**Note**: Currently, only `ContentType.json` is supported for decoding request `@Body`s. - -### Interceptors - -Often you'll have some sort of request component you'd like to apply to every request in a group of endpoints. For example, you may want to add an `Authorization` header. Instead of adding `@Header var authorization: String` to each request content, you can accomplish this using the `intercept()` function of `EndpointGroup`. It gives you the raw `HTTPComponents` to modify right before sending the request. - -```swift -final class TodoAPI: EndpointGroup { - @POST("/v1/create") var create: Endpoint - @POST("/v1/create") var getAll: Endpoint - @POST("/v1/create") var delete: Endpoint - - func intercept(_ components: inout HTTPComponents) { - components.headers["Authorization"] = "Bearer \(some_token)" - } -} -``` - -_Next page: [Database: Basics](5a_DatabaseBasics.md)_ - -_[Table of Contents](/Docs#docs)_ diff --git a/Docs/5a_DatabaseBasics.md b/Docs/5a_DatabaseBasics.md deleted file mode 100644 index d3ab0ccd..00000000 --- a/Docs/5a_DatabaseBasics.md +++ /dev/null @@ -1,97 +0,0 @@ -# Database: Basics - -- [Introduction](#introduction) -- [Connecting to a Database](#connecting-to-a-database) -- [Querying data](#querying-data) - * [Handling Query Responses](#handling-query-responses) - * [Transactions](#transactions) - -## Introduction - -Alchemy makes interacting with SQL databases a breeze. You can use raw SQL, the fully featured [query builder](5b_DatabaseQueryBuilder.md) or the built in ORM, [Rune](6a_RuneBasics.md). - -## Connecting to a Database - -Out of the box, Alchemy supports connecting to Postgres & MySQL databases. Database is a `Service` and so is configurable with the `config` function. - -```swift -Database.config(default: .postgres( - host: Env.DB_HOST ?? "localhost", - database: Env.DB ?? "db", - username: Env.DB_USER ?? "user", - password: Env.DB_PASSWORD ?? "password" -)) - -// Database queries are all asynchronous, using `EventLoopFuture`s in -// their API. -Database.default - .rawQuery("select * from users;") - .whenSuccess { rows in - print("Got \(rows.count) results!") - } -``` - -## Querying data - -You can query with raw SQL strings using `Database.rawQuery`. It supports bindings to protect against SQL injection. - -```swift -let email = "josh@withapollo.com" - -// Executing a raw query -database.rawQuery("select * from users where email='\(email)';") - -// Using bindings to protect against SQL injection -database.rawQuery("select * from users where email=?;", values: [.string(email)]) -``` - -**Note** regardless of SQL dialect, please use `?` as placeholders for bindings. Concrete `Database`s representing dialects that use other placeholders, such as `PostgresDatabase`, will replace `?`s with the proper placeholder. - -### Handling Query Responses - -Every query returns a future with an array of `DatabaseRow`s that you can use to parse out data. You can access all their columns with `allColumns` or try to get the value of a column with `.getField(column: String) throws -> DatabaseField`. - -```swift -dataBase.rawQuery("select * from users;") - .mapEach { (row: DatabaseRow) in - print("Got a user with columns: \(row.allColumns.join(", "))") - let email = try! row.getField(column: "email").string() - print("The email of this user was: \(email)") - } -``` - -Note that `DatabaseField` is made up of a `column: String` and a `value: DatabaseValue`. It contains functions for casting the value to a specific Swift data type, such as `.string()` above. - -```swift -let field: DatabaseField = ... - -let uuid: UUID = try field.uuid() -let string: String = try field.string() -let int: Int = try field.int() -let bool: Bool = try field.bool() -let double: Double = try field.double() -let json: Data = try field.json() -``` - -These functions will throw if the value at the given column isn't convertible to that type. - -### Transactions - -Sometimes, you'll want to run multiple database queries as a single atomic operation. For this, you can use the `transaction()` function; a wrapper around SQL transactions. You'll have exclusive access to a database connection for the lifetime of your transaction. - -```swift -database.transaction { conn in - conn.query() - .where("account" == 1) - .update(values: ["amount": 100]) - .flatMap { _ in - conn.query() - .where("account" == 2) - .update(values: ["amount": 200]) - } -} -``` - -_Next page: [Database: Query Builder](5b_DatabaseQueryBuilder.md)_ - -_[Table of Contents](/Docs#docs)_ \ No newline at end of file diff --git a/Docs/5b_DatabaseQueryBuilder.md b/Docs/5b_DatabaseQueryBuilder.md deleted file mode 100644 index 3607a7a6..00000000 --- a/Docs/5b_DatabaseQueryBuilder.md +++ /dev/null @@ -1,272 +0,0 @@ -# Database: Query Builder - -- [Running Database Queries](#running-database-queries) - * [Starting a query chain](#starting-a-query-chain) - * [Get all rows](#get-all-rows) - * [Get a single row](#get-a-single-row) -- [Select](#select) - * [Picking columns to return](#picking-columns-to-return) -- [Joins](#joins) -- [Where Clauses](#where-clauses) - * [Basic Where Clauses](#basic-where-clauses) - * [Or Where Clauses](#or-where-clauses) - * [Grouping Where Clauses](#grouping-where-clauses) - * [Additional Where Clauses](#additional-where-clauses) - + [Where Null](#where-null) - + [Where In](#where-in) -- [Ordering, Grouping, Paging](#ordering-grouping-paging) - * [Grouping](#grouping) - * [Ordering](#ordering) - * [Paging, Limits and Offsets](#paging-limits-and-offsets) -- [Inserting](#inserting) -- [Updating](#updating) -- [Deleting](#deleting) -- [Counting](#counting) - -Alchemy offers first class support for building and running database queries through a chaining query builder. It can be used for the majority of database operations, otherwise you can always run pure SQL as well. The syntax is heavily inspired by Knex and Laravel. - -## Running Database Queries - -### Starting a query chain -To start fetching records, you can begin a chain a number of different ways. Each will start a query builder chain that you can then build out. - -```swift -Query.from("users")... // Start a query on table `users` using the default database. -// or -Model.query()... // Start a query and automatically sets the table from the model. -// or -database.query().from("users") // Start a query using a database variable on table `users`. -``` - -### Get all rows -```swift -Query.from("users") - .get() -``` - -### Get a single row - -If you are only wanting to select a single row from the database table, you have a few different options. - -To select the first row only from a query, use the `first` method. -```swift -Query.from("users") - .where("name", "Steve") - .first() -``` - -If you want to get a single record based on a given column, you can use the `find` method. This will return the first record matching the criteria. -```swift -Query.from("users") - .find() -``` - -## Select - -### Picking columns to return - -Sometimes you may want to select just a subset of columns to return. While the `find` and `get` methods can take a list of columns to limit down to, you can always explicitly call `select`. - -```swift -Query.from("users") - .select(["first_name", "last_name"]) - .get() -``` - -## Joins - -You can easily join data from separate tables using the query builder. The `join` method needs the table you are joining, and a clause to match up the data. If for example you are wanting to join all of a users order data, you could do the following: - -```swift -Query.from("users") - .join(table: "orders", first: "users.id", op: .equals, second: "orders.user_id") - .get() -``` - -There are helper methods available for `leftJoin`, `rightJoin` and `crossJoin` that you can use that take the same basic parameters. - -## Where Clauses - -### Basic Where Clauses - -If you are wanting to filter down your results this can be done by using the `where` method. You can add as many where clauses to your query to continually filter down as far as needed. The simplest usage is to construct a `WhereValue` clause using some of the common operators. To do this, you would pass a column, the operator and then the value. For example if you wanted to get all users over 20 years old, you could do so as follows: - -```swift -Query.from("users") - .where("age" > 20) - .get() -``` - -The following operators are valid when constructing a `WhereValue` in this way: `==`, `!=`, `<`, `>`, `<=`, `>=`, `~=`. - -Alternatively you can manually create a `WhereValue` clause manually: - -```swift -Query.from("users") - .where(WhereValue(key: "age", op: .equals, value: 10)) - .get() -``` - -### Or Where Clauses - -By default chaining where clauses will be joined together using the `and` operator. If you ever need to switch the operator to `or` you can do so by using the `orWhere` method. - -```swift -Query.from("users") - .where("age" > 20) - .orWhere("age" < 50) - .get() -``` - -### Grouping Where Clauses - -If you need to group where clauses together, you can do so by using a closure. This will execute those clauses together within parenthesis to achieve your desired logical grouping. - -```swift -Query.from("users") - .where { - $0.where("age" < 30) - .orWhere("first_name" == "Paul") - } - .orWhere { - $0.where("age" > 50) - .orWhere("first_name" == "Karen") - } - .get() -``` - -The provided example would produce the following SQL: - -```sql -select * from users where (age < 50 or first_name = 'Paul') and (age > 50 or first_name = 'Karen') -``` - -### Additional Where Clauses - -There are some additional helper where methods available for common cases. All methods also have a corresponding `or` method as well. - -#### Where Null - -The `whereNull` method ensures that the given column is not null. - -```swift -Query.from("users") - .whereNull("last_name") - .get() -``` - -#### Where In - -The `where(key: String, in values [Parameter])` method lets you pass an array of values to match the column against. - -```swift -Query.from("users") - .where(key: "age", in: [10,20,30]) - .get() -``` - -## Ordering, Grouping, Paging - -### Grouping - -To group results together, you can use the `groupBy` method: - -```swift -Query.from("users") - .groupBy("age") - .get() -``` - -If you need to filter the grouped by rows, you can use the `having` method which performs similar to a `where` clause. - -```swift -Query.from("users") - .groupBy("age") - .having("age" > 100) - .get() -``` - -### Ordering - -You can sort results of a query by using the `orderBy` method. - -```swift -Query.from("users") - .orderBy(column: "first_name", direction: .asc) - .get() -``` - -If you need to sort by multiple columns, you can add `orderBy` as many times as needed. Sorting is based on call order. - -```swift -Query.from("users") - .orderBy(column: "first_name", direction: .asc) - .orderBy(column: "last_name", direction: .desc) - .get() -``` - -### Paging, Limits and Offsets - -If all you are looking for is to break a query down into chunks for paging, the easiest way to accomplish that is to use the `forPage` method. It will automatically set the limits and offsets appropriate for a page size you define. - -```swift -Query.from("users") - .forPage(page: 1, perPage: 25) - .get() -``` - -Otherwise, you can also define limits and offsets manually: -```swift -Query.from("users") - .offset(50) - .limit(10) - .get() -``` - -## Inserting - -You can insert records using the query builder as well. To do so, start a chain with only a table name, and then pass the record you wish to insert. You can additionally pass in an array of records to do a bulk insert. - -```swift -Query.table("users") - .insert([ - "first_name": "Steve", - "last_name": "Jobs" - ]) -``` - -## Updating - -Updating records is just as easy as inserting, however you also get the benefit of the rest of the query builder chain. Any where clauses that have been added are used to match which records you want to update. For example, if you wanted to update a single user based on an ID, you could do so as follows: - -```swift -Query.table("users") - .where("id" == 10) - .update(values: [ - "first_name": "Ashley" - ]) -``` - -## Deleting - -The `delete` method works similar to how `update` did. It uses the query builder chain to determine what records match, but then instead of updating them, it deletes them. If you wanted to delete all users whose name is Peter, you could do that as so: - -```swift -Query.table("users") - .where("name" == "Peter") - .delete() -``` - -## Counting - -To get the total number of records that match a query you can use the `count` method. - -```swift -Query.from("rentals") - .where("num_beds" >= 1) - .count(as: "rentals_count") -``` - -_Next page: [Database: Migrations](5c_DatabaseMigrations.md)_ - -_[Table of Contents](/Docs#docs)_ diff --git a/Docs/5c_DatabaseMigrations.md b/Docs/5c_DatabaseMigrations.md deleted file mode 100644 index e49b9f21..00000000 --- a/Docs/5c_DatabaseMigrations.md +++ /dev/null @@ -1,187 +0,0 @@ -# Database: Migrations - -- [Creating a migration](#creating-a-migration) -- [Implementing Migrations](#implementing-migrations) -- [Schema functions](#schema-functions) -- [Creating a table](#creating-a-table) - * [Adding Columns](#adding-columns) - * [Adding Indexes](#adding-indexes) -- [Altering a Table](#altering-a-table) -- [Other schema functions](#other-schema-functions) -- [Running a Migration](#running-a-migration) - * [Via Command](#via-command) - + [Applying](#applying) - + [Rolling Back](#rolling-back) - * [Via Code](#via-code) - + [Applying](#applying-1) - + [Rolling Back](#rolling-back-1) - -Migrations are a key part of working with an SQL database. Each migration defines changes to the schema of your database that can be either applied or rolled back. You'll typically create new migrations each time you want to make a change to your database, so that you can keep track of all the changes you've made over time. - -## Creating a migration -You can create a new migration using the CLI. - -```bash -alchemy make:migration MyMigration -``` - -This will create a new migration file in `Sources/App/Migrations`. - -## Implementing Migrations - -A migration conforms to the `Migration` protocol and is implemented by filling out the `up` and `down` functions. `up` is run when a migration is applied to a database. `down` is run when a migration is rolled back. - -`up` and `down` are passed a `Schema` object representing the schema of the database to which this migration will be applied. The database schema is modified via functions on `Schema`. - -For example, this migration renames the `user_todos` table to `todos`. Notice the `down` function does the reverse. You don't _have_ to fill out the down function of a migration, but it may be useful for rolling back the operation later. - -```swift -struct RenameTodos: Migration { - func up(schema: Schema) { - schema.rename(table: "user_todos", to: "todos") - } - - func down(schema: Schema) { - schema.rename(table: "todos", to: "user_todos") - } -} -``` - -## Schema functions - -`Schema` has a variety of useful builder methods for doing various database migrations. - -## Creating a table - -You can create a new table using `Schema.create(table: String, builder: (inout CreateTableBuilder) -> Void)`. - -The `CreateTableBuilder` comes packed with a variety of functions for adding columns of various types & modifiers to the new table. - -```swift -schema.create(table: "users") { table in - table.uuid("id").primary() - table.string("name").notNull() - table.string("email").notNull().unique() - table.uuid("mom").references("id", on: "users") -} -``` - -### Adding Columns - -You may add a column onto a table builder with functions like `.string()` or `.int()`. These define a named column of the given type and return a column builder for adding modifiers to the column. - -Supported builder functions for adding columns are - -| Table Builder Functions | Column Builder Functions | -|-|-| -| `.uuid(_ column: String)` | `.default(expression: String)` | -| `.int(_ column: String)` | `.default(val: String)` | -| `.string(_ column: String)` | `.notNull()` | -| `.increments(_ column: String)` | `.unique()` | -| `.double(_ column: String)` | `.primary()` | -| `.bool(_ column: String)` | `.references(_ column: String, on table: String)` | -| `.date(_ column: String)` | -| `.json(_ column: String)` | - -### Adding Indexes - -Indexes can be added via `.addIndex`. They can be on a single column or multiple columns and can be defined as unique or not. - -```swift -schema.create(table: "users") { table in - ... - table.addIndex(columns: ["email"], unique: true) -} -``` - -Indexes are named by concatinating table name + columns + "key" if unique or "idx" if not, all joined with underscores. For example, the index defined above would be named `users_email_key`. - -## Altering a Table - -You can alter an existing table with `alter(table: String, builder: (inout AlterTableBuilder) -> Void)`. - -`AlterTableBuilder` has the exact same interface as `CreateTableBuilder` with a few extra functions for dropping columns, dropping indexes, and renaming columns. - -```swift -schema.alter(table: "users") { - $0.bool("is_expired").default(val: false) - $0.drop(column: "name") - $0.drop(index: "users_email_key") - $0.rename(column: "createdAt", to: "created_at") -} -``` - -## Other schema functions - -You can also drop tables, rename tables, or execute arbitrary SQL strings from a migration. - -```swift -schema.drop(table: "old_users") -schema.rename(table: "createdAt", to: "created_at") -schema.raw("drop schema public cascade") -``` - -## Running a Migration - -To begin, you need to ensure that your migrations are registered on `Database.default`. You can should do this in your `Application.boot` function. - -```swift -// Make sure to register a database with `Database.config(default: )` first! -Database.default.migrations = [ - CreateUsers(), - CreateTodos(), - RenameTodos() -] -``` - -### Via Command - -#### Applying - -You can then apply all outstanding migrations in a single batch by passing the `migrate` argument to your app. This will cause the app to migrate `Database.default` instead of serving. - -```bash -# Applies all outstanding migrations -swift run Server migrate -``` - -#### Rolling Back - -You can pass the `--rollback` flag to instead rollback the latest batch of migrations. - -```bash -# Rolls back the most recent batch of migrations -swift run Server migrate --rollback -``` - -#### When Serving - -If you'd prefer to avoid running a separate migration command, you may pass the `--migrate` flag when running your server to automatically run outstanding migrations before serving. - -```swift -swift run Server --migrate -``` - -**Note**: Alchemy keeps track of run migrations and the current batch in your database in the `migrations` table. You can delete this table to clear all records of migrations. - -### Via Code - -#### Applying - -You may also migrate your database in code. The future will complete when the migration is finished. - -```swift -database.migrate() -``` - -#### Rolling Back - -Rolling back the latest migration batch is also possible in code. - -```swift -database.rollbackMigrations() -``` - -_Next page: [Redis](5d_Redis.md)_ - -_[Table of Contents](/Docs#docs)_ diff --git a/Docs/5d_Redis.md b/Docs/5d_Redis.md deleted file mode 100644 index e2c50e98..00000000 --- a/Docs/5d_Redis.md +++ /dev/null @@ -1,161 +0,0 @@ -# Redis - -- [Connecting to Redis](#connecting-to-redis) - * [Clusters](#clusters) -- [Interacting With Redis](#interacting-with-redis) -- [Scripting](#scripting) -- [Pub / Sub](#pub--sub) - * [Wildcard Subscriptions](#wildcard-subscriptions) -- [Transactions](#transactions) - -Redis is an open source, in-memory data store than can be used as a database, cache, and message broker. - -Alchemy provides first class Redis support out of the box, building on the extensive [RediStack](https://github.com/Mordil/RediStack) library. - -## Connecting to Redis - -You can connect to Redis using the `Redis` type. You should register this type for injection in your `Application.boot()`. It conforms to `Service` so you can do so with the `config` function. - -```swift -Redis.config(default: .connection("localhost")) -``` - -The intializer optionally takes a password and database index (if the index isn't supplied, Redis will connect to the database at index 0, the default). - -```swift -Redis.config(default: .connection( - "localhost", - port: 6379, - password: "P@ssw0rd", - database: 1 -)) -``` - -### Clusters - -If you're using a Redis cluster, your client can connect to multiple instances by passing multiple `Socket`s to the initializer. Connections will be distributed across the instances. - -```swift -Redis.config("cluster", .cluster( - .ip("localhost", port: 6379), - .ip("61.123.456.789", port: 6379), - .unix("/path/to/socket") -)) -``` - -## Interacting With Redis - -`Redis` conforms to `RediStack.RedisClient` meaning that by default, it has functions around nearly all Redis commands. - -You can easily get and set a value. - -```swift -// Get a value. -redis.get("some_key", as: String.self) // EventLoopFuture - -// Set a value. -redis.set("some_int", to: 42) // EventLoopFuture -``` - -You can also increment a value. -```swift -redis.increment("my_counter") // EventLoopFuture -``` - -There are convenient extensions for just about every command Redis supports. - -```swift -redis.lrange(from: "some_list", indices: 0...3) -``` - -Alternatively, you can always run a custom command via `command`. The first argument is the command name, all subsequent arguments are the command's arguments. - -```swift -redis.command("lrange", "some_list", 0, 3) -``` - -## Scripting - -You can run a script via `.eval(...)`. - -Scripts are written in Lua and have access to 1-based arrays `KEYS` and `ARGV` for accessing keys and arguments respectively. They also have access to a `redis` variable for calling Redis inside the script. Consult the [EVAL documentation](https://redis.io/commands/eval) for more information on scripting. - -```swift -redis.eval( - """ - local counter = redis.call("incr", KEYS[1]) - - if counter > 5 then - redis.call("incr", KEYS[2]) - end - - return counter - """, - keys: ["key1", "key2"] -) -``` - -## Pub / Sub - -Redis provides `publish` and `subscribe` commands to publish and listen to various channels. - -You can easily subscribe to a single channel or multiple channels. - -```swift -redis.subscribe(to: "my-channel") { value in - print("my-channel got: \(value)") -} - -redis.subscribe(to: ["my-channel", "other-channel"]) { channelName, value in - print("\(channelName) got: \(value)") -} -``` - -Publishing to them is just as simple. - -```swift -redis.publish("hello", to: "my-channel") -``` - -If you want to stop listening to a channel, use `unsubscribe`. - -```swift -redis.unsubscribe(from: "my-channel") -``` - -### Wildcard Subscriptions - -You may subscribe to wildcard channels using `psubscribe`. - -```swift -redis.psubscribe(to: ["*"]) { channelName, value in - print("\(channelName) got: \(value)") -} - -redis.psubscribe(to: ["subscriptions.*"]) { channelName, value in - print("\(channelName) got: \(value)") -} -``` - -Unsubscribe with `punsubscribe`. - -```swift -redis.punsubscribe(from: "*") -``` - -## Transactions - -Sometimes, you'll want to run multiple commands atomically to avoid race conditions. Alchemy makes this simple with the `transaction()` function which provides a wrapper around Redis' native `MULTI` & `EXEC` commands. - -```swift -redis.transaction { conn in - conn.increment("first_counter") - .flatMap { _ in - conn.increment("second_counter") - } -} -``` - -_Next page: [Rune Basics](6a_RuneBasics.md)_ - -_[Table of Contents](/Docs#docs)_ diff --git a/Docs/6a_RuneBasics.md b/Docs/6a_RuneBasics.md deleted file mode 100644 index 6d0df369..00000000 --- a/Docs/6a_RuneBasics.md +++ /dev/null @@ -1,313 +0,0 @@ -# Rune: Basics - -- [Creating a Model](#creating-a-model) -- [Custom Table Names](#custom-table-names) - * [Custom Key Mappings](#custom-key-mappings) -- [Model Field Types](#model-field-types) - * [Basic Types](#basic-types) - * [Advanced Types](#advanced-types) - + [Enums](#enums) - + [JSON](#json) - + [Custom JSON Encoders](#custom-json-encoders) - + [Custom JSON Decoders](#custom-json-decoders) -- [Decoding from `DatabaseRow`](#decoding-from-databaserow) -- [Model Querying](#model-querying) - * [All Models](#all-models) - * [First Model](#first-model) - * [Quick Lookups](#quick-lookups) -- [Model CRUD](#model-crud) - * [Get All](#get-all) - * [Save](#save) - * [Delete](#delete) - * [Sync](#sync) - * [Bulk Operations](#bulk-operations) - -Alchemy includes Rune, an object-relational mapper (ORM) to make it simple to interact with your database. With Rune, each database table has a corresponding `Model` type that is used to interact with that table. Use this Model type for querying, inserting, updating or deleting from the table. - -## Creating a Model - -To get started, implement the Model protocol. All it requires is an `id` property. Each property of your `Model` will correspond to a table column with the same name, converted to `snake_case`. - -```swift -struct User: Model { - var id: Int? // column `id` - let firstName: String // column `first_name` - let lastName: String // column `last_name` - let age: Int // column `age` -} -``` - -**Warning**: `Model` APIs rely heavily on Swift's `Codable`. Please avoid overriding the compiler synthesized `func encode(to: Encoder)` and `init(from: Decoder)` functions. You might be able to get away with it but it could cause issues under the hood. You _can_ however, add custom `CodingKeys` if you like, just be aware of the impact it will have on the `keyMappingStrategy` described below. - -## Custom Table Names - -By default, your model will correspond to a table with the name of your model type, pluralized. For custom table names, you can override the static `tableName: String` property. - -```swift -// Corresponds to table name `users`. -struct User: Model {} - -struct Todo: Model { - static let tableName = "todo_table" -} -``` - -### Custom Key Mappings - -As mentioned, by default all `Model` property names will be converted to `snake_case`, when mapping to corresponding table columns. You may change this behavior via the `keyMapping: DatabaseKeyMapping`. You could set it to `.useDefaultKeys` to use the verbatim `CodingKey`s of the `Model` object, or `.custom((String) -> String)` to provide a custom mapping closure. - -```swift -struct User: Model { - static let keyMapping = .useDefaultKeys - - var id: Int? // column `id` - let firstName: String // column `firstName` - let lastName: String // column `lastName` - let age: Int // column `age` -} -``` - -## Model Field Types - -### Basic Types - -Models support most basic Swift types such as `String`, `Bool`, `Int`, `Double`, `UUID`, `Date`. Under the hood, these are mapped to relevant types on the concrete `Database` you are using. - -### Advanced Types - -Models also support some more advanced Swift types, such as `enum`s and `JSON`. - -#### Enums - -`String` or `Int` backed Swift `enum`s are allowed as fields on a `Model`, as long as they conform to `ModelEnum`. - -```swift -struct Todo: Model { - enum Priority: String, ModelEnum { - case low, medium, high - } - - var id: Int? - let name: String - let isComplete: Bool - let priority: Priority -} -``` - -#### JSON - -Models require all properties to be `Codable`, so any property that isn't one of the types listed above will be stored as `JSON`. - -```swift -struct Todo: Model { - struct TodoMetadata: Codable { - var createdAt: Date - var lastUpdated: Date - var colorName: String - var comment: String - } - - var id: Int? - - let name: String - let isDone: Bool - let metadata: TodoMetadata // will be stored as JSON -} -``` - -#### Custom JSON Encoders - -By default, `JSON` properties are encoded using a default `JSONEncoder()` and stored in the table column. You can use a custom `JSONEncoder` by overriding the static `Model.jsonEncoder`. - -```swift -struct Todo: Model { - static var jsonEncoder: JSONEncoder = { - let encoder = JSONEncoder() - encoder.outputFormatting = .prettyPrinted - return encoder - }() - - ... -} -``` - -#### Custom JSON Decoders - -Likewise, you can provide a custom `JSONDecoder` for decoding data from JSON columns. - -```swift -struct Todo: Model { - static var jsonDecoder: JSONDecoder = { - let decoder = JSONDecoder() - decoder.dateDecodingStrategy = .iso8601 - return decoder - }() - - ... -} -``` - -## Decoding from `DatabaseRow` - -`Model`s may be "decoded" from a `DatabaseRow` that was the result of a raw query or query builder query. The `Model`'s properties will be mapped to their relevant columns, factoring in any custom `keyMappingStrategy`. This will throw an error if there is an issue while decoding, such as a missing column. - -```swift -struct User: Model { - var id: Int? - let firstName: String - let lastName: String - let age: String -} - -database.rawQuery("select * from users") - .mapEach { try! $0.decode(User.self) } - .whenSuccess { users in - for user in users { - print("Got user named \(user.firstName) \(user.lastName).") - } - } -``` - -**Note**: For the most part, if you are using Rune you won't need to call `DatabaseRow.decode(_ type:)` because the typed ORM queries described in the next section decode it for you. - -## Model Querying - -To add some type safety to query builder queries, you can initiate a typed query off of a `Model` with the static `.query` function. - -```swift -let users = User.query().allModels() -``` - -`ModelQuery` is a subclass of the generic `Query`, with a few functions for running and automatically decoding `M` from a query. - -### All Models - -`.allModels()` returns an EventLoopFuture<[M]> containing all `Model`s that matched the query. - -```swift -User.query() - .where("name", in: ["Josh", "Chris", "Rachel"]) - .allModels() // EventLoopFuture<[User]> of all users named Josh, Chris, or Rachel -``` - -### First Model - -`.firstModel()` returns an `EventLoopFuture` containing the first `Model` that matched the query, if it exists. - -```swift -User.query() - .where("age" > 30) - .firstModel() // EventLoopFuture with the first User over age 30. -``` - -If you want to throw an error if no item is found, you would `.unwrapFirst(or error: Error)`. - -```swift -let userEmail = ... -User.query() - .where("email" == userEmail) - .unwrapFirst(or: HTTPError(.unauthorized)) -``` - -### Quick Lookups - -There are also two functions for quickly looking up a `Model`. - -`ensureNotExists(where:error:)` does a query to ensure that a `Model` matching the provided where clause doesn't exist. If it does, it throws the provided error. - -```swift -func createNewAccount(with email: String) -> EventLoopFuture { - User.ensureNotExists(where: "email" == email, else: HTTPError(.conflict)) -} -``` - -`unwrapFirstWhere(_:error:)` is essentially the opposite, finding the first `Model` that matches the provided where clause or throwing an error if one doesn't exist. - -```swift -func resetPassword(for email: String) -> EventLoopFuture { - User.unwrapFirstWhere("email" == email, or: HTTPError(.notFound)) - .flatMap { user in - // reset the user's password - } -} -``` - -## Model CRUD - -There are also convenience functions around creating, fetching, and deleting `Model`s. - -### Get All - -Fetch all records of a `Model` with the `all()` function. - -```swift -User.all() - .whenSuccess { - print("There are \($0.count) users.") - } -``` - -### Save - -Save a `Model` to the database, either inserting it or updating it depending on if it has a nil id. - -```swift -// Creates a new user -User(name: "Josh", email: "josh@example.com") - .save() - -User.unwrapFirstWhere("email" == "josh@example.com") - .flatMap { user in - user.name = "Joshua" - // Updates the User's name. - return user.save() - } -``` - -### Delete - -Delete an existing `Model` from the database with `delete()`. - -```swift -let existingUser: User = ... -existingUser.delete() - .whenSuccess { - print("The user is deleted.") - } -``` - -### Sync - -Fetch an up to date copy of this `Model`. - -```swift -let outdatedUser: User = ... -outdatedUser.sync() - .whenSuccess { upToDateUser in - print("User's name is: \(upToDateUser.name)") - } -``` - -### Bulk Operations - -You can also do bulk inserts or deletes on `[Model]`. - -```swift -let newUsers: [User] = ... -newUsers.insertAll() - .whenSuccess { users in - print("Added \(users.count) new users!") - } -``` - -```swift -let usersToDelete: [User] = ... -usersToDelete.deleteAll() - .whenSuccess { - print("Added deleted \(usersToDelete.count) users.") - } -``` - -_Next page: [Rune: Relationships](6b_RuneRelationships.md)_ - -_[Table of Contents](/Docs#docs)_ \ No newline at end of file diff --git a/Docs/6b_RuneRelationships.md b/Docs/6b_RuneRelationships.md deleted file mode 100644 index 283e285f..00000000 --- a/Docs/6b_RuneRelationships.md +++ /dev/null @@ -1,249 +0,0 @@ -# Rune: Relationships - -- [Relationship Types](#relationship-types) - * [BelongsTo](#belongsto) - * [HasMany](#hasmany) - * [HasOne](#hasone) - * [HasMany through](#hasmany-through) - * [HasOne through](#hasone-through) - * [ManyToMany](#manytomany) -- [Eager Loading Relationships](#eager-loading-relationships) - * [Nested Eager Loading](#nested-eager-loading) - -Relationships are an important part of an SQL database. Rune provides first class support for defining, keeping track of, and loading relationships between records. - -## Relationship Types - -Out of the box, Rune supports three categories of relationships, represented by property wrappers `@BelongsTo`, `@HasMany`, and `@HasOne`. - -Consider a database with tables `users`, `todos`, `tags`, `todo_tags`. - -``` -users - - id - -todos - - id - - user_id - - name - -tags - - id - - name - -todo_tags - - id - - todo_id - - tag_id -``` - -### BelongsTo - -A `BelongsTo` is the simplest kind of relationship. It represents the child of a 1-1 or 1-M relationship. The child typically has a column referencing the primary key of another table. - -```swift -struct Todo: Model { - @BelongsTo var user: User -} -``` - -Given the `@BelongsTo` property wrapper and types, Rune will infer a `user_id` key on Todo and an `id` key on `users` when eager loading. If the keys differ, for example `users` local key is `my_id` you may access the `RelationshipMapping` in `Model.mapRelations` and override either key with `to(...)` or `from(...)`. `to` overrides the key on the destination of the relation, `from` overrides the key on the model the relation is on. - -```swift -struct Todo: Model { - @BelongsTo var user: User - - static func mapRelations(_ mapper: RelationshipMapper) { - // config takes a `KeyPath` to a relationship and returns its mapping - mapper.config(\.$user).to("my_id") - } -} -``` - -### HasMany - -A "HasMany" relationship represents the Parent side of a 1-M or a M-M relationship. - -```swift -struct User: Model { - @HasMany var todos: [Todo] -} -``` - -Again, Alchemy is inferring a local key `id` on `users` and a foreign key `user_id` on `todos`. You can override either using the same `mapRelations` function. - -```swift -struct User: Model { - @HasMany var todos: [Todo] - - static func mapRelations(_ mapper: RelationshipMapper) { - mapper.config(\.$todos).from("my_id").to("parent_id") - } -} -``` - -### HasOne - -Has one, a has relationship where there is only one value, functions the same as `HasMany` except it wraps single value, not an array. Overriding keys works the same way. - -```swift -struct User: Model { - @HasOne var car: Car -} -``` - -### HasMany through - -The `.through(...)` mapping provides a convenient way to access distant relations via an intermediate relation. - -Consider tables representing a CI system `user`, `projects`, `workflows`. - -``` -users - - id - -projects - - id - - user_id - -workflows - - id - - project_id -``` - -Given a user, you could access their workflows, through the project table by using the `through(...)` function. - -```swift -struct User: Model { - @HasMany var workflows: [Workflow] - - static func mapRelations(_ mapper: RelationshipMapper) { - mapper.config(\.$workflows).through("projects") - } -} -``` - -Again, Alchemy assumes all the keys in this relationship based on the types of the relationship, and the intermediary table name. You can override this using the same `.from` & `.to` functions and you can override the intermediary table keys with the `from` and `to` parameters of `through`. - -```swift -struct User: Model { - @HasMany var workflows: [Workflow] - - static func mapRelations(_ mapper: RelationshipMapper) { - mapper.config(\.$workflows) - .from("my_id") - .through("projects", from: "the_user_id", to: "_id") - .to("my_project_id") - } -} -``` - -### HasOne through - -The `.through(...)` mapping can also be applied to a `HasOne` relationship. It functions the same, with overrides available for `from`, `throughFrom`, `throughTo`, and `to`. - -```swift -struct User: Model { - @HasOne var workflow: Workflow - - static func mapRelations(_ mapper: RelationshipMapper) { - mapper.config(\.$workflow).through("projects") - } -} -``` - -### ManyToMany - -Often you'll have relationships that are defined by a pivot table containing references to each side of the relationship. You can use the `throughPivot` function to define a `@HasMany` relationship to function this way. - -```swift -struct Todo: Model { - @HasMany var tags: [Tag] - - static func mapRelations(_ mapper: RelationshipMapper) { - mapper.config(\.$tags).throughPivot("todo_tags") - } -} -``` - -Like `through`, keys are inferred but you may specify `from` and `to` parameters to indicate the keys on the pivot table. - -```swift -struct Todo: Model { - @HasMany var tags: [Tag] - - static func mapRelations(_ mapper: RelationshipMapper) { - mapper.config(\.$tags).throughPivot("todo_tags", from: "the_todo_id", to: "the_tag_id") - } -} -``` - -## Eager Loading Relationships - -In order to access a relationship property of a queried `Model`, you need to load that relationship first. You can "eager load" it using the `.with()` function on a `ModelQuery`. Eager loading refers to preemptively, or "eagerly", loading a relationship before it is used. Eager loading also solves the N+1 problem; if N `Pet`s are returned with a query, you won't need to run N queries to find each of their `Owner`s. Instead, a single, followup query will be run that finds all `Owner`s for all `Pet`s fetched. - -This function takes a `KeyPath` to a relationship and runs a query to fetch it when the initial query is finished. - -```swift -Pet.query() - .with(\.$person) - .getAll() - .whenSuccess { pets in - for pet in pets { - print("Pet \(pet.name) has owner \(pet.person.name)") - } - } -``` - -You may chain any number of eager loads from a `Model` using `.with()`. - -```swift -Pets.query() - .with(\.$owner) - .with(\.$otherRelationship) - .with(\.$yetAnotherRelationship) - .getAll() -``` - -**Warning 1**: The `.with()` function takes a `KeyPath` to a _relationship_ not a `Model`, so be sure to preface your key path with a `$`. - -**Warning 2**: If you access a relationship before it's loaded, the program will `fatalError`. Be sure a relationship is loaded with eager loading before accessing it! - -### Nested Eager Loading - -You may want to load relationships on your eager loaded relationship `Model`s. You can do this with the second, closure argument of `with()`. - -Consider three relationships, `Homework`, `Student`, `School`. A `Homework` belongs to a `Student` and a `Student` belongs to a `School`. - -You might represent them in a database like so - -```swift -struct Homework: Model { - @BelongsTo var student: Student -} - -struct Student: Model { - @BelongsTo var school: School -} - -struct School: Model {} -``` - -To load all these relationships when querying `Homework`, you can use nested eager loading like so - -```swift -Homework.query() - .with(\.$student) { student in - student.with(\.$school) - } - .getAll() - .whenSuccess { homeworks in - for homework in homeworks { - // Can safely access `homework.student` and `homework.student.school` - } - } -``` - -_Next page: [Security](7_Security.md)_ - -_[Table of Contents](/Docs#docs)_ diff --git a/Docs/7_Security.md b/Docs/7_Security.md deleted file mode 100644 index ecb6f450..00000000 --- a/Docs/7_Security.md +++ /dev/null @@ -1,160 +0,0 @@ -# Security - -- [Bcrypt](#bcrypt) -- [Request Auth](#request-auth) - * [Authorization: Basic](#authorization-basic) - * [Authorization: Bearer](#authorization-bearer) - * [Authorization: Either](#authorization-either) -- [Auth Middleware](#auth-middleware) - * [Basic Auth Middleware](#basic-auth-middleware) - * [Token Auth Middleware](#token-auth-middleware) - -Alchemy provides built in support for Bcrypt hashing and automatic authentication via Rune & `Middleware`. - -## Bcrypt - -Standard practice is to never store plain text passwords in your database. Bcrypt is a password hashing function that creates a one way hash of a plaintext password. It's an expensive process CPU-wise, so it will help protect your passwords from being easily cracked through brute forcing. - -It's simple to use. - -```swift -let hashedPassword = Bcrypt.hash("password") -let isPasswordValid = Bcrypt.verify("password", hashedPassword) // true -``` - -Because it's expensive, you may want to run this off of an `EventLoop` thread. For convenience, there's an API for that. This will run Bcrypt on a separate thread and complete back on the initiating `EventLoop`. - -```swift -Bcrypt.hashAsync("password") - .whenSuccess { hashedPassword in - // do something with the hashed password - } - -Bcrypt.verifyAsync("password", hashedPassword) - .whenSuccess { isMatch in - print("Was a match? \(isMatch).") - } -``` - -## Request Auth - -`Request` makes it easy to pull `Authorization` information off an incoming request. - -### Authorization: Basic - -You can access `Basic` auth info via `.basicAuth() -> HTTPAuth.Basic?`. - -```swift -let request: Request = ... -if let basic = request.basicAuth() { - print("Got basic auth. Username: \(basic.username) Password: \(basic.password)") -} -``` - -### Authorization: Bearer - -You can also get `Bearer` auth info via `.bearerAuth() -> HTTPAuth.Bearer?`. - -```swift -let request: Request = ... -if let bearer = request.bearerAuth() { - print("Got bearer auth with Token: \(bearer.token)") -} -``` - -### Authorization: Either - -You can also get any `Basic` or `Bearer` auth from the request. - -```swift -let request: Request = ... -if let auth = request.getAuth() { - switch auth { - case .bearer(let bearer): - print("Request had Basic auth!") - case .basic(let basic): - print("Request had Basic auth!") - } -} -``` - -## Auth Middleware - -Incoming `Request` can be automatically authorized against your Rune `Model`s by conforming your `Model`s to "authable" protocols and protecting routes with the generated `Middleware`. - -### Basic Auth Middleware - -To authenticate via the `Authorization: Basic ...` headers on incoming `Request`s, conform your Rune `Model` that stores usernames and password hashes to `BasicAuthable`. - -```swift -struct User: Model, BasicAuthable { - var id: Int? - let username: String - let password: String -} -``` - -Now, put `User.basicAuthMiddleware()` in front of any endpoints that need basic auth. When the request comes in, the `Middleware` will compare the username and password in the `Authorization: Basic ...` headers to the username and password hash of the `User` model. If the credentials are valid, the `Middleware` will set the relevant `User` instance on the `Request`, which can then be accessed via `request.get(User.self)`. - -If the credentials aren't valid, or there is no `Authorization: Basic ...` header, the Middleware will throw an `HTTPError(.unauthorized)`. - -```swift -app.use(User.basicAuthMiddleware()) -app.get("/login") { req in - let authedUser = try req.get(User.self) - // Do something with the authorized user... -} -``` - -Note that Rune is inferring a username at column `"email"` and password at column `"password"` when verifying credentials. You may set custom columns by overriding the `usernameKeyString` or `passwordKeyString` of your `Model`. - -```swift -struct User: Model, BasicAuthable { - static let usernameKeyString = "username" - static let passwordKeyString = "hashed_password" - - var id: Int? - let username: String - let hashedPassword: String -} -``` - -### Token Auth Middleware - -Similarly, to authenticate via the `Authorization: Bearer ...` headers on incoming `Request`s, conform your Rune `Model` that stores access token values to `TokenAuthable`. Note that this time, you'll need to specify a `BelongsTo` relationship to the User type this token authorizes. - -```swift -struct UserToken: Model, BasicAuthable { - var id: Int? - let value: String - - @BelongsTo var user: User -} -``` - -Like with `Basic` auth, put the `UserToken.tokenAuthMiddleware()` in front of endpoints that are protected by bearer authorization. The `Middleware` will automatically parse out tokens from incoming `Request`s and validate them via the `UserToken` type. If the token matches a `UserToken` row, the related `User` and `UserToken` will be `.set()` on the `Request` for access in a handler. - -```swift -router.middleWare(UserToken.tokenAuthMiddleware()) - .on(.GET, at: "/todos") { req in - let authedUser = try req.get(User.self) - let theToken = try req.get(UserToken.self) - } -``` - -Note that Rune is again inferring a `"value"` column on the `UserToken` to which it will compare the tokens on incoming `Request`s. This can be customized by overriding the `valueKeyString` property of your `Model`. - -```swift -struct UserToken: Model, BasicAuthable { - static let valueKeyString = "token_string" - - var id: Int? - let tokenString: String - - @BelongsTo var user: User -} -``` - -_Next page: [Queues](8_Queues.md)_ - -_[Table of Contents](/Docs#docs)_ diff --git a/Docs/8_Queues.md b/Docs/8_Queues.md deleted file mode 100644 index 598330a1..00000000 --- a/Docs/8_Queues.md +++ /dev/null @@ -1,152 +0,0 @@ -# Queues - -- [Configuring Queues](#configuring-queues) -- [Creating Jobs](#creating-jobs) -- [Dispatching Jobs](#dispatching-jobs) -- [Dequeuing and Running Jobs](#dequeuing-and-running-jobs) -- [Channels](#channels) -- [Handling Job Failures](#handling-job-failures) - -Often your app will have long running operations, such as sending emails or reading files, that take too long to run during a client request. To help with this, Alchemy makes it easy to create queued jobs that can be persisted and run in the background. Your requests will stay lightning fast and important long running operations will never be lost if your server restarts or re-deploys. - -Configure your queues with the `Queue` class. Out of the box, Alchemy provides drivers for queues backed by Redis and SQL as well as an in-memory mock queue. - -## Configuring Queues - -Like other Alchemy services, Queue conforms to the `Service` protocol. Configure it with the `config` function. - -```swift -Queue.config(default: .redis()) -``` - -If you're using the `database()` queue configuration, you'll need to add the `Queue.AddJobsMigration` migration to your database's migrations. - -```swift -Database.default.migrations = [ - Queue.AddJobsMigration(), - ... -] -``` - -## Creating Jobs - -To make a task to run on a queue, conform to the `Job` protocol. It includes a single `run` function. It also requires `Codable` conformance, so that any properties will be serialized and available when the job is run. - -```swift -struct SendWelcomeEmail: Job { - let email: String - - func run() -> EventLoopFuture { - // Send welcome email to email - } -} -``` - -Note that Rune `Model`s are Codable and can thus be included and persisted as properties of a job. - - -```swift -struct ProcessUserTransactions: Job { - let user: User - - func run() -> EventLoopFuture { - // Process user's daily transactions - } -} -``` - -## Dispatching Jobs - -Dispatching a job is as simple as calling `dispatch()`. - -```swift -SendWelcomeEmail(email: "josh@withapollo.com").dispatch() -``` - -By default, Alchemy will dispatch your job on the default queue. If you'd like to run on a different queue, you may specify it. - -```swift -ProcessUserTransactions(user: user) - .dispatch(on: .named("other_queue")) -``` - -If you'd like to run something when your job is complete, you may override the `finished` function to hook into the result of a completed job. - -```swift -struct SendWelcomeEmail: Job { - let email: String - - func run() -> EventLoopFuture { ... } - - func finished(result: Result) { - switch result { - case .success: - Log.info("Successfully sent welcome email to \(email).") - case .failure(let error): - Log.error("Failed to send welcome email to \(email). Error was: \(error).") - } - } -} -``` - -## Dequeuing and Running Jobs - -To actually have your jobs run after dispatching them to a queue, you'll need to run workers that monitor your various queues for work to be done. - -You can spin up workers as a separate process using the `queue` command. - -```bash -swift run MyApp queues -``` - -If you don't want to manage another running process, you can pass the `--workers` flag when starting your server have it run the given amount of workers in process. - -```swift -swift run MyApp --workers 2 -``` - -You can view the various options for the `queues` command in [Configuration](1_Configuration.md#queue). - -## Channels - -Sometimes you may want to prioritize running some jobs over others or have workers that only run certain kinds of jobs. Alchemy provides the concept of a "channel" to help you do so. By default, jobs run on the "default" channel, but you can specify the specific channel name to run on with the channel parameter in `dispatch()`. - -```swift -SendPasswordReset(for: user).dispatch(channel: "email") -``` - -By default, a worker will dequeue jobs from a queue's `"default"` channel, but you can tell them dequeue from another channel with the -c option. - -```shell -swift run MyServer queue -c email -``` - -You can also have them dequeue from multiple channels by separating channel names with commas. It will prioritize jobs from the first channels over subsequent ones. - -```shell -swift run MyServer queues -c email,sms,push -``` - -## Handling Job Failures - -By default, jobs that encounter an error during execution will not be retried. If you'd like to retry jobs on failure, you can add the `recoveryStrategy` property. This indicates what should happen when a job is failed. - -```swift -struct SyncSubscriptions: Job { - // Retry this job up to five times. - var recoveryStrategy: RecoveryStrategy = .retry(5) -} -``` - -You can also specify the `retryBackoff` to wait the specified time amount before retrying a job. - -```swift -struct SyncSubscriptions: Job { - // After a job failure, wait 1 minute before retrying - var retryBackoff: TimeAmount = .minutes(1) -} -``` - -_Next page: [Cache](9_Cache.md)_ - -_[Table of Contents](/Docs#docs)_ \ No newline at end of file diff --git a/Docs/9_Cache.md b/Docs/9_Cache.md deleted file mode 100644 index 035ab719..00000000 --- a/Docs/9_Cache.md +++ /dev/null @@ -1,161 +0,0 @@ -# Cache - -- [Configuration](#configuration) -- [Interacting with the Cache](#interacting-with-the-cache) - * [Storing Items in the Cache](#storing-items-in-the-cache) - + [Storing Custom Types](#storing-custom-types) - * [Retreiving Cache Items](#retreiving-cache-items) - + [Checking for item existence](#checking-for-item-existence) - + [Incrementing and Decrementing items](#incrementing-and-decrementing-items) - * [Removing Items from the Cache](#removing-items-from-the-cache) -- [Adding a Custom Cache Driver](#adding-a-custom-cache-driver) - -You'll often want to cache the results of expensive or long running operations to save CPU time and respond to future requests faster. Alchemy provides a `Cache` type for easily interacting with common caching backends. - -## Configuration - -Cache conforms to `Service` and can be configured like other Alchemy services with the `config` function. Out of the box, drivers are provided for Redis and SQL based caches as well as an in memory mock cache. - -```swift -Cache.config(default: .redis()) -``` - -If you're using the `Cache.sql()` cache configuration, you'll need to add the `Cache.AddCacheMigration` migration to your database's migrations. - -```swift -Database.default.migrations = [ - Cache.AddCacheMigration(), - ... -] -``` - -## Interacting with the Cache - -### Storing Items in the Cache - -You can store values to the cache using the `set()` function. - -```swift -cache.set("num_unique_users", 62, for: .seconds(60)) -``` - -The third parameter is optional and if not passed the value will be stored indefinitely. - -#### Storing Custom Types - -You can store any type that conforms to `CacheAllowed` in a cache. Out of the box, `Bool`, `String`, `Int`, and `Double` are supported, but you can easily store your own types as well. - -```swift -extension URL: CacheAllowed { - public var stringValue: String { - return absoluteString - } - - public init?(_ string: String) { - self.init(string: string) - } -} -``` - -### Retreiving Cache Items - -Once set, a value can be retrived using `get()`. - -```swift -cache.get("num_unique_users") -``` - -#### Checking for item existence - -You can check if a cache contains a specific item using `has()`. - -```swift -cache.has("\(user.id)_last_login") -``` - -#### Incrementing and Decrementing items - -When working with numerical cache values, you can use `increment()` and `decrement()`. - -```swift -cache.increment("key") -cache.increment("key", by: 4) -cache.decrement("key") -cache.decrement("key", by: 4) -``` - -### Removing Items from the Cache - -You can use `delete()` to clear an item from the cache. - -```swift -cache.delete(key) -``` - -Using `remove()`, you can clear and return a cache item. - -```swift -let value = cache.remove(key) -``` - -If you'd like to clear all data from a cache, you may use wipe. - -```swift -cache.wipe() -``` - -## Adding a Custom Cache Driver - -If you'd like to add a custom driver for cache, you can implement the `CacheDriver` protocol. - -```swift -struct MemcachedCache: CacheDriver { - func get(_ key: String) -> EventLoopFuture { - ... - } - - func set(_ key: String, value: C, for time: TimeAmount?) -> EventLoopFuture { - ... - } - - func has(_ key: String) -> EventLoopFuture { - ... - } - - func remove(_ key: String) -> EventLoopFuture { - ... - } - - func delete(_ key: String) -> EventLoopFuture { - ... - } - - func increment(_ key: String, by amount: Int) -> EventLoopFuture { - ... - } - - func decrement(_ key: String, by amount: Int) -> EventLoopFuture { - ... - } - - func wipe() -> EventLoopFuture { - ... - } -} -``` - -Then, add a static configuration function for using your new cache backend. - -```swift -extension Cache { - static func memcached() -> Cache { - Cache(MemcachedCache()) - } -} - -Cache.config(default: .memcached()) -``` - -_Next page: [Commands](13_Commands.md)_ - -_[Table of Contents](/Docs#docs)_ \ No newline at end of file diff --git a/Docs/README.md b/Docs/README.md deleted file mode 100644 index da06581e..00000000 --- a/Docs/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# Docs - -Alchemy is an elegant, batteries included web framework for Swift. - -## Table of Contents - -|Basics|Routing & HTTP|Database & Rune ORM|Advanced| -|-|-|-|-| -|[Getting Started](0_GettingStarted.md)|[Basics](3a_RoutingBasics.md)|[Basics](5a_DatabaseBasics.md)|[Redis](5d_Redis.md)| -|[Configuration](1_Configuration.md)|[Middleware](3b_RoutingMiddleware.md)|[Query Builder](5b_DatabaseQueryBuilder.md)|[Queues](8_Queues.md)| -|[Services & DI](2_Fusion.md)|[Network Interfaces](4_Papyrus.md)|[Migrations](5c_DatabaseMigrations.md)|[Cache](9_Cache.md)| -|||[Rune: Basics](6a_RuneBasics.md)|[Commands](13_Commands.md)| -|||[Rune: Relationships](6b_RuneRelationships.md)|[Security](7_Security.md)| -||||[Digging Deeper](10_DiggingDeeper.md)| -||||[Deploying](11_Deploying.md)| -||||[Under the Hood](12_UnderTheHood.md)| diff --git a/Package.swift b/Package.swift index cd69b7ca..8af2dd99 100644 --- a/Package.swift +++ b/Package.swift @@ -1,35 +1,33 @@ -// swift-tools-version:5.4 +// swift-tools-version:5.5 import PackageDescription let package = Package( name: "alchemy", platforms: [ - .macOS(.v10_15), - .iOS(.v13), + .macOS(.v12), ], products: [ .library(name: "Alchemy", targets: ["Alchemy"]), + .library(name: "AlchemyTest", targets: ["AlchemyTest"]), ], dependencies: [ + .package(url: "https://github.com/hummingbird-project/hummingbird.git", from: "0.15.3"), + .package(url: "https://github.com/hummingbird-project/hummingbird-core.git", from: "0.13.3"), .package(url: "https://github.com/apple/swift-log.git", from: "1.0.0"), - .package(url: "https://github.com/apple/swift-nio.git", from: "2.0.0"), - .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.6.0"), - .package(url: "https://github.com/apple/swift-nio-http2.git", from: "1.9.0"), - .package(url: "https://github.com/apple/swift-argument-parser", .upToNextMinor(from: "0.3.0")), - .package(url: "https://github.com/vapor/postgres-nio.git", from: "1.1.0"), - .package(url: "https://github.com/vapor/mysql-nio.git", from: "1.3.0"), - .package(url: "https://github.com/vapor/postgres-kit", from: "2.0.0"), - .package(url: "https://github.com/vapor/mysql-kit", from: "4.1.0"), - .package(url: "https://github.com/swift-server/swift-service-lifecycle.git", from: "1.0.0-alpha"), + .package(url: "https://github.com/apple/swift-argument-parser", from: "1.0.0"), + .package(url: "https://github.com/vapor/postgres-kit", from: "2.4.0"), + .package(url: "https://github.com/vapor/mysql-kit", from: "4.3.0"), + .package(url: "https://github.com/vapor/sqlite-kit", from: "4.0.0"), + .package(url: "https://github.com/vapor/multipart-kit", from: "4.5.1"), .package(url: "https://github.com/swift-server/async-http-client.git", from: "1.0.0"), - .package(url: "https://github.com/alchemy-swift/papyrus", .upToNextMinor(from: "0.1.0")), - .package(url: "https://github.com/alchemy-swift/fusion", .upToNextMinor(from: "0.1.0")), + .package(url: "https://github.com/alchemy-swift/papyrus", .branch("main")), + .package(url: "https://github.com/alchemy-swift/fusion", .upToNextMinor(from: "0.3.0")), .package(url: "https://github.com/alchemy-swift/cron.git", from: "2.3.2"), .package(url: "https://github.com/alchemy-swift/pluralize", from: "1.0.1"), .package(url: "https://github.com/johnsundell/Plot.git", from: "0.8.0"), .package(url: "https://github.com/Mordil/RediStack.git", from: "1.0.0"), - .package(url: "https://github.com/jakeheis/SwiftCLI", .upToNextMajor(from: "6.0.3")), .package(url: "https://github.com/onevcat/Rainbow", .upToNextMajor(from: "4.0.0")), + .package(url: "https://github.com/vadymmarkov/Fakery", from: "5.0.0"), ], targets: [ .target( @@ -38,30 +36,36 @@ let package = Package( /// External dependencies .product(name: "ArgumentParser", package: "swift-argument-parser"), .product(name: "AsyncHTTPClient", package: "async-http-client"), - .product(name: "PostgresKit", package: "postgres-kit"), - .product(name: "PostgresNIO", package: "postgres-nio"), .product(name: "MySQLKit", package: "mysql-kit"), - .product(name: "MySQLNIO", package: "mysql-nio"), - .product(name: "NIO", package: "swift-nio"), - .product(name: "NIOHTTP1", package: "swift-nio"), - .product(name: "NIOHTTP2", package: "swift-nio-http2"), - .product(name: "NIOSSL", package: "swift-nio-ssl"), + .product(name: "PostgresKit", package: "postgres-kit"), + .product(name: "SQLiteKit", package: "sqlite-kit"), + .product(name: "MultipartKit", package: "multipart-kit"), + .product(name: "RediStack", package: "RediStack"), .product(name: "Logging", package: "swift-log"), .product(name: "Plot", package: "Plot"), - .product(name: "LifecycleNIOCompat", package: "swift-service-lifecycle"), - .product(name: "RediStack", package: "RediStack"), .product(name: "Papyrus", package: "papyrus"), .product(name: "Fusion", package: "fusion"), .product(name: "Cron", package: "cron"), .product(name: "Pluralize", package: "pluralize"), - .product(name: "SwiftCLI", package: "SwiftCLI"), .product(name: "Rainbow", package: "Rainbow"), + .product(name: "Fakery", package: "Fakery"), + .product(name: "HummingbirdFoundation", package: "hummingbird"), + .product(name: "HummingbirdHTTP2", package: "hummingbird-core"), + .product(name: "HummingbirdTLS", package: "hummingbird-core"), /// Internal dependencies - "CAlchemy", + .byName(name: "AlchemyC"), ] ), - .target(name: "CAlchemy", dependencies: []), - .testTarget(name: "AlchemyTests", dependencies: ["Alchemy"]), + .target(name: "AlchemyC", dependencies: []), + .target(name: "AlchemyTest", dependencies: ["Alchemy"]), + .testTarget( + name: "AlchemyTests", + dependencies: ["AlchemyTest"], + path: "Tests/Alchemy"), + .testTarget( + name: "AlchemyTestTests", + dependencies: ["AlchemyTest"], + path: "Tests/AlchemyTest"), ] ) diff --git a/README.md b/README.md index 581ea837..10fe4282 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,13 @@ -

+

-Swift Version +Swift Version Latest Release License

+> __Now fully `async/await`!__ + Welcome to Alchemy, an elegant, batteries included backend framework for Swift. You can use it to build a production ready backend for your next mobile app, cloud project or website. ```swift @@ -23,13 +25,13 @@ struct App: Application { Alchemy provides you with Swifty APIs for everything you need to build production-ready backends. It makes writing your backend in Swift a breeze by easing typical tasks, such as: -- [Simple, fast routing engine](Docs/3a_RoutingBasics.md). -- [Powerful dependency injection container](Docs/2_Fusion.md). -- Expressive, Swifty [database ORM](Docs/6a_RuneBasics.md). -- Database agnostic [query builder](Docs/5b_DatabaseQueryBuilder.md) and [schema migrations](Docs/5c_DatabaseMigrations.md). -- [Robust job queues backed by Redis or SQL](Docs/8_Queues.md). +- [Simple, fast routing engine](https://www.alchemyswift.com/essentials/routing). +- [Powerful dependency injection container](https://www.alchemyswift.com/getting-started/services). +- Expressive, Swifty [database ORM](https://www.alchemyswift.com/rune-orm/rune). +- Database agnostic [query builder](https://www.alchemyswift.com/database/query-builder) and [schema migrations](https://www.alchemyswift.com/database/migrations). +- [Robust job queues backed by Redis or SQL](https://www.alchemyswift.com/digging-deeper/queues). - First class support for [Plot](https://github.com/JohnSundell/Plot), a typesafe HTML DSL. -- [Supporting libraries to share typesafe backend APIs with Swift frontends](Docs/4_Papyrus.md). +- [Supporting libraries to share typesafe backend APIs with Swift frontends](https://www.alchemyswift.com/supporting-libraries/papyrus). ## Why Alchemy? @@ -47,47 +49,28 @@ With Routing, an ORM, advanced Redis & SQL support, Authentication, Queues, Cron APIs focus on simple syntax with lots of baked in convention so you can build much more with less code. This doesn't mean you can't customize; there's always an escape hatch to configure things your own way. -**3. Ease of Use** - -A fully documented codebase organized in a single repo make it easy to get building, extending and contributing. - -**4. Keep it Swifty** - -Swift is built to write concice, safe and elegant code. Alchemy leverages it's best parts to help you write great code faster and obviate entire classes of backend bugs. +**3. Rapid Development** -# Get Started - -The Alchemy CLI is installable with [Mint](https://github.com/yonaskolb/Mint). +Alchemy is designed to help you take apps from idea to implementation as swiftly as possible. -```shell -mint install alchemy-swift/alchemy-cli -``` +**4. Interoperability** -## Create a New App +Alchemy is built on top of the lightweight, [blazingly](https://web-frameworks-benchmark.netlify.app/result?l=swift) fast [Hummingbird](https://github.com/hummingbird-project/hummingbird) framework. It is fully compatible with existing `swift-nio` and `vapor` components like [stripe-kit](https://github.com/vapor-community/stripe-kit), [soto](https://github.com/soto-project/soto) or [jwt-kit](https://github.com/vapor/jwt-kit) so that you can easily integrate with all existing Swift on the Server work. -Creating an app with the CLI lets you pick between a backend or fullstack project. +**5. Keep it Swifty** -1. `alchemy new MyNewProject` -2. `cd MyNewProject` (if you selected fullstack, `MyNewProject/Backend`) -3. `swift run` -4. view your brand new app at http://localhost:3000 +Swift is built to write concice, safe and elegant code. Alchemy leverages it's best parts to help you write great code faster and obviate entire classes of backend bugs. With v0.4.0 and above, it's API is completely `async/await` meaning you have access to all Swift's cutting edge concurrency features. -## Swift Package Manager - -You can also add Alchemy to your project manually with the [Swift Package Manager](https://github.com/apple/swift-package-manager). - -```swift -.package(url: "https://github.com/alchemy-swift/alchemy", .upToNextMinor(from: "0.3.0")) -``` +# Get Started -Until `1.0.0` is released, minor version changes might be breaking, so you may want to use `upToNextMinor`. +To get started check out the extensive docs starting with [Setup](https://www.alchemyswift.com/getting-started/setup). # Usage -You can view example apps in the [alchemy-examples repo](https://github.com/alchemy-swift/alchemy-examples). - The [Docs](Docs#docs) provide a step by step walkthrough of everything Alchemy has to offer. They also touch on essential core backend concepts for developers new to server side development. Below are some of the core pieces. +If you'd prefer to dive into some code, check out the example apps in the [alchemy-examples repo](https://github.com/alchemy-swift/alchemy-examples). + ## Basics & Routing Each Alchemy project starts with an implemention of the `Application` protocol. It has a single function, `boot()` for you to set up your app. In `boot()` you'll define your configurations, routes, jobs, and anything else needed to set up your application. @@ -98,18 +81,32 @@ Routing is done with action functions `get()`, `post()`, `delete()`, etc on the @main struct App: Application { func boot() { - post("/say_hello") { req -> String in - let name = req.query(for: "name")! - return "Hello, \(name)!" + post("/hello") { req in + "Hello, \(req.query("name")!)!" + } + + // handlers can be async supported + get("/download") { req in + // Fetch an image from another site. + try await Http.get("https://example.com/image.jpg") } } } ``` +Route handlers can also be async using Swift's new concurrency features. + +```swift +get("/download") { req in + // Fetch an image from another site. + try await Http.get("https://example.com/image.jpg") +} +``` + Route handlers will automatically convert returned `Codable` types to JSON. You can also return a `Response` if you'd like full control over the returned content & it's encoding. ```swift -struct Todo { +struct Todo: Codable { let name: String let isComplete: Bool let created: Date @@ -130,8 +127,8 @@ app.get("/xml") { req -> Response in """.data(using: .utf8)! return Response( status: .accepted, - headers: ["Some-Header": "value"], - body: HTTPBody(data: xmlData, mimeType: .xml) + headers: ["Content-Type": "application/xml"], + body: .data(xmlData) ) } ``` @@ -147,9 +144,9 @@ struct TodoController: Controller { .patch("/todo/:id", updateTodo) } - func getAllTodos(req: Request) -> [Todo] { ... } - func createTodo(req: Request) -> Todo { ... } - func updateTodo(req: Request) -> Todo { ... } + func getAllTodos(req: Request) async throws -> [Todo] { ... } + func createTodo(req: Request) async throws -> Todo { ... } + func updateTodo(req: Request) async throws -> Todo { ... } } // Register the controller @@ -183,87 +180,35 @@ let dbUsername: String = Env.DB_USER let dbPass: String = Env.DB_PASS ``` -Choose what env file your app uses by setting APP_ENV, your program will load it's environment from the file at `.{APP_ENV} `. - -## Services & DI - -Alchemy makes DI a breeze to keep your services pluggable and swappable in tests. Most services in Alchemy conform to `Service`, a protocol built on top of [Fusion](https://github.com/alchemy-swift/fusion), which you can use to set sensible default configurations for your services. - -You can use `Service.config(default: ...)` to configure the default instance of a service for the app. `Service.configure("key", ...)` lets you configure another, named instance. To keep you writing less code, most functions that interact with a `Service` will default to running on your `Service`'s default configuration. - -```swift -// Set the default database for the app. -Database.config( - default: .postgres( - host: "localhost", - database: "alchemy", - username: "user", - password: "password" - ) -) - -// Set the database identified by the "mysql" key. -Database.config("mysql", .mysql(host: "localhost", database: "alchemy")) - -// Get's all `User`s from the default Database (postgres). -Todo.all() - -// Get's all `User`s from the "mysql" database. -Todo.all(db: .named("mysql")) -``` - -In this way, you can easily configure as many `Database`s as you need while having Alchemy use the Postgres one by default. When it comes time for testing, injecting a mock service is easy. - -```swift -final class MyTests: XCTestCase { - func setup() { - Queue.configure(default: .mock()) - } -} -``` - -Since Service wraps [Fusion](https://github.com/alchemy-swift/fusion), you can also access default and named configurations via the @Inject property wrapper. A variety of services can be set up and accessed this way including `Database`, `Redis`, `Router`, `Queue`, `Cache`, `HTTPClient`, `Scheduler`, `NIOThreadPool`, and `ServiceLifecycle`. - -```swift -@Inject var postgres: Database -@Inject("mysql") var mysql: Database -@Inject var redis: Redis - -postgres.rawQuery("select * from users") -mysql.rawQuery("select * from some_table") -redis.get("cached_data_key") -``` +You can choose a custom env file by passing -e {env} or setting APP_ENV when running your program. The app will load it's environment from the file at `.env.{env}`. ## SQL queries -Alchemy comes with a powerful query builder that makes it easy to interact with SQL databases. In addition, you can always run raw SQL strings on a `Database` instance. +Alchemy comes with a powerful query builder that makes it easy to interact with SQL databases. You can always run raw queries as well. `DB` is a shortcut to injecting the default database. ```swift -// Runs on Database.default -Query.from("users").select("id").where("age" > 30) +try await DB.from("users").select("id").where("age" > 30) -database.rawQuery("SELECT * FROM users WHERE id = 1") +try await DB.raw("SELECT * FROM users WHERE id = 1") ``` Most SQL operations are supported, including nested `WHERE`s and atomic transactions. ```swift // The first user named Josh with age NULL or less than 28 -Query.from("users") +try await DB.from("users") .where("name" == "Josh") .where { $0.whereNull("age").orWhere("age" < 28) } .first() -// Wraps all inner queries in an atomic transaction. -database.transaction { conn in - conn.query() - .where("account" == 1) +// Wraps all inner queries in an atomic transaction, will rollback if an error is thrown. +try await DB.transaction { conn in + try await conn.from("accounts") + .where("id" == 1) .update(values: ["amount": 100]) - .flatMap { _ in - conn.query() - .where("account" == 2) - .update(values: ["amount": 200]) - } + try await conn.from("accounts") + .where("id" == 2) + .update(values: ["amount": 200]) } ``` @@ -280,17 +225,20 @@ struct User: Model { let age: Int } -let newUser = User(firstName: "Josh", lastName: "Wright", age: 28) -newUser.insert() +try await User(firstName: "Josh", lastName: "Wright", age: 28).insert() ``` You can easily query directly on your type using the same query builder syntax. Your model type is automatically decoded from the result of the query for you. ```swift -User.where("id" == 1).firstModel() +try await User.find(1) + +// equivalent to + +try await User.where("id" == 1).first() ``` -If your database naming convention is different than Swift's, for example `snake_case`, you can set the static `keyMapping` property on your Model to automatially convert from Swift `camelCase`. +If your database naming convention is different than Swift's, for example `snake_case` instead of `camelCase`, you can set the static `keyMapping` property on your Model to automatially convert to the proper case. ```swift struct User: Model { @@ -308,10 +256,13 @@ struct Todo: Model { } // Queries all `Todo`s with their related `User`s also loaded. -Todo.all().with(\.$user) +let todos = try await Todo.all().with(\.$user) +for todo in todos { + print("\(todo.title) is owned by \(user.name)") +} ``` -You can customize advanced relationship loading behavior, such as "has many through" by overriding `mapRelations()`. +You can customize advanced relationship loading behavior, such as "has many through" by overriding the static `mapRelations()` function. ```swift struct User: Model { @@ -329,16 +280,14 @@ Middleware lets you intercept requests coming in and responses coming out of you ```swift struct LoggingMiddleware: Middleware { - func intercept(_ request: Request, next: @escaping Next) throws -> EventLoopFuture { + func intercept(_ request: Request, next: @escaping Next) async throws -> Response { let start = Date() - let requestInfo = "\(request.head.method.rawValue) \(request.path)" - Log.info("Incoming Request: \(requestInfo)") - return next(request) - .map { response in - let elapsedTime = String(format: "%.2fs", Date().timeIntervalSince(start)) - Log.info("Outgoing Response: \(response.status.code) \(requestInfo) after \(elapsedTime)") - return response - } + let requestInfo = "\(request.head.method) \(request.path)" + Log.info("Received request: \(requestInfo)") + let response = try await next(request) + let elapsedTime = String(format: "%.2fs", Date().timeIntervalSince(start)) + Log.info("Sending response: \(response.status.code) \(requestInfo) after \(elapsedTime)") + return response } } @@ -349,6 +298,15 @@ app.use(LoggingMiddleware()) app.useAll(OtherMiddleware()) ``` +You may also add anonymous middlewares with a closure. + +```swift +app.use { req, next -> Response in + Log.info("\(req.method) \(req.path)") + return next(req) +} +``` + ## Authentication You'll often want to authenticate incoming requests using your database models. Alchemy provides out of the box middlewares for authorizing requests against your ORM models using Basic & Token based auth. @@ -364,9 +322,7 @@ struct UserToken: Model, TokenAuthable { app.use(UserToken.tokenAuthMiddleware()) app.get("/user") { req -> User in - let user = req.get(User.self) - // Do something with the authorized user... - return user + req.get(User.self) // The User is now accessible on the request } ``` @@ -379,7 +335,7 @@ Also note that, in this case, because `Model` descends from `Codable` you can re Working with Redis is powered by the excellent [RedisStack](https://github.com/Mordil/RediStack) package. Once you register a configuration, the `Redis` type has most Redis commands, including pub/sub, as functions you can access. ```swift -Redis.config(default: .connection("localhost")) +Redis.bind(.connection("localhost")) // Elsewhere @Inject var redis: Redis @@ -394,17 +350,17 @@ redis.subscribe(to: "my_channel") { val in If the function you want isn't available, you can always send a raw command. Atomic `MULTI`/`EXEC` transactions are supported with `.transaction()`. ```swift -redis.send(command: "GET my_key") +try await redis.send(command: "GET my_key") -redis.transaction { redisConn in - redisConn.increment("foo") - .flatMap { _ in redisConn.increment("bar") } +try await redis.transaction { redisConn in + try await redisConn.increment("foo").get() + try await redisConn.increment("bar").get() } ``` ## Queues -Alchemy offers `Queue` as a unified API around various queue backends. Queues allow your application to dispatch or schedule lightweight background tasks called `Job`s to be executed by a separate worker. Out of the box, `Redis` and relational databases are supported, but you can easily write your own driver by conforming to the `QueueDriver` protocol. +Alchemy offers `Queue` as a unified API around various queue backends. Queues allow your application to dispatch or schedule lightweight background tasks called `Job`s to be executed by a separate worker. Out of the box, `Redis`, relational databases, and memory backed queues are supported, but you can easily write your own provider by conforming to the `QueueProvider` protocol. To get started, configure the default `Queue` and `dispatch()` a `Job`. You can add any `Codable` fields to `Job`, such as a database `Model`, and they will be stored and decoded when it's time to run the job. @@ -415,18 +371,18 @@ Queue.config(default: .redis()) struct ProcessNewUser: Job { let user: User - func run() -> EventLoopFuture { + func run() async throws { // do something with the new user } } -ProcessNewUser(user: someUser).dispatch() +try await ProcessNewUser(user: someUser).dispatch() ``` Note that no jobs will be dequeued and run until you run a worker to do so. You can spin up workers by separately running your app with the `queue` argument. ```shell -swift run MyApp queue +swift run MyApp worker ``` If you'd like, you can run a worker as part of your main server by passing the `--workers` flag. @@ -441,7 +397,7 @@ When a job is successfully run, you can optionally run logic by overriding the ` struct EmailJob: Job { let email: String - func run() -> EventLoopFuture { ... } + func run() async throws { ... } func finished(result: Result) { switch result { @@ -454,45 +410,52 @@ struct EmailJob: Job { } ``` -For advanced queue usage including channels, queue priorities, backoff times, and retry policies, check out the [Queues guide](Docs/8_Queues.md). +For advanced queue usage including channels, queue priorities, backoff times, and retry policies, check out the [Queues guide](https://www.alchemyswift.com/digging-deeper/queues). ## Scheduling tasks -Alchemy contains a built in task scheduler so that you don't need to generate cron entries for repetitive work, and can instead schedule recurring tasks right from your code. You can schedule code or jobs from your `Application` instance. +Alchemy contains a built in task scheduler so that you don't need to generate cron entries for repetitive work, and can instead schedule recurring tasks right from your code. You can schedule code or jobs from the `scheudle()` method of your `Application` instance. ```swift -// Say good morning every day at 9:00 am. -app.schedule { print("Good morning!") } - .daily(hour: 9) +@main +struct MyApp: Application { + ... -// Run `SendInvoices` job on the first of every month at 9:30 am. -app.schedule(job: SendInvoices()) - .monthly(day: 1, hour: 9, min: 30) + func schedule(schedule: Scheduler) { + // Say good morning every day at 9:00 am. + schedule.run { print("Good morning!") } + .daily(hour: 9) + + // Run `SendInvoices` job on the first of every month at 9:30 am. + schedule.job(SendInvoices()) + .monthly(day: 1, hour: 9, min: 30) + } +} ``` A variety of builder functions are offered to customize your schedule frequency. If your desired frequency is complex, you can even schedule a task using a cron expression. ```swift // Every week on tuesday at 8:00 pm -app.schedule { ... } +schedule.run { ... } .weekly(day: .tue, hour: 20) // Every second -app.schedule { ... } +schedule.run { ... } .secondly() // Every minute at 30 seconds -app.schedule { ... } +schedule.run { ... } .minutely(sec: 30) -// At 22:00 on every day-of-week from Monday through Friday.” -app.schedule { ... } +// At 22:00 on every day from Monday through Friday.” +schedule.run { ... } .cron("0 22 * * 1-5") ``` ## ...and more! -Check out [the docs](Docs#docs) for more advanced guides on all of the above as well as [Migrations](Docs/5c_DatabaseMigrations.md), [Caching](Docs/9_Cache.md), [Custom Commands](Docs/13_Commands.md), [Logging](Docs/10_DiggingDeeper.md#logging), [making HTTP Requests](Docs/10_DiggingDeeper.md#making-http-requests), using the [HTML DSL](Docs/10_DiggingDeeper.md#plot--html-dsl), [advanced Request / Response usage](Docs/3a_RoutingBasics.md#responseencodable), [sharing API interfaces](Docs/4_Papyrus.md) between client and server, [deploying your apps to Linux or Docker](Docs/11_Deploying.md), and more. +Check out [the docs](https://www.alchemyswift.com/getting-started/setup) for more advanced guides on all of the above as well as [Migrations](https://www.alchemyswift.com/database/migrations), [Caching](https://www.alchemyswift.com/digging-deeper/cache), [Custom Commands](https://www.alchemyswift.com/digging-deeper/commands), [Logging](https://www.alchemyswift.com/essentials/logging), [making HTTP Requests](https://www.alchemyswift.com/digging-deeper/http-client), using the [HTML DSL](https://www.alchemyswift.com/essentials/views), advanced [Request](https://www.alchemyswift.com/essentials/requests) / [Response](https://www.alchemyswift.com/essentials/responses) usage, [typesafe APIs](https://www.alchemyswift.com/supporting-libraries/papyrus) between client and server, [deploying your apps to Linux or Docker](https://www.alchemyswift.com/getting-started/deploying), and more. # Contributing diff --git a/Sources/Alchemy/Alchemy+Papyrus/Application+Endpoint.swift b/Sources/Alchemy/Alchemy+Papyrus/Application+Endpoint.swift new file mode 100644 index 00000000..8f39eab0 --- /dev/null +++ b/Sources/Alchemy/Alchemy+Papyrus/Application+Endpoint.swift @@ -0,0 +1,105 @@ +import Foundation +import Papyrus +import NIO +import NIOHTTP1 + +extension RawResponse: ResponseConvertible { + public func response() async throws -> Response { + var headers: HTTPHeaders = [:] + headers.add(contentsOf: self.headers.map { $0 }) + return Response(status: .ok, headers: headers, body: body.map { .data($0) }) + } +} + +public extension Application { + /// Registers a `Papyrus.Endpoint`. When an incoming request + /// matches the path of the `Endpoint`, the `Endpoint.Request` + /// will automatically be decoded from the incoming + /// `HTTPRequest` for use in the provided handler. + /// + /// - Parameters: + /// - endpoint: The endpoint to register on this router. + /// - handler: The handler for handling incoming requests that + /// match this endpoint's path. This handler returns an + /// instance of the endpoint's response type. + /// - Returns: `self`, for chaining more requests. + @discardableResult + func on(_ endpoint: Endpoint, options: Router.RouteOptions = [], use handler: @escaping (Request, Req) async throws -> Res) -> Self where Res: Codable { + on(endpoint.nioMethod, at: endpoint.path, options: options) { request -> RawResponse in + let input = try endpoint.decodeRequest(method: request.method.rawValue, path: request.path, headers: request.headerDict, parameters: request.parameterDict, query: request.urlComponents.query ?? "", body: request.body?.data()) + let output = try await handler(request, input) + return try endpoint.rawResponse(with: output) + } + } + + /// Registers a `Papyrus.Endpoint` that has an `Empty` request + /// type. + /// + /// - Parameters: + /// - endpoint: The endpoint to register on this application. + /// - handler: The handler for handling incoming requests that + /// match this endpoint's path. This handler returns an + /// instance of the endpoint's response type. + /// - Returns: `self`, for chaining more requests. + @discardableResult + func on(_ endpoint: Endpoint, options: Router.RouteOptions = [], use handler: @escaping (Request) async throws -> Res) -> Self { + on(endpoint.nioMethod, at: endpoint.path, options: options) { request -> RawResponse in + let output = try await handler(request) + return try endpoint.rawResponse(with: output) + } + } + + /// Registers a `Papyrus.Endpoint` that has an `Empty` response + /// type. + /// + /// - Parameters: + /// - endpoint: The endpoint to register on this application. + /// - handler: The handler for handling incoming requests that + /// match this endpoint's path. This handler returns Void. + /// - Returns: `self`, for chaining more requests. + @discardableResult + func on(_ endpoint: Endpoint, options: Router.RouteOptions = [], use handler: @escaping (Request, Req) async throws -> Void) -> Self { + on(endpoint.nioMethod, at: endpoint.path, options: options) { request -> Response in + let input = try endpoint.decodeRequest(method: request.method.rawValue, path: request.path, headers: request.headerDict, parameters: request.parameterDict, query: request.urlComponents.query ?? "", body: request.body?.data()) + try await handler(request, input) + return Response() + } + } + + /// Registers a `Papyrus.Endpoint` that has an `Empty` request and + /// response type. + /// + /// - Parameters: + /// - endpoint: The endpoint to register on this application. + /// - handler: The handler for handling incoming requests that + /// match this endpoint's path. This handler returns Void. + /// - Returns: `self`, for chaining more requests. + @discardableResult + func on(_ endpoint: Endpoint, options: Router.RouteOptions = [], use handler: @escaping (Request) async throws -> Void) -> Self { + on(endpoint.nioMethod, at: endpoint.path, options: options) { request -> Response in + try await handler(request) + return Response() + } + } +} + +extension Request { + fileprivate var parameterDict: [String: String] { + var dict: [String: String] = [:] + for param in parameters { dict[param.key] = param.value } + return dict + } + + fileprivate var headerDict: [String: String] { + var dict: [String: String] = [:] + for header in headers { dict[header.name] = header.value } + return dict + } +} + +extension Endpoint { + /// Converts the Papyrus HTTP verb type to it's NIO equivalent. + fileprivate var nioMethod: HTTPMethod { + HTTPMethod(rawValue: method) + } +} diff --git a/Sources/Alchemy/Alchemy+Papyrus/Endpoint+Request.swift b/Sources/Alchemy/Alchemy+Papyrus/Endpoint+Request.swift index 72da08f6..f485b0d3 100644 --- a/Sources/Alchemy/Alchemy+Papyrus/Endpoint+Request.swift +++ b/Sources/Alchemy/Alchemy+Papyrus/Endpoint+Request.swift @@ -4,158 +4,60 @@ import NIO import NIOHTTP1 import Papyrus -/// An error that occurred when requesting a `Papyrus.Endpoint`. -public struct PapyrusClientError: Error { - /// What went wrong. - public let message: String - /// The `HTTPClient.Response` of the failed response. - public let response: HTTPClient.Response - /// The response body, converted to a String, if there is one. - public var bodyString: String? { - guard let body = response.body else { - return nil - } - - var copy = body - return copy.readString(length: copy.writerIndex) - } -} - -extension PapyrusClientError: CustomStringConvertible { - public var description: String { - """ - \(message) - Response: \(response.headers) - Status: \(response.status.code) \(response.status.reasonPhrase) - Body: \(bodyString ?? "N/A") - """ - } -} - extension Endpoint { - /// Requests a `Papyrus.Endpoint`, returning a future with the - /// decoded `Endpoint.Response`. + /// Requests a `Papyrus.Endpoint`, returning a decoded `Endpoint.Response`. /// /// - Parameters: /// - dto: An instance of the request DTO; `Endpoint.Request`. - /// - client: The HTTPClient to request this with. Defaults to - /// `Client.default`. - /// - Returns: A future containing the decoded `Endpoint.Response` - /// as well as the raw response of the `HTTPClient`. - public func request( - _ dto: Request, - with client: HTTPClient = .default - ) -> EventLoopFuture<(content: Response, response: HTTPClient.Response)> { - return catchError { - client.performRequest( - baseURL: baseURL, - parameters: try parameters(dto: dto), - encoder: jsonEncoder, - decoder: jsonDecoder - ) - } + /// - client: The client to request with. Defaults to `Client.default`. + /// - Returns: A raw `ClientResponse` and decoded `Response`. + public func request(_ dto: Request, with client: Client = .default) async throws -> (clientResponse: Client.Response, response: Response) { + try await client.request(endpoint: self, request: dto) } } extension Endpoint where Request == Empty { /// Requests a `Papyrus.Endpoint` where the `Request` type is - /// `Empty`, returning a future with the decoded - /// `Endpoint.Response`. + /// `Empty`, returning a decoded `Endpoint.Response`. /// - /// - Parameters: - /// - client: The HTTPClient to request this with. Defaults to - /// `Client.default`. - /// - decoder: The decoder with which to decode response data to - /// `Endpoint.Response`. Defaults to `JSONDecoder()`. - /// - Returns: A future containing the decoded `Endpoint.Response` - /// as well as the raw response of the `HTTPClient`. - public func request( - with client: HTTPClient = .default - ) -> EventLoopFuture<(content: Response, response: HTTPClient.Response)> { - return catchError { - client.performRequest( - baseURL: baseURL, - parameters: try parameters(dto: .value), - encoder: jsonEncoder, - decoder: jsonDecoder - ) - } + /// - Parameter client: The client to request with. Defaults to + /// `Client.default`. + /// - Returns: A raw `ClientResponse` and decoded `Response`. + public func request(with client: Client = .default) async throws -> (clientResponse: Client.Response, response: Response) { + try await client.request(endpoint: self, request: .value) } } -extension HTTPClient { +extension Client { /// Performs a request with the given request information. /// /// - Parameters: - /// - baseURL: The base URL of the endpoint to request. - /// - parameters: Information needed to make a request such as - /// method, body, headers, etc. - /// - encoder: The encoder with which to encode - /// `Endpoint.Request` to request data to Defaults to - /// `JSONEncoder()`. - /// - decoder: A decoder with which to decode the response type, - /// `Response`, from the `HTTPClient.Response`. - /// - Returns: A future containing the decoded response and the - /// raw `HTTPClient.Response`. - fileprivate func performRequest( - baseURL: String, - parameters: HTTPComponents, - encoder: JSONEncoder, - decoder: JSONDecoder - ) -> EventLoopFuture<(content: Response, response: HTTPClient.Response)> { - catchError { - var fullURL = baseURL + parameters.fullPath - var headers = HTTPHeaders(parameters.headers.map { $0 }) - var bodyData: Data? - - if parameters.bodyEncoding == .json { - headers.add(name: "Content-Type", value: "application/json") - bodyData = try parameters.body.map { try encoder.encode($0) } - } else if parameters.bodyEncoding == .urlEncoded, - let urlParams = try parameters.urlParams() { - headers.add(name: "Content-Type", value: "application/x-www-form-urlencoded") - bodyData = urlParams.data(using: .utf8) - fullURL = baseURL + parameters.basePath + parameters.query - } - - let request = try HTTPClient.Request( - url: fullURL, - method: HTTPMethod(rawValue: parameters.method), - headers: headers, - body: bodyData.map { HTTPClient.Body.data($0) } - ) - - return execute(request: request) - .flatMapThrowing { response in - guard (200...299).contains(response.status.code) else { - throw PapyrusClientError( - message: "The response code was not successful", - response: response - ) - } - - if Response.self == Empty.self { - return (Empty.value as! Response, response) - } - - guard let bodyBuffer = response.body else { - throw PapyrusClientError( - message: "Unable to decode response type `\(Response.self)`; the body of the response was empty!", - response: response - ) - } - - // Decode - do { - let responseJSON = try HTTPBody(buffer: bodyBuffer).decodeJSON(as: Response.self, with: decoder) - return (responseJSON, response) - } catch { - throw PapyrusClientError( - message: "Error decoding `\(Response.self)` from the response. \(error)", - response: response - ) - } - } + /// - endpoint: The Endpoint to request. + /// - request: An instance of the Endpoint's Request. + /// - Returns: A raw `ClientResponse` and decoded `Response`. + fileprivate func request( + endpoint: Endpoint, + request: Request + ) async throws -> (clientResponse: Client.Response, response: Response) { + let rawRequest = try endpoint.rawRequest(with: request) + var builder = builder() + if let body = rawRequest.body { + builder = builder.withBody(data: body) } + + builder = builder.withHeaders(rawRequest.headers) + + let method = HTTPMethod(rawValue: rawRequest.method) + let fullUrl = try rawRequest.fullURL() + let clientResponse = try await builder.request(method, uri: fullUrl).validateSuccessful() + + guard Response.self != Empty.self else { + return (clientResponse, Empty.value as! Response) + } + + var dict: [String: String] = [:] + clientResponse.headers.forEach { dict[$0] = $1 } + let response = try endpoint.decodeResponse(headers: dict, body: clientResponse.data) + return (clientResponse, response) } } diff --git a/Sources/Alchemy/Alchemy+Papyrus/Router+Endpoint.swift b/Sources/Alchemy/Alchemy+Papyrus/Router+Endpoint.swift deleted file mode 100644 index e32493a4..00000000 --- a/Sources/Alchemy/Alchemy+Papyrus/Router+Endpoint.swift +++ /dev/null @@ -1,130 +0,0 @@ -import Foundation -import Papyrus -import NIO - -public extension Application { - /// Registers a `Papyrus.Endpoint`. When an incoming request - /// matches the path of the `Endpoint`, the `Endpoint.Request` - /// will automatically be decoded from the incoming - /// `HTTPRequest` for use in the provided handler. - /// - /// - Parameters: - /// - endpoint: The endpoint to register on this router. - /// - handler: The handler for handling incoming requests that - /// match this endpoint's path. This handler expects a - /// future containing an instance of the endpoint's - /// response type. - /// - Returns: `self`, for chaining more requests. - @discardableResult - func on( - _ endpoint: Endpoint, - use handler: @escaping (Request, Req) throws -> EventLoopFuture - ) -> Self where Res: Codable { - self.on(endpoint.nioMethod, at: endpoint.path) { - return try handler($0, try Req(from: $0)) - .flatMapThrowing { Response(status: .ok, body: try HTTPBody(json: $0, encoder: endpoint.jsonEncoder)) } - } - } - - /// Registers a `Papyrus.Endpoint` that has an `Empty` request - /// type. - /// - /// - Parameters: - /// - endpoint: The endpoint to register on this application. - /// - handler: The handler for handling incoming requests that - /// match this endpoint's path. This handler expects a future - /// containing an instance of the endpoint's response type. - /// - Returns: `self`, for chaining more requests. - @discardableResult - func on( - _ endpoint: Endpoint, - use handler: @escaping (Request) throws -> EventLoopFuture - ) -> Self { - self.on(endpoint.nioMethod, at: endpoint.path) { - return try handler($0) - .flatMapThrowing { Response(status: .ok, body: try HTTPBody(json: $0, encoder: endpoint.jsonEncoder)) } - } - } -} - -extension EventLoopFuture { - /// Changes the `Value` of this future to `Empty`. Used for - /// interaction with Papyrus APIs. - /// - /// - Returns: An "empty" `EventLoopFuture`. - public func emptied() -> EventLoopFuture { - self.map { _ in Empty.value } - } -} - -// Provide a custom response for when `PapyrusValidationError`s are -// thrown. -extension PapyrusValidationError: ResponseConvertible { - public func convert() throws -> EventLoopFuture { - let body = try HTTPBody(json: ["validation_error": self.message]) - return .new(Response(status: .badRequest, body: body)) - } -} - -extension Request: DecodableRequest { - public func header(for key: String) -> String? { - self.headers.first(name: key) - } - - public func query(for key: String) -> String? { - self.queryItems - .filter ({ $0.name == key }) - .first? - .value - } - - public func pathComponent(for key: String) -> String? { - self.pathParameters.first(where: { $0.parameter == key })? - .stringValue - } - - /// Returns the first `PathParameter` for the given key, - /// converting the value to the given type. Throws if the value is - /// not there or not convertible to the given type. - /// - /// Use this to fetch any parameters from the path. - /// ```swift - /// app.post("/users/:user_id") { request in - /// let userID: String = try request.pathComponent("user_id") - /// ... - /// } - /// ``` - public func parameter(_ key: String) throws -> T { - guard let stringValue = pathParameters.first(where: { $0.parameter == "key" })?.stringValue else { - throw PapyrusValidationError("Missing parameter `\(key)` from path.") - } - - return try T(stringValue) - .unwrap(or: PapyrusValidationError("Path parameter `\(key)` was not convertible to a `\(name(of: T.self))`")) - } - - public func decodeBody(as: T.Type = T.self, with decoder: JSONDecoder = JSONDecoder()) throws -> T { - let body = try body.unwrap(or: PapyrusValidationError("Expecting a request body.")) - do { - return try body.decodeJSON(as: T.self, with: decoder) - } catch let DecodingError.keyNotFound(key, _) { - throw PapyrusValidationError("Missing field `\(key.stringValue)` from request body.") - } catch let DecodingError.typeMismatch(type, context) { - let key = context.codingPath.last?.stringValue ?? "unknown" - throw PapyrusValidationError("Request body field `\(key)` should be a `\(type)`.") - } catch { - throw PapyrusValidationError("Invalid request body.") - } - } - - public func decodeBody(encoding: BodyEncoding = .json) throws -> T where T: Decodable { - return try decodeBody(as: T.self) - } -} - -extension Endpoint { - /// Converts the Papyrus HTTP verb type to it's NIO equivalent. - fileprivate var nioMethod: HTTPMethod { - HTTPMethod(rawValue: method) - } -} diff --git a/Sources/Alchemy/Alchemy+Plot/HTMLView.swift b/Sources/Alchemy/Alchemy+Plot/HTMLView.swift index 6da96707..0c904d0a 100644 --- a/Sources/Alchemy/Alchemy+Plot/HTMLView.swift +++ b/Sources/Alchemy/Alchemy+Plot/HTMLView.swift @@ -1,4 +1,4 @@ -import Foundation +import Plot /// A protocol for defining HTML views to return to a client. /// @@ -41,8 +41,8 @@ public protocol HTMLView: ResponseConvertible { extension HTMLView { // MARK: ResponseConvertible - public func convert() throws -> EventLoopFuture { - let body = HTTPBody(text: self.content.render(), mimeType: .html) - return .new(Response(status: .ok, body: body)) + public func response() -> Response { + Response(status: .ok) + .withString(content.render(), type: .html) } } diff --git a/Sources/Alchemy/Alchemy+Plot/Plot+ResponseConvertible.swift b/Sources/Alchemy/Alchemy+Plot/Plot+ResponseConvertible.swift index 8de5867c..23e07999 100644 --- a/Sources/Alchemy/Alchemy+Plot/Plot+ResponseConvertible.swift +++ b/Sources/Alchemy/Alchemy+Plot/Plot+ResponseConvertible.swift @@ -1,15 +1,15 @@ import Plot extension HTML: ResponseConvertible { - public func convert() throws -> EventLoopFuture { - let body = HTTPBody(text: self.render(), mimeType: .html) - return .new(Response(status: .ok, body: body)) + public func response() -> Response { + Response(status: .ok) + .withString(render(), type: .html) } } extension XML: ResponseConvertible { - public func convert() throws -> EventLoopFuture { - let body = HTTPBody(text: self.render(), mimeType: .xml) - return .new(Response(status: .ok, body: body)) + public func response() -> Response { + Response(status: .ok) + .withString(render(), type: .xml) } } diff --git a/Sources/Alchemy/Application/Application+Commands.swift b/Sources/Alchemy/Application/Application+Commands.swift deleted file mode 100644 index 18a764a6..00000000 --- a/Sources/Alchemy/Application/Application+Commands.swift +++ /dev/null @@ -1,9 +0,0 @@ -extension Application { - /// Registers a command to your application. You can run a command - /// by passing it's argument when you launch your app. - /// - /// - Parameter commandType: The type of the command to register. - public func registerCommand(_ commandType: C.Type) { - Launch.userCommands.append(commandType) - } -} diff --git a/Sources/Alchemy/Application/Application+Configuration.swift b/Sources/Alchemy/Application/Application+Configuration.swift deleted file mode 100644 index a57cc4ca..00000000 --- a/Sources/Alchemy/Application/Application+Configuration.swift +++ /dev/null @@ -1,56 +0,0 @@ -import NIOSSL - -/// Settings for how this server should talk to clients. -public final class ApplicationConfiguration: Service { - /// Any TLS configuration for serving over HTTPS. - public var tlsConfig: TLSConfiguration? - /// The HTTP protocol versions supported. Defaults to `HTTP/1.1`. - public var httpVersions: [HTTPVersion] = [.http1_1] -} - -extension Application { - /// Use HTTPS when serving. - /// - /// - Parameters: - /// - key: The path to the private key. - /// - cert: The path of the cert. - /// - Throws: Any errors encountered when accessing the certs. - public func useHTTPS(key: String, cert: String) throws { - let config = Container.resolve(ApplicationConfiguration.self) - config.tlsConfig = TLSConfiguration - .makeServerConfiguration( - certificateChain: try NIOSSLCertificate - .fromPEMFile(cert) - .map { NIOSSLCertificateSource.certificate($0) }, - privateKey: .file(key)) - } - - /// Use HTTPS when serving. - /// - /// - Parameter tlsConfig: A raw NIO `TLSConfiguration` to use. - public func useHTTPS(tlsConfig: TLSConfiguration) { - let config = Container.resolve(ApplicationConfiguration.self) - config.tlsConfig = tlsConfig - } - - /// Use HTTP/2 when serving, over TLS with the given key and cert. - /// - /// - Parameters: - /// - key: The path to the private key. - /// - cert: The path of the cert. - /// - Throws: Any errors encountered when accessing the certs. - public func useHTTP2(key: String, cert: String) throws { - let config = Container.resolve(ApplicationConfiguration.self) - config.httpVersions = [.http2, .http1_1] - try useHTTPS(key: key, cert: cert) - } - - /// Use HTTP/2 when serving, over TLS with the given tls config. - /// - /// - Parameter tlsConfig: A raw NIO `TLSConfiguration` to use. - public func useHTTP2(tlsConfig: TLSConfiguration) { - let config = Container.resolve(ApplicationConfiguration.self) - config.httpVersions = [.http2, .http1_1] - useHTTPS(tlsConfig: tlsConfig) - } -} diff --git a/Sources/Alchemy/Application/Application+Controller.swift b/Sources/Alchemy/Application/Application+Controller.swift index 77166015..cc801ac2 100644 --- a/Sources/Alchemy/Application/Application+Controller.swift +++ b/Sources/Alchemy/Application/Application+Controller.swift @@ -15,8 +15,11 @@ extension Application { /// this router. /// - Returns: This router for chaining. @discardableResult - public func controller(_ controller: Controller) -> Self { - controller.route(self) + public func controller(_ controllers: Controller...) -> Self { + controllers.forEach { c in + _ = snapshotMiddleware { c.route($0) } + } + return self } } diff --git a/Sources/Alchemy/Application/Application+ErrorRoutes.swift b/Sources/Alchemy/Application/Application+ErrorRoutes.swift new file mode 100644 index 00000000..16b00612 --- /dev/null +++ b/Sources/Alchemy/Application/Application+ErrorRoutes.swift @@ -0,0 +1,25 @@ +extension Application { + /// Set a custom handler for when a handler isn't found for a + /// request. + /// + /// - Parameter handler: The handler that returns a custom not + /// found response. + /// - Returns: This application for chaining handlers. + @discardableResult + public func notFound(use handler: @escaping Handler) -> Self { + router.notFoundHandler = handler + return self + } + + /// Set a custom handler for when an internal error happens while + /// handling a request. + /// + /// - Parameter handler: The handler that returns a custom + /// internal error response. + /// - Returns: This application for chaining handlers. + @discardableResult + public func internalError(use handler: @escaping Router.ErrorHandler) -> Self { + router.internalErrorHandler = handler + return self + } +} diff --git a/Sources/Alchemy/Application/Application+HTTP2.swift b/Sources/Alchemy/Application/Application+HTTP2.swift new file mode 100644 index 00000000..8bdbeec4 --- /dev/null +++ b/Sources/Alchemy/Application/Application+HTTP2.swift @@ -0,0 +1,24 @@ +import NIOSSL +import NIOHTTP1 +import Hummingbird +import HummingbirdHTTP2 + +extension Application { + /// Use HTTP/2 when serving, over TLS with the given key and cert. + /// + /// - Parameters: + /// - key: The path to the private key. + /// - cert: The path of the cert. + /// - Throws: Any errors encountered when accessing the certs. + public func useHTTP2(key: String, cert: String) throws { + try useHTTP2(tlsConfig: .makeServerConfiguration(key: key, cert: cert)) + } + + /// Use HTTP/2 when serving, over TLS with the given tls config. + /// + /// - Parameter tlsConfig: A raw NIO `TLSConfiguration` to use. + public func useHTTP2(tlsConfig: TLSConfiguration) throws { + @Inject var app: HBApplication + try app.server.addHTTP2Upgrade(tlsConfiguration: tlsConfig) + } +} diff --git a/Sources/Alchemy/Application/Application+Jobs.swift b/Sources/Alchemy/Application/Application+Jobs.swift index cc89fd32..462b2f74 100644 --- a/Sources/Alchemy/Application/Application+Jobs.swift +++ b/Sources/Alchemy/Application/Application+Jobs.swift @@ -1,10 +1,15 @@ extension Application { /// Registers a job to be handled by your application. If you - /// don't register a job type, `QueueWorker`s won't be able to - /// handle jobs of that type. + /// don't register a job type, `QueueWorker`s won't be able + /// to handle jobs of that type. /// /// - Parameter jobType: The type of Job to register. public func registerJob(_ jobType: J.Type) { JobDecoding.register(jobType) } + + /// All custom Job types registered to this application. + public var registeredJobs: [Job.Type] { + JobDecoding.registeredJobs + } } diff --git a/Sources/Alchemy/Application/Application+Launch.swift b/Sources/Alchemy/Application/Application+Launch.swift deleted file mode 100644 index f89c91e2..00000000 --- a/Sources/Alchemy/Application/Application+Launch.swift +++ /dev/null @@ -1,35 +0,0 @@ -import Lifecycle -import LifecycleNIOCompat - -extension Application { - /// Lifecycle logs quite a bit by default, this quiets it's `info` - /// level logs by default. To output messages lower than `notice`, - /// you can override this property to `.info` or lower. - public var lifecycleLogLevel: Logger.Level { .notice } - - /// Launch this application. By default it serves, see `Launch` - /// for subcommands and options. Call this in the `main.swift` - /// of your project. - public static func main() { - loadEnv() - - do { - let app = Self() - app.bootServices() - try app.boot() - Launch.main() - try ServiceLifecycle.default.startAndWait() - } catch { - Launch.exit(withError: error) - } - } - - private static func loadEnv() { - let args = CommandLine.arguments - if let index = args.firstIndex(of: "--env"), let value = args[safe: index + 1] { - Env.defaultLocation = value - } else if let index = args.firstIndex(of: "-e"), let value = args[safe: index + 1] { - Env.defaultLocation = value - } - } -} diff --git a/Sources/Alchemy/Application/Application+Main.swift b/Sources/Alchemy/Application/Application+Main.swift new file mode 100644 index 00000000..18049596 --- /dev/null +++ b/Sources/Alchemy/Application/Application+Main.swift @@ -0,0 +1,80 @@ +import Hummingbird +import Lifecycle +import LifecycleNIOCompat + +extension Application { + /// The current application for easy access. + public static var current: Self { Container.resolveAssert() } + /// The application's lifecycle. + public var lifecycle: ServiceLifecycle { Container.resolveAssert() } + /// The underlying hummingbird application. + public var _application: HBApplication { Container.resolveAssert() } + /// The underlying router. + var router: Router { Container.resolveAssert() } + /// The underlying scheduler. + var scheduler: Scheduler { Container.resolveAssert() } + + /// Setup and launch this application. By default it serves, see `Launch` + /// for subcommands and options. Call this in the `main.swift` + /// of your project. + public static func main() throws { + let app = Self() + try app.setup() + try app.start() + app.wait() + } + + /// Sets up this application for running. + public func setup(testing: Bool = Env.isRunningTests) throws { + bootServices(testing: testing) + try boot() + services(container: .main) + schedule(schedule: Container.resolveAssert()) + } + + /// Starts the application with the given arguments. + public func start(_ args: String...) throws { + try start(args: args) + } + + /// Blocks until the application receives a shutdown signal. + public func wait() { + lifecycle.wait() + } + + /// Stops your application from running. + public func stop() throws { + var shutdownError: Error? = nil + let semaphore = DispatchSemaphore(value: 0) + lifecycle.shutdown { + shutdownError = $0 + semaphore.signal() + } + + semaphore.wait() + if let shutdownError = shutdownError { + throw shutdownError + } + } + + public func start(args: [String]) throws { + // When running tests, don't use the command line args as the default; + // they are irrelevant to running the app and may contain a bunch of + // options that will cause `ParsableCommand` parsing to fail. + let fallbackArgs = Env.isRunningTests ? [] : Array(CommandLine.arguments.dropFirst()) + Launch.customCommands.append(contentsOf: commands) + Launch.main(args.isEmpty ? fallbackArgs : args) + + var startupError: Error? = nil + let semaphore = DispatchSemaphore(value: 0) + lifecycle.start { + startupError = $0 + semaphore.signal() + } + + semaphore.wait() + if let startupError = startupError { + throw startupError + } + } +} diff --git a/Sources/Alchemy/Application/Application+Middleware.swift b/Sources/Alchemy/Application/Application+Middleware.swift index d240cde4..1e3753a8 100644 --- a/Sources/Alchemy/Application/Application+Middleware.swift +++ b/Sources/Alchemy/Application/Application+Middleware.swift @@ -1,25 +1,61 @@ -// Passthroughs on application to `Services.router`. extension Application { + /// A closure that represents an anonymous middleware. + public typealias MiddlewareClosure = (Request, (Request) async throws -> Response) async throws -> Response + /// Applies a middleware to all requests that come through the /// application, whether they are handled or not. /// - /// - Parameter middleware: The middleware which will intercept + /// - Parameter middlewares: The middlewares which will intercept /// all requests to this application. /// - Returns: This Application for chaining. @discardableResult - public func useAll(_ middleware: M) -> Self { - Router.default.globalMiddlewares.append(middleware) + public func useAll(_ middlewares: Middleware...) -> Self { + router.globalMiddlewares.append(contentsOf: middlewares) + return self + } + + /// Applies an middleware to all requests that come through the + /// application, whether they are handled or not. + /// + /// - Parameter middleware: The middleware closure which will intercept + /// all requests to this application. + /// - Returns: This Application for chaining. + @discardableResult + public func useAll(_ middleware: @escaping MiddlewareClosure) -> Self { + router.globalMiddlewares.append(AnonymousMiddleware(action: middleware)) + return self + } + + /// Adds middleware that will intercept before all subsequent + /// handlers. + /// + /// - Parameter middlewares: The middlewares. + /// - Returns: This application for chaining. + @discardableResult + public func use(_ middlewares: Middleware...) -> Self { + router.middlewares.append(contentsOf: middlewares) return self } - /// Adds a middleware that will intercept before all subsequent + /// Adds middleware that will intercept before all subsequent /// handlers. /// - /// - Parameter middleware: The middleware. + /// - Parameter middlewares: The middlewares. /// - Returns: This application for chaining. @discardableResult - public func use(_ middleware: M) -> Self { - Router.default.middlewares.append(middleware) + public func use(_ middlewares: [Middleware]) -> Self { + router.middlewares.append(contentsOf: middlewares) + return self + } + + /// Adds a middleware that will intercept before all subsequent handlers. + /// + /// - Parameter middlewares: The middleware closure which will intercept + /// all requests to this application. + /// - Returns: This application for chaining. + @discardableResult + public func use(_ middleware: @escaping MiddlewareClosure) -> Self { + router.middlewares.append(AnonymousMiddleware(action: middleware)) return self } @@ -35,10 +71,50 @@ extension Application { /// intercepted by the given `Middleware`. /// - Returns: This application for chaining handlers. @discardableResult - public func group(middleware: M, configure: (Application) -> Void) -> Self { - Router.default.middlewares.append(middleware) - configure(self) - _ = Router.default.middlewares.popLast() + public func group(_ middlewares: Middleware..., configure: (Application) -> Void) -> Self { + snapshotMiddleware { + $0.use(middlewares) + configure(self) + } + } + + /// Groups a set of endpoints by a middleware. This middleware + /// will intercept all endpoints added in the `configure` + /// closure, but none in the handler chain that + /// continues after the `.group`. + /// + /// - Parameters: + /// - middleware: The middleware closure which will intercept + /// all requests to this application. + /// - configure: A closure for adding endpoints that will be + /// intercepted by the given `Middleware`. + /// - Returns: This application for chaining handlers. + @discardableResult + public func group(middleware: @escaping MiddlewareClosure, configure: (Application) -> Void) -> Self { + snapshotMiddleware { + $0.use(AnonymousMiddleware(action: middleware)) + configure($0) + } + } +} + +extension Application { + /// Runs the action on this application. When the closure is finished, this + /// reverts the router middleware stack back to what it was before running + /// the action. + @discardableResult + func snapshotMiddleware(_ action: (Application) -> Void) -> Self { + let middlewaresBefore = router.middlewares.count + action(self) + router.middlewares = Array(router.middlewares.prefix(middlewaresBefore)) return self } } + +fileprivate struct AnonymousMiddleware: Middleware { + let action: Application.MiddlewareClosure + + func intercept(_ request: Request, next: (Request) async throws -> Response) async throws -> Response { + try await action(request, next) + } +} diff --git a/Sources/Alchemy/Application/Application+Routing.swift b/Sources/Alchemy/Application/Application+Routing.swift index 49538b59..68a79ea5 100644 --- a/Sources/Alchemy/Application/Application+Routing.swift +++ b/Sources/Alchemy/Application/Application+Routing.swift @@ -1,10 +1,10 @@ -import NIO import NIOHTTP1 +import Papyrus extension Application { /// A basic route handler closure. Most types you'll need conform /// to `ResponseConvertible` out of the box. - public typealias Handler = (Request) throws -> ResponseConvertible + public typealias Handler = (Request) async throws -> ResponseConvertible /// Adds a handler at a given method and path. /// @@ -12,59 +12,55 @@ extension Application { /// - method: The method of requests this handler will handle. /// - path: The path this handler expects. Dynamic path /// parameters should be prefaced with a `:` - /// (See `PathParameter`). + /// (See `Parameter`). /// - handler: The handler to respond to the request with. /// - Returns: This application for building a handler chain. @discardableResult - public func on( - _ method: HTTPMethod, - at path: String = "", - handler: @escaping Handler - ) -> Self { - Router.default.add(handler: handler, for: method, path: path) + public func on(_ method: HTTPMethod, at path: String = "", options: Router.RouteOptions = [], use handler: @escaping Handler) -> Self { + router.add(handler: handler, for: method, path: path, options: options) return self } /// `GET` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func get(_ path: String = "", handler: @escaping Handler) -> Self { - self.on(.GET, at: path, handler: handler) + public func get(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping Handler) -> Self { + on(.GET, at: path, options: options, use: handler) } /// `POST` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func post(_ path: String = "", handler: @escaping Handler) -> Self { - self.on(.POST, at: path, handler: handler) + public func post(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping Handler) -> Self { + on(.POST, at: path, options: options, use: handler) } /// `PUT` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func put(_ path: String = "", handler: @escaping Handler) -> Self { - self.on(.PUT, at: path, handler: handler) + public func put(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping Handler) -> Self { + on(.PUT, at: path, options: options, use: handler) } /// `PATCH` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func patch(_ path: String = "", handler: @escaping Handler) -> Self { - self.on(.PATCH, at: path, handler: handler) + public func patch(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping Handler) -> Self { + on(.PATCH, at: path, options: options, use: handler) } /// `DELETE` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func delete(_ path: String = "", handler: @escaping Handler) -> Self { - self.on(.DELETE, at: path, handler: handler) + public func delete(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping Handler) -> Self { + on(.DELETE, at: path, options: options, use: handler) } /// `OPTIONS` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func options(_ path: String = "", handler: @escaping Handler) -> Self { - self.on(.OPTIONS, at: path, handler: handler) + public func options(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping Handler) -> Self { + on(.OPTIONS, at: path, options: options, use: handler) } /// `HEAD` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func head(_ path: String = "", handler: @escaping Handler) -> Self { - self.on(.HEAD, at: path, handler: handler) + public func head(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping Handler) -> Self { + on(.HEAD, at: path, options: options, use: handler) } } @@ -72,86 +68,14 @@ extension Application { /// not possible to conform all handler return types we wish to /// support to `ResponseConvertible`. /// -/// Specifically, these extensions support having `Void`, -/// `EventLoopFuture`, `E: Encodable`, and -/// `EventLoopFuture` as handler return types. -/// -/// This extension is pretty bulky because we need each of these four -/// for `on` & each method. +/// Specifically, these extensions support having `Void` and +/// `Encodable` as handler return types. extension Application { // MARK: - Void /// A route handler that returns `Void`. - public typealias VoidHandler = (Request) throws -> Void - - /// Adds a handler at a given method and path. - /// - /// - Parameters: - /// - method: The method of requests this handler will handle. - /// - path: The path this handler expects. Dynamic path - /// parameters should be prefaced with a `:` - /// (See `PathParameter`). - /// - handler: The handler to respond to the request with. - /// - Returns: This application for building a handler chain. - @discardableResult - public func on( - _ method: HTTPMethod, - at path: String = "", - handler: @escaping VoidHandler - ) -> Self { - self.on(method, at: path, handler: { out -> VoidResponse in - try handler(out) - return VoidResponse() - }) - } - - /// `GET` wrapper of `Application.on(method:path:handler:)`. - @discardableResult - public func get(_ path: String = "", handler: @escaping VoidHandler) -> Self { - self.on(.GET, at: path, handler: handler) - } - - /// `POST` wrapper of `Application.on(method:path:handler:)`. - @discardableResult - public func post(_ path: String = "", handler: @escaping VoidHandler) -> Self { - self.on(.POST, at: path, handler: handler) - } - - /// `PUT` wrapper of `Application.on(method:path:handler:)`. - @discardableResult - public func put(_ path: String = "", handler: @escaping VoidHandler) -> Self { - self.on(.PUT, at: path, handler: handler) - } - - /// `PATCH` wrapper of `Application.on(method:path:handler:)`. - @discardableResult - public func patch(_ path: String = "", handler: @escaping VoidHandler) -> Self { - self.on(.PATCH, at: path, handler: handler) - } - - /// `DELETE` wrapper of `Application.on(method:path:handler:)`. - @discardableResult - public func delete(_ path: String = "", handler: @escaping VoidHandler) -> Self { - self.on(.DELETE, at: path, handler: handler) - } - - /// `OPTIONS` wrapper of `Application.on(method:path:handler:)`. - @discardableResult - public func options(_ path: String = "", handler: @escaping VoidHandler) -> Self { - self.on(.OPTIONS, at: path, handler: handler) - } - - /// `HEAD` wrapper of `Application.on(method:path:handler:)`. - @discardableResult - public func head(_ path: String = "", handler: @escaping VoidHandler) -> Self { - self.on(.HEAD, at: path, handler: handler) - } - - // MARK: - EventLoopFuture - - /// A route handler that returns an `EventLoopFuture`. - public typealias VoidFutureHandler = (Request) throws -> EventLoopFuture + public typealias VoidHandler = (Request) async throws -> Void /// Adds a handler at a given method and path. /// @@ -159,64 +83,63 @@ extension Application { /// - method: The method of requests this handler will handle. /// - path: The path this handler expects. Dynamic path /// parameters should be prefaced with a `:` - /// (See `PathParameter`). + /// (See `Parameter`). /// - handler: The handler to respond to the request with. /// - Returns: This application for building a handler chain. @discardableResult - public func on( - _ method: HTTPMethod, - at path: String = "", - handler: @escaping VoidFutureHandler - ) -> Self { - self.on(method, at: path, handler: { try handler($0).map { VoidResponse() } }) + public func on(_ method: HTTPMethod, at path: String = "", options: Router.RouteOptions = [], use handler: @escaping VoidHandler) -> Self { + on(method, at: path, options: options) { request -> Response in + try await handler(request) + return Response(status: .ok, body: nil) + } } /// `GET` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func get(_ path: String = "", handler: @escaping VoidFutureHandler) -> Self { - self.on(.GET, at: path, handler: handler) + public func get(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping VoidHandler) -> Self { + on(.GET, at: path, options: options, use: handler) } /// `POST` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func post(_ path: String = "", handler: @escaping VoidFutureHandler) -> Self { - self.on(.POST, at: path, handler: handler) + public func post(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping VoidHandler) -> Self { + on(.POST, at: path, options: options, use: handler) } /// `PUT` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func put(_ path: String = "", handler: @escaping VoidFutureHandler) -> Self { - self.on(.PUT, at: path, handler: handler) + public func put(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping VoidHandler) -> Self { + on(.PUT, at: path, options: options, use: handler) } /// `PATCH` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func patch(_ path: String = "", handler: @escaping VoidFutureHandler) -> Self { - self.on(.PATCH, at: path, handler: handler) + public func patch(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping VoidHandler) -> Self { + on(.PATCH, at: path, options: options, use: handler) } /// `DELETE` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func delete(_ path: String = "", handler: @escaping VoidFutureHandler) -> Self { - self.on(.DELETE, at: path, handler: handler) + public func delete(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping VoidHandler) -> Self { + on(.DELETE, at: path, options: options, use: handler) } /// `OPTIONS` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func options(_ path: String = "", handler: @escaping VoidFutureHandler) -> Self { - self.on(.OPTIONS, at: path, handler: handler) + public func options(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping VoidHandler) -> Self { + on(.OPTIONS, at: path, options: options, use: handler) } /// `HEAD` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func head(_ path: String = "", handler: @escaping VoidFutureHandler) -> Self { - self.on(.HEAD, at: path, handler: handler) + public func head(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping VoidHandler) -> Self { + on(.HEAD, at: path, options: options, use: handler) } - + // MARK: - E: Encodable /// A route handler that returns some `Encodable`. - public typealias EncodableHandler = (Request) throws -> E + public typealias EncodableHandler = (Request) async throws -> E /// Adds a handler at a given method and path. /// @@ -224,130 +147,61 @@ extension Application { /// - method: The method of requests this handler will handle. /// - path: The path this handler expects. Dynamic path /// parameters should be prefaced with a `:` - /// (See `PathParameter`). + /// (See `Parameter`). /// - handler: The handler to respond to the request with. /// - Returns: This application for building a handler chain. @discardableResult - public func on( - _ method: HTTPMethod, at path: String = "", handler: @escaping EncodableHandler - ) -> Self { - self.on(method, at: path, handler: { try handler($0).encode() }) - } - - /// `GET` wrapper of `Application.on(method:path:handler:)`. - @discardableResult - public func get(_ path: String = "", handler: @escaping EncodableHandler) -> Self { - self.on(.GET, at: path, handler: handler) - } - - /// `POST` wrapper of `Application.on(method:path:handler:)`. - @discardableResult - public func post(_ path: String = "", handler: @escaping EncodableHandler) -> Self { - self.on(.POST, at: path, handler: handler) - } - - /// `PUT` wrapper of `Application.on(method:path:handler:)`. - @discardableResult - public func put(_ path: String = "", handler: @escaping EncodableHandler) -> Self { - self.on(.PUT, at: path, handler: handler) - } - - /// `PATCH` wrapper of `Application.on(method:path:handler:)`. - @discardableResult - public func patch(_ path: String = "", handler: @escaping EncodableHandler) -> Self { - self.on(.PATCH, at: path, handler: handler) - } - - /// `DELETE` wrapper of `Application.on(method:path:handler:)`. - @discardableResult - public func delete(_ path: String = "", handler: @escaping EncodableHandler) -> Self { - self.on(.DELETE, at: path, handler: handler) - } - - /// `OPTIONS` wrapper of `Application.on(method:path:handler:)`. - @discardableResult - public func options(_ path: String = "", handler: @escaping EncodableHandler) -> Self { - self.on(.OPTIONS, at: path, handler: handler) - } - - /// `HEAD` wrapper of `Application.on(method:path:handler:)`. - @discardableResult - public func head(_ path: String = "", handler: @escaping EncodableHandler) -> Self { - self.on(.HEAD, at: path, handler: handler) - } - - - // MARK: - EventLoopFuture - - /// A route handler that returns an `EventLoopFuture`. - public typealias EncodableFutureHandler = (Request) throws -> EventLoopFuture - - /// Adds a handler at a given method and path. - /// - /// - Parameters: - /// - method: The method of requests this handler will handle. - /// - path: The path this handler expects. Dynamic path - /// parameters should be prefaced with a `:` - /// (See `PathParameter`). - /// - handler: The handler to respond to the request with. - /// - Returns: This application for building a handler chain. - @discardableResult - public func on( - _ method: HTTPMethod, - at path: String = "", - handler: @escaping (Request) throws -> EventLoopFuture - ) -> Self { - self.on(method, at: path, handler: { try handler($0).flatMapThrowing { try $0.encode() } }) + public func on(_ method: HTTPMethod, at path: String = "", options: Router.RouteOptions = [], use handler: @escaping EncodableHandler) -> Self { + on(method, at: path, options: options, use: { req -> Response in + let value = try await handler(req) + if let convertible = value as? ResponseConvertible { + return try await convertible.response() + } else { + return try value.response() + } + }) } /// `GET` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func get(_ path: String = "", handler: @escaping EncodableFutureHandler) -> Self { - self.on(.GET, at: path, handler: handler) + public func get(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping EncodableHandler) -> Self { + self.on(.GET, at: path, options: options, use: handler) } /// `POST` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func post(_ path: String = "", handler: @escaping EncodableFutureHandler) -> Self { - self.on(.POST, at: path, handler: handler) + public func post(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping EncodableHandler) -> Self { + self.on(.POST, at: path, options: options, use: handler) } /// `PUT` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func put(_ path: String = "", handler: @escaping EncodableFutureHandler) -> Self { - self.on(.PUT, at: path, handler: handler) + public func put(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping EncodableHandler) -> Self { + self.on(.PUT, at: path, options: options, use: handler) } /// `PATCH` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func patch(_ path: String = "", handler: @escaping EncodableFutureHandler) -> Self { - self.on(.PATCH, at: path, handler: handler) + public func patch(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping EncodableHandler) -> Self { + self.on(.PATCH, at: path, options: options, use: handler) } /// `DELETE` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func delete(_ path: String = "", handler: @escaping EncodableFutureHandler) -> Self { - self.on(.DELETE, at: path, handler: handler) + public func delete(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping EncodableHandler) -> Self { + self.on(.DELETE, at: path, options: options, use: handler) } /// `OPTIONS` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func options(_ path: String = "", handler: @escaping EncodableFutureHandler) -> Self { - self.on(.OPTIONS, at: path, handler: handler) + public func options(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping EncodableHandler) -> Self { + self.on(.OPTIONS, at: path, options: options, use: handler) } /// `HEAD` wrapper of `Application.on(method:path:handler:)`. @discardableResult - public func head(_ path: String = "", handler: @escaping EncodableFutureHandler) -> Self { - self.on(.HEAD, at: path, handler: handler) - } -} - -/// Used as the response for a handler returns `Void` or -/// `EventLoopFuture`. -private struct VoidResponse: ResponseConvertible { - func convert() throws -> EventLoopFuture { - .new(Response(status: .ok, body: nil)) + public func head(_ path: String = "", options: Router.RouteOptions = [], use handler: @escaping EncodableHandler) -> Self { + self.on(.HEAD, at: path, options: options, use: handler) } } @@ -366,10 +220,10 @@ extension Application { @discardableResult public func grouped(_ pathPrefix: String, configure: (Application) -> Void) -> Self { let prefixes = pathPrefix.split(separator: "/").map(String.init) - Router.default.pathPrefixes.append(contentsOf: prefixes) + router.pathPrefixes.append(contentsOf: prefixes) configure(self) for _ in prefixes { - _ = Router.default.pathPrefixes.popLast() + _ = router.pathPrefixes.popLast() } return self } diff --git a/Sources/Alchemy/Application/Application+Scheduler.swift b/Sources/Alchemy/Application/Application+Scheduler.swift deleted file mode 100644 index 70682a47..00000000 --- a/Sources/Alchemy/Application/Application+Scheduler.swift +++ /dev/null @@ -1,48 +0,0 @@ -import NIO - -extension Application { - /// Schedule a recurring `Job`. - /// - /// - Parameters: - /// - job: The job to schedule. - /// - queue: The queue to schedule it on. - /// - channel: The queue channel to schedule it on. - /// - Returns: A builder for customizing the scheduling frequency. - public func schedule(job: Job, queue: Queue = .default, channel: String = Queue.defaultChannel) -> ScheduleBuilder { - ScheduleBuilder(.default) { - _ = $0.flatSubmit { () -> EventLoopFuture in - return job.dispatch(on: queue, channel: channel) - .flatMapErrorThrowing { - Log.error("[Scheduler] error scheduling Job: \($0)") - throw $0 - } - } - } - } - - /// Schedule a recurring asynchronous task. - /// - /// - Parameter future: The async task to run. - /// - Returns: A builder for customizing the scheduling frequency. - public func schedule(future: @escaping () -> EventLoopFuture) -> ScheduleBuilder { - ScheduleBuilder(.default) { - _ = $0.flatSubmit(future) - } - } - - /// Schedule a recurring synchronous task. - /// - /// - Parameter future: The async task to run. - /// - Returns: A builder for customizing the scheduling frequency. - public func schedule(task: @escaping () throws -> Void) -> ScheduleBuilder { - ScheduleBuilder(.default) { _ in try task() } - } -} - -private extension ScheduleBuilder { - init(_ scheduler: Scheduler = .default, work: @escaping (EventLoop) throws -> Void) { - self.init { - scheduler.addWork(schedule: $0, work: work) - } - } -} diff --git a/Sources/Alchemy/Application/Application+Services.swift b/Sources/Alchemy/Application/Application+Services.swift index d49dc9cf..7cb9153a 100644 --- a/Sources/Alchemy/Application/Application+Services.swift +++ b/Sources/Alchemy/Application/Application+Services.swift @@ -1,52 +1,65 @@ import Fusion import Lifecycle +import Logging extension Application { /// Register core services to `Container.default`. - func bootServices() { - // Setup app lifecycle - var lifecycleLogger = Log.logger - lifecycleLogger.logLevel = lifecycleLogLevel - ServiceLifecycle.config( - default: ServiceLifecycle( - configuration: ServiceLifecycle.Configuration( - logger: lifecycleLogger, - installBacktrace: true - ))) + /// + /// - Parameter testing: If `true`, default services will be configured in a + /// manner appropriate for tests. + func bootServices(testing: Bool = false) { + if testing { + Container.main = Container() + Log.logger.logLevel = .notice + } + + Env.boot() + Container.bind(value: Env.current) + + // Register as Self & Application + Container.bind(.singleton, to: Application.self, value: self) + Container.bind(.singleton, value: self) - Loop.config() + // Setup app lifecycle + Container.bind(.singleton, value: ServiceLifecycle( + configuration: ServiceLifecycle.Configuration( + logger: Log.logger.withLevel(.notice), + installBacktrace: !testing))) // Register all services - ApplicationConfiguration.config(default: ApplicationConfiguration()) - Router.config(default: Router()) - Scheduler.config(default: Scheduler()) - NIOThreadPool.config(default: NIOThreadPool(numberOfThreads: System.coreCount)) - HTTPClient.config(default: HTTPClient(eventLoopGroupProvider: .shared(Loop.group))) - - // Start threadpool - NIOThreadPool.default.start() - } - - /// Mocks many common services. Can be called in the `setUp()` - /// function of test cases. - public func mockServices() { - Container.default = Container() - ServiceLifecycle.config(default: ServiceLifecycle()) - Router.config(default: Router()) - Loop.mock() - } -} + + if testing { + Loop.mock() + } else { + Loop.config() + } + + Container.bind(.singleton, value: Router()) + Container.bind(.singleton, value: Scheduler()) + Container.bind(.singleton) { container -> NIOThreadPool in + let threadPool = NIOThreadPool(numberOfThreads: System.coreCount) + threadPool.start() + container + .resolve(ServiceLifecycle.self)? + .registerShutdown(label: "\(name(of: NIOThreadPool.self))", .sync(threadPool.syncShutdownGracefully)) + return threadPool + } + + Client.bind(Client()) + + if testing { + FileCreator.mock() + } -extension HTTPClient: Service { - public func shutdown() throws { - try syncShutdown() + // Set up any configurable services. + ConfigurableServices.configureDefaults() } } -extension NIOThreadPool: Service { - public func shutdown() throws { - try syncShutdownGracefully() +extension Logger { + fileprivate func withLevel(_ level: Logger.Level) -> Logger { + var copy = self + copy.logLevel = level + return copy } } - -extension ServiceLifecycle: Service {} diff --git a/Sources/Alchemy/Application/Application+TLS.swift b/Sources/Alchemy/Application/Application+TLS.swift new file mode 100644 index 00000000..57bc7d19 --- /dev/null +++ b/Sources/Alchemy/Application/Application+TLS.swift @@ -0,0 +1,24 @@ +import NIOSSL +import NIOHTTP1 +import HummingbirdTLS +import Hummingbird + +extension Application { + /// Use HTTPS when serving. + /// + /// - Parameters: + /// - key: The path to the private key. + /// - cert: The path of the cert. + /// - Throws: Any errors encountered when accessing the certs. + public func useHTTPS(key: String, cert: String) throws { + try useHTTPS(tlsConfig: .makeServerConfiguration(key: key, cert: cert)) + } + + /// Use HTTPS when serving. + /// + /// - Parameter tlsConfig: A raw NIO `TLSConfiguration` to use. + public func useHTTPS(tlsConfig: TLSConfiguration) throws { + @Inject var app: HBApplication + try app.server.addTLS(tlsConfiguration: tlsConfig) + } +} diff --git a/Sources/Alchemy/Application/Application.swift b/Sources/Alchemy/Application/Application.swift index 26eb0a95..17fcb65d 100644 --- a/Sources/Alchemy/Application/Application.swift +++ b/Sources/Alchemy/Application/Application.swift @@ -1,25 +1,41 @@ +import Lifecycle +import Hummingbird + /// The core type for an Alchemy application. Implement this & it's /// `boot` function, then add the `@main` attribute to mark it as /// the entrypoint for your application. /// -/// ```swift -/// @main -/// struct App: Application { -/// func boot() { -/// get("/hello") { _ in -/// "Hello, world!" +/// @main +/// struct App: Application { +/// func boot() { +/// get("/hello") { _ in +/// "Hello, world!" +/// } /// } -/// ... /// } -/// } -/// ``` +/// public protocol Application { - /// Called before any launch command is run. Called AFTER any - /// environment is loaded and the global `EventLoopGroup` is - /// set. Called on an event loop, so `Loop.current` is - /// available for use if needed. - func boot() throws + /// Any custom commands provided by your application. + var commands: [Command.Type] { get } + /// The configuration of the underlying application. + var configuration: HBApplication.Configuration { get } + /// Setup your application here. Called after the environment + /// and services are loaded. + func boot() throws + /// Register your custom services to the application's service container + /// here + func services(container: Container) + /// Schedule any recurring jobs or tasks here. + func schedule(schedule: Scheduler) /// Required empty initializer. init() } + +// No-op defaults +extension Application { + public var commands: [Command.Type] { [] } + public var configuration: HBApplication.Configuration { HBApplication.Configuration(logLevel: .notice) } + public func services(container: Container) {} + public func schedule(schedule: Scheduler) {} +} diff --git a/Sources/Alchemy/Authentication/BasicAuthable.swift b/Sources/Alchemy/Auth/BasicAuthable.swift similarity index 75% rename from Sources/Alchemy/Authentication/BasicAuthable.swift rename to Sources/Alchemy/Auth/BasicAuthable.swift index 0f5a8840..0102af5a 100644 --- a/Sources/Alchemy/Authentication/BasicAuthable.swift +++ b/Sources/Alchemy/Auth/BasicAuthable.swift @@ -74,7 +74,7 @@ extension BasicAuthable { /// - Returns: A `Bool` indicating if `password` matched /// `passwordHash`. public static func verify(password: String, passwordHash: String) throws -> Bool { - try Bcrypt.verify(password, created: passwordHash) + try Bcrypt.verifySync(password, created: passwordHash) } /// A `Middleware` configured to validate the @@ -94,31 +94,24 @@ extension BasicAuthable { /// - password: The password to authenticate with. /// - error: An error to throw if the username password combo /// doesn't have a match. - /// - Returns: A future containing the authenticated - /// `BasicAuthable`, if there was one. The future will result in - /// `error` if the model is not found, or the password doesn't - /// match. - public static func authenticate( - username: String, - password: String, - else error: Error = HTTPError(.unauthorized) - ) -> EventLoopFuture { - return query() + /// - Returns: A the authenticated `BasicAuthable`, if there was + /// one. Throws `error` if the model is not found, or the + /// password doesn't match. + public static func authenticate(username: String, password: String, else error: Error = HTTPError(.unauthorized)) async throws -> Self { + let rows = try await query() .where(usernameKeyString == username) - .get(["\(tableName).*", passwordKeyString]) - .flatMapThrowing { rows -> Self in - guard let firstRow = rows.first else { - throw error - } - - let passwordHash = try firstRow.getField(column: passwordKeyString).string() - - guard try verify(password: password, passwordHash: passwordHash) else { - throw error - } - - return try firstRow.decode(Self.self) - } + .getRows(["\(tableName).*", passwordKeyString]) + + guard let firstRow = rows.first else { + throw error + } + + let passwordHash = try firstRow.get(passwordKeyString).value.string() + guard try verify(password: password, passwordHash: passwordHash) else { + throw error + } + + return try firstRow.decode(Self.self) } } @@ -130,17 +123,12 @@ extension BasicAuthable { /// basic auth values don't match a row in the database, an /// `HTTPError(.unauthorized)` will be thrown. public struct BasicAuthMiddleware: Middleware { - public func intercept( - _ request: Request, - next: @escaping Next - ) -> EventLoopFuture { - catchError { - guard let basicAuth = request.basicAuth() else { - throw HTTPError(.unauthorized) - } - - return B.authenticate(username: basicAuth.username, password: basicAuth.password) - .flatMap { next(request.set($0)) } + public func intercept(_ request: Request, next: Next) async throws -> Response { + guard let basicAuth = request.basicAuth() else { + throw HTTPError(.unauthorized) } + + let model = try await B.authenticate(username: basicAuth.username, password: basicAuth.password) + return try await next(request.set(model)) } } diff --git a/Sources/Alchemy/Authentication/TokenAuthable.swift b/Sources/Alchemy/Auth/TokenAuthable.swift similarity index 75% rename from Sources/Alchemy/Authentication/TokenAuthable.swift rename to Sources/Alchemy/Auth/TokenAuthable.swift index 8453f31a..6f2eb860 100644 --- a/Sources/Alchemy/Authentication/TokenAuthable.swift +++ b/Sources/Alchemy/Auth/TokenAuthable.swift @@ -10,9 +10,9 @@ import Foundation /// /// ```swift /// // Start with a Rune `Model`. -/// struct MyToken: TokenAuthable { +/// struct Token: TokenAuthable { /// // `KeyPath` to the relation of the `User`. -/// static var userKey: KeyPath> = \.$user +/// static var userKey = \Token.$user /// /// var id: Int? /// let value: String @@ -75,28 +75,23 @@ extension TokenAuthable { /// header, or the token value isn't valid, an /// `HTTPError(.unauthorized)` will be thrown. public struct TokenAuthMiddleware: Middleware { - public func intercept( - _ request: Request, - next: @escaping Next - ) -> EventLoopFuture { - catchError { - guard let bearerAuth = request.bearerAuth() else { - throw HTTPError(.unauthorized) - } - - return T.query() - .where(T.valueKeyString == bearerAuth.token) - .with(T.userKey) - .firstModel() - .flatMapThrowing { try $0.unwrap(or: HTTPError(.unauthorized)) } - .flatMap { - request - // Set the token - .set($0) - // Set the user - .set($0[keyPath: T.userKey].wrappedValue) - return next(request) - } + public func intercept(_ request: Request, next: Next) async throws -> Response { + guard let bearerAuth = request.bearerAuth() else { + throw HTTPError(.unauthorized) } + + let model = try await T.query() + .where(T.valueKeyString == bearerAuth.token) + .with(T.userKey) + .first() + .unwrap(or: HTTPError(.unauthorized)) + + return try await next( + request + // Set the token + .set(model) + // Set the user + .set(model[keyPath: T.userKey].wrappedValue) + ) } } diff --git a/Sources/Alchemy/Cache/Cache.swift b/Sources/Alchemy/Cache/Cache.swift index 6e8b7d7a..4a5d1eef 100644 --- a/Sources/Alchemy/Cache/Cache.swift +++ b/Sources/Alchemy/Cache/Cache.swift @@ -1,25 +1,32 @@ import Foundation -/// A type for accessing a persistant cache. Supported drivers are -/// `RedisCache`, `DatabaseCache` and, for testing, `MockCache`. +/// A type for accessing a persistant cache. Supported providers are +/// `RedisCache`, `DatabaseCache`, and `MemoryCache`. public final class Cache: Service { - private let driver: CacheDriver + public struct Identifier: ServiceIdentifier { + private let hashable: AnyHashable + public init(hashable: AnyHashable) { self.hashable = hashable } + } + + private let provider: CacheProvider - /// Initializer this cache with a driver. Prefer static functions + /// Initializer this cache with a provider. Prefer static functions /// like `.database()` or `.redis()` when configuring your /// application's cache. /// - /// - Parameter driver: A driver to back this cache with. - public init(_ driver: CacheDriver) { - self.driver = driver + /// - Parameter provider: A provider to back this cache with. + public init(provider: CacheProvider) { + self.provider = provider } /// Get the value for `key`. /// - /// - Parameter key: The key of the cache record. - /// - Returns: A future containing the value, if it exists. - public func get(_ key: String) -> EventLoopFuture { - driver.get(key) + /// - Parameters: + /// - key: The key of the cache record. + /// - type: The type to coerce fetched key to for return. + /// - Returns: The value for the key, if it exists. + public func get(_ key: String, as type: L.Type = L.self) async throws -> L? { + try await provider.get(key) } /// Set a record for `key`. @@ -28,33 +35,33 @@ public final class Cache: Service { /// - Parameter value: The value to set. /// - Parameter time: How long the cache record should live. /// Defaults to nil, indicating the record has no expiry. - /// - Returns: A future indicating the record has been set. - public func set(_ key: String, value: C, for time: TimeAmount? = nil) -> EventLoopFuture { - driver.set(key, value: value, for: time) + public func set(_ key: String, value: L, for time: TimeAmount? = nil) async throws { + try await provider.set(key, value: value, for: time) } /// Determine if a record for the given key exists. /// /// - Parameter key: The key to check. - /// - Returns: A future indicating if the record exists. - public func has(_ key: String) -> EventLoopFuture { - driver.has(key) + /// - Returns: Whether the record exists. + public func has(_ key: String) async throws -> Bool { + try await provider.has(key) } /// Delete and return a record at `key`. /// - /// - Parameter key: The key to delete. - /// - Returns: A future with the deleted record, if it existed. - public func remove(_ key: String) -> EventLoopFuture { - driver.remove(key) + /// - Parameters: + /// - key: The key to delete. + /// - type: The type to coerce the removed key to for return. + /// - Returns: The deleted record, if it existed. + public func remove(_ key: String, as type: L.Type = L.self) async throws -> L? { + try await provider.remove(key) } /// Delete a record at `key`. /// /// - Parameter key: The key to delete. - /// - Returns: A future that completes when the record is deleted. - public func delete(_ key: String) -> EventLoopFuture { - driver.delete(key) + public func delete(_ key: String) async throws { + try await provider.delete(key) } /// Increment the record at `key` by the give `amount`. @@ -62,9 +69,9 @@ public final class Cache: Service { /// - Parameters: /// - key: The key to increment. /// - amount: The amount to increment by. Defaults to 1. - /// - Returns: A future containing the new value of the record. - public func increment(_ key: String, by amount: Int = 1) -> EventLoopFuture { - driver.increment(key, by: amount) + /// - Returns: The new value of the record. + public func increment(_ key: String, by amount: Int = 1) async throws -> Int { + try await provider.increment(key, by: amount) } /// Decrement the record at `key` by the give `amount`. @@ -72,16 +79,13 @@ public final class Cache: Service { /// - Parameters: /// - key: The key to decrement. /// - amount: The amount to decrement by. Defaults to 1. - /// - Returns: A future containing the new value of the record. - public func decrement(_ key: String, by amount: Int = 1) -> EventLoopFuture { - driver.decrement(key, by: amount) + /// - Returns: The new value of the record. + public func decrement(_ key: String, by amount: Int = 1) async throws -> Int { + try await provider.decrement(key, by: amount) } /// Clear the entire cache. - /// - /// - Returns: A future that completes when the cache has been - /// wiped. - public func wipe() -> EventLoopFuture { - driver.wipe() + public func wipe() async throws { + try await provider.wipe() } } diff --git a/Sources/Alchemy/Cache/Drivers/CacheDriver.swift b/Sources/Alchemy/Cache/Drivers/CacheDriver.swift deleted file mode 100644 index a27a4042..00000000 --- a/Sources/Alchemy/Cache/Drivers/CacheDriver.swift +++ /dev/null @@ -1,87 +0,0 @@ -import Foundation - -public protocol CacheDriver { - /// Get the value for `key`. - /// - /// - Parameter key: The key of the cache record. - /// - Returns: A future containing the value, if it exists. - func get(_ key: String) -> EventLoopFuture - - /// Set a record for `key`. - /// - /// - Parameter key: The key. - /// - Parameter value: The value to set. - /// - Parameter time: How long the cache record should live. - /// Defaults to nil, indicating the record has no expiry. - /// - Returns: A future indicating the record has been set. - func set(_ key: String, value: C, for time: TimeAmount?) -> EventLoopFuture - - /// Determine if a record for the given key exists. - /// - /// - Parameter key: The key to check. - /// - Returns: A future indicating if the record exists. - func has(_ key: String) -> EventLoopFuture - - /// Delete and return a record at `key`. - /// - /// - Parameter key: The key to delete. - /// - Returns: A future with the deleted record, if it existed. - func remove(_ key: String) -> EventLoopFuture - - /// Delete a record at `key`. - /// - /// - Parameter key: The key to delete. - /// - Returns: A future that completes when the record is deleted. - func delete(_ key: String) -> EventLoopFuture - - /// Increment the record at `key` by the give `amount`. - /// - /// - Parameters: - /// - key: The key to increment. - /// - amount: The amount to increment by. Defaults to 1. - /// - Returns: A future containing the new value of the record. - func increment(_ key: String, by amount: Int) -> EventLoopFuture - - /// Decrement the record at `key` by the give `amount`. - /// - /// - Parameters: - /// - key: The key to decrement. - /// - amount: The amount to decrement by. Defaults to 1. - /// - Returns: A future containing the new value of the record. - func decrement(_ key: String, by amount: Int) -> EventLoopFuture - /// Clear the entire cache. - /// - /// - Returns: A future that completes when the cache has been - /// wiped. - func wipe() -> EventLoopFuture -} - -/// A type that can be set in a Cache. Must be convertible to and from -/// a `String`. -public protocol CacheAllowed { - /// Initialize this type with a string. - /// - /// - Parameter string: The string representing this object. - init?(_ string: String) - - /// The string value of this instance. - var stringValue: String { get } -} - -// MARK: - default CacheAllowed conformances - -extension Bool: CacheAllowed { - public var stringValue: String { "\(self)" } -} - -extension String: CacheAllowed { - public var stringValue: String { self } -} - -extension Int: CacheAllowed { - public var stringValue: String { "\(self)" } -} - -extension Double: CacheAllowed { - public var stringValue: String { "\(self)" } -} diff --git a/Sources/Alchemy/Cache/Drivers/DatabaseCache.swift b/Sources/Alchemy/Cache/Drivers/DatabaseCache.swift deleted file mode 100644 index dca51736..00000000 --- a/Sources/Alchemy/Cache/Drivers/DatabaseCache.swift +++ /dev/null @@ -1,174 +0,0 @@ -import Foundation -import NIO - -/// A SQL based driver for `Cache`. -final class DatabaseCache: CacheDriver { - private let db: Database - - /// Initialize this cache with a Database. - /// - /// - Parameter db: The database to cache with. - init(_ db: Database = .default) { - self.db = db - } - - /// Get's the item, deleting it and returning nil if it's expired. - private func getItem(key: String) -> EventLoopFuture { - CacheItem.query(database: self.db) - .where("_key" == key) - .firstModel() - .flatMap { item in - guard let item = item else { - return .new(nil) - } - - if item.isValid { - return .new(item) - } else { - return CacheItem.query() - .where("_key" == key) - .delete() - .map { _ in nil } - } - } - } - - // MARK: Cache - - func get(_ key: String) -> EventLoopFuture { - self.getItem(key: key) - .flatMapThrowing { try $0?.cast() } - } - - func set(_ key: String, value: C, for time: TimeAmount?) -> EventLoopFuture { - self.getItem(key: key) - .flatMap { item in - let expiration = time.map { Date().adding(time: $0) } - if var item = item { - item.text = value.stringValue - item.expiration = expiration ?? -1 - return item.save(db: self.db) - .voided() - } else { - return CacheItem(_key: key, text: value.stringValue, expiration: expiration ?? -1) - .save(db: self.db) - .voided() - } - } - } - - func has(_ key: String) -> EventLoopFuture { - self.getItem(key: key) - .map { $0?.isValid ?? false } - } - - func remove(_ key: String) -> EventLoopFuture { - self.getItem(key: key) - .flatMap { item in - catchError { - if let item = item { - let value: C = try item.cast() - return item - .delete() - .transform(to: item.isValid ? value : nil) - } else { - return .new(nil) - } - } - } - } - - func delete(_ key: String) -> EventLoopFuture { - CacheItem.query(database: self.db) - .where("_key" == key) - .delete() - .voided() - } - - func increment(_ key: String, by amount: Int) -> EventLoopFuture { - self.getItem(key: key) - .flatMap { item in - if var item = item { - return catchError { - let value: Int = try item.cast() - let newVal = value + amount - item.text = "\(value + amount)" - return item.save().transform(to: newVal) - } - } else { - return CacheItem(_key: key, text: "\(amount)") - .save(db: self.db) - .transform(to: amount) - } - } - } - - func decrement(_ key: String, by amount: Int) -> EventLoopFuture { - self.increment(key, by: -amount) - } - - func wipe() -> EventLoopFuture { - CacheItem.deleteAll(db: self.db) - } -} - -extension Cache { - /// Create a cache backed by an SQL database. - /// - /// - Parameter database: The database to drive your cache with. - /// Defaults to your default `Database`. - /// - Returns: A cache. - public static func database(_ database: Database = .default) -> Cache { - Cache(DatabaseCache(database)) - } -} - -/// Model for storing cache data -private struct CacheItem: Model { - static var tableName: String { "cache" } - - var id: Int? - let _key: String - var text: String - var expiration: Int = -1 - - var isValid: Bool { - guard expiration >= 0 else { - return true - } - - return expiration > Int(Date().timeIntervalSince1970) - } - - func validate() -> Self? { - self.isValid ? self : nil - } - - func cast(_ type: C.Type = C.self) throws -> C { - try C(self.text).unwrap(or: CacheError("Unable to cast cache item `\(self._key)` to \(C.self).")) - } -} - -extension Cache { - /// Migration for adding a cache table to your database. Don't - /// forget to apply this to your database before using a - /// database backed cache. - public struct AddCacheMigration: Alchemy.Migration { - public var name: String { "AddCacheMigration" } - - public init() {} - - public func up(schema: Schema) { - schema.create(table: "cache") { - $0.increments("id").primary() - $0.string("_key").notNull().unique() - $0.string("text", length: .unlimited).notNull() - $0.int("expiration").notNull() - } - } - - public func down(schema: Schema) { - schema.drop(table: "cache") - } - } -} diff --git a/Sources/Alchemy/Cache/Drivers/MockCache.swift b/Sources/Alchemy/Cache/Drivers/MockCache.swift deleted file mode 100644 index 4c13bc1c..00000000 --- a/Sources/Alchemy/Cache/Drivers/MockCache.swift +++ /dev/null @@ -1,123 +0,0 @@ -import Foundation - -/// An in memory driver for `Cache` for testing. -final class MockCacheDriver: CacheDriver { - private var data: [String: MockCacheItem] = [:] - - /// Create this cache populated with the given data. - /// - /// - Parameter defaultData: The initial items in the Cache. - init(_ defaultData: [String: MockCacheItem] = [:]) { - self.data = defaultData - } - - /// Gets an item and validates that it isn't expired, deleting it - /// if it is. - private func getItem(_ key: String) -> MockCacheItem? { - guard let item = self.data[key] else { - return nil - } - - if !item.isValid { - self.data[key] = nil - return nil - } else { - return item - } - } - - // MARK: Cache - - func get(_ key: String) -> EventLoopFuture where C : CacheAllowed { - catchError { - try .new(self.getItem(key)?.cast()) - } - } - - func set(_ key: String, value: C, for time: TimeAmount?) -> EventLoopFuture where C : CacheAllowed { - .new(self.data[key] = .init( - text: value.stringValue, - expiration: time.map { Date().adding(time: $0) }) - ) - } - - func has(_ key: String) -> EventLoopFuture { - .new(self.getItem(key) != nil) - } - - func remove(_ key: String) -> EventLoopFuture where C : CacheAllowed { - catchError { - let val: C? = try self.getItem(key)?.cast() - self.data.removeValue(forKey: key) - return .new(val) - } - } - - func delete(_ key: String) -> EventLoopFuture { - self.data.removeValue(forKey: key) - return .new() - } - - func increment(_ key: String, by amount: Int) -> EventLoopFuture { - catchError { - if let existing = self.getItem(key) { - let currentVal: Int = try existing.cast() - let newVal = currentVal + amount - self.data[key]?.text = "\(newVal)" - return .new(newVal) - } else { - self.data[key] = .init(text: "\(amount)") - return .new(amount) - } - } - } - - func decrement(_ key: String, by amount: Int) -> EventLoopFuture { - self.increment(key, by: -amount) - } - - func wipe() -> EventLoopFuture { - .new(self.data = [:]) - } -} - -/// An in memory cache item. -public struct MockCacheItem { - fileprivate var text: String - fileprivate var expiration: Int? - - fileprivate var isValid: Bool { - guard let expiration = self.expiration else { - return true - } - - return expiration > Int(Date().timeIntervalSince1970) - } - - /// Create a mock cache item. - /// - /// - Parameters: - /// - text: The text of the item. - /// - expiration: An optional expiration time, in seconds since - /// epoch. - public init(text: String, expiration: Int? = nil) { - self.text = text - self.expiration = expiration - } - - fileprivate func cast() throws -> C { - try C(self.text).unwrap(or: CacheError("Unable to cast '\(self.text)' to \(C.self)")) - } -} - -extension Cache { - /// Create a cache backed by an in memory dictionary. Useful for - /// tests. - /// - /// - Parameter data: Optional mock data to initialize your cache - /// with. Defaults to an empty dict. - /// - Returns: A mock cache. - public static func mock(_ data: [String: MockCacheItem] = [:]) -> Cache { - Cache(MockCacheDriver(data)) - } -} diff --git a/Sources/Alchemy/Cache/Drivers/RedisCache.swift b/Sources/Alchemy/Cache/Drivers/RedisCache.swift deleted file mode 100644 index f57c8aad..00000000 --- a/Sources/Alchemy/Cache/Drivers/RedisCache.swift +++ /dev/null @@ -1,69 +0,0 @@ -import Foundation -import RediStack - -/// A Redis based driver for `Cache`. -final class RedisCacheDriver: CacheDriver { - private let redis: Redis - - /// Initialize this cache with a Redis client. - /// - /// - Parameter redis: The client to cache with. - init(_ redis: Redis = .default) { - self.redis = redis - } - - // MARK: Cache - - func get(_ key: String) -> EventLoopFuture { - self.redis.get(RedisKey(key), as: String.self).map { $0.map(C.init) ?? nil } - } - - func set(_ key: String, value: C, for time: TimeAmount?) -> EventLoopFuture { - if let time = time { - return self.redis.setex(RedisKey(key), to: value.stringValue, expirationInSeconds: time.seconds) - } else { - return self.redis.set(RedisKey(key), to: value.stringValue) - } - } - - func has(_ key: String) -> EventLoopFuture { - self.redis.exists(RedisKey(key)).map { $0 > 0 } - } - - func remove(_ key: String) -> EventLoopFuture { - self.get(key).flatMap { (value: C?) -> EventLoopFuture in - guard let value = value else { - return .new(nil) - } - - return self.redis.delete(RedisKey(key)).transform(to: value) - } - } - - func delete(_ key: String) -> EventLoopFuture { - self.redis.delete(RedisKey(key)).voided() - } - - func increment(_ key: String, by amount: Int) -> EventLoopFuture { - self.redis.increment(RedisKey(key), by: amount) - } - - func decrement(_ key: String, by amount: Int) -> EventLoopFuture { - self.redis.decrement(RedisKey(key), by: amount) - } - - func wipe() -> EventLoopFuture { - self.redis.command("FLUSHDB").voided() - } -} - -public extension Cache { - /// Create a cache backed by Redis. - /// - /// - Parameter redis: The redis instance to drive your cache - /// with. Defaults to your default `Redis` configuration. - /// - Returns: A cache. - static func redis(_ redis: Redis = Redis.default) -> Cache { - Cache(RedisCacheDriver(redis)) - } -} diff --git a/Sources/Alchemy/Cache/Providers/CacheProvider.swift b/Sources/Alchemy/Cache/Providers/CacheProvider.swift new file mode 100644 index 00000000..0aa616d6 --- /dev/null +++ b/Sources/Alchemy/Cache/Providers/CacheProvider.swift @@ -0,0 +1,53 @@ +import Foundation + +public protocol CacheProvider { + /// Get the value for `key`. + /// + /// - Parameter key: The key of the cache record. + /// - Returns: The value, if it exists. + func get(_ key: String) async throws -> L? + + /// Set a record for `key`. + /// + /// - Parameter key: The key. + /// - Parameter value: The value to set. + /// - Parameter time: How long the cache record should live. + /// Defaults to nil, indicating the record has no expiry. + func set(_ key: String, value: L, for time: TimeAmount?) async throws + + /// Determine if a record for the given key exists. + /// + /// - Parameter key: The key to check. + /// - Returns: Whether the record exists. + func has(_ key: String) async throws -> Bool + + /// Delete and return a record at `key`. + /// + /// - Parameter key: The key to delete. + /// - Returns: The deleted record, if it existed. + func remove(_ key: String) async throws -> L? + + /// Delete a record at `key`. + /// + /// - Parameter key: The key to delete. + func delete(_ key: String) async throws + + /// Increment the record at `key` by the give `amount`. + /// + /// - Parameters: + /// - key: The key to increment. + /// - amount: The amount to increment by. Defaults to 1. + /// - Returns: The new value of the record. + func increment(_ key: String, by amount: Int) async throws -> Int + + /// Decrement the record at `key` by the give `amount`. + /// + /// - Parameters: + /// - key: The key to decrement. + /// - amount: The amount to decrement by. Defaults to 1. + /// - Returns: The new value of the record. + func decrement(_ key: String, by amount: Int) async throws -> Int + + /// Clear the entire cache. + func wipe() async throws +} diff --git a/Sources/Alchemy/Cache/Providers/DatabaseCache.swift b/Sources/Alchemy/Cache/Providers/DatabaseCache.swift new file mode 100644 index 00000000..efc62b19 --- /dev/null +++ b/Sources/Alchemy/Cache/Providers/DatabaseCache.swift @@ -0,0 +1,146 @@ +import Foundation +import NIO + +/// A SQL based provider for `Cache`. +final class DatabaseCache: CacheProvider { + private let db: Database + + /// Initialize this cache with a Database. + /// + /// - Parameter db: The database to cache with. + init(_ db: Database = DB) { + self.db = db + } + + /// Get's the item, deleting it and returning nil if it's expired. + private func getItem(key: String) async throws -> CacheItem? { + let item = try await CacheItem.query(database: db).where("_key" == key).first() + guard let item = item else { + return nil + } + + guard item.isValid else { + try await CacheItem.query(database: db).where("_key" == key).delete() + return nil + } + + return item + } + + // MARK: Cache + + func get(_ key: String) async throws -> L? { + try await getItem(key: key)?.cast() + } + + func set(_ key: String, value: L, for time: TimeAmount?) async throws { + let item = try await getItem(key: key) + let expiration = time.map { Date().adding(time: $0) } + if var item = item { + item.text = value.description + item.expiration = expiration ?? -1 + _ = try await item.save(db: db) + } else { + _ = try await CacheItem(_key: key, text: value.description, expiration: expiration ?? -1).save(db: db) + } + } + + func has(_ key: String) async throws -> Bool { + try await getItem(key: key)?.isValid ?? false + } + + func remove(_ key: String) async throws -> L? { + guard let item = try await getItem(key: key) else { + return nil + } + + let value: L = try item.cast() + _ = try await item.delete() + return item.isValid ? value : nil + } + + func delete(_ key: String) async throws { + _ = try await CacheItem.query(database: db).where("_key" == key).delete() + } + + func increment(_ key: String, by amount: Int) async throws -> Int { + if let item = try await getItem(key: key) { + let newVal = try item.cast() + amount + _ = try await item.update { $0.text = "\(newVal)" } + return newVal + } + + _ = try await CacheItem(_key: key, text: "\(amount)").save(db: db) + return amount + } + + func decrement(_ key: String, by amount: Int) async throws -> Int { + try await increment(key, by: -amount) + } + + func wipe() async throws { + try await CacheItem.deleteAll(db: db) + } +} + +extension Cache { + /// Create a cache backed by an SQL database. + /// + /// - Parameter database: The database to drive your cache with. + /// Defaults to your default `Database`. + /// - Returns: A cache. + public static func database(_ database: Database = DB) -> Cache { + Cache(provider: DatabaseCache(database)) + } + + /// Create a cache backed by the default SQL database. + public static var database: Cache { + .database() + } +} + +/// Model for storing cache data +private struct CacheItem: Model { + static var tableName: String { "cache" } + + var id: Int? + let _key: String + var text: String + var expiration: Int = -1 + + var isValid: Bool { + guard expiration >= 0 else { + return true + } + + return expiration > Int(Date().timeIntervalSince1970) + } + + func cast(_ type: L.Type = L.self) throws -> L { + try L(text).unwrap(or: CacheError("Unable to cast cache item `\(_key)` to \(L.self).")) + } +} + +extension Cache { + /// Migration for adding a cache table to your database. Don't + /// forget to apply this to your database before using a + /// database backed cache. + public struct AddCacheMigration: Alchemy.Migration { + public var name: String { "AddCacheMigration" } + + public init() {} + + public func up(schema: Schema) { + schema.create(table: "cache") { + $0.increments("id").primary() + $0.string("_key").notNull().unique() + $0.string("text", length: .unlimited).notNull() + $0.int("expiration").notNull() + } + } + + public func down(schema: Schema) { + schema.drop(table: "cache") + } + } +} diff --git a/Sources/Alchemy/Cache/Providers/MemoryCache.swift b/Sources/Alchemy/Cache/Providers/MemoryCache.swift new file mode 100644 index 00000000..61455b1d --- /dev/null +++ b/Sources/Alchemy/Cache/Providers/MemoryCache.swift @@ -0,0 +1,133 @@ +import Foundation + +/// An in memory provider for `Cache` for testing. +public final class MemoryCache: CacheProvider { + var data: [String: MemoryCacheItem] = [:] + + /// Create this cache populated with the given data. + /// + /// - Parameter defaultData: The initial items in the Cache. + init(_ defaultData: [String: MemoryCacheItem] = [:]) { + data = defaultData + } + + /// Gets an item and validates that it isn't expired, deleting it + /// if it is. + private func getItem(_ key: String) -> MemoryCacheItem? { + guard let item = self.data[key] else { + return nil + } + + guard item.isValid else { + self.data[key] = nil + return nil + } + + return item + } + + // MARK: Cache + + public func get(_ key: String) throws -> L? { + try getItem(key)?.cast() + } + + public func set(_ key: String, value: L, for time: TimeAmount?) { + data[key] = MemoryCacheItem(text: value.description, expiration: time.map { Date().adding(time: $0) }) + } + + public func has(_ key: String) -> Bool { + getItem(key) != nil + } + + public func remove(_ key: String) throws -> L? { + let val: L? = try getItem(key)?.cast() + data.removeValue(forKey: key) + return val + } + + public func delete(_ key: String) async throws { + data.removeValue(forKey: key) + } + + public func increment(_ key: String, by amount: Int) throws -> Int { + guard let existing = getItem(key) else { + self.data[key] = .init(text: "\(amount)") + return amount + } + + + let currentVal: Int = try existing.cast() + let newVal = currentVal + amount + self.data[key]?.text = "\(newVal)" + return newVal + } + + public func decrement(_ key: String, by amount: Int) throws -> Int { + try increment(key, by: -amount) + } + + public func wipe() { + data = [:] + } +} + +/// An in memory cache item. +public struct MemoryCacheItem { + fileprivate var text: String + fileprivate var expiration: Int? + + fileprivate var isValid: Bool { + guard let expiration = self.expiration else { + return true + } + + return expiration > Int(Date().timeIntervalSince1970) + } + + /// Create a mock cache item. + /// + /// - Parameters: + /// - text: The text of the item. + /// - expiration: An optional expiration time, in seconds since + /// epoch. + public init(text: String, expiration: Int? = nil) { + self.text = text + self.expiration = expiration + } + + fileprivate func cast() throws -> L { + try L(text).unwrap(or: CacheError("Unable to cast '\(text)' to \(L.self)")) + } +} + +extension Cache { + /// Create a cache backed by an in memory dictionary. Useful for + /// tests. + /// + /// - Parameter data: Any data to initialize your cache with. + /// Defaults to an empty dict. + /// - Returns: A memory backed cache. + public static func memory(_ data: [String: MemoryCacheItem] = [:]) -> Cache { + Cache(provider: MemoryCache(data)) + } + + /// A cache backed by an in memory dictionary. Useful for tests. + public static var memory: Cache { + .memory() + } + + /// Fakes a cache using by a memory based cache. Useful for tests. + /// + /// - Parameters: + /// - id: The identifier of the cache to fake. Defaults to `default`. + /// - data: Any data to initialize your cache with. Defaults to + /// an empty dict. + /// - Returns: A `MemoryCache` for verifying test expectations. + @discardableResult + public static func fake(_ identifier: Identifier = .default, _ data: [String: MemoryCacheItem] = [:]) -> MemoryCache { + let provider = MemoryCache(data) + bind(identifier, Cache(provider: provider)) + return provider + } +} diff --git a/Sources/Alchemy/Cache/Providers/RedisCache.swift b/Sources/Alchemy/Cache/Providers/RedisCache.swift new file mode 100644 index 00000000..6203d75d --- /dev/null +++ b/Sources/Alchemy/Cache/Providers/RedisCache.swift @@ -0,0 +1,80 @@ +import Foundation +import RediStack + +/// A Redis based provider for `Cache`. +final class RedisCache: CacheProvider { + private let redis: RedisClient + + /// Initialize this cache with a Redis client. + /// + /// - Parameter redis: The client to cache with. + init(_ redis: RedisClient = Redis) { + self.redis = redis + } + + // MARK: Cache + + func get(_ key: String) async throws -> L? { + guard let value = try await redis.get(RedisKey(key), as: String.self).get() else { + return nil + } + + return try L(value).unwrap(or: CacheError("Unable to cast cache item `\(key)` to \(L.self).")) + } + + func set(_ key: String, value: L, for time: TimeAmount?) async throws { + if let time = time { + _ = try await redis.transaction { conn in + try await conn.set(RedisKey(key), to: value.description).get() + _ = try await conn.send(command: "EXPIRE", with: [.init(from: key), .init(from: time.seconds)]).get() + } + } else { + try await redis.set(RedisKey(key), to: value.description).get() + } + } + + func has(_ key: String) async throws -> Bool { + try await redis.exists(RedisKey(key)).get() > 0 + } + + func remove(_ key: String) async throws -> L? { + guard let value: L = try await get(key) else { + return nil + } + + _ = try await redis.delete(RedisKey(key)).get() + return value + } + + func delete(_ key: String) async throws { + _ = try await redis.delete(RedisKey(key)).get() + } + + func increment(_ key: String, by amount: Int) async throws -> Int { + try await redis.increment(RedisKey(key), by: amount).get() + } + + func decrement(_ key: String, by amount: Int) async throws -> Int { + try await redis.decrement(RedisKey(key), by: amount).get() + } + + func wipe() async throws { + _ = try await redis.command("FLUSHDB") + } +} + +extension Cache { + /// Create a cache backed by Redis. + /// + /// - Parameter redis: The redis instance to drive your cache + /// with. Defaults to your default `Redis` configuration. + /// - Returns: A cache. + public static func redis(_ redis: RedisClient = Redis) -> Cache { + Cache(provider: RedisCache(redis)) + } + + /// A cache backed by the default Redis instance. + public static var redis: Cache { + .redis() + } +} diff --git a/Sources/Alchemy/Cache/Store+Config.swift b/Sources/Alchemy/Cache/Store+Config.swift new file mode 100644 index 00000000..3d2cab9c --- /dev/null +++ b/Sources/Alchemy/Cache/Store+Config.swift @@ -0,0 +1,13 @@ +extension Cache { + public struct Config { + public let caches: [Identifier: Cache] + + public init(caches: [Cache.Identifier : Cache]) { + self.caches = caches + } + } + + public static func configure(with config: Config) { + config.caches.forEach { Cache.bind($0, $1) } + } +} diff --git a/Sources/Alchemy/Client/Client.swift b/Sources/Alchemy/Client/Client.swift new file mode 100644 index 00000000..109cc98a --- /dev/null +++ b/Sources/Alchemy/Client/Client.swift @@ -0,0 +1,376 @@ +import AsyncHTTPClient +import NIOCore +import NIOHTTP1 + +/// A convenient client for making http requests from your app. Backed by +/// `AsyncHTTPClient`. +/// +/// The `Http` alias can be used to access your app's default client. +/// +/// let response = try await Http.get("https://swift.org") +/// +/// See `Client.Builder` for the request builder interface. +public final class Client: Service { + public struct Identifier: ServiceIdentifier { + private let hashable: AnyHashable + public init(hashable: AnyHashable) { self.hashable = hashable } + } + + /// A type for making http requests with a `Client`. Supports static or + /// streamed content. + public struct Request { + /// The url components. + public var urlComponents: URLComponents = URLComponents() + /// The request method. + public var method: HTTPMethod = .GET + /// Any headers for this request. + public var headers: HTTPHeaders = [:] + /// The body of this request, either a static buffer or byte stream. + public var body: ByteContent? = nil + /// The url of this request. + public var url: URL { urlComponents.url ?? URL(string: "/")! } + /// Remote host, resolved from `URL`. + public var host: String { urlComponents.host ?? "" } + /// The path of this request. + public var path: String { urlComponents.path } + /// How long until this request times out. + public var timeout: TimeAmount? = nil + /// Whether to stream the response. If false, the response body will be + /// fully accumulated before returning. + public var streamResponse: Bool = false + /// Custom config override when making this request. + public var config: HTTPClient.Configuration? = nil + /// Allows for extending storage on this type. + public var extensions = Extensions() + + public init(url: String = "", method: HTTPMethod = .GET, headers: HTTPHeaders = [:], body: ByteContent? = nil, timeout: TimeAmount? = nil) { + self.urlComponents = URLComponents(string: url) ?? URLComponents() + self.method = method + self.headers = headers + self.body = body + self.timeout = timeout + } + + /// The underlying `AsyncHTTPClient.HTTPClient.Request`. + fileprivate var _request: HTTPClient.Request { + get throws { + guard let url = urlComponents.url else { throw HTTPClientError.invalidURL } + let body: HTTPClient.Body? = { + switch self.body { + case .buffer(let buffer): + return .byteBuffer(buffer) + case .stream(let stream): + func writeStream(writer: HTTPClient.Body.StreamWriter) -> EventLoopFuture { + Loop.current.asyncSubmit { + try await stream.readAll { + try await writer.write(.byteBuffer($0)).get() + } + } + } + + return .stream(length: headers.contentLength, writeStream) + case .none: + return nil + } + }() + + return try HTTPClient.Request(url: url, method: method, headers: headers, body: body) + } + } + } + + /// The response type of a request made with client. Supports static or + /// streamed content. + public struct Response: ResponseInspector, ResponseConvertible { + /// The request that resulted in this response + public var request: Client.Request + /// Remote host of the request. + public var host: String + /// Response HTTP status. + public let status: HTTPResponseStatus + /// Response HTTP version. + public let version: HTTPVersion + /// Reponse HTTP headers. + public let headers: HTTPHeaders + /// Response body. + public var body: ByteContent? + /// Allows for extending storage on this type. + public var extensions = Extensions() + + /// Create a stubbed response with the given info. It will be returned + /// for any incoming request that matches the stub pattern. + public static func stub( + _ status: HTTPResponseStatus = .ok, + version: HTTPVersion = .http1_1, + headers: HTTPHeaders = [:], + body: ByteContent? = nil + ) -> Client.Response { + Client.Response(request: Request(url: ""), host: "", status: status, version: version, headers: headers, body: body) + } + + // MARK: ResponseConvertible + + public func response() async throws -> Alchemy.Response { + Alchemy.Response(status: status, headers: headers, body: body) + } + } + + public struct Builder: RequestBuilder { + public var client: Client + public var urlComponents: URLComponents { get { request.urlComponents } set { request.urlComponents = newValue} } + public var method: HTTPMethod { get { request.method } set { request.method = newValue} } + public var headers: HTTPHeaders { get { request.headers } set { request.headers = newValue} } + public var body: ByteContent? { get { request.body } set { request.body = newValue} } + private var request: Client.Request + + init(client: Client) { + self.client = client + self.request = Request() + } + + public func execute() async throws -> Client.Response { + try await client.execute(req: request) + } + + /// Sets an `HTTPClient.Configuration` for this request only. See the + /// `swift-server/async-http-client` package for configuration + /// options. + public func withClientConfig(_ config: HTTPClient.Configuration) -> Builder { + with { $0.request.config = config } + } + + /// Timeout if the request doesn't finish in the given time amount. + public func withTimeout(_ timeout: TimeAmount) -> Builder { + with { $0.request.timeout = timeout } + } + + /// Allow the response to be streamed. + public func withStream() -> Builder { + with { $0.request.streamResponse = true } + } + + /// Stub this builder's client, causing it to respond to all incoming + /// requests with a stub matching the request url or a default `200` + /// stub. + public func stub(_ stubs: Stubs = [:]) { + self.client.stubs = stubs + } + + /// Stub this builder's client, causing it to respond to all incoming + /// requests using the provided handler. + public func stub(_ handler: @escaping Stubs.Handler) { + self.client.stubs = Stubs(handler: handler) + } + } + + /// Represents stubbed responses for a client. + public final class Stubs: ExpressibleByDictionaryLiteral { + public typealias Handler = (Client.Request) -> Client.Response + private typealias Patterns = [(pattern: String, response: Client.Response)] + + private enum Kind { + case patterns(Patterns) + case handler(Handler) + } + + private static let wildcard: Character = "*" + private let kind: Kind + private(set) var stubbedRequests: [Client.Request] = [] + + init(handler: @escaping Handler) { + self.kind = .handler(handler) + } + + public init(dictionaryLiteral elements: (String, Client.Response)...) { + self.kind = .patterns(elements) + } + + func response(for req: Request) -> Response { + stubbedRequests.append(req) + + switch kind { + case .patterns(let patterns): + let match = patterns.first { pattern, _ in doesPattern(pattern, match: req) } + var stub: Client.Response = match?.response ?? .stub() + stub.request = req + stub.host = req.url.host ?? "" + return stub + case .handler(let handler): + return handler(req) + } + } + + private func doesPattern(_ pattern: String, match request: Request) -> Bool { + let requestUrl = [ + request.url.host, + request.url.port.map { ":\($0)" }, + request.url.path, + ] + .compactMap { $0 } + .joined() + + let patternUrl = pattern + .droppingPrefix("https://") + .droppingPrefix("http://") + + for (hostChar, patternChar) in zip(requestUrl, patternUrl) { + guard patternChar != Stubs.wildcard else { return true } + guard hostChar == patternChar else { return false } + } + + return requestUrl.count == patternUrl.count + } + } + + /// The underlying `AsyncHTTPClient.HTTPClient` used for making requests. + public var httpClient: HTTPClient + var stubs: Stubs? + + /// Create a client backed by the given `AsyncHTTPClient` client. Defaults + /// to a client using the default config and app `EventLoopGroup`. + public init(httpClient: HTTPClient = HTTPClient(eventLoopGroupProvider: .shared(Loop.group))) { + self.httpClient = httpClient + self.stubs = nil + } + + public func builder() -> Builder { + Builder(client: self) + } + + /// Shut down the underlying http client. + public func shutdown() throws { + try httpClient.syncShutdown() + } + + /// Stub this client, causing it to respond to all incoming requests with a + /// stub matching the request url or a default `200` stub. + public func stub(_ stubs: Stubs = [:]) { + self.stubs = stubs + } + + /// Stub this client, causing it to respond to all incoming requests using + /// the provided handler. + public func stub(_ handler: @escaping Stubs.Handler) { + self.stubs = Stubs(handler: handler) + } + + /// Execute a request. + /// + /// - Parameters: + /// - req: The request to execute. + /// - config: A custom configuration for the client that will execute the + /// request + /// - Returns: The request's response. + private func execute(req: Request) async throws -> Response { + if let stubs = stubs { + return stubs.response(for: req) + } else { + let deadline: NIODeadline? = req.timeout.map { .now() + $0 } + let httpClientOverride = req.config.map { HTTPClient(eventLoopGroupProvider: .shared(httpClient.eventLoopGroup), configuration: $0) } + defer { try? httpClientOverride?.syncShutdown() } + let _request = try req._request + let promise = Loop.group.next().makePromise(of: Response.self) + let delegate = ResponseDelegate(request: req, promise: promise, allowStreaming: req.streamResponse) + let client = httpClientOverride ?? httpClient + _ = client.execute(request: _request, delegate: delegate, deadline: deadline, logger: Log.logger) + return try await promise.futureResult.get() + } + } +} + +/// Converts an AsyncHTTPClient response into a `Client.Response`. +private class ResponseDelegate: HTTPClientResponseDelegate { + typealias Response = Void + + enum State { + case idle + case head(HTTPResponseHead) + case body(HTTPResponseHead, ByteBuffer) + case stream(HTTPResponseHead, ByteStream) + case error(Error) + } + + private let responsePromise: EventLoopPromise + private let request: Client.Request + private let allowStreaming: Bool + private var state = State.idle + + init(request: Client.Request, promise: EventLoopPromise, allowStreaming: Bool) { + self.request = request + self.responsePromise = promise + self.allowStreaming = allowStreaming + } + + func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { + switch self.state { + case .idle: + self.state = .head(head) + return task.eventLoop.makeSucceededFuture(()) + case .head: + preconditionFailure("head already set") + case .body: + preconditionFailure("no head received before body") + case .stream: + preconditionFailure("no head received before body") + case .error: + return task.eventLoop.makeSucceededFuture(()) + } + } + + func didReceiveBodyPart(task: HTTPClient.Task, _ part: ByteBuffer) -> EventLoopFuture { + switch self.state { + case .idle: + preconditionFailure("no head received before body") + case .head(let head): + self.state = .body(head, part) + return task.eventLoop.makeSucceededFuture(()) + case .body(let head, var body): + if allowStreaming { + let stream = ByteStream(eventLoop: task.eventLoop) + let response = Client.Response(request: request, host: request.host, status: head.status, version: head.version, headers: head.headers, body: .stream(stream)) + self.responsePromise.succeed(response) + self.state = .stream(head, stream) + + // Write the previous part, followed by this part, to the stream. + return stream._write(chunk: body) + .flatMap { stream._write(chunk: part) } + } else { + // The compiler can't prove that `self.state` is dead here (and it kinda isn't, there's + // a cross-module call in the way) so we need to drop the original reference to `body` in + // `self.state` or we'll get a CoW. To fix that we temporarily set the state to `.idle` (which + // has no associated data). We'll fix it at the bottom of this block. + self.state = .idle + var part = part + body.writeBuffer(&part) + self.state = .body(head, body) + return task.eventLoop.makeSucceededVoidFuture() + } + case .stream(_, let stream): + return stream._write(chunk: part) + case .error: + return task.eventLoop.makeSucceededFuture(()) + } + } + + func didReceiveError(task: HTTPClient.Task, _ error: Error) { + self.state = .error(error) + responsePromise.fail(error) + } + + func didFinishRequest(task: HTTPClient.Task) throws { + switch self.state { + case .idle: + preconditionFailure("no head received before end") + case .head(let head): + let response = Client.Response(request: request, host: request.host, status: head.status, version: head.version, headers: head.headers, body: nil) + responsePromise.succeed(response) + case .body(let head, let body): + let response = Client.Response(request: request, host: request.host, status: head.status, version: head.version, headers: head.headers, body: .buffer(body)) + responsePromise.succeed(response) + case .stream(_, let stream): + _ = stream._write(chunk: nil) + case .error: + break + } + } +} diff --git a/Sources/Alchemy/Client/ClientResponse+Helpers.swift b/Sources/Alchemy/Client/ClientResponse+Helpers.swift new file mode 100644 index 00000000..52659cff --- /dev/null +++ b/Sources/Alchemy/Client/ClientResponse+Helpers.swift @@ -0,0 +1,103 @@ +import AsyncHTTPClient + +extension Client.Response { + // MARK: Status Information + + public var isOk: Bool { status == .ok } + public var isSuccessful: Bool { (200...299).contains(status.code) } + public var isFailed: Bool { isClientError || isServerError } + public var isClientError: Bool { (400...499).contains(status.code) } + public var isServerError: Bool { (500...599).contains(status.code) } + + public func validateSuccessful() throws -> Self { + guard isSuccessful else { + throw ClientError(message: "The response code was not successful", request: request, response: self) + } + + return self + } + + // MARK: Headers + + public func header(_ name: String) -> String? { headers.first(name: name) } + + // MARK: Body + + public var data: Data? { body?.data() } + public var string: String? { body?.string() } + + public func decode(_ type: D.Type = D.self, using decoder: ContentDecoder = ByteContent.defaultDecoder) throws -> D { + guard let buffer = body?.buffer else { + throw ClientError(message: "The response had no body to decode from.", request: request, response: self) + } + + do { + return try decoder.decodeContent(D.self, from: buffer, contentType: headers.contentType) + } catch { + throw ClientError(message: "Error decoding `\(D.self)`. \(error)", request: request, response: self) + } + } +} + +/// An error encountered when making a `Client` request. +public struct ClientError: Error, CustomStringConvertible { + /// What went wrong. + public let message: String + /// The associated `HTTPClient.Request`. + public let request: Client.Request + /// The associated `HTTPClient.Response`. + public let response: Client.Response + + // MARK: - CustomStringConvertible + + public var description: String { + return """ + *** HTTP Client Error *** + \(message) + + *** Request *** + URL: \(request.method.rawValue) \(request.url.absoluteString) + Headers: [ + \(request.headers.debugString) + ] + Body: \(request.body?.debugString ?? "nil") + + *** Response *** + Status: \(response.status.code) \(response.status.reasonPhrase) + Headers: [ + \(response.headers.debugString) + ] + Body: \(response.body?.debugString ?? "nil") + """ + } +} + +extension HTTPHeaders { + fileprivate var debugString: String { + if Env.LOG_FULL_CLIENT_ERRORS ?? false { + return map { "\($0): \($1)" }.joined(separator: "\n ") + } else { + return map { "\($0.name)" }.joined(separator: "\n ") + } + } +} + +extension ByteContent { + fileprivate var debugString: String { + if Env.LOG_FULL_CLIENT_ERRORS ?? false { + switch self { + case .buffer(let buffer): + return buffer.string + case .stream: + return "" + } + } else { + switch self { + case .buffer(let buffer): + return "<\(buffer.readableBytes) bytes>" + case .stream: + return "" + } + } + } +} diff --git a/Sources/Alchemy/Commands/Command.swift b/Sources/Alchemy/Commands/Command.swift index 0f11411a..e3ff5dce 100644 --- a/Sources/Alchemy/Commands/Command.swift +++ b/Sources/Alchemy/Commands/Command.swift @@ -19,7 +19,7 @@ import ArgumentParser /// @Flag(help: "Should data be loaded but not saved.") /// var dry: Bool = false /// -/// func start() -> EventLoopFuture { +/// func start() async throws { /// if let userId = id { /// // sync only a specific user's data /// } else { @@ -41,6 +41,10 @@ import ArgumentParser /// $ swift run MyApp sync --id 2 --dry /// ``` public protocol Command: ParsableCommand { + /// The name of this command. Run it in the command line by passing this + /// name as an argument. Defaults to the type name. + static var name: String { get } + /// When running the app with this command, should the app /// shut down after the command `start()` is finished. /// Defaults to `true`. @@ -50,63 +54,69 @@ public protocol Command: ParsableCommand { /// worker or running the server. static var shutdownAfterRun: Bool { get } - /// Should the start and finish of this command be logged. - /// Defaults to true. + /// Should the start and finish of this command be logged. Defaults to true. static var logStartAndFinish: Bool { get } - /// Start the command. Your command's main logic should be here. - /// - /// - Returns: A future signalling the end of the command's - /// execution. - func start() -> EventLoopFuture + /// Run the command. Your command's main logic should be here. + func start() async throws /// An optional function to run when your command receives a /// shutdown signal. You likely don't need this unless your /// command runs indefinitely. Defaults to a no-op. - /// - /// - Returns: A future that finishes when shutdown finishes. - func shutdown() -> EventLoopFuture + func shutdown() async throws } extension Command { public static var shutdownAfterRun: Bool { true } public static var logStartAndFinish: Bool { true } - + + /// Registers this command with the application lifecycle. public func run() throws { - if Self.logStartAndFinish { - Log.info("[Command] running \(commandName)") - } - // By default, register self to lifecycle - registerToLifecycle() + registerWithLifecycle() } - public func shutdown() -> EventLoopFuture { - if Self.logStartAndFinish { - Log.info("[Command] finished \(commandName)") - } - return .new() - } + public func shutdown() {} /// Registers this command to the application lifecycle; useful /// for running the app with this command. - func registerToLifecycle() { - let lifecycle = ServiceLifecycle.default + func registerWithLifecycle() { + @Inject var lifecycle: ServiceLifecycle + lifecycle.register( - label: Self.configuration.commandName ?? name(of: Self.self), + label: Self.configuration.commandName ?? Alchemy.name(of: Self.self), start: .eventLoopFuture { Loop.group.next() - .flatSubmit(start) + .asyncSubmit { + if Self.logStartAndFinish { + Log.info("[Command] running \(Self.name)") + } + + try await start() + } .map { if Self.shutdownAfterRun { lifecycle.shutdown() } } }, - shutdown: .eventLoopFuture { Loop.group.next().flatSubmit(shutdown) } + shutdown: .eventLoopFuture { + Loop.group.next() + .asyncSubmit { + if Self.logStartAndFinish { + Log.info("[Command] finished \(Self.name)") + } + + try await shutdown() + } + } ) } - private var commandName: String { - name(of: Self.self) + public static var name: String { + Alchemy.name(of: Self.self) + } + + public static var configuration: CommandConfiguration { + CommandConfiguration(commandName: name) } } diff --git a/Sources/Alchemy/Commands/CommandError.swift b/Sources/Alchemy/Commands/CommandError.swift index 13ebd6f2..6fc7e93f 100644 --- a/Sources/Alchemy/Commands/CommandError.swift +++ b/Sources/Alchemy/Commands/CommandError.swift @@ -1,3 +1,15 @@ -struct CommandError: Error { +/// An error encountered when running a Command. +public struct CommandError: Error, CustomDebugStringConvertible { + /// What went wrong. let message: String + + /// Initialize a `CommandError` with a message detailing what + /// went wrong. + init(_ message: String) { + self.message = message + } + + public var debugDescription: String { + message + } } diff --git a/Sources/Alchemy/Commands/Launch.swift b/Sources/Alchemy/Commands/Launch.swift index 8bbad310..e47a032c 100644 --- a/Sources/Alchemy/Commands/Launch.swift +++ b/Sources/Alchemy/Commands/Launch.swift @@ -3,7 +3,7 @@ import Lifecycle /// Command to launch a given application. struct Launch: ParsableCommand { - @Locked static var userCommands: [Command.Type] = [] + static var customCommands: [Command.Type] = [] static var configuration: CommandConfiguration { CommandConfiguration( abstract: "Run an Alchemy app.", @@ -11,7 +11,10 @@ struct Launch: ParsableCommand { // Running RunServe.self, RunMigrate.self, - RunQueue.self, + RunWorker.self, + + // Database + SeedDatabase.self, // Make MakeController.self, @@ -20,15 +23,14 @@ struct Launch: ParsableCommand { MakeModel.self, MakeJob.self, MakeView.self, - ] + userCommands, + ] + customCommands, defaultSubcommand: RunServe.self ) } /// The environment file to load. Defaults to `env`. /// - /// This is a bit hacky since the env is actually parsed and set - /// in App.main, but this adds the validation for it being - /// entered properly. + /// This is a bit hacky since the env is actually parsed and set in Env, + /// but this adds the validation for it being entered properly. @Option(name: .shortAndLong) var env: String = "env" } diff --git a/Sources/Alchemy/Commands/Make/ColumnData.swift b/Sources/Alchemy/Commands/Make/ColumnData.swift index d32fc11c..14b1ef34 100644 --- a/Sources/Alchemy/Commands/Make/ColumnData.swift +++ b/Sources/Alchemy/Commands/Make/ColumnData.swift @@ -1,4 +1,4 @@ -struct ColumnData: Codable { +struct ColumnData: Codable, Equatable { let name: String let type: String let modifiers: [String] @@ -12,7 +12,7 @@ struct ColumnData: Codable { init(from input: String) throws { let components = input.split(separator: ":").map(String.init) guard components.count >= 2 else { - throw CommandError(message: "Invalid field: \(input). Need at least name and type, such as `name:string`") + throw CommandError("Invalid field: \(input). Need at least name and type, such as `name:string`") } let name = components[0] @@ -25,7 +25,7 @@ struct ColumnData: Codable { case "bigint": type = "bigInt" default: - throw CommandError(message: "Unknown field type `\(type)`") + throw CommandError("Unknown field type `\(type)`") } self.name = name @@ -36,7 +36,7 @@ struct ColumnData: Codable { extension Array where Element == ColumnData { static var defaultData: [ColumnData] = [ - ColumnData(name: "id", type: "increments", modifiers: ["notNull"]), + ColumnData(name: "id", type: "increments", modifiers: ["primary"]), ColumnData(name: "name", type: "string", modifiers: ["notNull"]), ColumnData(name: "email", type: "string", modifiers: ["notNull", "unique"]), ColumnData(name: "password", type: "string", modifiers: ["notNull"]), diff --git a/Sources/Alchemy/Commands/Make/FileCreator.swift b/Sources/Alchemy/Commands/Make/FileCreator.swift index d278ce8b..b2bdb07d 100644 --- a/Sources/Alchemy/Commands/Make/FileCreator.swift +++ b/Sources/Alchemy/Commands/Make/FileCreator.swift @@ -1,14 +1,18 @@ import Foundation import Rainbow -import SwiftCLI +/// Used to generate files related to an alchemy project. struct FileCreator { - static let shared = FileCreator() + static var shared = FileCreator(rootPath: "Sources/App/") - func create(fileName: String, contents: String, in directory: String, comment: String? = nil) throws { + /// The root path where files should be created, relative to the apps + /// working directory. + let rootPath: String + + func create(fileName: String, extension: String = "swift", contents: String, in directory: String, comment: String? = nil) throws { let migrationLocation = try folderPath(for: directory) - let filePath = "\(migrationLocation)/\(fileName).swift" + let filePath = "\(migrationLocation)/\(fileName).\(`extension`)" let destinationURL = URL(fileURLWithPath: filePath) try contents.write(to: destinationURL, atomically: true, encoding: .utf8) print("🧪 create \(filePath.green)") @@ -17,13 +21,22 @@ struct FileCreator { } } + func fileExists(at path: String) -> Bool { + FileManager.default.fileExists(atPath: rootPath + path) + } + private func folderPath(for name: String) throws -> String { - let locations = try Task.capture(bash: "find Sources/App -type d -name '\(name)'").stdout.split(separator: "\n") - if let folder = locations.first { - return String(folder) - } else { - try FileManager.default.createDirectory(at: URL(fileURLWithPath: "Sources/App/\(name)"), withIntermediateDirectories: true) - return "Sources/App/\(name)" + let folder = rootPath + name + guard FileManager.default.fileExists(atPath: folder) else { + try FileManager.default.createDirectory(at: URL(fileURLWithPath: folder), withIntermediateDirectories: true) + return folder } + + return folder + } + + static func mock() { + shared = FileCreator(rootPath: NSTemporaryDirectory()) } } + diff --git a/Sources/Alchemy/Commands/Make/MakeController.swift b/Sources/Alchemy/Commands/Make/MakeController.swift index 5240fa6c..f24a311c 100644 --- a/Sources/Alchemy/Commands/Make/MakeController.swift +++ b/Sources/Alchemy/Commands/Make/MakeController.swift @@ -18,14 +18,7 @@ struct MakeController: Command { self.model = model } - func start() -> EventLoopFuture { - catchError { - try createController() - return .new() - } - } - - private func createController() throws { + func start() throws { let template = model.map(modelControllerTemplate) ?? controllerTemplate() let fileName = model.map { "\($0)Controller" } ?? name try FileCreator.shared.create(fileName: "\(fileName)", contents: template, in: "Controllers") @@ -37,7 +30,7 @@ struct MakeController: Command { struct \(name): Controller { func route(_ app: Application) { - app.get("/index", handler: index) + app.get("/index", use: index) } private func index(req: Request) -> String { @@ -57,33 +50,32 @@ struct MakeController: Command { struct \(name)Controller: Controller { func route(_ app: Application) { app - .get("/\(resourcePath)", handler: index) - .post("/\(resourcePath)", handler: create) - .get("/\(resourcePath)/:id", handler: show) - .patch("/\(resourcePath)", handler: update) - .delete("/\(resourcePath)/:id", handler: delete) + .get("/\(resourcePath)", use: index) + .post("/\(resourcePath)", use: create) + .get("/\(resourcePath)/:id", use: show) + .patch("/\(resourcePath)", use: update) + .delete("/\(resourcePath)/:id", use: delete) } - private func index(req: Request) -> EventLoopFuture<[\(name)]> { - \(name).all() + private func index(req: Request) async throws -> [\(name)] { + try await \(name).all() } - private func create(req: Request) throws -> EventLoopFuture<\(name)> { - try req.decodeBody(as: \(name).self).insert() + private func create(req: Request) async throws -> \(name) { + try await req.decode(\(name).self).insertReturn() } - private func show(req: Request) throws -> EventLoopFuture<\(name)> { - \(name).find(try req.parameter("id")) - .unwrap(orError: HTTPError(.notFound)) + private func show(req: Request) async throws -> \(name) { + try await \(name).find(req.parameter("id")).unwrap(or: HTTPError(.notFound)) } - private func update(req: Request) throws -> EventLoopFuture<\(name)> { - \(name).update(try req.parameter("id"), with: try req.bodyDict()) - .unwrap(orError: HTTPError(.notFound)) + private func update(req: Request) async throws -> \(name) { + try await \(name).update(req.parameter("id"), with: req.body?.decodeJSONDictionary() ?? [:]) + .unwrap(or: HTTPError(.notFound)) } - private func delete(req: Request) throws -> EventLoopFuture { - \(name).delete(try req.parameter("id")) + private func delete(req: Request) async throws { + try await \(name).delete(req.parameter("id")) } } """ diff --git a/Sources/Alchemy/Commands/Make/MakeJob.swift b/Sources/Alchemy/Commands/Make/MakeJob.swift index e44d468a..e376be0a 100644 --- a/Sources/Alchemy/Commands/Make/MakeJob.swift +++ b/Sources/Alchemy/Commands/Make/MakeJob.swift @@ -9,11 +9,13 @@ struct MakeJob: Command { @Argument var name: String - func start() -> EventLoopFuture { - catchError { - try FileCreator.shared.create(fileName: name, contents: jobTemplate(), in: "Jobs") - return .new() - } + init() {} + init(name: String) { + self.name = name + } + + func start() throws { + try FileCreator.shared.create(fileName: name, contents: jobTemplate(), in: "Jobs") } private func jobTemplate() -> String { @@ -21,9 +23,8 @@ struct MakeJob: Command { import Alchemy struct \(name): Job { - func run() -> EventLoopFuture { + func run() async throws { // Write some code! - return .new() } } """ diff --git a/Sources/Alchemy/Commands/Make/MakeMiddleware.swift b/Sources/Alchemy/Commands/Make/MakeMiddleware.swift index 597924c6..1672a555 100644 --- a/Sources/Alchemy/Commands/Make/MakeMiddleware.swift +++ b/Sources/Alchemy/Commands/Make/MakeMiddleware.swift @@ -9,11 +9,13 @@ struct MakeMiddleware: Command { @Argument var name: String - func start() -> EventLoopFuture { - catchError { - try FileCreator.shared.create(fileName: name, contents: middlewareTemplate(), in: "Middleware") - return .new() - } + init() {} + init(name: String) { + self.name = name + } + + func start() throws { + try FileCreator.shared.create(fileName: name, contents: middlewareTemplate(), in: "Middleware") } private func middlewareTemplate() -> String { @@ -21,9 +23,9 @@ struct MakeMiddleware: Command { import Alchemy struct \(name): Middleware { - func intercept(_ request: Request, next: @escaping Next) throws -> EventLoopFuture { + func intercept(_ request: Request, next: Next) async throws -> Response { // Write some code! - return next(request) + return try await next(request) } } """ diff --git a/Sources/Alchemy/Commands/Make/MakeMigration.swift b/Sources/Alchemy/Commands/Make/MakeMigration.swift index 41b020ef..a0ccdc75 100644 --- a/Sources/Alchemy/Commands/Make/MakeMigration.swift +++ b/Sources/Alchemy/Commands/Make/MakeMigration.swift @@ -14,45 +14,40 @@ struct MakeMigration: Command { @Option(name: .shortAndLong) var table: String - private var columns: [ColumnData] = [] + @IgnoreDecoding + private var columns: [ColumnData]? init() {} - - init(name: String, table: String, columns: [ColumnData]) { + init(name: String, table: String, columns: [ColumnData]) { self.name = name self.table = table self.columns = columns + self.fields = [] } - func start() -> EventLoopFuture { - catchError { - guard !name.contains(":") else { - throw CommandError(message: "Invalid migration name `\(name)`. Perhaps you forgot to pass a name?") - } - - var migrationColumns: [ColumnData] = columns - - // Initialize rows - if migrationColumns.isEmpty { - migrationColumns = try fields.map(ColumnData.init) - if migrationColumns.isEmpty { migrationColumns = .defaultData } - } - - // Create files - try createMigration(columns: migrationColumns) - return .new() + func start() throws { + guard !name.contains(":") else { + throw CommandError("Invalid migration name `\(name)`. Perhaps you forgot to pass a name?") } + + var migrationColumns: [ColumnData] = columns ?? [] + + // Initialize rows + if migrationColumns.isEmpty { + migrationColumns = try fields.map(ColumnData.init) + if migrationColumns.isEmpty { migrationColumns = .defaultData } + } + + // Create files + try createMigration(columns: migrationColumns) } private func createMigration(columns: [ColumnData]) throws { - let dateFormatter = DateFormatter() - dateFormatter.dateFormat = "yyyy_MM_dd_HH_mm_ss" - let fileName = "\(dateFormatter.string(from: Date()))\(name)" try FileCreator.shared.create( - fileName: fileName, + fileName: name, contents: migrationTemplate(name: name, columns: columns), - in: "Migrations", - comment: "remember to add migration to a Database.migrations!") + in: "Database/Migrations", + comment: "remember to add migration to your database config!") } private func migrationTemplate(name: String, columns: [ColumnData]) throws -> String { @@ -83,7 +78,7 @@ private extension ColumnData { for modifier in modifiers.map({ String($0) }) { let splitComponents = modifier.split(separator: ".") guard let modifier = splitComponents.first else { - throw CommandError(message: "There was an empty field modifier.") + throw CommandError("There was an empty field modifier.") } switch modifier.lowercased() { @@ -98,12 +93,12 @@ private extension ColumnData { let table = splitComponents[safe: 1], let key = splitComponents[safe: 2] else { - throw CommandError(message: "Invalid references format `\(modifier)` expected `references.table.key`") + throw CommandError("Invalid references format `\(modifier)` expected `references.table.key`") } returnString.append(".references(\"\(key)\", on: \"\(table)\")") default: - throw CommandError(message: "Unknown column modifier \(modifier)") + throw CommandError("Unknown column modifier \(modifier)") } } diff --git a/Sources/Alchemy/Commands/Make/MakeModel.swift b/Sources/Alchemy/Commands/Make/MakeModel.swift index ca657738..81b532ff 100644 --- a/Sources/Alchemy/Commands/Make/MakeModel.swift +++ b/Sources/Alchemy/Commands/Make/MakeModel.swift @@ -5,7 +5,7 @@ import Papyrus typealias Flag = ArgumentParser.Flag typealias Option = ArgumentParser.Option -struct MakeModel: Command { +final class MakeModel: Command { static var logStartAndFinish: Bool = false static var configuration = CommandConfiguration( commandName: "make:model", @@ -28,27 +28,43 @@ struct MakeModel: Command { @Flag(name: .shortAndLong, help: "Also make a migration file for this model.") var migration: Bool = false @Flag(name: .shortAndLong, help: "Also make a controller with CRUD operations for this model.") var controller: Bool = false - func start() -> EventLoopFuture { - catchError { - guard !name.contains(":") else { - throw CommandError(message: "Invalid model name `\(name)`. Perhaps you forgot to pass a name?") - } - - // Initialize rows - var columns = try fields.map(ColumnData.init) - if columns.isEmpty { columns = .defaultData } - - // Create files - try createModel(columns: columns) - - let migrationFuture = migration ? MakeMigration( + @IgnoreDecoding + private var columns: [ColumnData]? + + init() {} + init(name: String, columns: [ColumnData] = [], migration: Bool = false, controller: Bool = false) { + self.name = name + self.columns = columns + self.fields = [] + self.migration = migration + self.controller = controller + } + + func start() throws { + guard !name.contains(":") else { + throw CommandError("Invalid model name `\(name)`. Perhaps you forgot to pass a name?") + } + + // Initialize rows + if (columns ?? []).isEmpty && fields.isEmpty { + columns = .defaultData + } else if (columns ?? []).isEmpty { + columns = try fields.map(ColumnData.init) + } + + // Create files + try createModel(columns: columns ?? []) + + if migration { + try MakeMigration( name: "Create\(name.pluralized)", table: name.camelCaseToSnakeCase().pluralized, - columns: columns - ).start() : .new() - - let controllerFuture = controller ? MakeController(model: name).start() : .new() - return migrationFuture.flatMap { controllerFuture } + columns: columns ?? [] + ).start() + } + + if controller { + try MakeController(model: name).start() } } @@ -95,11 +111,8 @@ private extension ColumnData { swiftType += "?" } - if name == "id" { - return "var \(name.snakeCaseToCamelCase()): \(swiftType)" - } else { - return "let \(name.snakeCaseToCamelCase()): \(swiftType)" - } + let declaration = name == "id" ? "var" : "let" + return "\(declaration) \(name.snakeCaseToCamelCase()): \(swiftType)" } } diff --git a/Sources/Alchemy/Commands/Make/MakeView.swift b/Sources/Alchemy/Commands/Make/MakeView.swift index 6b14a30a..3941f5eb 100644 --- a/Sources/Alchemy/Commands/Make/MakeView.swift +++ b/Sources/Alchemy/Commands/Make/MakeView.swift @@ -9,11 +9,13 @@ struct MakeView: Command { @Argument var name: String - func start() -> EventLoopFuture { - catchError { - try FileCreator.shared.create(fileName: name, contents: viewTemplate(), in: "Views") - return .new() - } + init() {} + init(name: String) { + self.name = name + } + + func start() throws { + try FileCreator.shared.create(fileName: name, contents: viewTemplate(), in: "Views") } private func viewTemplate() -> String { diff --git a/Sources/Alchemy/Commands/Migrate/RunMigrate.swift b/Sources/Alchemy/Commands/Migrate/RunMigrate.swift index 284d3f53..1e583057 100644 --- a/Sources/Alchemy/Commands/Migrate/RunMigrate.swift +++ b/Sources/Alchemy/Commands/Migrate/RunMigrate.swift @@ -18,17 +18,23 @@ struct RunMigrate: Command { @Flag(help: "Should migrations be rolled back") var rollback: Bool = false + init() {} + init(rollback: Bool) { + self.rollback = rollback + } + // MARK: Command - func start() -> EventLoopFuture { - // Run on event loop - Loop.group.next() - .flatSubmit(rollback ? Database.default.rollbackMigrations : Database.default.migrate) + func start() async throws { + if rollback { + try await DB.rollbackMigrations() + } else { + try await DB.migrate() + } } - func shutdown() -> EventLoopFuture { + func shutdown() async throws { let action = rollback ? "migration rollback" : "migrations" Log.info("[Migration] \(action) finished, shutting down.") - return .new() } } diff --git a/Sources/Alchemy/Commands/Queue/RunQueue.swift b/Sources/Alchemy/Commands/Queue/RunWorker.swift similarity index 66% rename from Sources/Alchemy/Commands/Queue/RunQueue.swift rename to Sources/Alchemy/Commands/Queue/RunWorker.swift index 0c9cdf41..3bd73aee 100644 --- a/Sources/Alchemy/Commands/Queue/RunQueue.swift +++ b/Sources/Alchemy/Commands/Queue/RunWorker.swift @@ -1,11 +1,10 @@ import ArgumentParser import Lifecycle -/// Command to serve on launched. This is a subcommand of `Launch`. -/// The app will route with the singleton `HTTPRouter`. -struct RunQueue: Command { +/// Command to run queue workers. +struct RunWorker: Command { static var configuration: CommandConfiguration { - CommandConfiguration(commandName: "queue") + CommandConfiguration(commandName: "worker") } static var shutdownAfterRun: Bool = false @@ -28,27 +27,38 @@ struct RunQueue: Command { /// work. @Flag var schedule: Bool = false + init() {} + init(name: String?, channels: String = Queue.defaultChannel, workers: Int = 1, schedule: Bool = false) { + self.name = name + self.channels = channels + self.workers = workers + self.schedule = schedule + } + // MARK: Command func run() throws { - let queue: Queue = name.map { .named($0) } ?? .default - ServiceLifecycle.default - .registerWorkers(workers, on: queue, channels: channels.components(separatedBy: ",")) + let queue: Queue = name.map { .id(.init(hashable: $0)) } ?? Q + + @Inject var lifecycle: ServiceLifecycle + lifecycle.registerWorkers(workers, on: queue, channels: channels.components(separatedBy: ",")) if schedule { - ServiceLifecycle.default.registerScheduler() + lifecycle.registerScheduler() } - + let schedulerText = schedule ? "scheduler and " : "" Log.info("[Queue] started \(schedulerText)\(workers) workers.") } - func start() -> EventLoopFuture { .new() } + func start() {} } extension ServiceLifecycle { + private var scheduler: Scheduler { Container.resolveAssert() } + /// Start the scheduler when the app starts. func registerScheduler() { - register(label: "Scheduler", start: .sync { Scheduler.default.start() }, shutdown: .none) + register(label: "Scheduler", start: .sync { scheduler.start() }, shutdown: .none) } /// Start queue workers when the app starts. @@ -62,17 +72,12 @@ extension ServiceLifecycle { for worker in 0.. EventLoopFuture -} - -/// Responds to incoming `HTTPRequests` with an `Response` generated -/// by the `HTTPRouter`. -final class HTTPHandler: ChannelInboundHandler { - typealias InboundIn = HTTPServerRequestPart - typealias OutboundOut = HTTPServerResponsePart - - // Indicates that the TCP connection needs to be closed after a - // response has been sent. - private var keepAlive = true - - /// A temporary local Request that is used to accumulate data - /// into. - private var request: Request? - - /// The responder to all requests. - private let router: HTTPRouter - - /// Initialize with a responder to handle all requests. - /// - /// - Parameter responder: The object to respond to all incoming - /// `Request`s. - init(router: HTTPRouter) { - self.router = router - } - - /// Received incoming `InboundIn` data, writing a response based - /// on the `Responder`. - /// - /// - Parameters: - /// - context: The context of the handler. - /// - data: The inbound data received. - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let part = self.unwrapInboundIn(data) - - switch part { - case .head(let requestHead): - // If the part is a `head`, a new Request is received - keepAlive = requestHead.isKeepAlive - - let contentLength: Int - - // We need to check the content length to reserve memory - // for the body - if let length = requestHead.headers["content-length"].first { - contentLength = Int(length) ?? 0 - } else { - contentLength = 0 - } - - let body: ByteBuffer? - - // Allocates the memory for accumulation - if contentLength > 0 { - body = context.channel.allocator.buffer(capacity: contentLength) - } else { - body = nil - } - - self.request = Request( - head: requestHead, - bodyBuffer: body - ) - case .body(var newData): - // Appends new data to the already reserved buffer - self.request?.bodyBuffer?.writeBuffer(&newData) - case .end: - guard let request = request else { return } - - // Responds to the request - let response = router.handle(request: request) - // Ensure we're on the right ELF or NIO will assert. - .hop(to: context.eventLoop) - self.request = nil - - // Writes the response when done - self.writeResponse(version: request.head.version, response: response, to: context) - } - } - - /// Writes the `Responder`'s `Response` to a - /// `ChannelHandlerContext`. - /// - /// - Parameters: - /// - version: The HTTP version of the connection. - /// - response: The reponse to write to the handler context. - /// - context: The context to write to. - /// - Returns: An future that completes when the response is - /// written. - @discardableResult - private func writeResponse(version: HTTPVersion, response: EventLoopFuture, to context: ChannelHandlerContext) -> EventLoopFuture { - return response.flatMap { response in - let responseWriter = HTTPResponseWriter(version: version, handler: self, context: context) - responseWriter.completionPromise.futureResult.whenComplete { _ in - if !self.keepAlive { - context.close(promise: nil) - } - } - - response.write(to: responseWriter) - return responseWriter.completionPromise.futureResult - } - } - - /// Handler for when the channel read is complete. - /// - /// - Parameter context: the context to send events to. - func channelReadComplete(context: ChannelHandlerContext) { - context.flush() - } -} - -/// Used for writing a response to a remote peer with an -/// `HTTPHandler`. -private struct HTTPResponseWriter: ResponseWriter { - /// A promise to hook into for when the writing is finished. - let completionPromise: EventLoopPromise - - /// The HTTP version we're working with. - private var version: HTTPVersion - - /// The handler in which this writer is writing. - private let handler: HTTPHandler - - /// The context that should be written to. - private let context: ChannelHandlerContext - - /// Initialize - /// - Parameters: - /// - version: The HTTPVersion of this connection. - /// - handler: The handler in which this response is writing - /// inside. - /// - context: The context to write responses to. - init(version: HTTPVersion, handler: HTTPHandler, context: ChannelHandlerContext) { - self.version = version - self.handler = handler - self.context = context - self.completionPromise = context.eventLoop.makePromise() - } - - // MARK: ResponseWriter - - func writeHead(status: HTTPResponseStatus, _ headers: HTTPHeaders) { - let head = HTTPResponseHead(version: version, status: status, headers: headers) - context.write(handler.wrapOutboundOut(.head(head)), promise: nil) - } - - func writeBody(_ body: ByteBuffer) { - context.writeAndFlush(handler.wrapOutboundOut(.body(IOData.byteBuffer(body))), promise: nil) - } - - func writeEnd() { - context.writeAndFlush(handler.wrapOutboundOut(.end(nil)), promise: completionPromise) - } -} diff --git a/Sources/Alchemy/Commands/Serve/RunServe.swift b/Sources/Alchemy/Commands/Serve/RunServe.swift index fbbc5e31..8ae43286 100644 --- a/Sources/Alchemy/Commands/Serve/RunServe.swift +++ b/Sources/Alchemy/Commands/Serve/RunServe.swift @@ -4,14 +4,12 @@ import NIOSSL import NIOHTTP1 import NIOHTTP2 import Lifecycle +import Hummingbird /// Command to serve on launched. This is a subcommand of `Launch`. /// The app will route with the singleton `HTTPRouter`. final class RunServe: Command { - static var configuration: CommandConfiguration { - CommandConfiguration(commandName: "serve") - } - + static let configuration = CommandConfiguration(commandName: "serve") static var shutdownAfterRun: Bool = false static var logStartAndFinish: Bool = false @@ -36,159 +34,113 @@ final class RunServe: Command { /// Should migrations be run before booting. Defaults to `false`. @Flag var migrate: Bool = false - @IgnoreDecoding - private var channel: Channel? + init() {} + init(host: String = "127.0.0.1", port: Int = 3000, workers: Int = 0, schedule: Bool = false, migrate: Bool = false) { + self.host = host + self.port = port + self.unixSocket = nil + self.workers = workers + self.schedule = schedule + self.migrate = migrate + } // MARK: Command func run() throws { - let lifecycle = ServiceLifecycle.default + @Inject var lifecycle: ServiceLifecycle + @Inject var app: Application + if migrate { lifecycle.register( label: "Migrate", start: .eventLoopFuture { Loop.group.next() - .flatSubmit(Database.default.migrate) + .asyncSubmit(DB.migrate) }, shutdown: .none ) } - registerToLifecycle() + var config = app.configuration + if let unixSocket = unixSocket { + config = config.with(address: .unixDomainSocket(path: unixSocket)) + } else { + config = config.with(address: .hostname(host, port: port)) + } + + let server = HBApplication(configuration: config, eventLoopGroupProvider: .shared(Loop.group)) + server.router = app.router + Container.bind(.singleton, value: server) + + registerWithLifecycle() if schedule { lifecycle.registerScheduler() } if workers > 0 { - lifecycle.registerWorkers(workers, on: .default) + lifecycle.registerWorkers(workers, on: Q) } } - func start() -> EventLoopFuture { - func childChannelInitializer(_ channel: Channel) -> EventLoopFuture { - channel.pipeline - .addAnyTLS() - .flatMap { channel.addHTTP() } - } + func start() throws { + @Inject var server: HBApplication - let serverBootstrap = ServerBootstrap(group: Loop.group) - .serverChannelOption(ChannelOptions.backlog, value: 256) - .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) - .childChannelInitializer(childChannelInitializer) - .childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) - .childChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) - .childChannelOption(ChannelOptions.maxMessagesPerRead, value: 1) + try server.start() + if let unixSocket = unixSocket { + Log.info("[Server] listening on \(unixSocket).") + } else { + Log.info("[Server] listening on \(host):\(port).") + } + } + + func shutdown() throws { + @Inject var server: HBApplication - let channel = { () -> EventLoopFuture in - if let unixSocket = unixSocket { - return serverBootstrap.bind(unixDomainSocketPath: unixSocket) + let promise = server.eventLoopGroup.next().makePromise(of: Void.self) + server.lifecycle.shutdown { error in + if let error = error { + promise.fail(error) } else { - return serverBootstrap.bind(host: host, port: port) + promise.succeed(()) } - }() + } - return channel - .map { boundChannel in - guard let channelLocalAddress = boundChannel.localAddress else { - fatalError("Address was unable to bind. Please check that the socket was not closed or that the address family was understood.") - } - - self.channel = boundChannel - Log.info("[Server] listening on \(channelLocalAddress.prettyName)") - } - } - - func shutdown() -> EventLoopFuture { - channel?.close() ?? .new() + try promise.futureResult.wait() } } -@propertyWrapper -private struct IgnoreDecoding: Decodable { - var wrappedValue: T? - - init(from decoder: Decoder) throws { - wrappedValue = nil +extension Router: HBRouter { + public func respond(to request: HBRequest) -> EventLoopFuture { + request.eventLoop + .asyncSubmit { await self.handle(request: Request(hbRequest: request)) } + .map { HBResponse(status: $0.status, headers: $0.headers, body: $0.hbResponseBody) } } - init() { - wrappedValue = nil - } + public func add(_ path: String, method: HTTPMethod, responder: HBResponder) { /* using custom router funcs */ } } -extension SocketAddress { - /// A human readable description for this socket. - var prettyName: String { - switch self { - case .unixDomainSocket: - return pathname ?? "" - case .v4: - let address = ipAddress ?? "" - let port = port ?? 0 - return "\(address):\(port)" - case .v6: - let address = ipAddress ?? "" - let port = port ?? 0 - return "\(address):\(port)" +extension Response { + var hbResponseBody: HBResponseBody { + switch body { + case .buffer(let buffer): + return .byteBuffer(buffer) + case .stream(let stream): + return .stream(stream) + case .none: + return .empty } } } -extension ChannelPipeline { - /// Configures this pipeline with any TLS config in the - /// `ApplicationConfiguration`. - /// - /// - Returns: A future that completes when the config completes. - fileprivate func addAnyTLS() -> EventLoopFuture { - let config = Container.resolve(ApplicationConfiguration.self) - if var tls = config.tlsConfig { - if config.httpVersions.contains(.http2) { - tls.applicationProtocols.append("h2") - } - if config.httpVersions.contains(.http1_1) { - tls.applicationProtocols.append("http/1.1") - } - let sslContext = try! NIOSSLContext(configuration: tls) - let sslHandler = NIOSSLServerHandler(context: sslContext) - return addHandler(sslHandler) - } else { - return .new() - } +extension ByteStream: HBResponseBodyStreamer { + public func read(on eventLoop: EventLoop) -> EventLoopFuture { + _read(on: eventLoop).map { $0.map { .byteBuffer($0) } ?? .end } } } -extension Channel { - /// Configures this channel to handle whatever HTTP versions the - /// server should be speaking over. - /// - /// - Returns: A future that completes when the config completes. - fileprivate func addHTTP() -> EventLoopFuture { - let config = Container.resolve(ApplicationConfiguration.self) - if config.httpVersions.contains(.http2) { - return configureHTTP2SecureUpgrade( - h2ChannelConfigurator: { h2Channel in - h2Channel.configureHTTP2Pipeline( - mode: .server, - inboundStreamInitializer: { channel in - channel.pipeline - .addHandlers([ - HTTP2FramePayloadToHTTP1ServerCodec(), - HTTPHandler(router: Router.default) - ]) - }) - .voided() - }, - http1ChannelConfigurator: { http1Channel in - http1Channel.pipeline - .configureHTTPServerPipeline(withErrorHandling: true) - .flatMap { self.pipeline.addHandler(HTTPHandler(router: Router.default)) } - } - ) - } else { - return pipeline - .configureHTTPServerPipeline(withErrorHandling: true) - .flatMap { self.pipeline.addHandler(HTTPHandler(router: Router.default)) } - } +extension HBHTTPError: ResponseConvertible { + public func response() -> Response { + Response(status: status, headers: headers, body: body.map { .string($0) }) } } diff --git a/Sources/Alchemy/Config/Configurable.swift b/Sources/Alchemy/Config/Configurable.swift new file mode 100644 index 00000000..9c30de3e --- /dev/null +++ b/Sources/Alchemy/Config/Configurable.swift @@ -0,0 +1,42 @@ +/// A service that's configurable with a custom configuration +public protocol Configurable: AnyConfigurable { + associatedtype Config + + static var config: Config { get } + static func configure(with config: Config) +} + +/// Register services that the user may provide configurations for here. +/// Services registered here will have their default configurations run +/// before the main application boots. +public struct ConfigurableServices { + private static var configurableTypes: [Any.Type] = [ + Database.self, + Cache.self, + Queue.self, + Filesystem.self + ] + + public static func register(_ type: T.Type) { + configurableTypes.append(type) + } + + static func configureDefaults() { + for type in configurableTypes { + if let type = type as? AnyConfigurable.Type { + type.configureDefaults() + } + } + } +} + +/// An erased configurable. +public protocol AnyConfigurable { + static func configureDefaults() +} + +extension Configurable { + public static func configureDefaults() { + configure(with: Self.config) + } +} diff --git a/Sources/Alchemy/Config/Service.swift b/Sources/Alchemy/Config/Service.swift new file mode 100644 index 00000000..39c92f58 --- /dev/null +++ b/Sources/Alchemy/Config/Service.swift @@ -0,0 +1,77 @@ +import Lifecycle + +public protocol Service { + /// An identifier, unique to the service. + associatedtype Identifier: ServiceIdentifier + /// Start this service. Will be called when this service is first resolved. + func startup() + /// Shutdown this service. Will be called when the application your + /// service is registered to shuts down. + func shutdown() throws +} + +public protocol ServiceIdentifier: Hashable, ExpressibleByStringLiteral, ExpressibleByIntegerLiteral { + static var `default`: Self { get } + init(hashable: AnyHashable) +} + +extension ServiceIdentifier { + public static var `default`: Self { Self(hashable: AnyHashable(nil as AnyHashable?)) } + + // MARK: - ExpressibleByStringLiteral + + public init(stringLiteral value: String) { + self.init(hashable: value) + } + + // MARK: - ExpressibleByIntegerLiteral + + public init(integerLiteral value: Int) { + self.init(hashable: value) + } +} + +// By default, startup and shutdown are no-ops. +extension Service { + public func startup() {} + public func shutdown() throws {} +} + +extension Service { + + // MARK: Resolve shorthand + + public static var `default`: Self { + Container.resolveAssert(Self.self, identifier: Database.Identifier.default) + } + + public static func id(_ identifier: Identifier) -> Self { + Container.resolveAssert(Self.self, identifier: identifier) + } + + // MARK: Bind shorthand + + public static func bind(_ value: @escaping @autoclosure () -> Self) { + bind(.default, value()) + } + + public static func bind(_ identifier: Identifier = .default, _ value: Self) { + // Register as a singleton to the default container. + Container.bind(.singleton, identifier: identifier) { container -> Self in + value.startup() + return value + } + + // Need to register shutdown before lifecycle starts, but need to shutdown EACH singleton, + Container.resolveAssert(ServiceLifecycle.self) + .registerShutdown(label: "\(name(of: Self.self)):\(identifier)", .sync { + try value.shutdown() + }) + } +} + +extension Inject where Service: Alchemy.Service { + public convenience init(_ identifier: Service.Identifier) { + self.init(identifier: identifier) + } +} diff --git a/Sources/Alchemy/Config/ServiceIdentifier.swift b/Sources/Alchemy/Config/ServiceIdentifier.swift new file mode 100644 index 00000000..778b4154 --- /dev/null +++ b/Sources/Alchemy/Config/ServiceIdentifier.swift @@ -0,0 +1,23 @@ +///// Used to identify different instances of common services in Alchemy. +//public struct ServiceIdentifier: Hashable, ExpressibleByStringLiteral, ExpressibleByIntegerLiteral { +// /// The default identifier for a service. +// public static var `default`: Self { ServiceIdentifier(nil) } +// +// private var identifier: AnyHashable? +// +// public init(_ identifier: AnyHashable?) { +// self.identifier = identifier +// } +// +// // MARK: - ExpressibleByStringLiteral +// +// public init(stringLiteral value: String) { +// self.init(value) +// } +// +// // MARK: - ExpressibleByIntegerLiteral +// +// public init(integerLiteral value: Int) { +// self.init(value) +// } +//} diff --git a/Sources/Alchemy/Env/Env.swift b/Sources/Alchemy/Env/Env.swift index 2bd61b23..306bbe9b 100644 --- a/Sources/Alchemy/Env/Env.swift +++ b/Sources/Alchemy/Env/Env.swift @@ -18,15 +18,39 @@ private let kEnvVariable = "APP_ENV" /// let otherVariable: Int? = Env.OTHER_KEY /// ``` @dynamicMemberLookup -public struct Env: Equatable { - /// The default env file path (will be prefixed by a .). - static var defaultLocation = "env" +public struct Env: Equatable, ExpressibleByStringLiteral { + public static let test: Env = "test" + public static let dev: Env = "dev" + public static let prod: Env = "prod" + + /// The current environment containing all variables loaded from + /// the environment file. + public internal(set) static var current: Env = Env.isRunningTests ? .test : .dev + + private static var didManuallyLoadDotEnv = false /// The environment file location of this application. Additional /// env variables are pulled from the file at '.{name}'. This - /// defaults to `env` or `APP_ENV` if that is set. + /// defaults to `env`, `APP_ENV`, or `-e` / `--env` command + /// line arguments. public let name: String + /// All environment variables available to the application. + public var dotEnvVariables: [String: String] = [:] + + /// All environment variables available to the application. + public var processVariables: [String: String] = [:] + + public init(stringLiteral value: String) { + self.init(name: value) + } + + init(name: String, dotEnvVariables: [String: String] = [:], processVariables: [String: String] = [:]) { + self.name = name + self.dotEnvVariables = dotEnvVariables + self.processVariables = processVariables + } + /// Returns any environment variables loaded from the environment /// file as type `T: EnvAllowed`. Supports `String`, `Int`, /// `Double`, and `Bool`. @@ -34,31 +58,99 @@ public struct Env: Equatable { /// - Parameter key: The name of the environment variable. /// - Returns: The variable converted to type `S`. `nil` if the /// variable doesn't exist or it cannot be converted as `S`. - public func get(_ key: String) -> S? { - if let val = getenv(key) { - let stringValue = String(validatingUTF8: val) - return stringValue.map { S($0) } ?? nil + public func get(_ key: String, as: L.Type = L.self) -> L? { + guard let val = processVariables[key] ?? dotEnvVariables[key] else { + return nil } - return nil + + return L(val) + } + + /// Returns any environment variables from `Env.current` as type + /// `T: StringInitializable`. Supports `String`, `Int`, + /// `Double`, `Bool`, and `UUID`. + /// + /// - Parameter key: The name of the environment variable. + /// - Returns: The variable converted to type `S`. `nil` if no fallback is + /// provided and the variable doesn't exist or cannot be converted as + /// `S`. + public static func get(_ key: String, as: L.Type = L.self) -> L? { + current.get(key) } /// Required for dynamic member lookup. - public static subscript(dynamicMember member: String) -> T? { - return Env.current.get(member) + public static subscript(dynamicMember member: String) -> L? { + Env.get(member) } - /// All environment variables available to the program. - public var all: [String: String] { - return ProcessInfo.processInfo.environment + /// Boots the environment with the given arguments. Loads additional + /// environment variables from a `.env` file. + /// + /// - Parameter args: The command line args of the program. -e or --env will + /// indicate a custom envfile location. + static func boot(args: [String] = CommandLine.arguments, processEnv: [String: String] = ProcessInfo.processInfo.environment) { + loadEnv(args: args, processEnv: processEnv) + loadDotEnv() } - /// The current environment containing all variables loaded from - /// the environment file. - public static var current: Env = { - let appEnvPath = ProcessInfo.processInfo.environment[kEnvVariable] ?? defaultLocation - Env.loadDotEnvFile(path: ".\(appEnvPath)") - return Env(name: appEnvPath) - }() + static func loadEnv(args: [String] = CommandLine.arguments, processEnv: [String: String] = ProcessInfo.processInfo.environment) { + var env: Env = isRunningTests ? .test : .dev + if let index = args.firstIndex(of: "--env"), let value = args[safe: index + 1] { + env = Env(name: value) + } else if let index = args.firstIndex(of: "-e"), let value = args[safe: index + 1] { + env = Env(name: value) + } else if let value = processEnv[kEnvVariable] { + env = Env(name: value) + } + + env.processVariables = processEnv + current = env + } + + public static func loadDotEnv(_ paths: String...) { + guard paths.isEmpty else { + for path in paths { + guard let values = loadDotEnvFile(path: path) else { + continue + } + + for (key, value) in values { + current.dotEnvVariables[key] = value + } + } + + didManuallyLoadDotEnv = true + return + } + + guard !didManuallyLoadDotEnv else { + return + } + + let defaultPath = ".env" + var overridePath: String? = nil + if current != .dev { + overridePath = ".env.\(current.name)" + } + + if let overridePath = overridePath { + if let values = loadDotEnvFile(path: overridePath) { + Log.info("[Environment] loaded env from `\(overridePath)`.") + current.dotEnvVariables = values + } else { + Log.error("[Environment] couldnt find dotenv at `\(overridePath)`.") + } + } else if let values = loadDotEnvFile(path: defaultPath) { + Log.info("[Environment] loaded env from `\(defaultPath)`.") + current.dotEnvVariables = values + } else { + Log.info("[Environment] no dotenv file found.") + } + } + + public static func == (lhs: Env, rhs: Env) -> Bool { + lhs.name == rhs.name + } } extension Env { @@ -67,18 +159,20 @@ extension Env { /// /// - Parameter path: The path of the file from which to load the /// variables. - private static func loadDotEnvFile(path: String) { - let absolutePath = path.starts(with: "/") ? path : self.getAbsolutePath(relativePath: "/\(path)") + private static func loadDotEnvFile(path: String) -> [String: String]? { + let absolutePath = path.starts(with: "/") ? path : getAbsolutePath(relativePath: "/\(path)") guard let pathString = absolutePath else { - return Log.info("[Environment] no environment file found at '\(path)'") + return nil } - guard let contents = try? NSString(contentsOfFile: pathString, encoding: String.Encoding.utf8.rawValue) else { - return Log.info("[Environment] unable to load contents of file at '\(pathString)'") + guard let contents = try? String(contentsOfFile: pathString, encoding: .utf8) else { + Log.info("[Environment] unable to load contents of file at '\(pathString)'") + return [:] } - let lines = String(describing: contents).split { $0 == "\n" || $0 == "\r\n" }.map(String.init) + var values: [String: String] = [:] + let lines = contents.split { $0 == "\n" || $0 == "\r\n" }.map(String.init) for line in lines { // ignore comments if line[line.startIndex] == "#" { @@ -92,11 +186,6 @@ extension Env { // extract key and value which are separated by an equals sign let parts = line.split(separator: "=", maxSplits: 1).map(String.init) - - guard parts.count > 0 else { - continue - } - let key = parts[0].trimmingCharacters(in: NSCharacterSet.whitespacesAndNewlines) let val = parts[safe: 1]?.trimmingCharacters(in: NSCharacterSet.whitespacesAndNewlines) guard var value = val else { @@ -107,10 +196,12 @@ extension Env { if value[value.startIndex] == "\"" && value[value.index(before: value.endIndex)] == "\"" { value.remove(at: value.startIndex) value.remove(at: value.index(before: value.endIndex)) - value = value.replacingOccurrences(of:"\\\"", with: "\"") } - setenv(key, value, 1) + + values[key] = value } + + return values } /// Determines the absolute path of the given argument relative to @@ -121,22 +212,38 @@ extension Env { /// - Returns: The absolute path of the `relativePath`, if it /// exists. private static func getAbsolutePath(relativePath: String) -> String? { + warnIfUsingDerivedData() + let fileManager = FileManager.default - let currentPath = fileManager.currentDirectoryPath - if currentPath.contains("/Library/Developer/Xcode/DerivedData") { + let filePath = fileManager.currentDirectoryPath + relativePath + return fileManager.fileExists(atPath: filePath) ? filePath : nil + } + + static func warnIfUsingDerivedData(_ directory: String = FileManager.default.currentDirectoryPath) { + if directory.contains("/DerivedData") { Log.warning(""" **WARNING** Your project is running in Xcode's `DerivedData` data directory. We _highly_ recommend setting a custom working directory, otherwise `.env` and `Public/` files won't be accessible. This takes ~9 seconds to fix. Here's how: https://github.com/alchemy-swift/alchemy/blob/main/Docs/1_Configuration.md#setting-a-custom-working-directory. - """) - } - let filePath = currentPath + relativePath - if fileManager.fileExists(atPath: filePath) { - return filePath - } else { - return nil + """.yellow) } } } + +extension Env { + public static var isProd: Bool { + current.name == Env.prod.name + } + + public static var isTest: Bool { + current.name == Env.test.name + } + + /// Whether the current program is running in a test suite. This is not the + /// same as `isTest` which returns whether the current env is `Env.test` + public static var isRunningTests: Bool { + CommandLine.arguments.contains { $0.contains("xctest") } + } +} diff --git a/Sources/Alchemy/Env/EnvAllowed.swift b/Sources/Alchemy/Env/EnvAllowed.swift deleted file mode 100644 index 12086498..00000000 --- a/Sources/Alchemy/Env/EnvAllowed.swift +++ /dev/null @@ -1,17 +0,0 @@ -/// Protocol representing a type that can be created from a `String`. -public protocol StringInitializable { - /// Create this type from a string. - /// - /// - Parameter value: The string to create this type from. - init?(_ value: String) -} - -extension String: StringInitializable {} -extension Int: StringInitializable {} -extension Double: StringInitializable {} -extension Bool: StringInitializable {} -extension UUID: StringInitializable { - public init?(_ value: String) { - self.init(uuidString: value) - } -} diff --git a/Sources/Alchemy/Exports.swift b/Sources/Alchemy/Exports.swift index f6688473..a64fa99b 100644 --- a/Sources/Alchemy/Exports.swift +++ b/Sources/Alchemy/Exports.swift @@ -2,15 +2,10 @@ // Alchemy related @_exported import Fusion -@_exported import Papyrus // Argument Parser @_exported import ArgumentParser -// AsyncHTTPClient -@_exported import class AsyncHTTPClient.HTTPClient -@_exported import struct AsyncHTTPClient.HTTPClientError - // Foundation @_exported import Foundation @@ -22,13 +17,9 @@ // NIO @_exported import struct NIO.ByteBuffer -@_exported import struct NIO.ByteBufferAllocator @_exported import class NIO.EmbeddedEventLoop @_exported import protocol NIO.EventLoop -@_exported import class NIO.EventLoopFuture @_exported import protocol NIO.EventLoopGroup -@_exported import struct NIO.EventLoopPromise -@_exported import class NIO.MultiThreadedEventLoopGroup @_exported import struct NIO.NonBlockingFileIO @_exported import class NIO.NIOThreadPool @_exported import enum NIO.System @@ -39,6 +30,3 @@ @_exported import enum NIOHTTP1.HTTPMethod @_exported import struct NIOHTTP1.HTTPVersion @_exported import enum NIOHTTP1.HTTPResponseStatus - -// Plot -@_exported import Plot diff --git a/Sources/Alchemy/Filesystem/File.swift b/Sources/Alchemy/Filesystem/File.swift new file mode 100644 index 00000000..70db2020 --- /dev/null +++ b/Sources/Alchemy/Filesystem/File.swift @@ -0,0 +1,91 @@ +import MultipartKit +import Papyrus + +/// Represents a file with a name and binary contents. +public struct File: Codable, ResponseConvertible { + // The name of the file, including the extension. + public var name: String + // The size of the file, in bytes. + public let size: Int + // The binary contents of the file. + public var content: ByteContent + /// The path extension of this file. + public var `extension`: String { name.components(separatedBy: ".").last ?? "" } + /// The content type of this file, based on it's extension. + public let contentType: ContentType + + public init(name: String, contentType: ContentType? = nil, size: Int, content: ByteContent) { + self.name = name + self.size = size + self.content = content + let _extension = name.components(separatedBy: ".").last ?? "" + self.contentType = contentType ?? ContentType(fileExtension: _extension) ?? .octetStream + } + + /// Returns a copy of this file with a new name. + public func named(_ name: String) -> File { + var copy = self + copy.name = name + return copy + } + + // MARK: - ResponseConvertible + + public func response() async throws -> Response { + Response(status: .ok, headers: ["content-disposition":"inline; filename=\"\(name)\""]) + .withBody(content, type: contentType, length: size) + } + + public func download() async throws -> Response { + Response(status: .ok, headers: ["content-disposition":"attachment; filename=\"\(name)\""]) + .withBody(content, type: contentType, length: size) + } + + // MARK: - Decodable + + enum CodingKeys: String, CodingKey { + case name, size, content + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.name = try container.decode(String.self, forKey: .name) + self.size = try container.decode(Int.self, forKey: .size) + self.content = .data(try container.decode(Data.self, forKey: .content)) + let _extension = name.components(separatedBy: ".").last ?? "" + self.contentType = ContentType(fileExtension: _extension) ?? .octetStream + } + + // MARK: - Encodable + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(name, forKey: .name) + try container.encode(size, forKey: .size) + try container.encode(content.data(), forKey: .content) + } +} + +// As of now, streamed files aren't possible over request multipart. +extension File: MultipartPartConvertible { + public var multipart: MultipartPart? { + var headers: HTTPHeaders = [:] + headers.contentType = ContentType(fileExtension: `extension`) + headers.contentDisposition = HTTPHeaders.ContentDisposition(value: "form-data", name: nil, filename: name) + headers.contentLength = size + return MultipartPart(headers: headers, body: content.buffer) + } + + public init?(multipart: MultipartPart) { + let fileExtension = multipart.headers.contentType?.fileExtension.map { ".\($0)" } ?? "" + let fileName = multipart.headers.contentDisposition?.filename ?? multipart.headers.contentDisposition?.name + let fileSize = multipart.headers.contentLength ?? multipart.body.writerIndex + + if multipart.headers.contentDisposition?.filename == nil { + Log.warning("A multipart part had no name or filename in the Content-Disposition header, using a random UUID for the file name.") + } + + // If there is no filename in the content disposition included (technically not required via RFC 7578) set to a random UUID. + self.init(name: (fileName ?? UUID().uuidString) + fileExtension, contentType: multipart.headers.contentType, size: fileSize, content: .buffer(multipart.body)) + } +} diff --git a/Sources/Alchemy/Filesystem/Filesystem+Config.swift b/Sources/Alchemy/Filesystem/Filesystem+Config.swift new file mode 100644 index 00000000..d15353fa --- /dev/null +++ b/Sources/Alchemy/Filesystem/Filesystem+Config.swift @@ -0,0 +1,13 @@ +extension Filesystem { + public struct Config { + public let disks: [Identifier: Filesystem] + + public init(disks: [Identifier : Filesystem]) { + self.disks = disks + } + } + + public static func configure(with config: Config) { + config.disks.forEach { Filesystem.bind($0, $1) } + } +} diff --git a/Sources/Alchemy/Filesystem/Filesystem.swift b/Sources/Alchemy/Filesystem/Filesystem.swift new file mode 100644 index 00000000..eaa9a457 --- /dev/null +++ b/Sources/Alchemy/Filesystem/Filesystem.swift @@ -0,0 +1,59 @@ +import Foundation + +/// An abstraction around local or remote file storage. +public struct Filesystem: Service { + public struct Identifier: ServiceIdentifier { + private let hashable: AnyHashable + public init(hashable: AnyHashable) { self.hashable = hashable } + } + + private let provider: FilesystemProvider + + /// The root directory for storing and fetching files. + public var root: String { provider.root } + + public init(provider: FilesystemProvider) { + self.provider = provider + } + + /// Create a file in this storage. + /// - Parameters: + /// - filename: The name of the file, including extension, to create. + /// - directory: The directory to put the file in. If nil, goes in root. + /// - contents: the binary contents of the file. + /// - Returns: The newly created file. + @discardableResult + public func create(_ filepath: String, content: ByteContent) async throws -> File { + try await provider.create(filepath, content: content) + } + + /// Returns whether a file with the given path exists. + public func exists(_ filepath: String) async throws -> Bool { + try await provider.exists(filepath) + } + + /// Gets a file with the given path. + public func get(_ filepath: String) async throws -> File { + try await provider.get(filepath) + } + + /// Delete a file at the given path. + public func delete(_ filepath: String) async throws { + try await provider.delete(filepath) + } + + public func put(_ file: File, in directory: String? = nil) async throws { + guard let directory = directory, let directoryUrl = URL(string: directory) else { + try await create(file.name, content: file.content) + return + } + + try await create(directoryUrl.appendingPathComponent(file.name).path, content: file.content) + } +} + +extension File { + public func store(in directory: String? = nil, on filesystem: Filesystem = Storage) async throws { + try await filesystem.put(self, in: directory) + } +} diff --git a/Sources/Alchemy/Filesystem/FilesystemError.swift b/Sources/Alchemy/Filesystem/FilesystemError.swift new file mode 100644 index 00000000..993c637d --- /dev/null +++ b/Sources/Alchemy/Filesystem/FilesystemError.swift @@ -0,0 +1,5 @@ +public enum FileError: Error { + case invalidFileUrl + case fileDoesntExist + case filenameAlreadyExists +} diff --git a/Sources/Alchemy/Filesystem/Providers/FilesystemProvider.swift b/Sources/Alchemy/Filesystem/Providers/FilesystemProvider.swift new file mode 100644 index 00000000..c37850f8 --- /dev/null +++ b/Sources/Alchemy/Filesystem/Providers/FilesystemProvider.swift @@ -0,0 +1,23 @@ +public protocol FilesystemProvider { + /// The root directory for storing and fetching files. + var root: String { get } + + /// Create a file in this filesystem. + /// + /// - Parameters: + /// - filename: The name of the file, including extension, to create. + /// - directory: The directory to put the file in. If nil, goes in root. + /// - contents: the binary contents of the file. + /// - Returns: The newly created file. + @discardableResult + func create(_ filepath: String, content: ByteContent) async throws -> File + + /// Returns whether a file with the given path exists. + func exists(_ filepath: String) async throws -> Bool + + /// Gets a file with the given path. + func get(_ filepath: String) async throws -> File + + /// Delete a file at the given path. + func delete(_ filepath: String) async throws +} diff --git a/Sources/Alchemy/Filesystem/Providers/LocalFilesystem.swift b/Sources/Alchemy/Filesystem/Providers/LocalFilesystem.swift new file mode 100644 index 00000000..522fa93b --- /dev/null +++ b/Sources/Alchemy/Filesystem/Providers/LocalFilesystem.swift @@ -0,0 +1,111 @@ +import NIOCore + +extension Filesystem { + /// Create a filesystem backed by the local filesystem at the given root + /// directory. + public static func local(root: String = "Public/") -> Filesystem { + Filesystem(provider: LocalFilesystem(root: root)) + } + + /// Create a filesystem backed by the local filesystem in the "Public/" + /// directory. + public static var local: Filesystem { + .local() + } +} + +struct LocalFilesystem: FilesystemProvider { + /// The file IO helper for streaming files. + private let fileIO = NonBlockingFileIO(threadPool: Thread.pool) + /// Used for allocating buffers when pulling out file data. + private let bufferAllocator = ByteBufferAllocator() + + var root: String + + // MARK: - FilesystemProvider + + init(root: String) { + self.root = root + } + + func get(_ filepath: String) async throws -> File { + guard try await exists(filepath) else { + throw FileError.fileDoesntExist + } + + let url = try url(for: filepath) + let fileInfo = try FileManager.default.attributesOfItem(atPath: url.path) + guard let fileSizeBytes = (fileInfo[.size] as? NSNumber)?.intValue else { + Log.error("[Storage] attempted to access file at `\(url.path)` but it didn't have a size.") + throw HTTPError(.internalServerError) + } + + return File( + name: url.lastPathComponent, + size: fileSizeBytes, + content: .stream { writer in + // Load the file in chunks, streaming it. + let fileHandle = try NIOFileHandle(path: url.path) + defer { try? fileHandle.close() } + try await fileIO.readChunked( + fileHandle: fileHandle, + byteCount: fileSizeBytes, + chunkSize: NonBlockingFileIO.defaultChunkSize, + allocator: bufferAllocator, + eventLoop: Loop.current, + chunkHandler: { chunk in + Loop.current.asyncSubmit { try await writer.write(chunk) } + } + ).get() + }) + } + + func create(_ filepath: String, content: ByteContent) async throws -> File { + let url = try url(for: filepath) + guard try await !exists(filepath) else { + throw FileError.filenameAlreadyExists + } + + let fileHandle = try NIOFileHandle(path: url.path, mode: .write, flags: .allowFileCreation()) + defer { try? fileHandle.close() } + + // Stream and write + var offset: Int64 = 0 + try await content.stream.readAll { buffer in + try await fileIO.write(fileHandle: fileHandle, toOffset: offset, buffer: buffer, eventLoop: Loop.current).get() + offset += Int64(buffer.writerIndex) + } + + return try await get(filepath) + } + + func exists(_ filepath: String) async throws -> Bool { + let url = try url(for: filepath, createDirectories: false) + var isDirectory: ObjCBool = false + return FileManager.default.fileExists(atPath: url.path, isDirectory: &isDirectory) && !isDirectory.boolValue + } + + func delete(_ filepath: String) async throws { + guard try await exists(filepath) else { + throw FileError.fileDoesntExist + } + + try FileManager.default.removeItem(atPath: url(for: filepath).path) + } + + private func url(for filepath: String, createDirectories: Bool = true) throws -> URL { + guard let rootUrl = URL(string: root) else { + throw FileError.invalidFileUrl + } + + let url = rootUrl.appendingPathComponent(filepath.trimmingForwardSlash) + + // Ensure directory exists. + let directory = url.deletingLastPathComponent().path + if createDirectories && !FileManager.default.fileExists(atPath: directory) { + try FileManager.default.createDirectory(atPath: directory, withIntermediateDirectories: true) + } + + return url + } +} diff --git a/Sources/Alchemy/HTTP/Content/ByteContent.swift b/Sources/Alchemy/HTTP/Content/ByteContent.swift new file mode 100644 index 00000000..2efbd96e --- /dev/null +++ b/Sources/Alchemy/HTTP/Content/ByteContent.swift @@ -0,0 +1,307 @@ +import AsyncHTTPClient +import NIO +import Foundation +import NIOHTTP1 +import HummingbirdCore + +/// A collection of bytes that is either a single buffer or a stream of buffers. +public enum ByteContent: ExpressibleByStringLiteral { + /// The default decoder for reading content from an incoming request. + public static var defaultDecoder: ContentDecoder = .json + /// The default encoder for writing content to an outgoing response. + public static var defaultEncoder: ContentEncoder = .json + + case buffer(ByteBuffer) + case stream(ByteStream) + + public var buffer: ByteBuffer { + switch self { + case .stream: + preconditionFailure("Can't synchronously access data from streaming body, try `collect()` instead.") + case .buffer(let buffer): + return buffer + } + } + + public var stream: ByteStream { + switch self { + case .stream(let stream): + return stream + case .buffer(let buffer): + return .new { try await $0.write(buffer) } + } + } + + public var length: Int? { + switch self { + case .stream: + return nil + case .buffer(let buffer): + return buffer.writerIndex + } + } + + public init(stringLiteral value: StringLiteralType) { + self = .buffer(ByteBuffer(string: value)) + } + + /// Returns the contents of the entire buffer or stream as a single buffer. + public func collect() async throws -> ByteBuffer { + switch self { + case .buffer(let byteBuffer): + return byteBuffer + case .stream(let byteStream): + var collection = ByteBuffer() + try await byteStream.readAll { buffer in + var chunk = buffer + collection.writeBuffer(&chunk) + } + + return collection + } + } + + public static func stream(_ stream: @escaping ByteStream.Closure) -> ByteContent { + return .stream(.new(startStream: stream)) + } +} + +extension File { + @discardableResult + mutating func collect() async throws -> File { + self.content = .buffer(try await content.collect()) + return self + } +} + +extension Client.Response { + @discardableResult + public mutating func collect() async throws -> Client.Response { + self.body = (try await body?.collect()).map { .buffer($0) } + return self + } +} + +extension Response { + @discardableResult + public func collect() async throws -> Response { + self.body = (try await body?.collect()).map { .buffer($0) } + return self + } +} + +extension Request { + @discardableResult + public func collect() async throws -> Request { + self.hbRequest.body = .byteBuffer(try await body?.collect()) + return self + } +} + +public final class ByteStream: AsyncSequence { + public typealias Element = ByteBuffer + public struct Writer { + fileprivate let stream: ByteStream + + func write(_ chunk: Element) async throws { + try await stream._write(chunk: chunk).get() + } + } + + public typealias Closure = (Writer) async throws -> Void + + private let eventLoop: EventLoop + private let onFirstRead: ((ByteStream) -> Void)? + private var didFirstRead: Bool + + var _streamer: HBByteBufferStreamer? + + init(eventLoop: EventLoop, onFirstRead: ((ByteStream) -> Void)? = nil) { + self.eventLoop = eventLoop + self.onFirstRead = onFirstRead + self.didFirstRead = false + } + + private func createStreamerIfNotExists() -> EventLoopFuture { + eventLoop.submit { + guard let _streamer = self._streamer else { + /// Don't give a max size to the underlying streamer; that will be handled elsewhere. + let created = HBByteBufferStreamer(eventLoop: self.eventLoop, maxSize: .max, maxStreamingBufferSize: nil) + self._streamer = created + return created + } + + return _streamer + } + } + + func _write(chunk: Element?) -> EventLoopFuture { + createStreamerIfNotExists() + .flatMap { + if let chunk = chunk { + return $0.feed(buffer: chunk) + } else { + $0.feed(.end) + return self.eventLoop.makeSucceededVoidFuture() + } + } + } + + func _write(error: Error) { + _ = createStreamerIfNotExists().map { $0.feed(.error(error)) } + } + + func _read(on eventLoop: EventLoop) -> EventLoopFuture { + createStreamerIfNotExists() + .flatMap { + if !self.didFirstRead { + self.didFirstRead = true + self.onFirstRead?(self) + } + + return $0.consume(on: eventLoop).map { output in + switch output { + case .byteBuffer(let buffer): + return buffer + case .end: + return nil + } + } + } + } + + public func readAll(chunkHandler: (Element) async throws -> Void) async throws { + for try await chunk in self { + try await chunkHandler(chunk) + } + } + + public static func new(startStream: @escaping Closure) -> ByteStream { + ByteStream(eventLoop: Loop.current) { stream in + Task { + do { + try await startStream(Writer(stream: stream)) + try await stream._write(chunk: nil).get() + } catch { + stream._write(error: error) + } + } + } + } + + // MARK: - AsycIterator + + public struct AsyncIterator: AsyncIteratorProtocol { + let stream: ByteStream + let eventLoop: EventLoop + + mutating public func next() async throws -> Element? { + try await stream._read(on: eventLoop).get() + } + } + + __consuming public func makeAsyncIterator() -> AsyncIterator { + AsyncIterator(stream: self, eventLoop: eventLoop) + } +} + +extension Response { + /// Used to create new ByteBuffers. + private static let allocator = ByteBufferAllocator() + + public func withBody(_ byteContent: ByteContent, type: ContentType? = nil, length: Int? = nil) -> Response { + body = byteContent + headers.contentType = type + headers.contentLength = length + return self + } + + /// Creates a new body from a binary `NIO.ByteBuffer`. + /// + /// - Parameters: + /// - buffer: The buffer holding the data in the body. + /// - type: The content type of data in the body. + public func withBuffer(_ buffer: ByteBuffer, type: ContentType? = nil) -> Response { + withBody(.buffer(buffer), type: type, length: buffer.writerIndex) + } + + /// Creates a new body containing the text of the given string. + /// + /// - Parameter string: The string contents of the body. + /// - Parameter type: The media type of this text. Defaults to + /// `.plainText` ("text/plain"). + public func withString(_ string: String, type: ContentType = .plainText) -> Response { + var buffer = Response.allocator.buffer(capacity: string.utf8.count) + buffer.writeString(string) + return withBuffer(buffer, type: type) + } + + /// Creates a new body from a binary `Foundation.Data`. + /// + /// - Parameters: + /// - data: The data in the body. + /// - type: The content type of the body. + public func withData(_ data: Data, type: ContentType? = nil) -> Response { + var buffer = Response.allocator.buffer(capacity: data.count) + buffer.writeBytes(data) + return withBuffer(buffer, type: type) + } + + /// Creates a new body from an `Encodable`. + /// + /// - Parameters: + /// - data: The data in the body. + /// - type: The content type of the body. + public func withValue(_ value: E, encoder: ContentEncoder = ByteContent.defaultEncoder) throws -> Response { + let (buffer, type) = try encoder.encodeContent(value) + return withBuffer(buffer, type: type) + } +} + +extension ByteContent { + /// The contents of this body. + public func data() -> Data { + guard case let .buffer(buffer) = self else { + preconditionFailure("Can't synchronously access data from streaming body, try `collect()` instead.") + } + + return buffer.withUnsafeReadableBytes { buffer -> Data in + let buffer = buffer.bindMemory(to: UInt8.self) + return Data.init(buffer: buffer) + } + } + + /// Decodes the body as a `String`. + /// + /// - Parameter encoding: The `String.Encoding` value to decode + /// with. Defaults to `.utf8`. + /// - Returns: The string decoded from the contents of this body. + public func string(with encoding: String.Encoding = .utf8) -> String? { + String(data: data(), encoding: encoding) + } + + public static func string(_ string: String) -> ByteContent { + .buffer(ByteBuffer(string: string)) + } + + public static func data(_ data: Data) -> ByteContent { + .buffer(ByteBuffer(data: data)) + } + + public static func value(_ value: E, encoder: ContentEncoder = ByteContent.defaultEncoder) throws -> ByteContent { + .buffer(try encoder.encodeContent(value).buffer) + } + + public static func json(_ dict: [String: Any?]) throws -> ByteContent { + .buffer(ByteBuffer(data: try JSONSerialization.data(withJSONObject: dict))) + } + + /// Decodes the body as a JSON dictionary. + /// + /// - Throws: If there's a error decoding the dictionary. + /// - Returns: The dictionary decoded from the contents of this + /// body. + public func decodeJSONDictionary() throws -> [String: Any]? { + try JSONSerialization.jsonObject(with: data(), options: []) as? [String: Any] + } +} diff --git a/Sources/Alchemy/HTTP/Content/Content.swift b/Sources/Alchemy/HTTP/Content/Content.swift new file mode 100644 index 00000000..5edc26e3 --- /dev/null +++ b/Sources/Alchemy/HTTP/Content/Content.swift @@ -0,0 +1,367 @@ +import Foundation +import Papyrus + +public protocol ContentValue { + var string: String? { get } + var bool: Bool? { get } + var double: Double? { get } + var int: Int? { get } + var file: File? { get } +} + +struct AnyContentValue: ContentValue { + let value: Any + + var string: String? { value as? String } + var bool: Bool? { value as? Bool } + var int: Int? { value as? Int } + var double: Double? { value as? Double } + var file: File? { nil } +} + +/// Utility making it easy to set or modify http content +@dynamicMemberLookup +public final class Content: Buildable { + public enum Node { + case array([Node]) + case dict([String: Node]) + case value(ContentValue) + case null + + static func dict(_ dict: [String: Any]) -> Node { + .dict(dict.mapValues(Node.any)) + } + + static func array(_ array: [Any]) -> Node { + .array(array.map(Node.any)) + } + + static func any(_ value: Any) -> Node { + if let array = value as? [Any] { + return .array(array) + } else if let dict = value as? [String: Any] { + return .dict(dict) + } else if case Optional.none = value { + return .null + } else { + return .value(AnyContentValue(value: value)) + } + } + } + + enum Operator { + case field(String) + case index(Int) + case flatten + } + + enum State { + case node(Node) + case error(Error) + } + + let state: State + // The path taken to get here. + let path: [Operator] + + public var string: String? { try? stringThrowing } + public var stringThrowing: String { get throws { try unwrap(convertValue().string) } } + public var int: Int? { try? intThrowing } + public var intThrowing: Int { get throws { try unwrap(convertValue().int) } } + public var bool: Bool? { try? boolThrowing } + public var boolThrowing: Bool { get throws { try unwrap(convertValue().bool) } } + public var double: Double? { try? doubleThrowing } + public var doubleThrowing: Double { get throws { try unwrap(convertValue().double) } } + public var file: File? { try? fileThrowing } + public var fileThrowing: File { get throws { try unwrap(convertValue().file) } } + public var array: [Content]? { try? convertArray() } + public var arrayThrowing: [Content] { get throws { try convertArray() } } + + public var exists: Bool { (try? decode(Empty.self)) != nil } + public var isNull: Bool { self == nil } + + public var error: Error? { + guard case .error(let error) = state else { return nil } + return error + } + + var node: Node? { + guard case .node(let node) = state else { return nil } + return node + } + + var value: ContentValue? { + guard let node = node, case .value(let value) = node else { return nil } + return value + } + + init(root: Node, path: [Operator] = []) { + self.state = .node(root) + self.path = path + } + + init(error: Error, path: [Operator] = []) { + self.state = .error(error) + self.path = path + } + + // MARK: - Subscripts + + subscript(index: Int) -> Content { + let newPath = path + [.index(index)] + switch state { + case .node(let node): + guard case .array(let array) = node else { + return Content(error: ContentError.notArray, path: newPath) + } + + return Content(root: array[index], path: newPath) + case .error(let error): + return Content(error: error, path: newPath) + } + } + + subscript(field: String) -> Content { + let newPath = path + [.field(field)] + switch state { + case .node(let node): + guard case .dict(let dict) = node else { + return Content(error: ContentError.notDictionary, path: newPath) + } + + return Content(root: dict[field] ?? .null, path: newPath) + case .error(let error): + return Content(error: error, path: newPath) + } + } + + public subscript(dynamicMember member: String) -> Content { + self[member] + } + + subscript(operator: (Content, Content) -> Void) -> [Content] { + let newPath = path + [.flatten] + switch state { + case .node(let node): + switch node { + case .null, .value: + return [Content(error: ContentError.cantFlatten, path: newPath)] + case .dict(let dict): + return Array(dict.values).map { Content(root: $0, path: newPath) } + case .array(let array): + return array + .flatMap { content -> [Node] in + if case .array(let array) = content { + return array + } else if case .dict = content { + return [content] + } else { + return [.null] + } + } + .map { Content(root: $0, path: newPath) } + } + case .error(let error): + return [Content(error: error, path: newPath)] + } + } + + static func *(lhs: Content, rhs: Content) {} + + static func ==(lhs: Content, rhs: Void?) -> Bool { + switch lhs.state { + case .node(let node): + if case .null = node { + return true + } else { + return false + } + case .error: + return false + } + } + + private func convertArray() throws -> [Content] { + switch state { + case .node(let node): + guard case .array(let array) = node else { + throw ContentError.typeMismatch + } + + return array.enumerated().map { Content(root: $1, path: path + [.index($0)]) } + case .error(let error): + throw error + } + } + + private func convertValue() throws -> ContentValue { + switch state { + case .node(let node): + guard case .value(let val) = node else { + throw ContentError.typeMismatch + } + + return val + case .error(let error): + throw error + } + } + + private func unwrap(_ value: T?) throws -> T { + try value.unwrap(or: ContentError.typeMismatch) + } + + public func decode(_ type: D.Type = D.self) throws -> D { + try D(from: GenericDecoder(delegate: self)) + } +} + +enum ContentError: Error { + case unknownContentType(ContentType?) + case emptyBody + case cantFlatten + case notDictionary + case notArray + case doesntExist + case wasNull + case typeMismatch + case notSupported(String) +} + +extension Content: DecoderDelegate { + + private func require(_ optional: T?, key: CodingKey?) throws -> T { + try optional.unwrap(or: DecodingError.valueNotFound(T.self, .init(codingPath: [key].compactMap { $0 }, debugDescription: "Value wasn`t available."))) + } + + func decodeString(for key: CodingKey?) throws -> String { + let value = key.map { self[$0.stringValue] } ?? self + return try require(value.string, key: key) + } + + func decodeDouble(for key: CodingKey?) throws -> Double { + let value = key.map { self[$0.stringValue] } ?? self + return try require(value.double, key: key) + } + + func decodeInt(for key: CodingKey?) throws -> Int { + let value = key.map { self[$0.stringValue] } ?? self + return try require(value.int, key: key) + } + + func decodeBool(for key: CodingKey?) throws -> Bool { + let value = key.map { self[$0.stringValue] } ?? self + return try require(value.bool, key: key) + } + + func decodeNil(for key: CodingKey?) -> Bool { + let value = key.map { self[$0.stringValue] } ?? self + return value == nil + } + + var allKeys: [String] { + guard case .node(let node) = state, case .dict(let dict) = node else { + return [] + } + + return Array(dict.keys) + } + + func contains(key: CodingKey) -> Bool { + guard case .node(let node) = state, case .dict(let dict) = node else { + return false + } + + return dict.keys.contains(key.stringValue) + } + + func map(for key: CodingKey) -> DecoderDelegate { + self[key.stringValue] + } + + func array(for key: CodingKey?) throws -> [DecoderDelegate] { + let val = key.map { self[$0.stringValue] } ?? self + return try val.arrayThrowing.map { $0 } + } +} + +extension Array where Element == Content { + var string: [String]? { try? stringThrowing } + var stringThrowing: [String] { get throws { try map { try $0.stringThrowing } } } + var int: [Int]? { try? intThrowing } + var intThrowing: [Int] { get throws { try map { try $0.intThrowing } } } + var bool: [Bool]? { try? boolThrowing } + var boolThrowing: [Bool] { get throws { try map { try $0.boolThrowing } } } + var double: [Double]? { try? doubleThrowing } + var doubleThrowing: [Double] { get throws { try map { try $0.doubleThrowing } } } + + subscript(field: String) -> [Content] { + return map { $0[field] } + } + + subscript(dynamicMember member: String) -> [Content] { + self[member] + } + + func decode(_ type: D.Type = D.self) throws -> [D] { + try map { try D(from: GenericDecoder(delegate: $0)) } + } +} + +extension Content: CustomStringConvertible { + public var description: String { + switch state { + case .error(let error): + return "Content(error: \(error)" + case .node(let node): + return createString(root: node) + } + } + + private func createString(root: Node?, tabs: String = "") -> String { + var string = "" + var tabs = tabs + switch root { + case .array(let array): + tabs += "\t" + if array.isEmpty { + string.append("[]") + } else { + string.append("[\n") + for (index, node) in array.enumerated() { + let comma = index == array.count - 1 ? "" : "," + string.append(tabs + createString(root: node, tabs: tabs) + "\(comma)\n") + } + tabs = String(tabs.dropLast(1)) + string.append("\(tabs)]") + } + case .value(let value): + if let file = value.file { + string.append("<\(file.name)>") + } else if let bool = value.bool { + string.append("\(bool)") + } else if let int = value.int { + string.append("\(int)") + } else if let double = value.double { + string.append("\(double)") + } else if let stringVal = value.string { + string.append("\"\(stringVal)\"") + } else { + string.append("\(value)") + } + case .dict(let dict): + tabs += "\t" + string.append("{\n") + for (index, (key, node)) in dict.enumerated() { + let comma = index == dict.count - 1 ? "" : "," + string.append(tabs + "\"\(key)\": " + createString(root: node, tabs: tabs) + "\(comma)\n") + } + tabs = String(tabs.dropLast(1)) + string.append("\(tabs)}") + case .null, .none: + string.append("null") + } + + return string + } +} diff --git a/Sources/Alchemy/HTTP/Content/ContentCoding+FormURL.swift b/Sources/Alchemy/HTTP/Content/ContentCoding+FormURL.swift new file mode 100644 index 00000000..505842e0 --- /dev/null +++ b/Sources/Alchemy/HTTP/Content/ContentCoding+FormURL.swift @@ -0,0 +1,71 @@ +import HummingbirdFoundation + +extension ContentEncoder where Self == URLEncodedFormEncoder { + public static var urlForm: URLEncodedFormEncoder { URLEncodedFormEncoder() } +} + +extension ContentDecoder where Self == URLEncodedFormDecoder { + public static var urlForm: URLEncodedFormDecoder { URLEncodedFormDecoder() } +} + +extension URLEncodedFormEncoder: ContentEncoder { + public func encodeContent(_ value: E) throws -> (buffer: ByteBuffer, contentType: ContentType?) where E : Encodable { + return (buffer: ByteBuffer(string: try encode(value)), contentType: .urlForm) + } +} + +extension URLEncodedFormDecoder: ContentDecoder { + public func decodeContent(_ type: D.Type, from buffer: ByteBuffer, contentType: ContentType?) throws -> D where D : Decodable { + try decode(type, from: buffer.string) + } + + public func content(from buffer: ByteBuffer, contentType: ContentType?) -> Content { + do { + let topLevel = try decode(URLEncodedNode.self, from: buffer.string) + return Content(root: parse(node: topLevel)) + } catch { + return Content(error: error) + } + } + + private func parse(node: URLEncodedNode) -> Content.Node { + switch node { + case .dict(let dict): + return .dict(dict.mapValues { parse(node: $0) }) + case .array(let array): + return .array(array.map { parse(node: $0) }) + case .value(let string): + return .value(URLValue(value: string)) + } + } + + private struct URLValue: ContentValue { + let value: String + + var string: String? { value } + var bool: Bool? { Bool(value) } + var int: Int? { Int(value) } + var double: Double? { Double(value) } + var file: File? { nil } + } +} + +enum URLEncodedNode: Decodable { + case dict([String: URLEncodedNode]) + case array([URLEncodedNode]) + case value(String) + + init(from decoder: Decoder) throws { + if let array = try? [URLEncodedNode](from: decoder) { + self = .array(array) + } else if let dict = try? [String: URLEncodedNode](from: decoder) { + self = .dict(dict) + } else { + self = .value(try String(from: decoder)) + } + } +} + +extension URLEncodedNode { + +} diff --git a/Sources/Alchemy/HTTP/Content/ContentCoding+JSON.swift b/Sources/Alchemy/HTTP/Content/ContentCoding+JSON.swift new file mode 100644 index 00000000..f8905b6d --- /dev/null +++ b/Sources/Alchemy/HTTP/Content/ContentCoding+JSON.swift @@ -0,0 +1,52 @@ +import Foundation + +extension ContentEncoder where Self == JSONEncoder { + public static var json: JSONEncoder { JSONEncoder() } +} + +extension ContentDecoder where Self == JSONDecoder { + public static var json: JSONDecoder { JSONDecoder() } +} + +extension JSONEncoder: ContentEncoder { + public func encodeContent(_ value: E) throws -> (buffer: ByteBuffer, contentType: ContentType?) where E : Encodable { + (buffer: ByteBuffer(data: try encode(value)), contentType: .json) + } +} + +extension JSONDecoder: ContentDecoder { + public func decodeContent(_ type: D.Type, from buffer: ByteBuffer, contentType: ContentType?) throws -> D where D : Decodable { + try decode(type, from: buffer.data) + } + + public func content(from buffer: ByteBuffer, contentType: ContentType?) -> Content { + do { + let topLevel = try JSONSerialization.jsonObject(with: buffer, options: .fragmentsAllowed) + return Content(root: parse(val: topLevel)) + } catch { + return Content(error: error) + } + } + + private func parse(val: Any) -> Content.Node { + if let dict = val as? [String: Any] { + return .dict(dict.mapValues { parse(val: $0) }) + } else if let array = val as? [Any] { + return .array(array.map { parse(val: $0) }) + } else if (val as? NSNull) != nil { + return .null + } else { + return .value(JSONValue(value: val)) + } + } + + private struct JSONValue: ContentValue { + let value: Any + + var string: String? { value as? String } + var bool: Bool? { value as? Bool } + var int: Int? { value as? Int } + var double: Double? { value as? Double } + var file: File? { nil } + } +} diff --git a/Sources/Alchemy/HTTP/Content/ContentCoding+Multipart.swift b/Sources/Alchemy/HTTP/Content/ContentCoding+Multipart.swift new file mode 100644 index 00000000..3fd38bc4 --- /dev/null +++ b/Sources/Alchemy/HTTP/Content/ContentCoding+Multipart.swift @@ -0,0 +1,77 @@ +import MultipartKit + +extension ContentEncoder where Self == FormDataEncoder { + public static var multipart: FormDataEncoder { FormDataEncoder() } +} + +extension ContentDecoder where Self == FormDataDecoder { + public static var multipart: FormDataDecoder { FormDataDecoder() } +} + +extension FormDataEncoder: ContentEncoder { + static var boundary: () -> String = { "AlchemyFormBoundary" + .randomAlphaNumberic(15) } + + public func encodeContent(_ value: E) throws -> (buffer: ByteBuffer, contentType: ContentType?) where E : Encodable { + let boundary = FormDataEncoder.boundary() + return (buffer: ByteBuffer(string: try encode(value, boundary: boundary)), contentType: .multipart(boundary: boundary)) + } +} + +extension FormDataDecoder: ContentDecoder { + public func decodeContent(_ type: D.Type, from buffer: ByteBuffer, contentType: ContentType?) throws -> D where D : Decodable { + guard let boundary = contentType?.parameters["boundary"] else { + throw HTTPError(.notAcceptable, message: "Attempted to decode multipart/form-data but couldn't find a `boundary` in the `Content-Type` header.") + } + + return try decode(type, from: buffer, boundary: boundary) + } + + public func content(from buffer: ByteBuffer, contentType: ContentType?) -> Content { + guard contentType == .multipart else { + return Content(error: ContentError.unknownContentType(contentType)) + } + + guard let boundary = contentType?.parameters["boundary"] else { + return Content(error: ContentError.unknownContentType(contentType)) + } + + let parser = MultipartParser(boundary: boundary) + var parts: [MultipartPart] = [] + var headers: HTTPHeaders = .init() + var body: ByteBuffer = ByteBuffer() + + parser.onHeader = { headers.replaceOrAdd(name: $0, value: $1) } + parser.onBody = { body.writeBuffer(&$0) } + parser.onPartComplete = { + parts.append(MultipartPart(headers: headers, body: body)) + headers = [:] + body = ByteBuffer() + } + + do { + try parser.execute(buffer) + let dict = Dictionary(uniqueKeysWithValues: parts.compactMap { part in part.name.map { ($0, part) } }) + return Content(root: .dict(dict.mapValues { .value($0) })) + } catch { + return Content(error: error) + } + } +} + +extension MultipartPart: ContentValue { + public var string: String? { body.string } + public var int: Int? { Int(body.string) } + public var bool: Bool? { Bool(body.string) } + public var double: Double? { Double(body.string) } + + public var file: File? { + guard let disposition = headers.contentDisposition, let filename = disposition.filename else { return nil } + return File(name: filename, size: body.writerIndex, content: .buffer(body)) + } +} + +extension String { + static func randomAlphaNumberic(_ length: Int) -> String { + String((1...length).compactMap { _ in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789".randomElement() }) + } +} diff --git a/Sources/Alchemy/HTTP/Content/ContentCoding.swift b/Sources/Alchemy/HTTP/Content/ContentCoding.swift new file mode 100644 index 00000000..60dd48ac --- /dev/null +++ b/Sources/Alchemy/HTTP/Content/ContentCoding.swift @@ -0,0 +1,10 @@ +import NIOCore + +public protocol ContentDecoder { + func decodeContent(_ type: D.Type, from buffer: ByteBuffer, contentType: ContentType?) throws -> D + func content(from buffer: ByteBuffer, contentType: ContentType?) -> Content +} + +public protocol ContentEncoder { + func encodeContent(_ value: E) throws -> (buffer: ByteBuffer, contentType: ContentType?) +} diff --git a/Sources/Alchemy/HTTP/Content/ContentType.swift b/Sources/Alchemy/HTTP/Content/ContentType.swift new file mode 100644 index 00000000..ccbcc35e --- /dev/null +++ b/Sources/Alchemy/HTTP/Content/ContentType.swift @@ -0,0 +1,220 @@ +import Foundation + +/// An HTTP content type. It has a `value: String` appropriate for +/// putting into `Content-Type` headers. +public struct ContentType: Equatable { + /// The name of this content type + public var value: String + /// Any parameters to go along with the content type value. + public var parameters: [String: String] = [:] + /// The entire string for the Content-Type header including name and parameters. + public var string: String { + ([value] + parameters.map { "\($0)=\($1)" }).joined(separator: "; ") + } + /// A file extension that matches this content type, if one exists. + public var fileExtension: String? { + ContentType.fileExtensionMapping.first { _, value in value == self }?.key + } + + /// Create with a string. + /// + /// - Parameter value: The string of the content type. + public init(_ value: String) { + let components = value.components(separatedBy: ";").map { $0.trimmingCharacters(in: .whitespaces) } + self.value = components.first! + components[1...] + .compactMap { (string: String) -> (String, String)? in + let split = string.components(separatedBy: "=") + guard let first = split[safe: 0], let second = split[safe: 1] else { + return nil + } + + return (first, second) + } + .forEach { parameters[$0] = $1 } + } + + /// Creates based off of a known file extension that can be mapped + /// to an appropriate `Content-Type` header value. Returns nil if + /// no content type is known. + /// + /// The `.` in front of the file extension is optional. + /// + /// Usage: + /// ```swift + /// let mt = ContentType(fileExtension: "html")! + /// print(mt.value) // "text/html" + /// ``` + /// + /// - Parameter fileExtension: The file extension to look up a + /// content type for. + public init?(fileExtension: String) { + var noDot = fileExtension + if noDot.hasPrefix(".") { + noDot = String(noDot.dropFirst()) + } + + guard let type = ContentType.fileExtensionMapping[noDot] else { + return nil + } + + self = type + } + + // MARK: Common content types + + /// image/bmp + public static let bmp = ContentType("image/bmp") + /// text/css + public static let css = ContentType("text/css") + /// text/csv + public static let csv = ContentType("text/csv") + /// application/epub+zip + public static let epub = ContentType("application/epub+zip") + /// application/gzip + public static let gzip = ContentType("application/gzip") + /// image/gif + public static let gif = ContentType("image/gif") + /// text/html + public static let html = ContentType("text/html") + /// text/calendar + public static let calendar = ContentType("text/calendar") + /// image/jpeg + public static let jpeg = ContentType("image/jpeg") + /// text/javascript + public static let javascript = ContentType("text/javascript") + /// application/json + public static let json = ContentType("application/json") + /// audio/midi + public static let mid = ContentType("audio/midi") + /// audio/mpeg + public static let mp3 = ContentType("audio/mpeg") + /// video/mpeg + public static let mpeg = ContentType("video/mpeg") + /// application/octet-stream + public static let octetStream = ContentType("application/octet-stream") + /// audio/ogg + public static let oga = ContentType("audio/ogg") + /// video/ogg + public static let ogv = ContentType("video/ogg") + /// font/otf + public static let otf = ContentType("font/otf") + /// application/pdf + public static let pdf = ContentType("application/pdf") + /// application/x-httpd-php + public static let php = ContentType("application/x-httpd-php") + /// text/plain + public static let plainText = ContentType("text/plain") + /// image/png + public static let png = ContentType("image/png") + /// application/rtf + public static let rtf = ContentType("application/rtf") + /// image/svg+xml + public static let svg = ContentType("image/svg+xml") + /// application/x-tar + public static let tar = ContentType("application/x-tar") + /// image/tiff + public static let tiff = ContentType("image/tiff") + /// font/ttf + public static let ttf = ContentType("font/ttf") + /// audio/wav + public static let wav = ContentType("audio/wav") + /// application/xhtml+xml + public static let xhtml = ContentType("application/xhtml+xml") + /// application/xml + public static let xml = ContentType("application/xml") + /// application/zip + public static let zip = ContentType("application/zip") + /// application/x-www-form-urlencoded + public static let urlForm = ContentType("application/x-www-form-urlencoded") + /// multipart/form-data + public static let multipart = ContentType("multipart/form-data") + + /// multipart/form-data + public static func multipart(boundary: String) -> ContentType { + ContentType("multipart/form-data; boundary=\(boundary)") + } + + /// A non exhaustive mapping of file extensions to known content + /// types. + private static let fileExtensionMapping = [ + "aac": ContentType("audio/aac"), + "abw": ContentType("application/x-abiword"), + "arc": ContentType("application/x-freearc"), + "avi": ContentType("video/x-msvideo"), + "azw": ContentType("application/vnd.amazon.ebook"), + "bin": ContentType("application/octet-stream"), + "bmp": ContentType("image/bmp"), + "bz": ContentType("application/x-bzip"), + "bz2": ContentType("application/x-bzip2"), + "csh": ContentType("application/x-csh"), + "css": ContentType("text/css"), + "csv": ContentType("text/csv"), + "doc": ContentType("application/msword"), + "docx": ContentType("application/vnd.openxmlformats-officedocument.wordprocessingml.document"), + "eot": ContentType("application/vnd.ms-fontobject"), + "epub": ContentType("application/epub+zip"), + "gz": ContentType("application/gzip"), + "gif": ContentType("image/gif"), + "htm": ContentType("text/html"), + "html": ContentType("text/html"), + "ico": ContentType("image/vnd.microsoft.icon"), + "ics": ContentType("text/calendar"), + "jar": ContentType("application/java-archive"), + "jpeg": ContentType("image/jpeg"), + "jpg": ContentType("image/jpeg"), + "js": ContentType("text/javascript"), + "json": ContentType("application/json"), + "jsonld": ContentType("application/ld+json"), + "mid" : ContentType("audio/midi"), + "midi": ContentType("audio/midi"), + "mjs": ContentType("text/javascript"), + "mp3": ContentType("audio/mpeg"), + "mpeg": ContentType("video/mpeg"), + "mpkg": ContentType("application/vnd.apple.installer+xml"), + "odp": ContentType("application/vnd.oasis.opendocument.presentation"), + "ods": ContentType("application/vnd.oasis.opendocument.spreadsheet"), + "odt": ContentType("application/vnd.oasis.opendocument.text"), + "oga": ContentType("audio/ogg"), + "ogv": ContentType("video/ogg"), + "ogx": ContentType("application/ogg"), + "opus": ContentType("audio/opus"), + "otf": ContentType("font/otf"), + "png": ContentType("image/png"), + "pdf": ContentType("application/pdf"), + "php": ContentType("application/x-httpd-php"), + "ppt": ContentType("application/vnd.ms-powerpoint"), + "pptx": ContentType("application/vnd.openxmlformats-officedocument.presentationml.presentation"), + "rar": ContentType("application/vnd.rar"), + "rtf": ContentType("application/rtf"), + "sh": ContentType("application/x-sh"), + "svg": ContentType("image/svg+xml"), + "swf": ContentType("application/x-shockwave-flash"), + "tar": ContentType("application/x-tar"), + "tif": ContentType("image/tiff"), + "tiff": ContentType("image/tiff"), + "ts": ContentType("video/mp2t"), + "ttf": ContentType("font/ttf"), + "txt": ContentType("text/plain"), + "vsd": ContentType("application/vnd.visio"), + "wav": ContentType("audio/wav"), + "weba": ContentType("audio/webm"), + "webm": ContentType("video/webm"), + "webp": ContentType("image/webp"), + "woff": ContentType("font/woff"), + "woff2": ContentType("font/woff2"), + "xhtml": ContentType("application/xhtml+xml"), + "xls": ContentType("application/vnd.ms-excel"), + "xlsx": ContentType("application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"), + "xml": ContentType("application/xml"), + "xul": ContentType("application/vnd.mozilla.xul+xml"), + "zip": ContentType("application/zip"), + "7z": ContentType("application/x-7z-compressed"), + ] + + // MARK: - Equatable + + public static func == (lhs: ContentType, rhs: ContentType) -> Bool { + lhs.value == rhs.value + } +} diff --git a/Sources/Alchemy/HTTP/HTTPBody.swift b/Sources/Alchemy/HTTP/HTTPBody.swift deleted file mode 100644 index 8fbbbb70..00000000 --- a/Sources/Alchemy/HTTP/HTTPBody.swift +++ /dev/null @@ -1,116 +0,0 @@ -import AsyncHTTPClient -import NIO -import Foundation -import NIOHTTP1 - -/// The contents of an HTTP request or response. -public struct HTTPBody: ExpressibleByStringLiteral { - /// Used to create new ByteBuffers. - private static let allocator = ByteBufferAllocator() - - /// The binary data in this body. - public let buffer: ByteBuffer - - /// The mime type of the data stored in this body. Used to set the - /// `content-type` header when sending back a response. - public let mimeType: MIMEType? - - /// Creates a new body from a binary `NIO.ByteBuffer`. - /// - /// - Parameters: - /// - buffer: The buffer holding the data in the body. - /// - mimeType: The MIME type of data in the body. - public init(buffer: ByteBuffer, mimeType: MIMEType? = nil) { - self.buffer = buffer - self.mimeType = mimeType - } - - /// Creates a new body containing the text with MIME type - /// `text/plain`. - /// - /// - Parameter text: The string contents of the body. - /// - Parameter mimeType: The media type of this text. Defaults to - /// `.plainText` ("text/plain"). - public init(text: String, mimeType: MIMEType = .plainText) { - var buffer = HTTPBody.allocator.buffer(capacity: text.utf8.count) - buffer.writeString(text) - self.buffer = buffer - self.mimeType = mimeType - } - - /// Creates a new body from a binary `Foundation.Data`. - /// - /// - Parameters: - /// - data: The data in the body. - /// - mimeType: The MIME type of the body. - public init(data: Data, mimeType: MIMEType? = nil) { - var buffer = HTTPBody.allocator.buffer(capacity: data.count) - buffer.writeBytes(data) - self.buffer = buffer - self.mimeType = mimeType - } - - /// Creates a body with a JSON object. - /// - /// - Parameters: - /// - json: The object to encode into the body. - /// - encoder: A customer encoder to encoder the JSON with. - /// Defaults to `Response.defaultJSONEncoder`. - /// - Throws: Any error thrown during encoding. - public init(json: E, encoder: JSONEncoder = Response.defaultJSONEncoder) throws { - let data = try encoder.encode(json) - self.init(data: data, mimeType: .json) - } - - /// Create a body via a string literal. - /// - /// - Parameter value: The string literal contents of the body. - public init(stringLiteral value: String) { - self.init(text: value) - } - - /// The contents of this body. - public var data: Data { - return buffer.withUnsafeReadableBytes { buffer -> Data in - let buffer = buffer.bindMemory(to: UInt8.self) - return Data.init(buffer: buffer) - } - } -} - -extension HTTPBody { - /// Decodes the body as a `String`. - /// - /// - Parameter encoding: The `String.Encoding` value to decode - /// with. Defaults to `.utf8`. - /// - Returns: The string decoded from the contents of this body. - public func decodeString(with encoding: String.Encoding = .utf8) -> String? { - String(data: self.data, encoding: encoding) - } - - /// Decodes the body as a JSON dictionary. - /// - /// - Throws: If there's a error decoding the dictionary. - /// - Returns: The dictionary decoded from the contents of this - /// body. - public func decodeJSONDictionary() throws -> [String: Any]? { - try JSONSerialization.jsonObject(with: self.data, options: []) - as? [String: Any] - } - - /// Decodes the body as JSON into the provided Decodable type. - /// - /// - Parameters: - /// - type: The Decodable type to which the body should be - /// decoded. - /// - decoder: The Decoder with which to decode. Defaults to - /// `Request.defaultJSONEncoder`. - /// - Throws: Any errors encountered during decoding. - /// - Returns: The decoded object of type `type`. - public func decodeJSON( - as type: D.Type = D.self, - with decoder: JSONDecoder = Request.defaultJSONDecoder - ) throws -> D { - return try decoder.decode(type, from: data) - } -} diff --git a/Sources/Alchemy/HTTP/HTTPError.swift b/Sources/Alchemy/HTTP/HTTPError.swift index 4649a2f7..f69a22f3 100644 --- a/Sources/Alchemy/HTTP/HTTPError.swift +++ b/Sources/Alchemy/HTTP/HTTPError.swift @@ -16,7 +16,7 @@ import NIOHTTP1 /// throw HTTPError(.notImplemented, "This endpoint isn't implemented yet") /// } /// ``` -public struct HTTPError: Error, ResponseConvertible { +public struct HTTPError: Error { /// The status code of this error. public let status: HTTPResponseStatus /// An optional message to include in a @@ -33,16 +33,11 @@ public struct HTTPError: Error, ResponseConvertible { self.status = status self.message = message } - - // MARK: ResponseConvertible - - public func convert() throws -> EventLoopFuture { - let response = Response( - status: self.status, - body: try self.message.map { - try HTTPBody(json: ["message": $0]) - } - ) - return .new(response) +} + +extension HTTPError: ResponseConvertible { + public func response() throws -> Response { + try Response(status: status) + .withValue(["message": message ?? status.reasonPhrase]) } } diff --git a/Sources/Alchemy/HTTP/MIMEType.swift b/Sources/Alchemy/HTTP/MIMEType.swift deleted file mode 100644 index fff051cd..00000000 --- a/Sources/Alchemy/HTTP/MIMEType.swift +++ /dev/null @@ -1,189 +0,0 @@ -import Foundation - -/// An HTTP Media Type. It has a `value: String` appropriate for -/// putting into `Content-Type` headers. -public struct MIMEType { - /// The value of this MIME type, appropriate for `Content-Type` - /// headers. - public var value: String - - /// Create with a string. - /// - /// - Parameter value: The string of the MIME type. - public init(_ value: String) { - self.value = value - } - - // MARK: Common MIME types - - /// image/bmp - public static let bmp = MIMEType("image/bmp") - /// text/css - public static let css = MIMEType("text/css") - /// text/csv - public static let csv = MIMEType("text/csv") - /// application/epub+zip - public static let epub = MIMEType("application/epub+zip") - /// application/gzip - public static let gzip = MIMEType("application/gzip") - /// image/gif - public static let gif = MIMEType("image/gif") - /// text/html - public static let html = MIMEType("text/html") - /// text/calendar - public static let calendar = MIMEType("text/calendar") - /// image/jpeg - public static let jpeg = MIMEType("image/jpeg") - /// text/javascript - public static let javascript = MIMEType("text/javascript") - /// application/json - public static let json = MIMEType("application/json") - /// audio/midi - public static let mid = MIMEType("audio/midi") - /// audio/mpeg - public static let mp3 = MIMEType("audio/mpeg") - /// video/mpeg - public static let mpeg = MIMEType("video/mpeg") - /// application/octet-stream - public static let octetStream = MIMEType("application/octet-stream") - /// audio/ogg - public static let oga = MIMEType("audio/ogg") - /// video/ogg - public static let ogv = MIMEType("video/ogg") - /// font/otf - public static let otf = MIMEType("font/otf") - /// application/pdf - public static let pdf = MIMEType("application/pdf") - /// application/x-httpd-php - public static let php = MIMEType("application/x-httpd-php") - /// text/plain - public static let plainText = MIMEType("text/plain") - /// image/png - public static let png = MIMEType("image/png") - /// application/rtf - public static let rtf = MIMEType("application/rtf") - /// image/svg+xml - public static let svg = MIMEType("image/svg+xml") - /// application/x-tar - public static let tar = MIMEType("application/x-tar") - /// image/tiff - public static let tiff = MIMEType("image/tiff") - /// font/ttf - public static let ttf = MIMEType("font/ttf") - /// audio/wav - public static let wav = MIMEType("audio/wav") - /// application/xhtml+xml - public static let xhtml = MIMEType("application/xhtml+xml") - /// application/xml - public static let xml = MIMEType("application/xml") - /// application/zip - public static let zip = MIMEType("application/zip") - -} - -// Map of file extensions -extension MIMEType { - /// Creates based off of a known file extension that can be mapped - /// to an appropriate `Content-Type` header value. Returns nil if - /// no MIME type is known. - /// - /// The `.` in front of the file extension is optional. - /// - /// Usage: - /// ```swift - /// let mt = MediaType(fileExtension: "html")! - /// print(mt.value) // "text/html" - /// ``` - /// - /// - Parameter fileExtension: The file extension to look up a - /// MIME type for. - public init?(fileExtension: String) { - var noDot = fileExtension - if noDot.hasPrefix(".") { - noDot = String(noDot.dropFirst()) - } - - guard let type = MIMEType.fileExtensionMapping[noDot] else { - return nil - } - - self = type - } - - /// A non exhaustive mapping of file extensions to known MIME - /// types. - private static let fileExtensionMapping = [ - "aac": MIMEType("audio/aac"), - "abw": MIMEType("application/x-abiword"), - "arc": MIMEType("application/x-freearc"), - "avi": MIMEType("video/x-msvideo"), - "azw": MIMEType("application/vnd.amazon.ebook"), - "bin": MIMEType("application/octet-stream"), - "bmp": MIMEType("image/bmp"), - "bz": MIMEType("application/x-bzip"), - "bz2": MIMEType("application/x-bzip2"), - "csh": MIMEType("application/x-csh"), - "css": MIMEType("text/css"), - "csv": MIMEType("text/csv"), - "doc": MIMEType("application/msword"), - "docx": MIMEType("application/vnd.openxmlformats-officedocument.wordprocessingml.document"), - "eot": MIMEType("application/vnd.ms-fontobject"), - "epub": MIMEType("application/epub+zip"), - "gz": MIMEType("application/gzip"), - "gif": MIMEType("image/gif"), - "htm": MIMEType("text/html"), - "html": MIMEType("text/html"), - "ico": MIMEType("image/vnd.microsoft.icon"), - "ics": MIMEType("text/calendar"), - "jar": MIMEType("application/java-archive"), - "jpeg": MIMEType("image/jpeg"), - "jpg": MIMEType("image/jpeg"), - "js": MIMEType("text/javascript"), - "json": MIMEType("application/json"), - "jsonld": MIMEType("application/ld+json"), - "mid" : MIMEType("audio/midi"), - "midi": MIMEType("audio/midi"), - "mjs": MIMEType("text/javascript"), - "mp3": MIMEType("audio/mpeg"), - "mpeg": MIMEType("video/mpeg"), - "mpkg": MIMEType("application/vnd.apple.installer+xml"), - "odp": MIMEType("application/vnd.oasis.opendocument.presentation"), - "ods": MIMEType("application/vnd.oasis.opendocument.spreadsheet"), - "odt": MIMEType("application/vnd.oasis.opendocument.text"), - "oga": MIMEType("audio/ogg"), - "ogv": MIMEType("video/ogg"), - "ogx": MIMEType("application/ogg"), - "opus": MIMEType("audio/opus"), - "otf": MIMEType("font/otf"), - "png": MIMEType("image/png"), - "pdf": MIMEType("application/pdf"), - "php": MIMEType("application/x-httpd-php"), - "ppt": MIMEType("application/vnd.ms-powerpoint"), - "pptx": MIMEType("application/vnd.openxmlformats-officedocument.presentationml.presentation"), - "rar": MIMEType("application/vnd.rar"), - "rtf": MIMEType("application/rtf"), - "sh": MIMEType("application/x-sh"), - "svg": MIMEType("image/svg+xml"), - "swf": MIMEType("application/x-shockwave-flash"), - "tar": MIMEType("application/x-tar"), - "tif": MIMEType("image/tiff"), - "tiff": MIMEType("image/tiff"), - "ts": MIMEType("video/mp2t"), - "ttf": MIMEType("font/ttf"), - "txt": MIMEType("text/plain"), - "vsd": MIMEType("application/vnd.visio"), - "wav": MIMEType("audio/wav"), - "weba": MIMEType("audio/webm"), - "webm": MIMEType("video/webm"), - "webp": MIMEType("image/webp"), - "woff": MIMEType("font/woff"), - "woff2": MIMEType("font/woff2"), - "xhtml": MIMEType("application/xhtml+xml"), - "xls": MIMEType("application/vnd.ms-excel"), - "xlsx": MIMEType("application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"), - "xml": MIMEType("application/xml"), - "xul": MIMEType("application/vnd.mozilla.xul+xml"), - "zip": MIMEType("application/zip"), - "7z": MIMEType("application/x-7z-compressed"), - ] -} diff --git a/Sources/Alchemy/HTTP/Protocols/ContentBuilder.swift b/Sources/Alchemy/HTTP/Protocols/ContentBuilder.swift new file mode 100644 index 00000000..f4340cf9 --- /dev/null +++ b/Sources/Alchemy/HTTP/Protocols/ContentBuilder.swift @@ -0,0 +1,95 @@ +import NIOHTTP1 +import HummingbirdFoundation +import MultipartKit + +public protocol ContentBuilder: Buildable { + var headers: HTTPHeaders { get set } + var body: ByteContent? { get set } +} + +extension ContentBuilder { + // MARK: - Headers + + public func withHeader(_ name: String, value: String) -> Self { + with { $0.headers.add(name: name, value: value) } + } + + public func withHeaders(_ dict: [String: String]) -> Self { + dict.reduce(self) { $0.withHeader($1.key, value: $1.value) } + } + + public func withBasicAuth(username: String, password: String) -> Self { + let basicAuthString = Data("\(username):\(password)".utf8).base64EncodedString() + return withHeader("Authorization", value: "Basic \(basicAuthString)") + } + + public func withToken(_ token: String) -> Self { + withHeader("Authorization", value: "Bearer \(token)") + } + + public func withContentType(_ contentType: ContentType) -> Self { + withHeader("Content-Type", value: contentType.string) + } + + // MARK: - Body + + public func withBody(_ content: ByteContent, type: ContentType? = nil, length: Int? = nil) -> Self { + guard body == nil else { + preconditionFailure("A request body should only be set once.") + } + + return with { + $0.body = content + $0.headers.contentType = type + $0.headers.contentLength = length ?? content.length + } + } + + public func withBody(data: Data) -> Self { + withBody(.data(data)) + } + + public func withBody(buffer: ByteBuffer) -> Self { + withBody(.buffer(buffer)) + } + + public func withBody(_ value: E, encoder: ContentEncoder = .json) throws -> Self { + let (buffer, type) = try encoder.encodeContent(value) + return withBody(.buffer(buffer), type: type) + } + + public func withJSON(_ dict: [String: Any?]) throws -> Self { + withBody(try .json(dict), type: .json) + } + + public func withJSON(_ json: E, encoder: JSONEncoder = JSONEncoder()) throws -> Self { + try withBody(json, encoder: encoder) + } + + public func withForm(_ dict: [String: Any?]) throws -> Self { + withBody(try .json(dict), type: .urlForm) + } + + public func withForm(_ form: E, encoder: URLEncodedFormEncoder = URLEncodedFormEncoder()) throws -> Self { + try withBody(form, encoder: encoder) + } + + public func attach(_ name: String, contents: ByteBuffer, filename: String? = nil, encoder: FormDataEncoder = FormDataEncoder()) async throws -> Self { + let file = File(name: filename ?? name, size: contents.writerIndex, content: .buffer(contents)) + return try withBody([name: file], encoder: encoder) + } + + public func attach(_ name: String, file: File, encoder: FormDataEncoder = FormDataEncoder()) async throws -> Self { + var copy = file + return try withBody([name: await copy.collect()], encoder: encoder) + } + + public func attach(_ files: [String: File], encoder: FormDataEncoder = FormDataEncoder()) async throws -> Self { + var collectedFiles: [String: File] = [:] + for (name, var file) in files { + collectedFiles[name] = try await file.collect() + } + + return try withBody(files, encoder: encoder) + } +} diff --git a/Sources/Alchemy/HTTP/Protocols/ContentInspector.swift b/Sources/Alchemy/HTTP/Protocols/ContentInspector.swift new file mode 100644 index 00000000..f5e264b9 --- /dev/null +++ b/Sources/Alchemy/HTTP/Protocols/ContentInspector.swift @@ -0,0 +1,142 @@ +import Hummingbird +import MultipartKit + +public protocol ContentInspector: Extendable { + var headers: HTTPHeaders { get } + var body: ByteContent? { get } +} + +extension ContentInspector { + + // MARK: Files + + /// Get any attached file with the given name from this request. + public func file(_ name: String) -> File? { + files()[name] + } + + /// Any files attached to this content, keyed by their multipart name + /// (separate from filename). Only populated if this content is + /// associated with a multipart request containing files. + /// + /// Async since the request may need to finish streaming before we get the + /// files. + public func files() -> [String: File] { + guard !content().allKeys.isEmpty else { + return [:] + } + + let content = content() + let files = Set(content.allKeys).compactMap { key -> (String, File)? in + guard let file = content[key].value?.file else { + return nil + } + + return (key, file) + } + + return Dictionary(uniqueKeysWithValues: files) + } + + // MARK: Partial Content + + public subscript(dynamicMember member: String) -> Content { + if let int = Int(member) { + return self[int] + } else { + return self[member] + } + } + + public subscript(index: Int) -> Content { + content()[index] + } + + public subscript(field: String) -> Content { + content()[field] + } + + func content() -> Content { + if let content = _content { + return content + } else { + guard let body = body else { + return Content(error: ContentError.emptyBody) + } + + guard let decoder = preferredDecoder() else { + return Content(error: ContentError.unknownContentType(headers.contentType)) + } + + let content = decoder.content(from: body.buffer, contentType: headers.contentType) + _content = content + return content + } + } + + private var _content: Content? { + get { extensions.get(\._content) } + nonmutating set { extensions.set(\._content, value: newValue) } + } + + // MARK: Content + + /// Decodes the content as a decodable, based on it's content type or with + /// the given content decoder. + /// + /// - Parameters: + /// - type: The Decodable type to which the body should be decoded. + /// - decoder: The decoder with which to decode. Defaults to + /// `Content.defaultDecoder`. + /// - Throws: Any errors encountered during decoding. + /// - Returns: The decoded object of type `type`. + public func decode(_ type: D.Type = D.self, with decoder: ContentDecoder? = nil) throws -> D { + guard let buffer = body?.buffer else { + throw ValidationError("expecting a request body") + } + + guard let decoder = decoder else { + guard let preferredDecoder = preferredDecoder() else { + throw HTTPError(.notAcceptable) + } + + return try preferredDecoder.decodeContent(type, from: buffer, contentType: headers.contentType) + } + + do { + return try decoder.decodeContent(type, from: buffer, contentType: headers.contentType) + } catch let DecodingError.keyNotFound(key, context) { + let path = context.codingPath.map(\.stringValue).joined(separator: ".") + let pathWithKey = path.isEmpty ? key.stringValue : "\(path).\(key.stringValue)" + throw ValidationError("Missing field `\(pathWithKey)` from request body.") + } catch let DecodingError.typeMismatch(type, context) { + let key = context.codingPath.last?.stringValue ?? "unknown" + throw ValidationError("Request body field `\(key)` should be a `\(type)`.") + } catch { + throw ValidationError("Invalid request body.") + } + } + + public func preferredDecoder() -> ContentDecoder? { + guard let contentType = headers.contentType else { + return ByteContent.defaultDecoder + } + + switch contentType { + case .json: + return .json + case .urlForm: + return .urlForm + case .multipart(boundary: ""): + return .multipart + default: + return nil + } + } +} + +extension Array { + func removingFirst() -> [Element] { + Array(dropFirst()) + } +} diff --git a/Sources/Alchemy/HTTP/Protocols/RequestBuilder.swift b/Sources/Alchemy/HTTP/Protocols/RequestBuilder.swift new file mode 100644 index 00000000..101c4ce3 --- /dev/null +++ b/Sources/Alchemy/HTTP/Protocols/RequestBuilder.swift @@ -0,0 +1,87 @@ +import Foundation +import NIOHTTP1 + +public protocol RequestBuilder: ContentBuilder { + associatedtype Res + + var urlComponents: URLComponents { get set } + var method: HTTPMethod { get set } + + func execute() async throws -> Res +} + +extension RequestBuilder { + + // MARK: Queries + + public func withQuery(_ name: String, value: CustomStringConvertible?) -> Self { + with { request in + let newItem = URLQueryItem(name: name, value: value?.description) + if let existing = request.urlComponents.queryItems { + request.urlComponents.queryItems = existing + [newItem] + } else { + request.urlComponents.queryItems = [newItem] + } + } + } + + public func withQueries(_ dict: [String: CustomStringConvertible]) -> Self { + dict.reduce(self) { $0.withQuery($1.key, value: $1.value) } + } + + // MARK: Methods & URL + + public func withBaseUrl(_ url: String) -> Self { + with { + var newComponents = URLComponents(string: url) + if let oldQueryItems = $0.urlComponents.queryItems { + let newQueryItems = newComponents?.queryItems ?? [] + newComponents?.queryItems = newQueryItems + oldQueryItems + } + + $0.urlComponents = newComponents ?? URLComponents() + } + } + + public func withMethod(_ method: HTTPMethod) -> Self { + with { $0.method = method } + } + + // MARK: Execution + + public func execute() async throws -> Res { + try await execute() + } + + public func request(_ method: HTTPMethod, uri: String) async throws -> Res { + try await withBaseUrl(uri).withMethod(method).execute() + } + + public func get(_ uri: String) async throws -> Res { + try await withBaseUrl(uri).withMethod(.GET).execute() + } + + public func post(_ uri: String) async throws -> Res { + try await withBaseUrl(uri).withMethod(.POST).execute() + } + + public func put(_ uri: String) async throws -> Res { + try await withBaseUrl(uri).withMethod(.PUT).execute() + } + + public func patch(_ uri: String) async throws -> Res { + try await withBaseUrl(uri).withMethod(.PATCH).execute() + } + + public func delete(_ uri: String) async throws -> Res { + try await withBaseUrl(uri).withMethod(.DELETE).execute() + } + + public func options(_ uri: String) async throws -> Res { + try await withBaseUrl(uri).withMethod(.OPTIONS).execute() + } + + public func head(_ uri: String) async throws -> Res { + try await withBaseUrl(uri).withMethod(.HEAD).execute() + } +} diff --git a/Sources/Alchemy/HTTP/Protocols/RequestInspector.swift b/Sources/Alchemy/HTTP/Protocols/RequestInspector.swift new file mode 100644 index 00000000..71c450bc --- /dev/null +++ b/Sources/Alchemy/HTTP/Protocols/RequestInspector.swift @@ -0,0 +1,17 @@ +import Foundation +import NIOHTTP1 + +public protocol RequestInspector: ContentInspector { + var method: HTTPMethod { get } + var urlComponents: URLComponents { get } +} + +extension RequestInspector { + public func query(_ key: String) -> String? { + urlComponents.queryItems?.first(where: { $0.name == key })?.value + } + + public func query(_ key: String, as: L.Type = L.self) -> L? { + query(key).map { L($0) } ?? nil + } +} diff --git a/Sources/Alchemy/HTTP/Protocols/ResponseBuilder.swift b/Sources/Alchemy/HTTP/Protocols/ResponseBuilder.swift new file mode 100644 index 00000000..a9fce7f8 --- /dev/null +++ b/Sources/Alchemy/HTTP/Protocols/ResponseBuilder.swift @@ -0,0 +1,5 @@ +import NIOHTTP1 + +public protocol ResponseBuilder: ContentBuilder { + var status: HTTPResponseStatus { get set } +} diff --git a/Sources/Alchemy/HTTP/Protocols/ResponseInspector.swift b/Sources/Alchemy/HTTP/Protocols/ResponseInspector.swift new file mode 100644 index 00000000..187a6636 --- /dev/null +++ b/Sources/Alchemy/HTTP/Protocols/ResponseInspector.swift @@ -0,0 +1,5 @@ +import NIOHTTP1 + +public protocol ResponseInspector: ContentInspector { + var status: HTTPResponseStatus { get } +} diff --git a/Sources/Alchemy/HTTP/Request.swift b/Sources/Alchemy/HTTP/Request.swift deleted file mode 100644 index ae536489..00000000 --- a/Sources/Alchemy/HTTP/Request.swift +++ /dev/null @@ -1,164 +0,0 @@ -import Foundation -import NIO -import NIOHTTP1 - -/// A simplified Request type as you'll come across in many web -/// frameworks -public final class Request { - /// The default JSONDecoder with which to decode HTTP request - /// bodies. - public static var defaultJSONDecoder = JSONDecoder() - - /// The head contains all request "metadata" like the URI and - /// request method. - /// - /// The headers are also found in the head, and they are often - /// used to describe the body as well. - public let head: HTTPRequestHead - - /// The url components of this request. - public let components: URLComponents? - - /// The any parameters inside the path. - public var pathParameters: [PathParameter] = [] - - /// The bodyBuffer is internal because the HTTPBody API is exposed - /// for simpler access. - var bodyBuffer: ByteBuffer? - - /// Any information set by a middleware. - var middlewareData: [ObjectIdentifier: Any] = [:] - - /// This initializer is necessary because the `bodyBuffer` is a - /// private property. - init(head: HTTPRequestHead, bodyBuffer: ByteBuffer?) { - self.head = head - self.bodyBuffer = bodyBuffer - self.components = URLComponents(string: head.uri) - } -} - -extension Request { - /// The HTTPMethod of the request. - public var method: HTTPMethod { - self.head.method - } - - /// The path of the request. Does not include the query string. - public var path: String { - self.components?.path ?? "" - } - - /// Any headers associated with the request. - public var headers: HTTPHeaders { - self.head.headers - } - - /// Any query items parsed from the URL. These are not percent - /// encoded. - public var queryItems: [URLQueryItem] { - self.components?.queryItems ?? [] - } - - /// Returns the first `PathParameter` for the given key, if there - /// is one. - /// - /// Use this to fetch any parameters from the path. - /// ```swift - /// app.post("/users/:user_id") { request in - /// let theUserID = request.pathParameter(named: "user_id")?.stringValue - /// ... - /// } - /// ``` - public func pathParameter(named key: String) -> PathParameter? { - self.pathParameters.first(where: { $0.parameter == "key" }) - } - - /// A dictionary with the contents of this Request's body. - /// - Throws: Any errors from decoding the body. - /// - Returns: A [String: Any] with the contents of this Request's - /// body. - func bodyDict() throws -> [String: Any]? { - try body?.decodeJSONDictionary() - } - - /// The body is a wrapper used to provide simple access to any - /// body data, such as JSON. - public var body: HTTPBody? { - guard let bodyBuffer = bodyBuffer else { - return nil - } - - return HTTPBody(buffer: bodyBuffer) - } - - /// Sets a value associated with this request. Useful for setting - /// objects with middleware. - /// - /// Usage: - /// ```swift - /// struct ExampleMiddleware: Middleware { - /// func intercept(_ request: Request) -> EventLoopFuture { - /// let someData: SomeData = ... - /// request.set(someData) - /// return .new(value: request) - /// } - /// } - /// - /// app - /// .use(ExampleMiddleware()) - /// .on(.GET, at: "/example") { request in - /// let theData = try request.get(SomeData.self) - /// } - /// - /// ``` - /// - /// - Parameter value: The value to set. - /// - Returns: `self`, with the new value set internally for - /// access with `self.get(Value.self)`. - @discardableResult - public func set(_ value: T) -> Self { - middlewareData[ObjectIdentifier(T.self)] = value - return self - } - - /// Gets a value associated with this request, throws if there is - /// not a value of type `T` already set. - /// - /// - Parameter type: The type of the associated value to get from - /// the request. - /// - Throws: An `AssociatedValueError` if there isn't a value of - /// type `T` found associated with the request. - /// - Returns: The value of type `T` from the request. - public func get(_ type: T.Type = T.self) throws -> T { - let error = AssociatedValueError(message: "Couldn't find type `\(name(of: type))` on this request") - return try middlewareData[ObjectIdentifier(T.self)] - .unwrap(as: type, or: error) - } -} - -/// Error thrown when the user tries to `.get` an assocaited value -/// from an `Request` but one isn't set. -struct AssociatedValueError: Error { - /// What went wrong. - let message: String -} - -private extension Optional { - /// Unwraps an optional as the provided type or throws the - /// provided error. - /// - /// - Parameters: - /// - as: The type to unwrap to. - /// - error: The error to be thrown if `self` is unable to be - /// unwrapped as the provided type. - /// - Throws: An error if unwrapping as the provided type fails. - /// - Returns: `self` unwrapped and cast as the provided type. - func unwrap(as: T.Type = T.self, or error: Error) throws -> T { - guard let wrapped = self as? T else { - throw error - } - - return wrapped - } -} diff --git a/Sources/Alchemy/HTTP/PathParameter.swift b/Sources/Alchemy/HTTP/Request/Parameter.swift similarity index 72% rename from Sources/Alchemy/HTTP/PathParameter.swift rename to Sources/Alchemy/HTTP/Request/Parameter.swift index df24f009..a6a94b93 100644 --- a/Sources/Alchemy/HTTP/PathParameter.swift +++ b/Sources/Alchemy/HTTP/Request/Parameter.swift @@ -1,10 +1,10 @@ import Foundation -/// Represents a dynamic parameter inside the URL. Parameter +/// Represents a dynamic parameter inside the path. Parameter /// placeholders should be prefaced with a colon (`:`) in /// the route string. Something like `:user_id` in the /// path `/v1/users/:user_id`. -public struct PathParameter: Equatable { +public struct Parameter: Equatable { /// An error encountered while decoding a path parameter value /// string to a specific type such as `UUID` or `Int`. public struct DecodingError: Error { @@ -14,36 +14,36 @@ public struct PathParameter: Equatable { /// The escaped parameter that was matched, _without_ the colon. /// Something like `user_id` if `:user_id` was in the path. - public let parameter: String + public let key: String /// The actual string value of the parameter. - public let stringValue: String + public let value: String /// Decodes a `UUID` from this parameter's value or throws if the /// string is an invalid `UUID`. /// - /// - Throws: A `PathParameter.DecodingError` if the value string + /// - Throws: A `Parameter.DecodingError` if the value string /// is not convertible to a `UUID`. /// - Returns: The decoded `UUID`. public func uuid() throws -> UUID { - try UUID(uuidString: self.stringValue) - .unwrap(or: DecodingError("Unable to decode UUID for '\(self.parameter)'. Value was '\(self.stringValue)'.")) + try UUID(uuidString: value) + .unwrap(or: DecodingError("Unable to decode UUID for '\(key)'. Value was '\(value)'.")) } /// Returns the `String` value of this parameter. /// /// - Returns: the value of this parameter. public func string() -> String { - self.stringValue + value } /// Decodes an `Int` from this parameter's value or throws if the /// string can't be converted to an `Int`. /// - /// - Throws: a `PathParameter.DecodingError` if the value string + /// - Throws: a `Parameter.DecodingError` if the value string /// is not convertible to a `Int`. /// - Returns: the decoded `Int`. public func int() throws -> Int { - try Int(self.stringValue) - .unwrap(or: DecodingError("Unable to decode Int for '\(self.parameter)'. Value was '\(self.stringValue)'.")) + try Int(value) + .unwrap(or: DecodingError("Unable to decode Int for '\(key)'. Value was '\(value)'.")) } } diff --git a/Sources/Alchemy/HTTP/Request/Request+AssociatedValue.swift b/Sources/Alchemy/HTTP/Request/Request+AssociatedValue.swift new file mode 100644 index 00000000..ec32e50e --- /dev/null +++ b/Sources/Alchemy/HTTP/Request/Request+AssociatedValue.swift @@ -0,0 +1,80 @@ +extension Request { + private var associatedValues: [ObjectIdentifier: Any]? { + get { extensions.get(\.associatedValues) } + set { extensions.set(\.associatedValues, value: newValue) } + } + + /// Sets a value associated with this request. Useful for setting + /// objects with middleware. + /// + /// Usage: + /// + /// struct ExampleMiddleware: Middleware { + /// func intercept(_ request: Request, next: Next) async throws -> Response { + /// let someData: SomeData = ... + /// return try await next(request.set(someData)) + /// } + /// } + /// + /// app + /// .use(ExampleMiddleware()) + /// .on(.GET, at: "/example") { request in + /// let theData = try request.get(SomeData.self) + /// } + /// + /// - Parameter value: The value to set. + /// - Returns: This reqeust, with the new value set internally for access + /// with `get(Value.self)`. + @discardableResult + public func set(_ value: T) -> Self { + if associatedValues != nil { + associatedValues?[id(of: T.self)] = value + } else { + associatedValues = [id(of: T.self): value] + } + + return self + } + + /// Gets a value associated with this request, throws if there is + /// not a value of type `T` already set. + /// + /// - Parameter type: The type of the associated value to get from + /// the request. + /// - Throws: An `AssociatedValueError` if there isn't a value of + /// type `T` found associated with the request. + /// - Returns: The value of type `T` from the request. + public func get(_ type: T.Type = T.self, or error: Error = AssociatedValueError(message: "Couldn't find type `\(name(of: T.self))` on this request")) throws -> T { + try (associatedValues?[id(of: T.self)]).unwrap(as: type, or: error) + } +} + +/// Error thrown when the user tries to `.get` an assocaited value +/// from an `Request` but one isn't set. +public struct AssociatedValueError: Error { + /// What went wrong. + public let message: String + + public init(message: String) { + self.message = message + } +} + +extension Optional { + /// Unwraps an optional as the provided type or throws the + /// provided error. + /// + /// - Parameters: + /// - as: The type to unwrap to. + /// - error: The error to be thrown if `self` is unable to be + /// unwrapped as the provided type. + /// - Throws: An error if unwrapping as the provided type fails. + /// - Returns: `self` unwrapped and cast as the provided type. + fileprivate func unwrap(as: T.Type = T.self, or error: Error) throws -> T { + guard let wrapped = self as? T else { + throw error + } + + return wrapped + } +} diff --git a/Sources/Alchemy/HTTP/Request+Auth.swift b/Sources/Alchemy/HTTP/Request/Request+Auth.swift similarity index 84% rename from Sources/Alchemy/HTTP/Request+Auth.swift rename to Sources/Alchemy/HTTP/Request/Request+Auth.swift index f7b1f50a..eb0c0b92 100644 --- a/Sources/Alchemy/HTTP/Request+Auth.swift +++ b/Sources/Alchemy/HTTP/Request/Request+Auth.swift @@ -15,29 +15,27 @@ extension Request { if authString.starts(with: "Basic ") { authString.removeFirst(6) - guard let base64Data = Data(base64Encoded: authString), - let authString = String(data: base64Data, encoding: .utf8) else - { - // Or maybe we should throw error? + guard + let base64Data = Data(base64Encoded: authString), + let authString = String(data: base64Data, encoding: .utf8) + else { return nil } - - let components = authString.components(separatedBy: ":") - guard let username = components.first else { + + guard !authString.isEmpty else { return nil } + let components = authString.components(separatedBy: ":") + let username = components[0] let password = components.dropFirst().joined() - - return .basic( - HTTPAuth.Basic(username: username, password: password) - ) + return .basic(HTTPAuth.Basic(username: username, password: password)) } else if authString.starts(with: "Bearer ") { authString.removeFirst(7) return .bearer(HTTPAuth.Bearer(token: authString)) - } else { - return nil } + + return nil } /// Gets any `Basic` authorization data from this request. @@ -51,9 +49,9 @@ extension Request { if case let .basic(authData) = auth { return authData - } else { - return nil } + + return nil } /// Gets any `Bearer` authorization data from this request. @@ -67,19 +65,19 @@ extension Request { if case let .bearer(authData) = auth { return authData - } else { - return nil } + + return nil } } /// A type representing any auth that may be on an HTTP request. /// Supports `Basic` and `Bearer`. -public enum HTTPAuth { +public enum HTTPAuth: Equatable { /// The basic auth of an Request. Corresponds to a header that /// looks like /// `Authorization: Basic `. - public struct Basic { + public struct Basic: Equatable { /// The username of this authorization. Comes before the colon /// in the decoded `Authorization` header value i.e. /// `Basic :`. @@ -92,7 +90,7 @@ public enum HTTPAuth { /// The bearer auth of an Request. Corresponds to a header that /// looks like `Authorization: Bearer `. - public struct Bearer { + public struct Bearer: Equatable { /// The token in the `Authorization` header value. /// i.e. `Bearer `. public let token: String diff --git a/Sources/Alchemy/HTTP/Request/Request.swift b/Sources/Alchemy/HTTP/Request/Request.swift new file mode 100644 index 00000000..2f42875d --- /dev/null +++ b/Sources/Alchemy/HTTP/Request/Request.swift @@ -0,0 +1,90 @@ +import Foundation +import NIO +import NIOHTTP1 +import Hummingbird + +/// A type that represents inbound requests to your application. +public final class Request: RequestInspector { + /// The request body. + public var body: ByteContent? { hbRequest.byteContent } + /// The byte buffer of this request's body, if there is one. + public var buffer: ByteBuffer? { body?.buffer } + /// The stream of this request's body, if there is one. + public var stream: ByteStream? { body?.stream } + /// The remote address where this request came from. + public var remoteAddress: SocketAddress? { hbRequest.remoteAddress } + /// The remote address where this request came from. + public var ip: String { remoteAddress?.ipAddress ?? "" } + /// The event loop this request is being handled on. + public var loop: EventLoop { hbRequest.eventLoop } + /// The HTTPMethod of the request. + public var method: HTTPMethod { hbRequest.method } + /// Any headers associated with the request. + public var headers: HTTPHeaders { hbRequest.headers } + /// The complete url of the request. + public var url: URL { urlComponents.url ?? URL(fileURLWithPath: "") } + /// The path of the request. Does not include the query string. + public var path: String { urlComponents.path } + /// Any query items parsed from the URL. These are not percent encoded. + public var queryItems: [URLQueryItem]? { urlComponents.queryItems } + /// The underlying hummingbird request + public var hbRequest: HBRequest + /// Allows for extending storage on this type. + public var extensions: Extensions + /// The url components of this request. + public let urlComponents: URLComponents + /// Parameters parsed from the path. + public var parameters: [Parameter] { + get { extensions.get(\.parameters) } + set { extensions.set(\.parameters, value: newValue) } + } + + init(hbRequest: HBRequest, parameters: [Parameter] = []) { + self.hbRequest = hbRequest + self.urlComponents = URLComponents(string: hbRequest.uri.string) ?? URLComponents() + self.extensions = Extensions() + self.parameters = parameters + } + + /// Returns the first parameter for the given key, if there is one. + /// + /// Use this to fetch any parameters from the path. + /// ```swift + /// app.post("/users/:user_id") { request in + /// let userId: Int = try request.parameter("user_id") + /// ... + /// } + /// ``` + public func parameter(_ key: String, as: L.Type = L.self) throws -> L { + guard let parameterString: String = parameters.first(where: { $0.key == key })?.value else { + throw ValidationError("expected parameter \(key)") + } + + guard let converted = L(parameterString) else { + throw ValidationError("parameter \(key) was \(parameterString) which couldn't be converted to \(name(of: L.self))") + } + + return converted + } +} + +extension HBRequest { + fileprivate var byteContent: ByteContent? { + switch body { + case .byteBuffer(let bytes): + return bytes.map { .buffer($0) } + case .stream(let streamer): + return .stream(streamer.byteStream(eventLoop)) + } + } +} + +extension HBStreamerProtocol { + func byteStream(_ loop: EventLoop) -> ByteStream { + return .new { reader in + try await self.consumeAll(on: loop) { buffer in + return loop.asyncSubmit { try await reader.write(buffer) } + }.get() + } + } +} diff --git a/Sources/Alchemy/HTTP/Response.swift b/Sources/Alchemy/HTTP/Response.swift deleted file mode 100644 index 775da719..00000000 --- a/Sources/Alchemy/HTTP/Response.swift +++ /dev/null @@ -1,123 +0,0 @@ -import NIO -import NIOHTTP1 - -/// A type representing the response from an HTTP endpoint. This -/// response can be a failure or success case depending on the -/// status code in the `head`. -public final class Response { - /// The default `JSONEncoder` with which to encode JSON responses. - public static var defaultJSONEncoder = JSONEncoder() - - /// The success or failure status response code. - public var status: HTTPResponseStatus - - /// The HTTP headers. - public var headers: HTTPHeaders - - /// The body which contains any data you want to send back to the - /// client This can be HTML, an image or JSON among many other - /// data types. - public let body: HTTPBody? - - /// This will be called when this `Response` writes data to a - /// remote peer. - internal var writerClosure: (ResponseWriter) -> Void { - get { self._writerClosure ?? self.defaultWriterClosure } - } - - /// Closure for deferring writing. - private var _writerClosure: ((ResponseWriter) -> Void)? - - /// Creates a new response using a status code, headers and body. - /// If the headers do not contain `content-length` or - /// `content-type`, those will be appended based on - /// the supplied `HTTPBody`. - /// - /// - Parameters: - /// - status: The status code of this response. - /// - headers: Any headers to return in the response. Defaults - /// to empty headers. - /// - body: The body of this response. See `HTTPBody` for - /// initializing with various data. - public init(status: HTTPResponseStatus, headers: HTTPHeaders = HTTPHeaders(), body: HTTPBody?) { - var headers = headers - headers.replaceOrAdd(name: "content-length", value: String(body?.buffer.writerIndex ?? 0)) - body?.mimeType.map { headers.replaceOrAdd(name: "content-type", value: $0.value) } - - self.status = status - self.headers = headers - self.body = body - } - - /// Initialize this response with a closure that will be called, - /// allowing you to directly write headers, body, and end to - /// the response. The request connection will be left open - /// until you `.writeEnd()` to the closure's - /// `ResponseWriter`. - /// - /// Usage: - /// ```swift - /// app.get("/stream") { - /// Response { writer in - /// writer.writeHead(...) - /// writer.writeBody(...) - /// writer.writeEnd() - /// } - /// } - /// ``` - /// - /// - Parameter writer: A closure take a `ResponseWriter` and - /// using it to write response data to a remote peer. - public init(_ writer: @escaping (ResponseWriter) -> Void) { - self.status = .ok - self.headers = HTTPHeaders() - self.body = nil - self._writerClosure = writer - } - - /// Writes this response to an remote peer via a `ResponseWriter`. - /// - /// - Parameter writer: An abstraction around writing data to a - /// remote peer. - func write(to writer: ResponseWriter) { - self.writerClosure(writer) - } - - /// Provides default writing behavior for a `Response`. - /// - /// - Parameter writer: An abstraction around writing data to a - /// remote peer. - private func defaultWriterClosure(writer: ResponseWriter) { - writer.writeHead(status: status, headers) - if let body = body { - writer.writeBody(body.buffer) - } - writer.writeEnd() - } -} - -/// An abstraction around writing data to a remote peer. Conform to -/// this protocol and inject it into the `Response` for responding -/// to a remote peer at a later point in time. -/// -/// Be sure to call `writeEnd` when you are finished writing data or -/// the client response will never complete. -public protocol ResponseWriter { - /// Write the status and head of a response. Should only be called - /// once. - /// - /// - Parameters: - /// - status: The status code of the response. - /// - headers: Any headers of this response. - func writeHead(status: HTTPResponseStatus, _ headers: HTTPHeaders) - - /// Write some body data to the remote peer. May be called 0 or - /// more times. - /// - /// - Parameter body: The buffer of data to write. - func writeBody(_ body: ByteBuffer) - - /// Write the end of the response. Needs to be called once per - /// response, when all data has been written. - func writeEnd() -} diff --git a/Sources/Alchemy/HTTP/Response/Response.swift b/Sources/Alchemy/HTTP/Response/Response.swift new file mode 100644 index 00000000..c6bc8bab --- /dev/null +++ b/Sources/Alchemy/HTTP/Response/Response.swift @@ -0,0 +1,67 @@ +import Hummingbird +import NIO +import NIOHTTP1 + +/// A type representing the response from an HTTP endpoint. This +/// response can be a failure or success case depending on the +/// status code in the `head`. +public final class Response: ResponseBuilder { + /// The success or failure status response code. + public var status: HTTPResponseStatus + /// The HTTP headers. + public var headers: HTTPHeaders + /// The body of this response. + public var body: ByteContent? + /// Allows for extending storage on this type. + public var extensions: Extensions + + /// Creates a new response using a status code, headers and body. If the + /// body is of type `.buffer()` or `nil`, the `Content-Length` header + /// will be set, if not already, in the headers. + /// + /// - Parameters: + /// - status: The status of this response. + /// - headers: Any headers for this response. + /// - body: Any response body, either a buffer or streamed. + public init(status: HTTPResponseStatus = .ok, headers: HTTPHeaders = [:], body: ByteContent? = nil) { + self.status = status + self.headers = headers + self.body = body + self.extensions = Extensions() + + switch body { + case .buffer(let buffer): + self.headers.contentLength = buffer.writerIndex + case .none: + self.headers.contentLength = 0 + default: + break + } + } + + /// Initialize this response with a closure that will be called, + /// allowing you to directly write headers, body, and end to + /// the response. The request connection will be left open + /// until you `.writeEnd()` to the closure's + /// `ResponseWriter`. + /// + /// Usage: + /// ```swift + /// app.get("/stream") { + /// Response(status: .ok, headers: ["Content-Length": "248"]) { writer in + /// writer.writeHead(...) + /// writer.writeBody(...) + /// writer.writeEnd() + /// } + /// } + /// ``` + /// + /// - Parameter writer: A closure take a `ResponseWriter` and + /// using it to write response data to a remote peer. + public init(status: HTTPResponseStatus = .ok, headers: HTTPHeaders = [:], stream: @escaping ByteStream.Closure) { + self.status = .ok + self.headers = HTTPHeaders() + self.body = .stream(stream) + self.extensions = Extensions() + } +} diff --git a/Sources/Alchemy/HTTP/ValidationError.swift b/Sources/Alchemy/HTTP/ValidationError.swift new file mode 100644 index 00000000..d47ab6f4 --- /dev/null +++ b/Sources/Alchemy/HTTP/ValidationError.swift @@ -0,0 +1,22 @@ +import Foundation + +/// An error related to decoding a type from a `DecodableRequest`. +public struct ValidationError: Error { + /// What went wrong. + public let message: String + + /// Create an error with the specified message. + /// + /// - Parameter message: What went wrong. + public init(_ message: String) { + self.message = message + } +} + +// Provide a custom response for when `ValidationError`s are thrown. +extension ValidationError: ResponseConvertible { + public func response() throws -> Response { + try Response(status: .badRequest) + .withValue(["validation_error": message]) + } +} diff --git a/Sources/Alchemy/Middleware/CORSMiddleware.swift b/Sources/Alchemy/Middleware/Concrete/CORSMiddleware.swift similarity index 76% rename from Sources/Alchemy/Middleware/CORSMiddleware.swift rename to Sources/Alchemy/Middleware/Concrete/CORSMiddleware.swift index f79b882f..55092959 100644 --- a/Sources/Alchemy/Middleware/CORSMiddleware.swift +++ b/Sources/Alchemy/Middleware/Concrete/CORSMiddleware.swift @@ -60,15 +60,15 @@ public final class CORSMiddleware: Middleware { /// header should be created. /// - Returns: Header string to be used in response for /// allowed origin. - public func header(forRequest req: Request) -> String { + public func header(forOrigin origin: String) -> String { switch self { - case .none: return "" - case .originBased: return req.headers["Origin"].first ?? "" - case .all: return "*" + case .none: + return "" + case .originBased: + return origin + case .all: + return "*" case .any(let origins): - guard let origin = req.headers["Origin"].first else { - return "" - } return origins.contains(origin) ? origin : "" case .custom(let string): return string @@ -88,7 +88,7 @@ public final class CORSMiddleware: Middleware { /// - Allow Headers: `Accept`, `Authorization`, /// `Content-Type`, `Origin`, `X-Requested-With` public static func `default`() -> Configuration { - return .init( + Configuration( allowedOrigin: .originBased, allowedMethods: [.GET, .POST, .PUT, .OPTIONS, .DELETE, .PATCH], allowedHeaders: ["Accept", "Authorization", "Content-Type", "Origin", "X-Requested-With"] @@ -165,56 +165,53 @@ public final class CORSMiddleware: Middleware { // MARK: Middleware - public func intercept(_ request: Request, next: @escaping Next) throws -> EventLoopFuture { + public func intercept(_ request: Request, next: Next) async throws -> Response { // Check if it's valid CORS request - guard request.headers["Origin"].first != nil else { - return next(request) + guard let origin = request.headers["Origin"].first else { + return try await next(request) } // Determine if the request is pre-flight. If it is, create // empty response otherwise get response from the responder // chain. - let response = request.isPreflight ? .new(Response(status: .ok, body: nil)) : next(request) + let response = request.isPreflight ? Response(status: .ok, body: nil) : try await next(request) - return response.map { response in - // Modify response headers based on CORS settings - response.headers.replaceOrAdd( - name: "Access-Control-Allow-Origin", - value: self.configuration.allowedOrigin.header(forRequest: request) - ) - response.headers.replaceOrAdd( - name: "Access-Control-Allow-Headers", - value: self.configuration.allowedHeaders - ) + // Modify response headers based on CORS settings + response.headers.replaceOrAdd( + name: "Access-Control-Allow-Origin", + value: self.configuration.allowedOrigin.header(forOrigin: origin) + ) + response.headers.replaceOrAdd( + name: "Access-Control-Allow-Headers", + value: self.configuration.allowedHeaders + ) + response.headers.replaceOrAdd( + name: "Access-Control-Allow-Methods", + value: self.configuration.allowedMethods + ) + + if let exposedHeaders = self.configuration.exposedHeaders { + response.headers.replaceOrAdd(name: "Access-Control-Expose-Headers", value: exposedHeaders) + } + + if let cacheExpiration = self.configuration.cacheExpiration { + response.headers.replaceOrAdd(name: "Access-Control-Max-Age", value: String(cacheExpiration)) + } + + if self.configuration.allowCredentials { response.headers.replaceOrAdd( - name: "Access-Control-Allow-Methods", - value: self.configuration.allowedMethods + name: "Access-Control-Allow-Credentials", + value: "true" ) - - if let exposedHeaders = self.configuration.exposedHeaders { - response.headers.replaceOrAdd(name: "Access-Control-Expose-Headers", value: exposedHeaders) - } - - if let cacheExpiration = self.configuration.cacheExpiration { - response.headers.replaceOrAdd(name: "Access-Control-Max-Age", value: String(cacheExpiration)) - } - - if self.configuration.allowCredentials { - response.headers.replaceOrAdd( - name: "Access-Control-Allow-Credentials", - value: "true" - ) - } - - return response } + + return response } } -private extension Request { +extension Request { /// Returns `true` if the request is a pre-flight CORS request. - var isPreflight: Bool { - return self.method.rawValue == "OPTIONS" - && self.headers["Access-Control-Request-Method"].first != nil + fileprivate var isPreflight: Bool { + method.rawValue == "OPTIONS" && headers["Access-Control-Request-Method"].first != nil } } diff --git a/Sources/Alchemy/Middleware/Concrete/FileMiddleware.swift b/Sources/Alchemy/Middleware/Concrete/FileMiddleware.swift new file mode 100644 index 00000000..b94ff58f --- /dev/null +++ b/Sources/Alchemy/Middleware/Concrete/FileMiddleware.swift @@ -0,0 +1,59 @@ +/// Middleware for serving static files from a given directory. +/// +/// Usage: +/// +/// app.useAll(FileMiddleware(from: "resources")) +/// +/// Now your app will serve the files that are in the `resources` directory. +public struct FileMiddleware: Middleware { + /// The filesystem for getting files. + private let filesystem: Filesystem + /// Additional extensions to try if a file with the exact name isn't found. + private let extensions: [String] + + /// Creates a new middleware to serve static files from a given directory. + /// + /// - Parameters: + /// - directory: The directory to server static files from. Defaults to + /// "Public/". + /// - extensions: File extension fallbacks. When set, if a file is not + /// found, the given extensions will be added to the file name and + /// searched for. The first that exists will be served. Defaults + /// to []. Example: ["html", "htm"]. + public init(from directory: String = "Public/", extensions: [String] = []) { + self.filesystem = .local(root: directory) + self.extensions = extensions + } + + // MARK: Middleware + + public func intercept(_ request: Request, next: Next) async throws -> Response { + // Ignore non `GET` requests. + guard request.method == .GET else { + return try await next(request) + } + + // Ensure path doesn't contain any parent directories. + guard !request.path.contains("../") else { + throw HTTPError(.forbidden) + } + + // Trim forward slashes + var sanitizedPath = request.path.trimmingForwardSlash + + // Route / to + if sanitizedPath.isEmpty { + sanitizedPath = "index.html" + } + + // See if there's a file at any possible extension + let allPossiblePaths = [sanitizedPath] + extensions.map { sanitizedPath + ".\($0)" } + for possiblePath in allPossiblePaths { + if try await filesystem.exists(possiblePath) { + return try await filesystem.get(possiblePath).response() + } + } + + return try await next(request) + } +} diff --git a/Sources/Alchemy/Middleware/Middleware.swift b/Sources/Alchemy/Middleware/Middleware.swift index ba3965dc..1b0fdb45 100644 --- a/Sources/Alchemy/Middleware/Middleware.swift +++ b/Sources/Alchemy/Middleware/Middleware.swift @@ -1,51 +1,50 @@ import NIO /// A `Middleware` is used to intercept either incoming `Request`s or -/// outgoing `Response`s. Using futures, they can do something -/// with those, either synchronously or asynchronously. +/// outgoing `Response`s. The can intercept either synchronously or +/// asynchronously. /// /// Usage: /// ```swift -/// // Example synchronous middleware -/// struct SyncMiddleware: Middleware { -/// func intercept(_ request: Request, next: @escaping Next) throws -> EventLoopFuture -/// ... // Do something with `request`. -/// // Then continue the chain. Could hook into this future to -/// // do something with the `Response`. -/// return next(request) +/// // Log all requests and responses to the server +/// struct RequestLoggingMiddleware: Middleware { +/// func intercept(_ request: Request, next: Next) async throws -> Response { +/// // log the request +/// Log.info("\(request.head.method.rawValue) \(request.path)") +/// +/// // await and log the response +/// let response = try await next(request) +/// Log.info("\(response.status.code) \(request.head.method.rawValue) \(request.path)") +/// return response /// } /// } /// -/// // Example asynchronous middleware -/// struct AsyncMiddleware: Middleware { -/// func intercept(_ request: Request, next: @escaping Next) throws -> EventLoopFuture -/// // Run some async operation -/// Database.default -/// .rawQuery(...) -/// .flatMap { someData in -/// // Set some data on the request for access in -/// // subsequent Middleware or request handlers. -/// // See `HTTPRequst.set` for more detail. -/// request.set(someData) -/// return next(request) -/// } +/// // Find and set a user on a Request if the request path has a +/// // `user_id` parameter +/// struct FindUserMiddleware: Middleware { +/// func intercept(_ request: Request, next: Next) async throws -> Response { +/// let userId = request.parameter(for: "user_id") +/// let user = try await User.find(userId) +/// // Set some data on the request for access in subsequent +/// // Middleware or request handlers. See `HTTPRequst.set` +/// // for more detail. +/// return try await next(request.set(user)) /// } /// } /// ``` public protocol Middleware { /// Passes a request to the next piece of the handler chain. It is - /// a closure that expects a request and returns a future - /// containing a response. - typealias Next = (Request) -> EventLoopFuture + /// a closure that expects a request and returns a response. + typealias Next = (Request) async throws -> Response - /// Intercept a requst, returning a future with a Response - /// representing the result of the subsequent handlers. + /// Intercept a requst, returning a Response representing from + /// the subsequent handlers. /// - /// Be sure to call next when returning, unless you don't want the - /// request to be handled. + /// Be sure to call `next` when returning, unless you don't want + /// the request to be handled. /// /// - Parameter request: The incoming request to intercept, then /// pass along the handler chain. /// - Throws: Any error encountered when intercepting the request. - func intercept(_ request: Request, next: @escaping Next) throws -> EventLoopFuture + func intercept(_ request: Request, next: Next) async throws -> Response } diff --git a/Sources/Alchemy/Middleware/StaticFileMiddleware.swift b/Sources/Alchemy/Middleware/StaticFileMiddleware.swift deleted file mode 100644 index 341c5027..00000000 --- a/Sources/Alchemy/Middleware/StaticFileMiddleware.swift +++ /dev/null @@ -1,128 +0,0 @@ -import Foundation -import NIO -import NIOHTTP1 - -/// Middleware for serving static files from a given directory. -/// -/// Usage: -/// ```swift -/// /// Will server static files from the 'public' directory of -/// /// your project. -/// app.useAll(StaticFileMiddleware(from: "public")) -/// ``` -/// Now your router will serve the files that are in the `Public` -/// directory. -public struct StaticFileMiddleware: Middleware { - /// The directory from which static files will be served. - private let directory: String - - /// The file IO helper for streaming files. - private let fileIO = NonBlockingFileIO(threadPool: .default) - - /// Used for allocating buffers when pulling out file data. - private let bufferAllocator = ByteBufferAllocator() - - /// Creates a new middleware to serve static files from a given - /// directory. Directory defaults to "public/". - /// - /// - Parameter directory: The directory to server static files - /// from. Defaults to "Public/". - public init(from directory: String = "Public/") { - self.directory = directory.hasSuffix("/") ? directory : "\(directory)/" - } - - // MARK: Middleware - - public func intercept(_ request: Request, next: @escaping Next) throws -> EventLoopFuture { - // Ignore non `GET` requests. - guard request.method == .GET else { - return next(request) - } - - let filePath = try self.directory + self.sanitizeFilePath(request.path) - - // See if there's a file at the given path - var isDirectory: ObjCBool = false - let exists = FileManager.default.fileExists(atPath: filePath, isDirectory: &isDirectory) - - if exists && !isDirectory.boolValue { - let fileInfo = try FileManager.default.attributesOfItem(atPath: filePath) - guard let fileSizeBytes = (fileInfo[.size] as? NSNumber)?.intValue else { - Log.error("[StaticFileMiddleware] attempted to access file at `\(filePath)` but it didn't have a size.") - throw HTTPError(.internalServerError) - } - - let fileHandle = try NIOFileHandle(path: filePath) - let response = Response { responseWriter in - // Set any relevant headers based off the file info. - var headers: HTTPHeaders = ["content-length": "\(fileSizeBytes)"] - if let ext = filePath.components(separatedBy: ".").last, - let mediaType = MIMEType(fileExtension: ext) { - headers.add(name: "content-type", value: mediaType.value) - } - responseWriter.writeHead(status: .ok, headers) - - // Load the file in chunks, streaming it. - self.fileIO.readChunked( - fileHandle: fileHandle, - byteCount: fileSizeBytes, - chunkSize: NonBlockingFileIO.defaultChunkSize, - allocator: self.bufferAllocator, - eventLoop: Loop.current, - chunkHandler: { buffer in - responseWriter.writeBody(buffer) - return .new(()) - } - ) - .flatMapThrowing { - try fileHandle.close() - } - .whenComplete { result in - try? fileHandle.close() - switch result { - case .failure(let error): - // Not a ton that can be done in the case of - // an error, not sure what else can be done - // besides logging and ending the request. - Log.error("[StaticFileMiddleware] Encountered an error loading a static file: \(error)") - responseWriter.writeEnd() - case .success: - responseWriter.writeEnd() - } - } - } - - return .new(response) - } else { - // No file, continue to handlers. - return next(request) - } - } - - /// Sanitize a file path, returning the new sanitized path. - /// - /// - Parameter path: The path to sanitize for file access. - /// - Throws: An error if the path is forbidden. - /// - Returns: The sanitized path, appropriate for loading files - /// from. - private func sanitizeFilePath(_ path: String) throws -> String { - var sanitizedPath = path - - // Ensure path is relative to the current directory. - while sanitizedPath.hasPrefix("/") { - sanitizedPath = String(sanitizedPath.dropFirst()) - } - - // Ensure path doesn't contain any parent directories. - guard !sanitizedPath.contains("../") else { - throw HTTPError(.forbidden) - } - - // Route / to - if sanitizedPath.isEmpty { - sanitizedPath = "index.html" - } - - return sanitizedPath - } -} diff --git a/Sources/Alchemy/Queue/Drivers/QueueDriver.swift b/Sources/Alchemy/Queue/Drivers/QueueDriver.swift deleted file mode 100644 index 7619a7ad..00000000 --- a/Sources/Alchemy/Queue/Drivers/QueueDriver.swift +++ /dev/null @@ -1,127 +0,0 @@ -import NIO - -/// Conform to this protocol to implement a custom driver for the -/// `Queue` class. -public protocol QueueDriver { - /// Add a job to the end of the Queue. - func enqueue(_ job: JobData) -> EventLoopFuture - /// Dequeue the next job from the given channel. - func dequeue(from channel: String) -> EventLoopFuture - /// Handle an in progress job that has been completed with the - /// given outcome. - /// - /// The `JobData` will have any fields that should be updated - /// (such as `attempts`) already updated when it is passed - /// to this function. - func complete(_ job: JobData, outcome: JobOutcome) -> EventLoopFuture -} - -/// An outcome of when a job is run. It should either be flagged as -/// successful, failed, or be retried. -public enum JobOutcome { - /// The job succeeded. - case success - /// The job failed. - case failed - /// The job should be requeued. - case retry -} - -extension QueueDriver { - /// Dequeue the next job from a given set of channels, ordered by - /// priority. - /// - /// - Parameter channels: The channels to dequeue from. - /// - Returns: A future containing a dequeued `Job`, if there is - /// one. - func dequeue(from channels: [String]) -> EventLoopFuture { - guard let channel = channels.first else { - return .new(nil) - } - - return dequeue(from: channel) - .flatMap { result in - guard let result = result else { - return dequeue(from: Array(channels.dropFirst())) - } - - return .new(result) - } - } - - /// Start monitoring a queue for jobs to run. - /// - /// - Parameters: - /// - channels: The channels this worker should monitor. - /// - pollRate: The rate at which the worker should check the - /// queue for work. - /// - eventLoop: The loop on which this worker should run. - func startWorker(for channels: [String], pollRate: TimeAmount, on eventLoop: EventLoop) { - return eventLoop.execute { - self.runNext(from: channels) - .whenComplete { _ in - // Run check again in the `pollRate`. - eventLoop.scheduleTask(in: pollRate) { - self.startWorker(for: channels, pollRate: pollRate, on: eventLoop) - } - } - } - } - - private func runNext(from channels: [String]) -> EventLoopFuture { - dequeue(from: channels) - .flatMapErrorThrowing { - Log.error("[Queue] error dequeueing job from `\(channels)`. \($0)") - throw $0 - } - .flatMap { jobData in - guard let jobData = jobData else { - return .new() - } - - Log.debug("Dequeued job \(jobData.jobName) from queue \(jobData.channel)") - return self.execute(jobData) - .flatMap { self.runNext(from: channels) } - } - } - - private func execute(_ jobData: JobData) -> EventLoopFuture { - var jobData = jobData - return catchError { - do { - let job = try JobDecoding.decode(jobData) - return job.run() - .always { - job.finished(result: $0) - do { - jobData.json = try job.jsonString() - } catch { - Log.error("[QueueWorker] tried updating Job persistance object after completion, but encountered error \(error)") - } - } - } catch { - Log.error("error decoding job named \(jobData.jobName). Error was: \(error).") - throw error - } - } - .flatMapAlways { (result: Result) -> EventLoopFuture in - jobData.attempts += 1 - switch result { - case .success: - return self.complete(jobData, outcome: .success) - case .failure where jobData.canRetry: - jobData.backoffUntil = jobData.nextRetryDate() - return self.complete(jobData, outcome: .retry) - case .failure(let error): - if let err = error as? JobError, err == JobError.unknownType { - // Always retry if the type was unknown, and - // ignore the attempt. - jobData.attempts -= 1 - return self.complete(jobData, outcome: .retry) - } else { - return self.complete(jobData, outcome: .failed) - } - } - } - } -} diff --git a/Sources/Alchemy/Queue/Drivers/RedisQueue.swift b/Sources/Alchemy/Queue/Drivers/RedisQueue.swift deleted file mode 100644 index fa6c512b..00000000 --- a/Sources/Alchemy/Queue/Drivers/RedisQueue.swift +++ /dev/null @@ -1,124 +0,0 @@ -import NIO -import RediStack - -/// A queue that persists jobs to a Redis instance. -final class RedisQueue: QueueDriver { - /// The underlying redis connection. - private let redis: Redis - /// All job data. - private let dataKey = RedisKey("jobs:data") - /// All processing jobs. - private let processingKey = RedisKey("jobs:processing") - /// All backed off jobs. "job_id" : "backoff:channel" - private let backoffsKey = RedisKey("jobs:backoffs") - - /// Initialize with a Redis instance to persist jobs to. - /// - /// - Parameter redis: The Redis instance. - init(redis: Redis = .default) { - self.redis = redis - monitorBackoffs() - } - - private func monitorBackoffs() { - let loop = Loop.group.next() - loop.scheduleRepeatedAsyncTask(initialDelay: .zero, delay: .seconds(1)) { (task: RepeatedTask) -> - EventLoopFuture in - return self.redis - // Get and remove backoffs that can be rerun. - .transaction { conn -> EventLoopFuture in - let set = RESPValue(from: self.backoffsKey.rawValue) - let min = RESPValue(from: 0) - let max = RESPValue(from: Date().timeIntervalSince1970) - return conn.send(command: "ZRANGEBYSCORE", with: [set, min, max]) - .flatMap { _ in conn.send(command: "ZREMRANGEBYSCORE", with: [set, min, max]) } - } - .map { (value: RESPValue) -> [String] in - guard let values = value.array, let scores = values.first?.array, !scores.isEmpty else { - return [] - } - - return scores.compactMap(\.string) - } - .flatMapEach(on: loop) { backoffKey -> EventLoopFuture in - let values = backoffKey.split(separator: ":") - let jobId = String(values[0]) - let channel = String(values[1]) - let queueList = self.key(for: channel) - return self.redis.lpush(jobId, into: queueList).voided() - } - .voided() - } - } - - // MARK: - Queue - - func enqueue(_ job: JobData) -> EventLoopFuture { - return self.storeJobData(job) - .flatMap { self.redis.lpush(job.id, into: self.key(for: job.channel)) } - .voided() - } - - private func storeJobData(_ job: JobData) -> EventLoopFuture { - catchError { - let jsonString = try job.jsonString() - return redis.hset(job.id, to: jsonString, in: self.dataKey).voided() - } - } - - func dequeue(from channel: String) -> EventLoopFuture { - /// Move from queueList to processing - let queueList = key(for: channel) - return self.redis.rpoplpush(from: queueList, to: self.processingKey, valueType: String.self) - .flatMap { jobID in - guard let jobID = jobID else { - return .new(nil) - } - - return self.redis - .hget(jobID, from: self.dataKey, as: String.self) - .unwrap(orError: JobError("Missing job data for key `\(jobID)`.")) - .flatMapThrowing { try JobData(jsonString: $0) } - } - } - - func complete(_ job: JobData, outcome: JobOutcome) -> EventLoopFuture { - switch outcome { - case .success, .failed: - // Remove from processing. - return self.redis.lrem(job.id, from: self.processingKey) - // Remove job data. - .flatMap { _ in self.redis.hdel(job.id, from: self.dataKey) } - .voided() - case .retry: - // Remove from processing - return self.redis.lrem(job.id, from: self.processingKey) - .flatMap { _ in - if let backoffUntil = job.backoffUntil { - let backoffKey = "\(job.id):\(job.channel)" - let backoffScore = backoffUntil.timeIntervalSince1970 - return self.storeJobData(job) - .flatMap { self.redis.zadd((backoffKey, backoffScore), to: self.backoffsKey) } - .voided() - } else { - return self.enqueue(job) - } - } - } - } - - private func key(for channel: String) -> RedisKey { - RedisKey("jobs:queue:\(channel)") - } -} - -public extension Queue { - /// A queue backed by a Redis connection. - /// - /// - Parameter redis: A redis connection to drive this queue. - /// Defaults to your default redis connection. - /// - Returns: The configured queue. - static func redis(_ redis: Redis = Redis.default) -> Queue { - Queue(RedisQueue(redis: redis)) - } -} diff --git a/Sources/Alchemy/Queue/Job.swift b/Sources/Alchemy/Queue/Job.swift index 09f89c14..6a3e52da 100644 --- a/Sources/Alchemy/Queue/Job.swift +++ b/Sources/Alchemy/Queue/Job.swift @@ -1,6 +1,6 @@ import NIO -/// A task that can be persisted and queued for future handling. +/// A task that can be persisted and queued for background processing. public protocol Job: Codable { /// The name of this Job. Defaults to the type name. static var name: String { get } @@ -14,8 +14,10 @@ public protocol Job: Codable { /// Called when a job finishes, either successfully or with too /// many failed attempts. func finished(result: Result) + /// Called when a job fails, whether it can be retried or not. + func failed(error: Error) /// Run this Job. - func run() -> EventLoopFuture + func run() async throws } // Default implementations. @@ -32,9 +34,11 @@ extension Job { Log.error("[Queue] Job '\(Self.name)' failed with error: \(error).") } } + + public func failed(error: Error) {} } -public enum RecoveryStrategy { +public enum RecoveryStrategy: Equatable { /// Removes task from the queue case none /// Retries the task a specified amount of times @@ -51,6 +55,17 @@ public enum RecoveryStrategy { } } +extension TimeAmount: Codable { + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + try container.encode(nanoseconds) + } + + public init(from decoder: Decoder) throws { + self = .nanoseconds(try decoder.singleValueContainer().decode(Int64.self)) + } +} + extension RecoveryStrategy: Codable { enum CodingKeys: String, CodingKey { case none, retry diff --git a/Sources/Alchemy/Queue/JobEncoding/JobData.swift b/Sources/Alchemy/Queue/JobEncoding/JobData.swift index ccba2e5c..b99f8e0e 100644 --- a/Sources/Alchemy/Queue/JobEncoding/JobData.swift +++ b/Sources/Alchemy/Queue/JobEncoding/JobData.swift @@ -3,9 +3,9 @@ import NIO public typealias JobID = String public typealias JSONString = String -/// Represents a persisted Job, contains the serialized Job as well -/// as some additional info for `Queue`s & `QueueWorker`s. -public struct JobData: Codable { +/// Represents a persisted Job, contains the serialized Job as well as some +/// additional info for `Queue`s. +public struct JobData: Codable, Equatable { /// The unique id of this job, by default this is a UUID string. public let id: JobID /// The serialized Job this persists. @@ -18,15 +18,14 @@ public struct JobData: Codable { public let recoveryStrategy: RecoveryStrategy /// How long should be waited before retrying a Job after a /// failure. - public let backoffSeconds: Int + public let backoff: TimeAmount /// Don't run this again until this time. public var backoffUntil: Date? /// The number of attempts this Job has been attempted. public var attempts: Int - /// Can this job be retried. public var canRetry: Bool { - self.attempts <= self.recoveryStrategy.maximumRetries + attempts <= recoveryStrategy.maximumRetries } /// Indicates if this job is currently in backoff, and should not @@ -47,12 +46,17 @@ public struct JobData: Codable { /// - channel: The name of the queue the `job` belongs on. /// - Throws: If the `job` is unable to be serialized to a String. public init(_ job: J, id: String = UUID().uuidString, channel: String) throws { + // If the Job hasn't been registered, register it. + if !JobDecoding.isRegistered(J.self) { + JobDecoding.register(J.self) + } + self.id = id self.jobName = J.name self.channel = channel self.recoveryStrategy = job.recoveryStrategy self.attempts = 0 - self.backoffSeconds = job.retryBackoff.seconds + self.backoff = job.retryBackoff self.backoffUntil = nil do { self.json = try job.jsonString() @@ -81,7 +85,7 @@ public struct JobData: Codable { self.jobName = jobName self.channel = channel self.recoveryStrategy = recoveryStrategy - self.backoffSeconds = retryBackoff.seconds + self.backoff = retryBackoff self.attempts = attempts self.backoffUntil = backoffUntil } @@ -89,19 +93,6 @@ public struct JobData: Codable { /// The next date this job can be attempted. `nil` if the job can /// be retried immediately. func nextRetryDate() -> Date? { - return backoffSeconds > 0 ? Date().addingTimeInterval(TimeInterval(backoffSeconds)) : nil - } - - /// Update the job payload. - /// - /// - Parameter job: The new job payload. - /// - Throws: Any error encountered while encoding this payload - /// to a string. - mutating func updatePayload(_ job: J) throws { - do { - self.json = try job.jsonString() - } catch { - throw JobError("Error updating JobData payload to Job type `\(J.name)`: \(error)") - } + return backoff.seconds > 0 ? Date().addingTimeInterval(TimeInterval(backoff.seconds)) : nil } } diff --git a/Sources/Alchemy/Queue/JobEncoding/JobDecoding.swift b/Sources/Alchemy/Queue/JobEncoding/JobDecoding.swift index 16995931..d9705c93 100644 --- a/Sources/Alchemy/Queue/JobEncoding/JobDecoding.swift +++ b/Sources/Alchemy/Queue/JobEncoding/JobDecoding.swift @@ -1,13 +1,22 @@ +import NIOConcurrencyHelpers + /// Storage for `Job` decoding behavior. struct JobDecoding { + static var registeredJobs: [Job.Type] = [] + /// Stored decoding behavior for jobs. - @Locked private static var decoders: [String: (JobData) throws -> Job] = [:] + private static var decoders: [String: (JobData) throws -> Job] = [:] + + private static let lock = Lock() /// Register a job to cache its decoding behavior. /// /// - Parameter type: A job type. static func register(_ type: J.Type) { - self.decoders[J.name] = { try J(jsonString: $0.json) } + lock.withLock { + decoders[J.name] = { try J(jsonString: $0.json) } + registeredJobs.append(type) + } } /// Indicates if the given type is already registered. @@ -15,7 +24,9 @@ struct JobDecoding { /// - Parameter type: A job type. /// - Returns: Whether this job type is already registered. static func isRegistered(_ type: J.Type) -> Bool { - decoders[J.name] != nil + lock.withLock { + decoders[J.name] != nil + } } /// Decode a job from the given job data. @@ -24,10 +35,25 @@ struct JobDecoding { /// - Throws: Any errors encountered while decoding the job. /// - Returns: The decoded job. static func decode(_ jobData: JobData) throws -> Job { - guard let decoder = JobDecoding.decoders[jobData.jobName] else { - throw JobError("Unknown job of type '\(jobData.jobName)'. Please register it via `app.registerJob(MyJob.self)`.") + try lock.withLock { + guard let decoder = decoders[jobData.jobName] else { + Log.warning("Unknown job of type '\(jobData.jobName)'. Please register it via `app.registerJob(\(jobData.jobName).self)`.") + throw JobError.unknownType + } + + do { + return try decoder(jobData) + } catch { + Log.error("[Queue] error decoding job named \(jobData.jobName). Error was: \(error).") + throw error + } + } + } + + static func reset() { + lock.withLock { + decoders = [:] + registeredJobs = [] } - - return try decoder(jobData) } } diff --git a/Sources/Alchemy/Queue/Drivers/DatabaseQueue.swift b/Sources/Alchemy/Queue/Providers/DatabaseQueue.swift similarity index 68% rename from Sources/Alchemy/Queue/Drivers/DatabaseQueue.swift rename to Sources/Alchemy/Queue/Providers/DatabaseQueue.swift index 1cc84163..7f7c22d4 100644 --- a/Sources/Alchemy/Queue/Drivers/DatabaseQueue.swift +++ b/Sources/Alchemy/Queue/Providers/DatabaseQueue.swift @@ -1,7 +1,7 @@ import Foundation /// A queue that persists jobs to a database. -final class DatabaseQueue: QueueDriver { +final class DatabaseQueue: QueueProvider { /// The database backing this queue. private let database: Database @@ -9,46 +9,43 @@ final class DatabaseQueue: QueueDriver { /// /// - Parameters: /// - database: The database. - init(database: Database = .default) { + init(database: Database = DB) { self.database = database } // MARK: - Queue - func enqueue(_ job: JobData) -> EventLoopFuture { - JobModel(jobData: job).insert(db: database).voided() + func enqueue(_ job: JobData) async throws { + _ = try await JobModel(jobData: job).insertReturn(db: database) } - func dequeue(from channel: String) -> EventLoopFuture { - return database.transaction { (database: Database) -> EventLoopFuture in - return JobModel.query(database: database) + func dequeue(from channel: String) async throws -> JobData? { + return try await database.transaction { conn in + let job = try await JobModel.query(database: conn) .where("reserved" != true) .where("channel" == channel) .where { $0.whereNull(key: "backoff_until").orWhere("backoff_until" < Date()) } - .orderBy(column: "queued_at") + .orderBy("queued_at") .limit(1) - .forLock(.update, option: .skipLocked) - .firstModel() - .optionalFlatMap { job -> EventLoopFuture in - var job = job - job.reserved = true - job.reservedAt = Date() - return job.save(db: database) - } - .map { $0?.toJobData() } + .lock(for: .update, option: .skipLocked) + .first() + + return try await job?.update(db: conn) { + $0.reserved = true + $0.reservedAt = Date() + }.toJobData() } } - func complete(_ job: JobData, outcome: JobOutcome) -> EventLoopFuture { + func complete(_ job: JobData, outcome: JobOutcome) async throws { switch outcome { case .success, .failed: - return JobModel.query(database: database) + _ = try await JobModel.query(database: database) .where("id" == job.id) .where("channel" == job.channel) .delete() - .voided() case .retry: - return JobModel(jobData: job).update(db: database).voided() + _ = try await JobModel(jobData: job).update(db: database) } } } @@ -59,15 +56,20 @@ public extension Queue { /// - Parameter database: A database to drive this queue with. /// Defaults to your default database. /// - Returns: The configured queue. - static func database(_ database: Database = .default) -> Queue { - Queue(DatabaseQueue(database: database)) + static func database(_ database: Database = DB) -> Queue { + Queue(provider: DatabaseQueue(database: database)) + } + + /// A queue backed by the default SQL database. + static var database: Queue { + .database() } } // MARK: - Models /// Represents the table of jobs backing a `DatabaseQueue`. -private struct JobModel: Model { +struct JobModel: Model { static var tableName: String = "jobs" var id: String? @@ -90,14 +92,14 @@ private struct JobModel: Model { json = jobData.json attempts = jobData.attempts recoveryStrategy = jobData.recoveryStrategy - backoffSeconds = jobData.backoffSeconds + backoffSeconds = jobData.backoff.seconds backoffUntil = jobData.backoffUntil reserved = false } - func toJobData() -> JobData { - return JobData( - id: (try? getID()) ?? "N/A", + func toJobData() throws -> JobData { + JobData( + id: try getID(), json: json, jobName: jobName, channel: channel, diff --git a/Sources/Alchemy/Queue/Drivers/MockQueue.swift b/Sources/Alchemy/Queue/Providers/MemoryQueue.swift similarity index 60% rename from Sources/Alchemy/Queue/Drivers/MockQueue.swift rename to Sources/Alchemy/Queue/Providers/MemoryQueue.swift index bb9f7261..d6a53840 100644 --- a/Sources/Alchemy/Queue/Drivers/MockQueue.swift +++ b/Sources/Alchemy/Queue/Providers/MemoryQueue.swift @@ -3,10 +3,10 @@ import NIO /// A queue that persists jobs to memory. Jobs will be lost if the /// app shuts down. Useful for tests. -final class MockQueue: QueueDriver { - private var jobs: [JobID: JobData] = [:] - private var pending: [String: [JobID]] = [:] - private var reserved: [String: [JobID]] = [:] +public final class MemoryQueue: QueueProvider { + var jobs: [JobID: JobData] = [:] + var pending: [String: [JobID]] = [:] + var reserved: [String: [JobID]] = [:] private let lock = NSRecursiveLock() @@ -14,16 +14,15 @@ final class MockQueue: QueueDriver { // MARK: - Queue - func enqueue(_ job: JobData) -> EventLoopFuture { + public func enqueue(_ job: JobData) async throws { lock.lock() defer { lock.unlock() } jobs[job.id] = job append(id: job.id, on: job.channel, dict: &pending) - return .new() } - func dequeue(from channel: String) -> EventLoopFuture { + public func dequeue(from channel: String) async throws -> JobData? { lock.lock() defer { lock.unlock() } @@ -34,14 +33,14 @@ final class MockQueue: QueueDriver { }), let job = jobs[id] else { - return .new(nil) + return nil } append(id: id, on: job.channel, dict: &reserved) - return .new(job) + return job } - func complete(_ job: JobData, outcome: JobOutcome) -> EventLoopFuture { + public func complete(_ job: JobData, outcome: JobOutcome) async throws { lock.lock() defer { lock.unlock() } @@ -49,10 +48,9 @@ final class MockQueue: QueueDriver { case .success, .failed: reserved[job.channel]?.removeAll(where: { $0 == job.id }) jobs.removeValue(forKey: job.id) - return .new() case .retry: reserved[job.channel]?.removeAll(where: { $0 == job.id }) - return enqueue(job) + try await enqueue(job) } } @@ -64,9 +62,21 @@ final class MockQueue: QueueDriver { } extension Queue { - /// An in memory queue. Useful primarily for testing. - public static func mock() -> Queue { - Queue(MockQueue()) + /// An in memory queue. + public static var memory: Queue { + Queue(provider: MemoryQueue()) + } + + /// Fake the queue with an in memory queue. Useful for testing. + /// + /// - Parameter id: The identifier of the queue to fake. Defaults to + /// `default`. + /// - Returns: A `MemoryQueue` for verifying test expectations. + @discardableResult + public static func fake(_ identifier: Identifier = .default) -> MemoryQueue { + let mock = MemoryQueue() + bind(identifier, Queue(provider: mock)) + return mock } } @@ -77,10 +87,10 @@ extension Array { /// - Returns: The first matching element, or nil if no elements /// match. fileprivate mutating func popFirst(where conditional: (Element) -> Bool) -> Element? { - if let firstIndex = firstIndex(where: conditional) { - return remove(at: firstIndex) - } else { + guard let firstIndex = firstIndex(where: conditional) else { return nil } + + return remove(at: firstIndex) } } diff --git a/Sources/Alchemy/Queue/Providers/QueueProvider.swift b/Sources/Alchemy/Queue/Providers/QueueProvider.swift new file mode 100644 index 00000000..a489007f --- /dev/null +++ b/Sources/Alchemy/Queue/Providers/QueueProvider.swift @@ -0,0 +1,29 @@ +import NIO + +/// Conform to this protocol to implement a custom queue provider. +public protocol QueueProvider { + /// Enqueue a job. + func enqueue(_ job: JobData) async throws + + /// Dequeue the next job from the given channel. + func dequeue(from channel: String) async throws -> JobData? + + /// Handle an in progress job that has been completed with the + /// given outcome. + /// + /// The `JobData` will have any fields that should be updated + /// (such as `attempts`) already updated when it is passed + /// to this function. + func complete(_ job: JobData, outcome: JobOutcome) async throws +} + +/// An outcome of when a job is run. It should either be flagged as +/// successful, failed, or be retried. +public enum JobOutcome { + /// The job succeeded. + case success + /// The job failed. + case failed + /// The job should be requeued. + case retry +} diff --git a/Sources/Alchemy/Queue/Providers/RedisQueue.swift b/Sources/Alchemy/Queue/Providers/RedisQueue.swift new file mode 100644 index 00000000..a00149fc --- /dev/null +++ b/Sources/Alchemy/Queue/Providers/RedisQueue.swift @@ -0,0 +1,112 @@ +import NIO +import RediStack + +/// A queue that persists jobs to a Redis instance. +struct RedisQueue: QueueProvider { + /// The underlying redis connection. + private let redis: RedisClient + /// All job data. + private let dataKey = RedisKey("jobs:data") + /// All processing jobs. + private let processingKey = RedisKey("jobs:processing") + /// All backed off jobs. "job_id" : "backoff:channel" + private let backoffsKey = RedisKey("jobs:backoffs") + + /// Initialize with a Redis instance to persist jobs to. + /// + /// - Parameter redis: The Redis instance. + init(redis: RedisClient = Redis) { + self.redis = redis + monitorBackoffs() + } + + // MARK: - Queue + + func enqueue(_ job: JobData) async throws { + try await storeJobData(job) + _ = try await redis.lpush(job.id, into: key(for: job.channel)).get() + } + + func dequeue(from channel: String) async throws -> JobData? { + let jobId = try await redis.rpoplpush(from: key(for: channel), to: processingKey, valueType: String.self).get() + guard let jobId = jobId else { + return nil + } + + let jobString = try await redis.hget(jobId, from: dataKey, as: String.self).get() + let unwrappedJobString = try jobString.unwrap(or: JobError("Missing job data for key `\(jobId)`.")) + return try JobData(jsonString: unwrappedJobString) + } + + func complete(_ job: JobData, outcome: JobOutcome) async throws { + _ = try await redis.lrem(job.id, from: processingKey).get() + switch outcome { + case .success, .failed: + _ = try await redis.hdel(job.id, from: dataKey).get() + case .retry: + if let backoffUntil = job.backoffUntil { + let backoffKey = "\(job.id):\(job.channel)" + let backoffScore = backoffUntil.timeIntervalSince1970 + try await storeJobData(job) + _ = try await redis.zadd((backoffKey, backoffScore), to: backoffsKey).get() + } else { + try await enqueue(job) + } + } + } + + // MARK: - Private Helpers + + private func key(for channel: String) -> RedisKey { + RedisKey("jobs:queue:\(channel)") + } + + private func monitorBackoffs() { + let loop = Loop.group.next() + loop.scheduleRepeatedAsyncTask(initialDelay: .zero, delay: .seconds(1)) { _ in + loop.asyncSubmit { + let result = try await redis + // Get and remove backoffs that can be rerun. + .transaction { conn in + let set = RESPValue(from: backoffsKey.rawValue) + let min = RESPValue(from: 0) + let max = RESPValue(from: Date().timeIntervalSince1970) + _ = try await conn.send(command: "ZRANGEBYSCORE", with: [set, min, max]).get() + _ = try await conn.send(command: "ZREMRANGEBYSCORE", with: [set, min, max]).get() + } + + guard let values = result.array, let scores = values.first?.array, !scores.isEmpty else { + return + } + + for backoffKey in scores.compactMap(\.string) { + let values = backoffKey.split(separator: ":") + let jobId = String(values[0]) + let channel = String(values[1]) + _ = try await redis.lpush(jobId, into: key(for: channel)).get() + } + } + } + } + + private func storeJobData(_ job: JobData) async throws { + let jsonString = try job.jsonString() + _ = try await redis.hset(job.id, to: jsonString, in: dataKey).get() + } +} + +public extension Queue { + /// A queue backed by a Redis connection. + /// + /// - Parameter redis: A redis connection to drive this queue. + /// Defaults to your default redis connection. + /// - Returns: The configured queue. + static func redis(_ redis: RedisClient = Redis) -> Queue { + Queue(provider: RedisQueue(redis: redis)) + } + + /// A queue backed by the default Redis connection. + static var redis: Queue { + .redis() + } +} diff --git a/Sources/Alchemy/Queue/Queue+Config.swift b/Sources/Alchemy/Queue/Queue+Config.swift new file mode 100644 index 00000000..c92dddbf --- /dev/null +++ b/Sources/Alchemy/Queue/Queue+Config.swift @@ -0,0 +1,25 @@ +extension Queue { + public struct Config { + public struct JobType { + private init(_ type: J.Type) { + JobDecoding.register(type) + } + + public static func job(_ type: J.Type) -> JobType { + JobType(type) + } + } + + public let queues: [Identifier: Queue] + public let jobs: [JobType] + + public init(queues: [Queue.Identifier : Queue], jobs: [Queue.Config.JobType]) { + self.queues = queues + self.jobs = jobs + } + } + + public static func configure(with config: Config) { + config.queues.forEach { Queue.bind($0, $1) } + } +} diff --git a/Sources/Alchemy/Queue/Queue+Worker.swift b/Sources/Alchemy/Queue/Queue+Worker.swift new file mode 100644 index 00000000..fda5456e --- /dev/null +++ b/Sources/Alchemy/Queue/Queue+Worker.swift @@ -0,0 +1,99 @@ +extension Queue { + /// Start a worker that dequeues and runs jobs from this queue. + /// + /// - Parameters: + /// - channels: The channels this worker should monitor for + /// work. Defaults to `Queue.defaultChannel`. + /// - pollRate: The rate at which this worker should poll the + /// queue for new work. Defaults to `Queue.defaultPollRate`. + /// - eventLoop: The loop this worker will run on. Defaults to + /// your apps next available loop. + public func startWorker(for channels: [String] = [Queue.defaultChannel], pollRate: TimeAmount = Queue.defaultPollRate, untilEmpty: Bool = true, on eventLoop: EventLoop = Loop.group.next()) { + let worker = eventLoop.queueId + Log.info("[Queue] starting worker \(worker)") + workers.append(worker) + _startWorker(for: channels, pollRate: pollRate, untilEmpty: untilEmpty, on: eventLoop) + } + + private func _startWorker(for channels: [String] = [Queue.defaultChannel], pollRate: TimeAmount = Queue.defaultPollRate, untilEmpty: Bool, on eventLoop: EventLoop = Loop.group.next()) { + eventLoop.asyncSubmit { try await self.runNext(from: channels, untilEmpty: untilEmpty) } + .whenComplete { _ in + // Run check again in the `pollRate`. + eventLoop.scheduleTask(in: pollRate) { + self._startWorker(for: channels, pollRate: pollRate, untilEmpty: untilEmpty, on: eventLoop) + } + } + } + + func runNext(from channels: [String], untilEmpty: Bool) async throws { + do { + guard let jobData = try await dequeue(from: channels) else { + return + } + + Log.info("[Queue] dequeued job \(jobData.jobName) from queue \(jobData.channel)") + try await execute(jobData) + + if untilEmpty { + try await runNext(from: channels, untilEmpty: untilEmpty) + } + } catch { + Log.error("[Queue] error running job \(name(of: Self.self)) from `\(channels)`. \(error)") + throw error + } + } + + /// Dequeue the next job from a given set of channels, ordered by + /// priority. + /// + /// - Parameter channels: The channels to dequeue from. + /// - Returns: A dequeued `Job`, if there is one. + func dequeue(from channels: [String]) async throws -> JobData? { + guard let channel = channels.first else { + return nil + } + + if let job = try await provider.dequeue(from: channel) { + return job + } else { + return try await dequeue(from: Array(channels.dropFirst())) + } + } + + private func execute(_ jobData: JobData) async throws { + var jobData = jobData + jobData.attempts += 1 + + func retry(ignoreAttempt: Bool = false) async throws { + if ignoreAttempt { jobData.attempts -= 1 } + jobData.backoffUntil = jobData.nextRetryDate() + try await provider.complete(jobData, outcome: .retry) + } + + var job: Job? + do { + job = try JobDecoding.decode(jobData) + try await job?.run() + try await provider.complete(jobData, outcome: .success) + job?.finished(result: .success(())) + } catch where jobData.canRetry { + try await retry() + job?.failed(error: error) + } catch where (error as? JobError) == JobError.unknownType { + // So that an old worker won't fail new, unrecognized jobs. + try await retry(ignoreAttempt: true) + job?.failed(error: error) + throw error + } catch { + try await provider.complete(jobData, outcome: .failed) + job?.finished(result: .failure(error)) + job?.failed(error: error) + } + } +} + +extension EventLoop { + var queueId: String { + String(ObjectIdentifier(self).debugDescription.dropLast().suffix(6)) + } +} diff --git a/Sources/Alchemy/Queue/Queue.swift b/Sources/Alchemy/Queue/Queue.swift index 38a0d572..bae71bbf 100644 --- a/Sources/Alchemy/Queue/Queue.swift +++ b/Sources/Alchemy/Queue/Queue.swift @@ -1,21 +1,30 @@ import NIO /// Queue lets you run queued jobs to be processed in the background. -/// Jobs are persisted by the given `QueueDriver`. +/// Jobs are persisted by the given `QueueProvider`. public final class Queue: Service { + public struct Identifier: ServiceIdentifier { + private let hashable: AnyHashable + public init(hashable: AnyHashable) { self.hashable = hashable } + } + /// The default channel to dispatch jobs on for all queues. public static let defaultChannel = "default" /// The default rate at which workers poll queues. public static let defaultPollRate: TimeAmount = .seconds(1) - /// The driver backing this queue. - private let driver: QueueDriver + /// The ids of any workers associated with this queue and running in this + /// process. + public var workers: [String] = [] - /// Initialize a queue backed by the given driver. + /// The provider backing this queue. + let provider: QueueProvider + + /// Initialize a queue backed by the given provider. /// - /// - Parameter driver: A queue driver to back this queue with. - public init(_ driver: QueueDriver) { - self.driver = driver + /// - Parameter provider: A queue provider to back this queue with. + public init(provider: QueueProvider) { + self.provider = provider } /// Enqueues a generic `Job` to this queue on the given channel. @@ -24,30 +33,8 @@ public final class Queue: Service { /// - job: A job to enqueue to this queue. /// - channel: The channel on which to enqueue the job. Defaults /// to `Queue.defaultChannel`. - /// - Returns: An future that completes when the job is enqueued. - public func enqueue(_ job: J, channel: String = defaultChannel) -> EventLoopFuture { - // If the Job hasn't been registered, register it. - if !JobDecoding.isRegistered(J.self) { - JobDecoding.register(J.self) - } - return catchError { driver.enqueue(try JobData(job, channel: channel)) } - } - - /// Start a worker that dequeues and runs jobs from this queue. - /// - /// - Parameters: - /// - channels: The channels this worker should monitor for - /// work. Defaults to `Queue.defaultChannel`. - /// - pollRate: The rate at which this worker should poll the - /// queue for new work. Defaults to `Queue.defaultPollRate`. - /// - eventLoop: The loop this worker will run on. Defaults to - /// your apps next available loop. - public func startWorker( - for channels: [String] = [Queue.defaultChannel], - pollRate: TimeAmount = Queue.defaultPollRate, - on eventLoop: EventLoop = Loop.group.next() - ) { - driver.startWorker(for: channels, pollRate: pollRate, on: eventLoop) + public func enqueue(_ job: J, channel: String = defaultChannel) async throws { + try await provider.enqueue(JobData(job, channel: channel)) } } @@ -57,9 +44,7 @@ extension Job { /// - Parameters: /// - queue: The queue to dispatch on. /// - channel: The name of the channel to dispatch on. - /// - Returns: A future that completes when this job has been - /// dispatched to the queue. - public func dispatch(on queue: Queue = .default, channel: String = Queue.defaultChannel) -> EventLoopFuture { - queue.enqueue(self, channel: channel) + public func dispatch(on queue: Queue = Q, channel: String = Queue.defaultChannel) async throws { + try await queue.enqueue(self, channel: channel) } } diff --git a/Sources/Alchemy/Redis/Redis+Commands.swift b/Sources/Alchemy/Redis/Redis+Commands.swift index 255e8d41..233c86ed 100644 --- a/Sources/Alchemy/Redis/Redis+Commands.swift +++ b/Sources/Alchemy/Redis/Redis+Commands.swift @@ -1,21 +1,21 @@ -import Foundation +import NIO import RediStack /// RedisClient conformance. See `RedisClient` for docs. -extension Redis: RedisClient { +extension RedisClient: RediStack.RedisClient { - // MARK: RedisClient + // MARK: RediStack.RedisClient public var eventLoop: EventLoop { Loop.current } - public func logging(to logger: Logger) -> RedisClient { - driver.getClient().logging(to: logger) + public func logging(to logger: Logger) -> RediStack.RedisClient { + provider.getClient().logging(to: logger) } public func send(command: String, with arguments: [RESPValue]) -> EventLoopFuture { - driver.getClient() + provider.getClient() .send(command: command, with: arguments).hop(to: Loop.current) } @@ -25,7 +25,7 @@ extension Redis: RedisClient { onSubscribe subscribeHandler: RedisSubscriptionChangeHandler?, onUnsubscribe unsubscribeHandler: RedisSubscriptionChangeHandler? ) -> EventLoopFuture { - driver.getClient() + provider.getClient() .subscribe( to: channels, messageReceiver: receiver, @@ -40,7 +40,7 @@ extension Redis: RedisClient { onSubscribe subscribeHandler: RedisSubscriptionChangeHandler?, onUnsubscribe unsubscribeHandler: RedisSubscriptionChangeHandler? ) -> EventLoopFuture { - driver.getClient() + provider.getClient() .psubscribe( to: patterns, messageReceiver: receiver, @@ -50,11 +50,11 @@ extension Redis: RedisClient { } public func unsubscribe(from channels: [RedisChannelName]) -> EventLoopFuture { - driver.getClient().unsubscribe(from: channels) + provider.getClient().unsubscribe(from: channels) } public func punsubscribe(from patterns: [String]) -> EventLoopFuture { - driver.getClient().punsubscribe(from: patterns) + provider.getClient().punsubscribe(from: patterns) } // MARK: - Alchemy sugar @@ -64,10 +64,9 @@ extension Redis: RedisClient { /// - Parameters: /// - name: The name of the command. /// - args: Any arguments for the command. - /// - Returns: A future containing the return value of the - /// command. - public func command(_ name: String, args: RESPValueConvertible...) -> EventLoopFuture { - self.command(name, args: args) + /// - Returns: The return value of the command. + public func command(_ name: String, args: RESPValueConvertible...) async throws -> RESPValue { + try await command(name, args: args) } /// Wrapper around sending commands to Redis. @@ -75,10 +74,9 @@ extension Redis: RedisClient { /// - Parameters: /// - name: The name of the command. /// - args: An array of arguments for the command. - /// - Returns: A future containing the return value of the - /// command. - public func command(_ name: String, args: [RESPValueConvertible]) -> EventLoopFuture { - self.send(command: name, with: args.map { $0.convertedToRESPValue() }) + /// - Returns: The return value of the command. + public func command(_ name: String, args: [RESPValueConvertible]) async throws -> RESPValue { + try await send(command: name, with: args.map { $0.convertedToRESPValue() }).get() } /// Evaluate the given Lua script. @@ -88,10 +86,9 @@ extension Redis: RedisClient { /// - keys: The arguments that represent Redis keys. See /// [EVAL](https://redis.io/commands/eval) docs for details. /// - args: All other arguments. - /// - Returns: A future that completes with the result of the - /// script. - public func eval(_ script: String, keys: [String] = [], args: [RESPValueConvertible] = []) -> EventLoopFuture { - self.command("EVAL", args: [script] + [keys.count] + keys + args) + /// - Returns: The result of the script. + public func eval(_ script: String, keys: [String] = [], args: [RESPValueConvertible] = []) async throws -> RESPValue { + try await command("EVAL", args: [script] + [keys.count] + keys + args) } /// Subscribe to a single channel. @@ -100,33 +97,33 @@ extension Redis: RedisClient { /// - channel: The name of the channel to subscribe to. /// - messageReciver: The closure to execute when a message /// comes through the given channel. - /// - Returns: A future that completes when the subscription is - /// established. - public func subscribe(to channel: RedisChannelName, messageReciver: @escaping (RESPValue) -> Void) -> EventLoopFuture { - self.subscribe(to: [channel]) { _, value in messageReciver(value) } + public func subscribe(to channel: RedisChannelName, messageReciver: @escaping (RESPValue) -> Void) async throws { + try await subscribe(to: [channel]) { _, value in messageReciver(value) }.get() } /// Sends a Redis transaction over a single connection. Wrapper around /// "MULTI" ... "EXEC". - public func transaction(_ action: @escaping (Redis) -> EventLoopFuture) -> EventLoopFuture { - driver.leaseConnection { conn in - return conn.send(command: "MULTI") - .flatMap { _ in action(Redis(driver: conn)) } - .flatMap { _ in return conn.send(command: "EXEC") } + /// + /// - Returns: The result of finishing the transaction. + public func transaction(_ action: @escaping (RedisClient) async throws -> Void) async throws -> RESPValue { + try await provider.transaction { conn in + _ = try await conn.getClient().send(command: "MULTI").get() + try await action(RedisClient(provider: conn)) + return try await conn.getClient().send(command: "EXEC").get() } } } -extension RedisConnection: RedisDriver { - func getClient() -> RedisClient { +extension RedisConnection: RedisProvider { + public func getClient() -> RediStack.RedisClient { self } - func shutdown() throws { + public func shutdown() throws { try close().wait() } - func leaseConnection(_ transaction: @escaping (RedisConnection) -> EventLoopFuture) -> EventLoopFuture { - transaction(self) + public func transaction(_ transaction: @escaping (RedisProvider) async throws -> T) async throws -> T { + try await transaction(self) } } diff --git a/Sources/Alchemy/Redis/Redis.swift b/Sources/Alchemy/Redis/RedisClient.swift similarity index 61% rename from Sources/Alchemy/Redis/Redis.swift rename to Sources/Alchemy/Redis/RedisClient.swift index 24d7c9a0..0250da47 100644 --- a/Sources/Alchemy/Redis/Redis.swift +++ b/Sources/Alchemy/Redis/RedisClient.swift @@ -1,14 +1,23 @@ import NIO +import NIOConcurrencyHelpers import RediStack /// A client for interfacing with a Redis instance. -public struct Redis: Service { - let driver: RedisDriver +public struct RedisClient: Service { + public struct Identifier: ServiceIdentifier { + private let hashable: AnyHashable + public init(hashable: AnyHashable) { self.hashable = hashable } + } + + let provider: RedisProvider + + public init(provider: RedisProvider) { + self.provider = provider + } - /// Shuts down this `Redis` client, closing it's associated - /// connection pools. + /// Shuts down this client, closing it's associated connection pools. public func shutdown() throws { - try driver.shutdown() + try provider.shutdown() } /// A single redis connection @@ -18,7 +27,7 @@ public struct Redis: Service { password: String? = nil, database: Int? = nil, poolSize: RedisConnectionPoolSize = .maximumActiveConnections(1) - ) -> Redis { + ) -> RedisClient { return .cluster(.ip(host: host, port: port), password: password, database: database, poolSize: poolSize) } @@ -40,8 +49,8 @@ public struct Redis: Service { password: String? = nil, database: Int? = nil, poolSize: RedisConnectionPoolSize = .maximumActiveConnections(1) - ) -> Redis { - return .rawPoolConfiguration( + ) -> RedisClient { + return .configuration( RedisConnectionPool.Configuration( initialServerConnectionAddresses: sockets.map { do { @@ -71,31 +80,33 @@ public struct Redis: Service { /// - Parameters: /// - config: The configuration of the pool backing this `Redis` /// client. - public static func rawPoolConfiguration(_ config: RedisConnectionPool.Configuration) -> Redis { - return Redis(driver: ConnectionPool(config: config)) + public static func configuration(_ config: RedisConnectionPool.Configuration) -> RedisClient { + return RedisClient(provider: ConnectionPool(config: config)) } } -/// Under the hood driver for `Redis`. Used so either connection pools +/// Under the hood provider for `Redis`. Used so either connection pools /// or connections can be injected into `Redis` for accessing redis. -protocol RedisDriver { +public protocol RedisProvider { /// Get a redis client for running commands. - func getClient() -> RedisClient + func getClient() -> RediStack.RedisClient /// Shut down. func shutdown() throws - /// Lease a private connection for the duration of a transaction. + /// Runs a transaction on the redis client using a given closure. /// /// - Parameter transaction: An asynchronous transaction to run on /// the connection. - func leaseConnection(_ transaction: @escaping (RedisConnection) -> EventLoopFuture) -> EventLoopFuture + /// - Returns: The resulting value of the transaction. + func transaction(_ transaction: @escaping (RedisProvider) async throws -> T) async throws -> T } -/// A connection pool is a redis driver with a pool per `EventLoop`. -private final class ConnectionPool: RedisDriver { +/// A connection pool is a redis provider with a pool per `EventLoop`. +private final class ConnectionPool: RedisProvider { /// Map of `EventLoop` identifiers to respective connection pools. - @Locked private var poolStorage: [ObjectIdentifier: RedisConnectionPool] = [:] + private var poolStorage: [ObjectIdentifier: RedisConnectionPool] = [:] + private var poolLock = Lock() /// The configuration to create pools with. private var config: RedisConnectionPool.Configuration @@ -104,19 +115,24 @@ private final class ConnectionPool: RedisDriver { self.config = config } - func getClient() -> RedisClient { + func getClient() -> RediStack.RedisClient { getPool() } - func leaseConnection(_ transaction: @escaping (RedisConnection) -> EventLoopFuture) -> EventLoopFuture { - getPool().leaseConnection(transaction) + func transaction(_ transaction: @escaping (RedisProvider) async throws -> T) async throws -> T { + let pool = getPool() + return try await pool.leaseConnection { conn in + pool.eventLoop.asyncSubmit { try await transaction(conn) } + }.get() } func shutdown() throws { - try poolStorage.values.forEach { - let promise: EventLoopPromise = $0.eventLoop.makePromise() - $0.close(promise: promise) - try promise.futureResult.wait() + try poolLock.withLock { + try poolStorage.values.forEach { + let promise: EventLoopPromise = $0.eventLoop.makePromise() + $0.close(promise: promise) + try promise.futureResult.wait() + } } } @@ -127,12 +143,14 @@ private final class ConnectionPool: RedisDriver { private func getPool() -> RedisConnectionPool { let loop = Loop.current let key = ObjectIdentifier(loop) - if let pool = self.poolStorage[key] { - return pool - } else { - let newPool = RedisConnectionPool(configuration: self.config, boundEventLoop: loop) - self.poolStorage[key] = newPool - return newPool + return poolLock.withLock { + if let pool = self.poolStorage[key] { + return pool + } else { + let newPool = RedisConnectionPool(configuration: self.config, boundEventLoop: loop) + self.poolStorage[key] = newPool + return newPool + } } } } diff --git a/Sources/Alchemy/Routing/ResponseConvertible.swift b/Sources/Alchemy/Routing/ResponseConvertible.swift index bc956e77..c3dd11a4 100644 --- a/Sources/Alchemy/Routing/ResponseConvertible.swift +++ b/Sources/Alchemy/Routing/ResponseConvertible.swift @@ -1,43 +1,25 @@ -import NIO - /// Represents any type that can be converted into a response & is /// thus returnable from a request handler. public protocol ResponseConvertible { - /// Takes the response and turns it into an - /// `EventLoopFuture`. + /// Takes the type and turns it into a `Response`. /// /// - Throws: Any error that might occur when this is turned into - /// a `Response` future. - /// - Returns: A future containing an `Response` to respond to a - /// `Request` with. - func convert() throws -> EventLoopFuture + /// a `Response`. + /// - Returns: A `Response` to respond to a `Request` with. + func response() async throws -> Response } // MARK: Convenient `ResponseConvertible` Conformances. -extension Array: ResponseConvertible where Element: Encodable { - public func convert() throws -> EventLoopFuture { - .new(Response(status: .ok, body: try HTTPBody(json: self))) - } -} - extension Response: ResponseConvertible { - public func convert() throws -> EventLoopFuture { - .new(self) - } -} - -extension EventLoopFuture: ResponseConvertible where Value: ResponseConvertible { - public func convert() throws -> EventLoopFuture { - self.flatMap { res in - catchError { try res.convert() } - } + public func response() -> Response { + self } } extension String: ResponseConvertible { - public func convert() throws -> EventLoopFuture { - return .new(Response(status: .ok, body: HTTPBody(text: self))) + public func response() -> Response { + Response(status: .ok).withString(self) } } @@ -46,7 +28,7 @@ extension String: ResponseConvertible { // implementation here (and a special case router // `.on` specifically for `Encodable`) types. extension Encodable { - public func encode() throws -> EventLoopFuture { - .new(Response(status: .ok, body: try HTTPBody(json: self))) + public func response() throws -> Response { + try Response(status: .ok).withValue(self) } } diff --git a/Sources/Alchemy/Routing/Router.swift b/Sources/Alchemy/Routing/Router.swift index 85eb2f53..87efbb3e 100644 --- a/Sources/Alchemy/Routing/Router.swift +++ b/Sources/Alchemy/Routing/Router.swift @@ -1,5 +1,6 @@ import NIO import NIOHTTP1 +import Hummingbird /// The escape character for escaping path parameters. /// @@ -10,24 +11,39 @@ fileprivate let kRouterPathParameterEscape = ":" /// An `Router` responds to HTTP requests from the client. /// Specifically, it takes an `Request` and routes it to /// a handler that returns an `ResponseConvertible`. -public final class Router: HTTPRouter, Service { - /// A router handler. Takes a request and returns a future with a - /// response. - private typealias RouterHandler = (Request) -> EventLoopFuture +public final class Router { + public struct RouteOptions: OptionSet { + public let rawValue: Int + + public init(rawValue: Int) { + self.rawValue = rawValue + } + public static let stream = RouteOptions(rawValue: 1 << 0) + } + + private struct HandlerEntry { + let options: RouteOptions + let handler: (Request) async -> Response + } + + /// A route handler. Takes a request and returns a response. + public typealias Handler = (Request) async throws -> ResponseConvertible + + /// A handler for returning a response after an error is + /// encountered while initially handling the request. + public typealias ErrorHandler = (Request, Error) async throws -> ResponseConvertible + /// The default response for when there is an error along the /// routing chain that does not conform to /// `ResponseConvertible`. - public static var internalErrorResponse = Response( - status: .internalServerError, - body: HTTPBody(text: HTTPResponseStatus.internalServerError.reasonPhrase) - ) - + var internalErrorHandler: ErrorHandler = Router.uncaughtErrorHandler + /// The response for when no handler is found for a Request. - public static var notFoundResponse = Response( - status: .notFound, - body: HTTPBody(text: HTTPResponseStatus.notFound.reasonPhrase) - ) + var notFoundHandler: Handler = { _ in + Response(status: .notFound) + .withString(HTTPResponseStatus.notFound.reasonPhrase) + } /// `Middleware` that will intercept all requests through this /// router, before all other `Middleware` regardless of @@ -41,7 +57,7 @@ public final class Router: HTTPRouter, Service { var pathPrefixes: [String] = [] /// A trie that holds all the handlers. - private let trie = RouterTrieNode() + private let trie = Trie() /// Creates a new router. init() {} @@ -54,22 +70,20 @@ public final class Router: HTTPRouter, Service { /// given method and path. /// - method: The method of a request this handler expects. /// - path: The path of a requst this handler can handle. - func add(handler: @escaping (Request) throws -> ResponseConvertible, for method: HTTPMethod, path: String) { - let pathPrefixes = pathPrefixes.map { $0.hasPrefix("/") ? String($0.dropFirst()) : $0 } - let splitPath = pathPrefixes + path.tokenized - let middlewareClosures = middlewares.reversed().map(Middleware.interceptConvertError) - trie.insert(path: splitPath, storageKey: method) { - var next = { request in - catchError { try handler(request).convert() }.convertErrorToResponse() - } - + func add(handler: @escaping Handler, for method: HTTPMethod, path: String, options: RouteOptions) { + let splitPath = pathPrefixes + path.tokenized(with: method) + let middlewareClosures = middlewares.reversed().map(Middleware.intercept) + let entry = HandlerEntry(options: options) { + var next = self.cleanHandler(handler) for middleware in middlewareClosures { let oldNext = next - next = { middleware($0, oldNext) } + next = self.cleanHandler { try await middleware($0, oldNext) } } - return next($0) + return await next($0) } + + trie.insert(path: splitPath, value: entry) } /// Handles a request. If the request has any dynamic path @@ -78,68 +92,78 @@ public final class Router: HTTPRouter, Service { /// passing it to the handler closure. /// /// - Parameter request: The request this router will handle. - /// - Returns: A future containing the response of a handler or a - /// `.notFound` response if there was not a matching handler. - func handle(request: Request) -> EventLoopFuture { - var handler = notFoundHandler - - // Find a matching handler - if let match = trie.search(path: request.path.tokenized, storageKey: request.method) { - request.pathParameters = match.1 - handler = match.0 + /// - Returns: The response of a matching handler or a + /// `.notFound` response if there was not a + /// matching handler. + func handle(request: Request) async -> Response { + var handler = cleanHandler(notFoundHandler) + var additionalMiddlewares = Array(globalMiddlewares.reversed()) + let hbApp: HBApplication? = Container.resolve() + + if let length = request.headers.contentLength, length > hbApp?.configuration.maxUploadSize ?? .max { + handler = cleanHandler { _ in throw HTTPError(.payloadTooLarge) } + } else if let match = trie.search(path: request.path.tokenized(with: request.method)) { + request.parameters = match.parameters + handler = match.value.handler + + // Collate the request if streaming isn't specified. + if !match.value.options.contains(.stream) { + additionalMiddlewares.append(AccumulateMiddleware()) + } } - + // Apply global middlewares - for middleware in globalMiddlewares.reversed() { + for middleware in additionalMiddlewares { let lastHandler = handler - handler = { middleware.interceptConvertError($0, next: lastHandler) } + handler = cleanHandler { + try await middleware.intercept($0, next: lastHandler) + } } - - return handler(request) - } - - private func notFoundHandler(_ request: Request) -> EventLoopFuture { - return .new(Router.notFoundResponse) + + return await handler(request) } -} - -private extension Middleware { - func interceptConvertError(_ request: Request, next: @escaping Next) -> EventLoopFuture { - return catchError { - try intercept(request, next: next) - }.convertErrorToResponse() - } -} - -private extension EventLoopFuture where Value == Response { - func convertErrorToResponse() -> EventLoopFuture { - return flatMapError { error in - func serverError() -> EventLoopFuture { - Log.error("[Server] encountered internal error: \(error).") - return .new(Router.internalErrorResponse) - } - + + /// Converts a throwing, ResponseConvertible handler into a + /// non-throwing Response handler. + private func cleanHandler(_ handler: @escaping Handler) -> (Request) async -> Response { + return { req in do { - if let error = error as? ResponseConvertible { - return try error.convert() - } else { - return serverError() - } + return try await handler(req).response() } catch { - return serverError() + do { + if let error = error as? ResponseConvertible { + do { + return try await error.response() + } catch { + return try await self.internalErrorHandler(req, error).response() + } + } + + return try await self.internalErrorHandler(req, error).response() + } catch { + return Router.uncaughtErrorHandler(req: req, error: error) + } } } } + + /// The default error handler if an error is encountered while handling a + /// request. + private static func uncaughtErrorHandler(req: Request, error: Error) -> Response { + Log.error("[Server] encountered internal error: \(error).") + return Response(status: .internalServerError) + .withString(HTTPResponseStatus.internalServerError.reasonPhrase) + } } -private extension String { - var tokenized: [String] { - return split(separator: "/").map(String.init) +extension String { + fileprivate func tokenized(with method: HTTPMethod) -> [String] { + split(separator: "/").map(String.init).filter { !$0.isEmpty } + [method.rawValue] } } -extension HTTPMethod: Hashable { - public func hash(into hasher: inout Hasher) { - hasher.combine(self.rawValue) +private struct AccumulateMiddleware: Middleware { + func intercept(_ request: Request, next: (Request) async throws -> Response) async throws -> Response { + try await next(request.collect()) } } diff --git a/Sources/Alchemy/Routing/RouterTrieNode.swift b/Sources/Alchemy/Routing/RouterTrieNode.swift deleted file mode 100644 index 065036da..00000000 --- a/Sources/Alchemy/Routing/RouterTrieNode.swift +++ /dev/null @@ -1,65 +0,0 @@ -/// A trie that stores objects at each node. Supports wildcard path -/// elements denoted by a ":" at the beginning. -final class RouterTrieNode { - /// Storage of the objects at this node. - private var storage: [StorageKey: StorageObject] = [:] - /// This node's children, mapped by their path for instant lookup. - private var children: [String: RouterTrieNode] = [:] - /// Any children with wildcards in their path. - private var wildcardChildren: [String: RouterTrieNode] = [:] - - /// Search this node & it's children for an object at a path, - /// stored with the given key. - /// - /// - Parameters: - /// - path: The path of the object to search for. If this is - /// empty, it is assumed the object can only be at this node. - /// - storageKey: The key by which the object is stored. - /// - Returns: A tuple containing the object and any parsed path - /// parameters. `nil` if the object isn't in this node or its - /// children. - func search(path: [String], storageKey: StorageKey) -> (StorageObject, [PathParameter])? { - if let first = path.first { - let newPath = Array(path.dropFirst()) - if let matchingChild = self.children[first] { - return matchingChild.search(path: newPath, storageKey: storageKey) - } else { - for (wildcard, node) in self.wildcardChildren { - guard var val = node.search(path: newPath, storageKey: storageKey) else { - continue - } - - val.1.insert(PathParameter(parameter: wildcard, stringValue: first), at: 0) - return val - } - return nil - } - } else { - return self.storage[storageKey].map { ($0, []) } - } - } - - /// Inserts a value at the given path with a storage key. - /// - /// - Parameters: - /// - path: The path to the node where this value should be - /// stored. - /// - storageKey: The key by which to store the value. - /// - value: The value to store. - func insert(path: [String], storageKey: StorageKey, value: StorageObject) { - if let first = path.first { - if first.hasPrefix(":") { - let firstWithoutEscape = String(first.dropFirst()) - let child = self.wildcardChildren[firstWithoutEscape] ?? Self() - child.insert(path: Array(path.dropFirst()), storageKey: storageKey, value: value) - self.wildcardChildren[firstWithoutEscape] = child - } else { - let child = self.children[first] ?? Self() - child.insert(path: Array(path.dropFirst()), storageKey: storageKey, value: value) - self.children[first] = child - } - } else { - self.storage[storageKey] = value - } - } -} diff --git a/Sources/Alchemy/Routing/Trie.swift b/Sources/Alchemy/Routing/Trie.swift new file mode 100644 index 00000000..336d249c --- /dev/null +++ b/Sources/Alchemy/Routing/Trie.swift @@ -0,0 +1,62 @@ +/// A trie that stores objects at each node. Supports wildcard path +/// elements denoted by a ":" at the beginning. +final class Trie { + /// Storage of the object at this node. + private var value: Value? + /// This node's children, mapped by their path for instant lookup. + private var children: [String: Trie] = [:] + /// Any children with parameters in their path. + private var parameterChildren: [String: Trie] = [:] + + /// Search this node & it's children for an object at a path. + /// + /// - Parameter path: The path of the object to search for. If this is + /// empty, it is assumed the object can only be at this node. + /// - Returns: A tuple containing the object and any parsed path + /// parameters. `nil` if the object isn't in this node or its + /// children. + func search(path: [String]) -> (value: Value, parameters: [Parameter])? { + if let first = path.first { + let newPath = Array(path.dropFirst()) + if let matchingChild = children[first] { + return matchingChild.search(path: newPath) + } + + for (wildcard, node) in parameterChildren { + guard var val = node.search(path: newPath) else { + continue + } + + val.parameters.insert(Parameter(key: wildcard, value: first), at: 0) + return val + } + + return nil + } + + return value.map { ($0, []) } + } + + /// Inserts a value at the given path. + /// + /// - Parameters: + /// - path: The path to the node where this value should be + /// stored. + /// - value: The value to store. + func insert(path: [String], value: Value) { + if let first = path.first { + if first.hasPrefix(":") { + let firstWithoutEscape = String(first.dropFirst()) + let child = parameterChildren[firstWithoutEscape] ?? Self() + child.insert(path: Array(path.dropFirst()), value: value) + parameterChildren[firstWithoutEscape] = child + } else { + let child = children[first] ?? Self() + child.insert(path: Array(path.dropFirst()), value: value) + children[first] = child + } + } else { + self.value = value + } + } +} diff --git a/Sources/Alchemy/Rune/Model/Decoding/DatabaseFieldDecoder.swift b/Sources/Alchemy/Rune/Model/Decoding/DatabaseFieldDecoder.swift deleted file mode 100644 index 281bfb4e..00000000 --- a/Sources/Alchemy/Rune/Model/Decoding/DatabaseFieldDecoder.swift +++ /dev/null @@ -1,111 +0,0 @@ -/// Used in the internals of the `DatabaseRowDecoder`, used when -/// the `DatabaseRowDecoder` attempts to decode a `Decodable`, -/// not primitive, property from a single `DatabaseField`. -struct DatabaseFieldDecoder: ModelDecoder { - /// The field this `Decoder` will be decoding from. - let field: DatabaseField - - // MARK: Decoder - - var codingPath: [CodingKey] = [] - var userInfo: [CodingUserInfoKey : Any] = [:] - - func container( - keyedBy type: Key.Type - ) throws -> KeyedDecodingContainer where Key: CodingKey { - throw DatabaseCodingError("`container` shouldn't be called; this is only for single " - + "values.") - } - - func unkeyedContainer() throws -> UnkeyedDecodingContainer { - throw DatabaseCodingError("`unkeyedContainer` shouldn't be called; this is only for " - + "single values.") - } - - func singleValueContainer() throws -> SingleValueDecodingContainer { - _SingleValueDecodingContainer(field: self.field) - } -} - -/// A `SingleValueDecodingContainer` for decoding from a -/// `DatabaseField`. -private struct _SingleValueDecodingContainer: SingleValueDecodingContainer { - /// The field from which the container will be decoding from. - let field: DatabaseField - - // MARK: SingleValueDecodingContainer - - var codingPath: [CodingKey] = [] - - func decodeNil() -> Bool { - self.field.value.isNil - } - - func decode(_ type: Bool.Type) throws -> Bool { - try self.field.bool() - } - - func decode(_ type: String.Type) throws -> String { - try self.field.string() - } - - func decode(_ type: Double.Type) throws -> Double { - try self.field.double() - } - - func decode(_ type: Float.Type) throws -> Float { - Float(try self.field.double()) - } - - func decode(_ type: Int.Type) throws -> Int { - try self.field.int() - } - - func decode(_ type: Int8.Type) throws -> Int8 { - Int8(try self.field.int()) - } - - func decode(_ type: Int16.Type) throws -> Int16 { - Int16(try self.field.int()) - } - - func decode(_ type: Int32.Type) throws -> Int32 { - Int32(try self.field.int()) - } - - func decode(_ type: Int64.Type) throws -> Int64 { - Int64(try self.field.int()) - } - - func decode(_ type: UInt.Type) throws -> UInt { - UInt(try self.field.int()) - } - - func decode(_ type: UInt8.Type) throws -> UInt8 { - UInt8(try self.field.int()) - } - - func decode(_ type: UInt16.Type) throws -> UInt16 { - UInt16(try self.field.int()) - } - - func decode(_ type: UInt32.Type) throws -> UInt32 { - UInt32(try self.field.int()) - } - - func decode(_ type: UInt64.Type) throws -> UInt64 { - UInt64(try self.field.int()) - } - - func decode(_ type: T.Type) throws -> T where T: Decodable { - if type == Int.self { - return try self.field.int() as! T - } else if type == UUID.self { - return try self.field.uuid() as! T - } else if type == String.self { - return try self.field.string() as! T - } else { - throw DatabaseCodingError("Decoding a \(type) from a `DatabaseField` is not supported. \(field.column)") - } - } -} diff --git a/Sources/Alchemy/Rune/Model/FieldReading/Model+Fields.swift b/Sources/Alchemy/Rune/Model/FieldReading/Model+Fields.swift deleted file mode 100644 index e34e2d6f..00000000 --- a/Sources/Alchemy/Rune/Model/FieldReading/Model+Fields.swift +++ /dev/null @@ -1,27 +0,0 @@ -extension Model { - /// Returns all `DatabaseField`s on a `Model` object. Useful for - /// inserting or updating values into a database. - /// - /// - Throws: A `DatabaseCodingError` if there is an error - /// creating any of the fields of this instance. - /// - Returns: An array of database fields representing the stored - /// properties of `self`. - public func fields() throws -> [DatabaseField] { - try ModelFieldReader(Self.keyMapping).getFields(of: self) - } - - /// Returns an ordered dictionary of column names to `Parameter` - /// values, appropriate for working with the QueryBuilder. - /// - /// - Throws: A `DatabaseCodingError` if there is an error - /// creating any of the fields of this instance. - /// - Returns: An ordered dictionary mapping column names to - /// parameters for use in a QueryBuilder `Query`. - public func fieldDictionary() throws -> OrderedDictionary { - var dict = OrderedDictionary() - for field in try self.fields() { - dict.updateValue(field.value, forKey: field.column) - } - return dict - } -} diff --git a/Sources/Alchemy/Rune/Model/FieldReading/ModelFieldReader.swift b/Sources/Alchemy/Rune/Model/FieldReading/ModelFieldReader.swift deleted file mode 100644 index eea5d54b..00000000 --- a/Sources/Alchemy/Rune/Model/FieldReading/ModelFieldReader.swift +++ /dev/null @@ -1,263 +0,0 @@ -import Foundation - -/// Used so `Relationship` types can know not to encode themselves to -/// a `ModelEncoder`. -protocol ModelEncoder: Encoder {} - -/// Used for turning any `Model` into an array of `DatabaseField`s -/// (column/value combinations) based on its stored properties. -final class ModelFieldReader: ModelEncoder { - /// Used for keeping track of the database fields pulled off the - /// object encoded to this encoder. - fileprivate var readFields: [DatabaseField] = [] - - /// The mapping strategy for associating `CodingKey`s on an object - /// with column names in a database. - fileprivate let mappingStrategy: DatabaseKeyMapping - - // MARK: Encoder - - var codingPath = [CodingKey]() - var userInfo: [CodingUserInfoKey: Any] = [:] - - /// Create with an associated `DatabasekeyMapping`. - /// - /// - Parameter mappingStrategy: The strategy for mapping - /// `CodingKey` string values to the `column`s of - /// `DatabaseField`s. - init(_ mappingStrategy: DatabaseKeyMapping) { - self.mappingStrategy = mappingStrategy - } - - /// Read and return the stored properties of an `Model` object as - /// a `[DatabaseField]`. - /// - /// - Parameter value: The `Model` instance to read from. - /// - Throws: A `DatabaseCodingError` if there is an error reading - /// fields from `value`. - /// - Returns: An array of `DatabaseField`s representing the - /// properties of `value`. - func getFields(of value: M) throws -> [DatabaseField] { - try value.encode(to: self) - let toReturn = self.readFields - self.readFields = [] - return toReturn - } - - func container(keyedBy: Key.Type) -> KeyedEncodingContainer { - let container = _KeyedEncodingContainer(encoder: self, codingPath: codingPath) - return KeyedEncodingContainer(container) - } - - func unkeyedContainer() -> UnkeyedEncodingContainer { - fatalError("`Model`s should never encode to an unkeyed container.") - } - - func singleValueContainer() -> SingleValueEncodingContainer { - fatalError("`Model`s should never encode to a single value container.") - } -} - -/// Encoder helper for pulling out `DatabaseField`s from any fields -/// that encode to a `SingleValueEncodingContainer`. -private struct _SingleValueEncoder: ModelEncoder { - /// The database column to which a value encoded here should map - /// to. - let column: String - - /// The `DatabaseFieldReader` that is being used to read the - /// stored properties of an object. Need to pass it around - /// so various containers can add to it's `readFields`. - let encoder: ModelFieldReader - - // MARK: Encoder - - var codingPath: [CodingKey] = [] - var userInfo: [CodingUserInfoKey : Any] = [:] - - func container( - keyedBy type: Key.Type - ) -> KeyedEncodingContainer where Key : CodingKey { - KeyedEncodingContainer( - _KeyedEncodingContainer(encoder: self.encoder, codingPath: codingPath) - ) - } - - func unkeyedContainer() -> UnkeyedEncodingContainer { - fatalError("Arrays aren't supported by `Model`.") - } - - func singleValueContainer() -> SingleValueEncodingContainer { - _SingleValueEncodingContainer(column: self.column, encoder: self.encoder) - } -} - -private struct _SingleValueEncodingContainer< - M: Model ->: SingleValueEncodingContainer, ModelValueReader { - /// The database column to which a value encoded to this container - /// should map to. - let column: String - - /// The `DatabaseFieldReader` that is being used to read the - /// stored properties of an object. Need to pass it around - /// so various containers can add to it's `readFields`. - var encoder: ModelFieldReader - - // MARK: SingleValueEncodingContainer - - var codingPath: [CodingKey] = [] - - mutating func encodeNil() throws { - // Can't infer the type so not much we can do here. - } - - mutating func encode(_ value: Bool) throws { - self.encoder.readFields.append(DatabaseField(column: self.column, value: .bool(value))) - } - - mutating func encode(_ value: String) throws { - self.encoder.readFields.append(DatabaseField(column: self.column, value: .string(value))) - } - - mutating func encode(_ value: Double) throws { - self.encoder.readFields.append(DatabaseField(column: self.column, value: .double(value))) - } - - mutating func encode(_ value: Float) throws { - let field = DatabaseField(column: self.column, value: .double(Double(value))) - self.encoder.readFields.append(field) - } - - mutating func encode(_ value: Int) throws { - self.encoder.readFields.append(DatabaseField(column: self.column, value: .int(value))) - } - - mutating func encode(_ value: Int8) throws { - self.encoder.readFields.append(DatabaseField(column: self.column, value: .int(Int(value)))) - } - - mutating func encode(_ value: Int16) throws { - self.encoder.readFields.append(DatabaseField(column: self.column, value: .int(Int(value)))) - } - - mutating func encode(_ value: Int32) throws { - self.encoder.readFields.append(DatabaseField(column: self.column, value: .int(Int(value)))) - } - - mutating func encode(_ value: Int64) throws { - self.encoder.readFields.append(DatabaseField(column: self.column, value: .int(Int(value)))) - } - - mutating func encode(_ value: UInt) throws { - self.encoder.readFields.append(DatabaseField(column: self.column, value: .int(Int(value)))) - } - - mutating func encode(_ value: UInt8) throws { - self.encoder.readFields.append(DatabaseField(column: self.column, value: .int(Int(value)))) - } - - mutating func encode(_ value: UInt16) throws { - self.encoder.readFields.append(DatabaseField(column: self.column, value: .int(Int(value)))) - } - - mutating func encode(_ value: UInt32) throws { - self.encoder.readFields.append(DatabaseField(column: self.column, value: .int(Int(value)))) - } - - mutating func encode(_ value: UInt64) throws { - self.encoder.readFields.append(DatabaseField(column: self.column, value: .int(Int(value)))) - } - - mutating func encode(_ value: T) throws where T : Encodable { - if let value = try self.databaseValue(of: value) { - self.encoder.readFields.append(DatabaseField(column: self.column, value: value)) - } else { - throw DatabaseCodingError("Error encoding type `\(type(of: T.self))` into single value " - + "container.") - } - } -} - -private struct _KeyedEncodingContainer< - M: Model, - Key: CodingKey ->: KeyedEncodingContainerProtocol, ModelValueReader { - var encoder: ModelFieldReader - - // MARK: KeyedEncodingContainerProtocol - - var codingPath = [CodingKey]() - - mutating func encodeNil(forKey key: Key) throws { - print("Got nil for \(self.encoder.mappingStrategy.map(input: key.stringValue)).") - } - - mutating func encode(_ value: T, forKey key: Key) throws { - if let theType = try self.databaseValue(of: value) { - let keyString = self.encoder.mappingStrategy.map(input: key.stringValue) - self.encoder.readFields.append(DatabaseField(column: keyString, value: theType)) - } else if value is AnyBelongsTo { - // Special case parent relationships to append - // `M.belongsToColumnSuffix` to the property name. - let keyString = self.encoder.mappingStrategy - .map(input: key.stringValue + "Id") - try value.encode( - to: _SingleValueEncoder(column: keyString, encoder: self.encoder) - ) - } else { - let keyString = self.encoder.mappingStrategy.map(input: key.stringValue) - try value.encode(to: _SingleValueEncoder(column: keyString, encoder: self.encoder)) - } - } - - mutating func nestedContainer( - keyedBy keyType: NestedKey.Type, forKey key: Key - ) -> KeyedEncodingContainer where NestedKey: CodingKey { - fatalError("Nested coding of `Model` not supported.") - } - - mutating func nestedUnkeyedContainer(forKey key: Key) -> UnkeyedEncodingContainer { - fatalError("Nested coding of `Model` not supported.") - } - - mutating func superEncoder() -> Encoder { - fatalError("Superclass encoding of `Model` not supported.") - } - - mutating func superEncoder(forKey key: Key) -> Encoder { - fatalError("Superclass encoding of `Model` not supported.") - } -} - -/// Used for passing along the type of the `Model` various containers -/// are working with so that the `Model`'s custom encoders can be -/// used. -private protocol ModelValueReader { - /// The `Model` type this field reader is reading from. - associatedtype M: Model -} - -extension ModelValueReader { - /// Returns a `DatabaseValue` for a `Model` value. If the value - /// isn't a supported `DatabaseValue`, it is encoded to `Data` - /// returned as `.json(Data)`. This is special cased to - /// return nil if the value is a Rune relationship. - /// - /// - Parameter value: The value to map to a `DatabaseValue`. - /// - Throws: An `EncodingError` if there is an issue encoding a - /// value perceived to be JSON. - /// - Returns: A `DatabaseValue` representing `value` or `nil` if - /// value is a Rune relationship. - fileprivate func databaseValue(of value: E) throws -> DatabaseValue? { - if let value = value as? Parameter { - return value.value - } else if value is AnyBelongsTo || value is AnyHas { - return nil - } else { - // Assume anything else is JSON. - let jsonData = try M.jsonEncoder.encode(value) - return .json(jsonData) - } - } -} diff --git a/Sources/Alchemy/Rune/Model/Model+CRUD.swift b/Sources/Alchemy/Rune/Model/Model+CRUD.swift deleted file mode 100644 index 913aaa10..00000000 --- a/Sources/Alchemy/Rune/Model/Model+CRUD.swift +++ /dev/null @@ -1,296 +0,0 @@ -import NIO - -/// Useful extensions for various CRUD operations of a `Model`. -extension Model { - /// Load all models of this type from a database. - /// - /// - Parameter db: The database to load models from. Defaults to - /// `Database.default`. - /// - Returns: An `EventLoopFuture` with an array of this model, - /// loaded from the database. - public static func all(db: Database = .default) -> EventLoopFuture<[Self]> { - Self.query(database: db) - .allModels() - } - - /// Fetch the first model with the given id. - /// - /// - Parameters: - /// - db: The database to fetch the model from. Defaults to - /// `Database.default`. - /// - id: The id of the model to find. - /// - Returns: A future with a matching model. - public static func find(db: Database = .default, _ id: Self.Identifier) -> EventLoopFuture { - Self.firstWhere("id" == id, db: db) - } - - /// Fetch the first model with the given id, throwing the given - /// error if it doesn't exist. - /// - /// - Parameters: - /// - db: The database to fetch the model from. Defaults to - /// `Database.default`. - /// - id: The id of the model to find. - /// - error: An error to throw if the model doesn't exist. - /// - Returns: A future with a matching model. - public static func find(db: Database = .default, _ id: Self.Identifier, or error: Error) -> EventLoopFuture { - Self.firstWhere("id" == id, db: db).unwrap(orError: error) - } - - /// Delete the first model with the given id. - /// - /// - Parameters: - /// - db: The database to delete the model from. Defaults to - /// `Database.default`. - /// - id: The id of the model to delete. - /// - Returns: A future that completes when the model is deleted. - public static func delete(db: Database = .default, _ id: Self.Identifier) -> EventLoopFuture { - query().where("id" == id).delete().voided() - } - - /// Delete all models of this type from a database. - /// - /// - Parameter - /// - db: The database to delete models from. Defaults - /// to `Database.default`. - /// - where: An optional where clause to specify the elements - /// to delete. - /// - Returns: A future that completes when the models are - /// deleted. - public static func deleteAll(db: Database = .default, where: WhereValue? = nil) -> EventLoopFuture { - var query = Self.query(database: db) - if let clause = `where` { query = query.where(clause) } - return query.delete().voided() - } - - /// Throws an error if a query with the specified where clause - /// returns a value. The opposite of `unwrapFirstWhere(...)`. - /// - /// Useful for detecting if a value with a key that may conflict - /// (such as a unique email) already exists on a table. - /// - /// - Parameters: - /// - where: The where clause to attempt to match. - /// - error: The error that will be thrown, should a query with - /// the where clause find a result. - /// - db: The database to query. Defaults to `Database.default`. - /// - Returns: A future that will result in an error out if there - /// is a row on the table matching the given `where` clause. - public static func ensureNotExists( - _ where: WhereValue, - else error: Error, - db: Database = .default -) -> EventLoopFuture { - Self.query(database: db) - .where(`where`) - .first() - .flatMapThrowing { try $0.map { _ in throw error } } - } - - /// Creates a query on the given model with the given where - /// clause. - /// - /// - Parameters: - /// - where: A clause to match. - /// - db: The database to query. Defaults to `Database.default`. - /// - Returns: A query on the `Model`'s table that matches the - /// given where clause. - public static func `where`(_ where: WhereValue, db: Database = .default) -> ModelQuery { - Self.query(database: db) - .where(`where`) - } - - /// Gets the first element that meets the given where value. - /// - /// - Parameters: - /// - where: The table will be queried for a row matching this - /// clause. - /// - db: The database to query. Defaults to `Database.default`. - /// - Returns: A future containing the first result matching the - /// `where` clause, if one exists. - public static func firstWhere(_ where: WhereValue, db: Database = .default) -> EventLoopFuture { - Self.query(database: db) - .where(`where`) - .firstModel() - } - - /// Gets all elements that meets the given where value. - /// - /// - Parameters: - /// - where: The table will be queried for a row matching this - /// clause. - /// - db: The database to query. Defaults to `Database.default`. - /// - Returns: A future containing all the results matching the - /// `where` clause. - public static func allWhere(_ where: WhereValue, db: Database = .default) -> EventLoopFuture<[Self]> { - Self.query(database: db) - .where(`where`) - .allModels() - } - - /// Gets the first element that meets the given where value. - /// Throws an error if no results match. The opposite of - /// `ensureNotExists(...)`. - /// - /// - Parameters: - /// - where: The table will be queried for a row matching this - /// clause. - /// - error: The error to throw if there are no results. - /// - db: The database to query. Defaults to `Database.default`. - /// - Returns: A future containing the first result matching the - /// `where` clause. Will result in `error` if no result is - /// found. - public static func unwrapFirstWhere( - _ where: WhereValue, - or error: Error, - db: Database = .default - ) -> EventLoopFuture { - Self.query(database: db) - .where(`where`) - .unwrapFirst(or: error) - } - - /// Saves this model to a database. If this model's `id` is nil, - /// it inserts it. If the `id` is not nil, it updates. - /// - /// - Parameter db: The database to save this model to. Defaults - /// to `Database.default`. - /// - Returns: A future that contains an updated version of self - /// with an updated copy of this model, reflecting any changes - /// that may have occurred saving this object to the database - /// (an `id` being populated, for example). - public func save(db: Database = .default) -> EventLoopFuture { - if self.id != nil { - return self.update(db: db) - } else { - return self.insert(db: db) - } - } - - /// Update this model in a database. - /// - /// - Parameter db: The database to update this model to. Defaults - /// to `Database.default`. - /// - Returns: A future that contains an updated version of self - /// with an updated copy of this model, reflecting any changes - /// that may have occurred saving this object to the database. - public func update(db: Database = .default) -> EventLoopFuture { - return catchError { - let id = try self.getID() - return Self.query(database: db) - .where("id" == id) - .update(values: try self.fieldDictionary().unorderedDictionary) - .map { _ in self } - } - } - - public func update(db: Database = .default, updateClosure: (inout Self) -> Void) -> EventLoopFuture { - return catchError { - let id = try self.getID() - var copy = self - updateClosure(©) - return Self.query(database: db) - .where("id" == id) - .update(values: try copy.fieldDictionary().unorderedDictionary) - .map { _ in copy } - } - } - - public static func update(db: Database = .default, _ id: Identifier, with dict: [String: Any]?) -> EventLoopFuture { - Self.find(id) - .optionalFlatMap { $0.update(with: dict ?? [:]) } - } - - public func update(db: Database = .default, with dict: [String: Any]) -> EventLoopFuture { - Self.query() - .where("id" == id) - .update(values: dict.compactMapValues { $0 as? Parameter }) - .flatMap { _ in self.sync() } - } - - /// Inserts this model to a database. - /// - /// - Parameter db: The database to insert this model to. Defaults - /// to `Database.default`. - /// - Returns: A future that contains an updated version of self - /// with an updated copy of this model, reflecting any changes - /// that may have occurred saving this object to the database. - /// (an `id` being populated, for example). - public func insert(db: Database = .default) -> EventLoopFuture { - catchError { - Self.query(database: db) - .insert(try self.fieldDictionary()) - .flatMapThrowing { try $0.first.unwrap(or: RuneError.notFound) } - .flatMapThrowing { try $0.decode(Self.self) } - } - } - - /// Deletes this model from a database. This will fail if the - /// model has a nil `id` field. - /// - /// - Parameter db: The database to remove this model from. - /// Defaults to `Database.default`. - /// - Returns: A future that completes when the model has been - /// deleted. - public func delete(db: Database = .default) -> EventLoopFuture { - catchError { - let idField = try self.getID() - return Self.query(database: db) - .where("id" == idField) - .delete() - .voided() - } - } - - /// Fetches an copy of this model from a database, with any - /// updates that may have been made since it was last - /// fetched. - /// - /// - Parameter db: The database to load from. Defaults to - /// `Database.default`. - /// - Returns: A future containing a freshly synced copy of this - /// model. - public func sync(db: Database = .default, query: ((ModelQuery) -> ModelQuery) = { $0 }) -> EventLoopFuture { - catchError { - guard let id = self.id else { - throw RuneError.syncErrorNoId - } - - return query(Self.query(database: db).where("id" == id)) - .firstModel() - .unwrap(orError: RuneError.syncErrorNoMatch(table: Self.tableName, id: id)) - } - } -} - -/// Usefuly extensions for CRUD operations on an array of `Model`s. -extension Array where Element: Model { - /// Inserts each element in this array to a database. - /// - /// - Parameter db: The database to insert the models into. - /// Defaults to `Database.default`. - /// - Returns: A future that contains copies of all models in this - /// array, updated to reflect any changes in the model caused by inserting. - public func insertAll(db: Database = .default) -> EventLoopFuture { - catchError { - Element.query(database: db) - .insert(try self.map { try $0.fieldDictionary() }) - .flatMapEachThrowing { try $0.decode(Element.self) } - } - } - - /// Deletes all objects in this array from a database. If an - /// object in this array isn't actually in the database, it - /// will be ignored. - /// - /// - Parameter db: The database to delete from. Defaults to - /// `Database.default`. - /// - Returns: A future that completes when all models in this - /// array are deleted from the database. - public func deleteAll(db: Database = .default) -> EventLoopFuture { - Element.query(database: db) - .where(key: "id", in: self.compactMap { $0.id }) - .delete() - .voided() - } -} diff --git a/Sources/Alchemy/Rune/Model/Model+Query.swift b/Sources/Alchemy/Rune/Model/Model+Query.swift deleted file mode 100644 index 6e5dc276..00000000 --- a/Sources/Alchemy/Rune/Model/Model+Query.swift +++ /dev/null @@ -1,238 +0,0 @@ -import Foundation -import NIO - -public extension Model { - /// Begin a `ModelQuery` from a given database. - /// - /// - Parameter database: The database to run the query on. - /// Defaults to `Database.default`. - /// - Returns: A builder for building your query. - static func query(database: Database = .default) -> ModelQuery { - ModelQuery(database: database.driver).from(table: Self.tableName) - } -} - -/// A `ModelQuery` is just a subclass of `Query` with some added -/// typing and convenience functions for querying the table of -/// a specific `Model`. -public class ModelQuery: Query { - /// A closure for defining any nested eager loading when loading a - /// relationship on this `Model`. - /// - /// "Eager loading" refers to loading a model at the other end of - /// a relationship of this queried model. Nested eager loads - /// refers to loading a model from a relationship on that - /// _other_ model. - public typealias NestedEagerLoads = (ModelQuery) -> ModelQuery - - /// The closures of any eager loads to run. To be run after the - /// initial models of type `Self` are fetched. - /// - /// - Warning: Right now these only run when the query is - /// finished with `allModels` or `firstModel`. If the user - /// finishes a query with a `get()` we don't know if/when the - /// decode will happen and how to handle it. A potential ways - /// of doing this could be to call eager loading @ the - /// `.decode` level of a `DatabaseRow`, but that's too - /// complicated for now). - private var eagerLoadQueries: [([(M, DatabaseRow)]) -> EventLoopFuture<[(M, DatabaseRow)]>] = [] - - /// Gets all models matching this query from the database. - /// - /// - Returns: A future containing all models matching this query. - public func allModels() -> EventLoopFuture<[M]> { - self._allModels().mapEach(\.0) - } - - private func _allModels(columns: [Column]? = ["\(M.tableName).*"]) -> EventLoopFuture<[(M, DatabaseRow)]> { - return self.get(columns) - .flatMapThrowing { - try $0.map { (try $0.decode(M.self), $0) } - } - .flatMap { self.evaluateEagerLoads(for: $0) } - } - - /// Get the first model matching this query from the database. - /// - /// - Returns: A future containing the first model matching this - /// query or nil if this query has no results. - public func firstModel() -> EventLoopFuture { - self.first() - .flatMapThrowing { result -> (M, DatabaseRow)? in - guard let result = result else { - return nil - } - - return (try result.decode(M.self), result) - } - .flatMap { result -> EventLoopFuture<(M, DatabaseRow)?> in - if let result = result { - return self.evaluateEagerLoads(for: [result]).map { $0.first } - } else { - return .new(nil) - } - } - .map { $0?.0 } - } - - /// Similary to `getFirst`. Gets the first result of a query, but - /// unwraps the element, throwing an error if it doesn't exist. - /// - /// - Parameter error: The error to throw should no element be - /// found. Defaults to `RuneError.notFound`. - /// - Returns: A future containing the unwrapped first result of - /// this query, or the supplied error if no result was found. - public func unwrapFirst(or error: Error = RuneError.notFound) -> EventLoopFuture { - self.firstModel() - .flatMapThrowing { try $0.unwrap(or: error) } - } - - /// Eager loads (loads a related `Model`) a `Relationship` on this - /// model. - /// - /// Eager loads are evaluated in a single query per eager load - /// after the initial model query has completed. - /// - /// - Warning: **PLEASE NOTE** Eager loads only load when your - /// query is completed with functions from `ModelQuery`, such as - /// `allModels` or `firstModel`. If you finish your query with - /// functions from `Query`, such as `delete`, `insert`, `save`, - /// or `get`, the `Model` type isn't guaranteed to be decoded so - /// we can't run the eager loads. **TL;DR**: only finish your - /// query with functions that automatically decode your model - /// when using eager loads (i.e. doesn't result in - /// `EventLoopFuture<[DatabaseRow]>`). - /// - /// Usage: - /// ```swift - /// // Consider three types, `Pet`, `Person`, and `Plant`. They - /// // have the following relationships: - /// struct Pet: Model { - /// ... - /// - /// @BelongsTo - /// var owner: Person - /// } - /// - /// struct Person: Model { - /// ... - /// - /// @BelongsTo - /// var favoritePlant: Plant - /// } - /// - /// struct Plant: Model { ... } - /// - /// // A `Pet` query that loads each pet's related owner _as well_ - /// // as those owners' favorite plants would look like this: - /// Pet.query() - /// // An eager load - /// .with(\.$owner) { ownerQuery in - /// // `ownerQuery` is the query that will be run when - /// // fetching owner objects; we can give it its own - /// // eager loads (aka nested eager loading) - /// ownerQuery.with(\.$favoritePlant) - /// } - /// .getAll() - /// ``` - /// - Parameters: - /// - relationshipKeyPath: The `KeyPath` of the relationship to - /// load. Please note that this is a `KeyPath` to a - /// `Relationship`, not a `Model`, so it will likely - /// start with a '$', such as `\.$user`. - /// - nested: A closure for any nested loading to do. See - /// example above. Defaults to an empty closure. - /// - Returns: A query builder for extending the query. - public func with( - _ relationshipKeyPath: KeyPath, - nested: @escaping NestedEagerLoads = { $0 } - ) -> ModelQuery where R.From == M { - self.eagerLoadQueries.append { fromResults in - catchError { - let mapper = RelationshipMapper() - M.mapRelations(mapper) - let config = mapper.getConfig(for: relationshipKeyPath) - - // If there are no results, don't need to eager load. - guard !fromResults.isEmpty else { - return .new([]) - } - - // Alias whatever key we'll join the relationship on - let toJoinKeyAlias = "_to_join_key" - let toJoinKey: String = { - let table = config.through?.table ?? config.toTable - let key = config.through?.fromKey ?? config.toKey - return "\(table).\(key) as \(toJoinKeyAlias)" - }() - - let allRows = fromResults.map(\.1) - return nested(try config.load(allRows)) - ._allModels(columns: ["\(R.To.Value.tableName).*", toJoinKey]) - .flatMapEachThrowing { (try R.To.from($0), $1) } - // Key the results by the "from" identifier - .flatMapThrowing { - try Dictionary(grouping: $0) { _, row in - try row.getField(column: toJoinKeyAlias).value - } - } - // For each `from` populate it's relationship - .flatMapThrowing { toResultsKeyedByFromId in - return try fromResults.map { model, row in - let pk = try row.getField(column: config.fromKey).value - let models = toResultsKeyedByFromId[pk]?.map(\.0) ?? [] - try model[keyPath: relationshipKeyPath].set(values: models) - return (model, row) - } - } - } - } - - return self - } - - /// Evaluate all eager loads in this `ModelQuery` sequentially. - /// This occurs after the inital `M` query has completed. - /// - /// - Parameter models: The models that were loaded by the initial - /// query. - /// - Returns: A future containing the loaded models that will - /// have all specified relationships loaded. - private func evaluateEagerLoads(for models: [(M, DatabaseRow)]) -> EventLoopFuture<[(M, DatabaseRow)]> { - self.eagerLoadQueries - .reduce(.new(models)) { future, eagerLoad in - future.flatMap { eagerLoad($0) } - } - } -} - -private extension RelationshipMapping { - func load(_ values: [DatabaseRow]) throws -> ModelQuery { - var query = M.query().from(table: toTable) - var whereKey = "\(toTable).\(toKey)" - if let through = through { - whereKey = "\(through.table).\(through.fromKey)" - query = query.leftJoin(table: through.table, first: "\(through.table).\(through.toKey)", second: "\(toTable).\(toKey)") - } - - let ids = try values.map { try $0.getField(column: fromKey).value } - query = query.where(key: "\(whereKey)", in: ids.uniques) - return query - } -} - -private extension Array where Element: Hashable { - /// Removes any duplicates from the array while maintaining the - /// original order. - var uniques: Array { - var buffer = Array() - var added = Set() - for elem in self { - if !added.contains(elem) { - buffer.append(elem) - added.insert(elem) - } - } - return buffer - } -} diff --git a/Sources/Alchemy/Rune/Model/ModelEnum.swift b/Sources/Alchemy/Rune/Model/ModelEnum.swift deleted file mode 100644 index 7ef386a3..00000000 --- a/Sources/Alchemy/Rune/Model/ModelEnum.swift +++ /dev/null @@ -1,28 +0,0 @@ -/// A protocol to which enums on `Model`s should conform to. The enum -/// will be modeled in the backing table by it's raw value. -/// -/// Usage: -/// ```swift -/// enum TaskPriority: Int, ModelEnum { -/// case low, medium, high -/// } -/// -/// struct Todo: Model { -/// var id: Int? -/// let name: String -/// let isDone: Bool -/// let priority: TaskPriority // Stored as `Int` in the database. -/// } -/// ``` -public protocol ModelEnum: AnyModelEnum, CaseIterable {} - -/// A type erased `ModelEnum`. -public protocol AnyModelEnum: Codable, Parameter { - /// The default case of this enum. Defaults to the first of - /// `Self.allCases`. - static var defaultCase: Self { get } -} - -extension ModelEnum { - public static var defaultCase: Self { Self.allCases.first! } -} diff --git a/Sources/Alchemy/SQL/Database/Abstract/DatabaseConfig.swift b/Sources/Alchemy/SQL/Database/Abstract/DatabaseConfig.swift deleted file mode 100644 index 4c41bab5..00000000 --- a/Sources/Alchemy/SQL/Database/Abstract/DatabaseConfig.swift +++ /dev/null @@ -1,29 +0,0 @@ -/// The information needed to connect to a database. -public struct DatabaseConfig { - /// The socket where this database server is available. - public let socket: Socket - /// The name of the database on the database server to connect to. - public let database: String - /// The username to connect to the database with. - public let username: String - /// The password to connect to the database with. - public let password: String - /// Should the connection use SSL. - public let enableSSL: Bool - - /// Initialize a database configuration with the relevant info. - /// - /// - Parameters: - /// - socket: The location of the database. - /// - database: The name of the database to connect to. - /// - username: The username to connect with. - /// - password: The password to connect with. - /// - enableSSL: Should the connection use SSL. - public init(socket: Socket, database: String, username: String, password: String, enableSSL: Bool = false) { - self.socket = socket - self.database = database - self.username = username - self.password = password - self.enableSSL = enableSSL - } -} diff --git a/Sources/Alchemy/SQL/Database/Abstract/DatabaseField.swift b/Sources/Alchemy/SQL/Database/Abstract/DatabaseField.swift deleted file mode 100644 index ccea6889..00000000 --- a/Sources/Alchemy/SQL/Database/Abstract/DatabaseField.swift +++ /dev/null @@ -1,150 +0,0 @@ -/// Represents a column & value pair in a database row. -/// -/// If there were a table with columns "id", "email", "phone" and a -/// row with values 1 ,"josh@alchemy.dev", "(555) 555-5555", -/// `DatabaseField(column: id, .int(1))` would represent a -/// field on that table. -public struct DatabaseField: Equatable { - /// The name of the column this value came from. - public let column: String - /// The value of this field. - public let value: DatabaseValue -} - -/// Functions for easily accessing the unwrapped contents of -/// `DatabaseField` values. -extension DatabaseField { - /// Unwrap and return an `Int` value from this `DatabaseField`. - /// This throws if the underlying `value` isn't an `.int` or - /// the `.int` has a `nil` associated value. - /// - /// - Throws: A `DatabaseError` if this field's `value` isn't a - /// `DatabaseValue.int` or its contents is nil. - /// - Returns: The unwrapped `Int` of this field's value, if it - /// was indeed a non-null `.int`. - public func int() throws -> Int { - guard case let .int(value) = self.value else { - throw typeError("int") - } - - return try self.unwrapOrError(value) - } - - /// Unwrap and return a `String` value from this `DatabaseField`. - /// This throws if the underlying `value` isn't a `.string` or - /// the `.string` has a nil associated value. - /// - /// - Throws: A `DatabaseError` if this field's `value` isn't a - /// `DatabaseValue.string` or its contents is nil. - /// - Returns: The unwrapped `String` of this field's value, if - /// it was indeed a non-null `.string`. - public func string() throws -> String { - guard case let .string(value) = self.value else { - throw typeError("string") - } - - return try self.unwrapOrError(value) - } - - /// Unwrap and return a `Double` value from this `DatabaseField`. - /// This throws if the underlying `value` isn't a `.double` or - /// the `.double` has a nil associated value. - /// - /// - Throws: A `DatabaseError` if this field's `value` isn't a - /// `DatabaseValue.double` or its contents is nil. - /// - Returns: The unwrapped `Double` of this field's value, if it - /// was indeed a non-null `.double`. - public func double() throws -> Double { - guard case let .double(value) = self.value else { - throw typeError("double") - } - - return try self.unwrapOrError(value) - } - - /// Unwrap and return a `Bool` value from this `DatabaseField`. - /// This throws if the underlying `value` isn't a `.bool` or - /// the `.bool` has a nil associated value. - /// - /// - Throws: A `DatabaseError` if this field's `value` isn't a - /// `DatabaseValue.bool` or its contents is nil. - /// - Returns: The unwrapped `Bool` of this field's value, if it - /// was indeed a non-null `.bool`. - public func bool() throws -> Bool { - guard case let .bool(value) = self.value else { - throw typeError("bool") - } - - return try self.unwrapOrError(value) - } - - /// Unwrap and return a `Date` value from this `DatabaseField`. - /// This throws if the underlying `value` isn't a `.date` or - /// the `.date` has a nil associated value. - /// - /// - Throws: A `DatabaseError` if this field's `value` isn't a - /// `DatabaseValue.date` or its contents is nil. - /// - Returns: The unwrapped `Date` of this field's value, if it - /// was indeed a non-null `.date`. - public func date() throws -> Date { - guard case let .date(value) = self.value else { - throw typeError("date") - } - - return try self.unwrapOrError(value) - } - - /// Unwrap and return a JSON `Data` value from this - /// `DatabaseField`. This throws if the underlying `value` isn't - /// a `.json` or the `.json` has a nil associated value. - /// - /// - Throws: A `DatabaseError` if this field's `value` isn't a - /// `DatabaseValue.json` or its contents is nil. - /// - Returns: The `Data` of this field's unwrapped json value, if - /// it was indeed a non-null `.json`. - public func json() throws -> Data { - guard case let .json(value) = self.value else { - throw typeError("json") - } - - return try self.unwrapOrError(value) - } - - /// Unwrap and return a `UUID` value from this `DatabaseField`. - /// This throws if the underlying `value` isn't a `.uuid` or - /// the `.uuid` has a nil associated value. - /// - /// - Throws: A `DatabaseError` if this field's `value` isn't a - /// `DatabaseValue.uuid` or its contents is nil. - /// - Returns: The unwrapped `UUID` of this field's value, if it - /// was indeed a non-null `.uuid`. - public func uuid() throws -> UUID { - guard case let .uuid(value) = self.value else { - throw typeError("uuid") - } - - return try self.unwrapOrError(value) - } - - /// Generates an `DatabaseError` appropriate to throw if the user - /// tries to get a type that isn't compatible with this - /// `DatabaseField`'s `value`. - /// - /// - Parameter typeName: The name of the type the user tried to - /// get. - /// - Returns: A `DatabaseError` with a message describing the - /// predicament. - private func typeError(_ typeName: String) -> Error { - DatabaseError("Field at column '\(self.column)' expected to be `\(typeName)` but wasn't.") - } - - /// Unwraps a value of type `T`, or throws an error detailing the - /// nil data at the column. - /// - /// - Parameter value: The value to unwrap. - /// - Throws: A `DatabaseError` if the value is nil. - /// - Returns: The value, `T`, if it was successfully unwrapped. - private func unwrapOrError(_ value: T?) throws -> T { - try value.unwrap(or: DatabaseError("Tried to get a value from '\(self.column)' but it was `nil`.")) - } -} diff --git a/Sources/Alchemy/SQL/Database/Abstract/DatabaseRow.swift b/Sources/Alchemy/SQL/Database/Abstract/DatabaseRow.swift deleted file mode 100644 index 02272627..00000000 --- a/Sources/Alchemy/SQL/Database/Abstract/DatabaseRow.swift +++ /dev/null @@ -1,33 +0,0 @@ -/// A row of data returned from a database. Various database packages -/// can use this as an abstraction around their internal row types. -public protocol DatabaseRow { - /// The `String` names of all columns that have values in this - /// `DatabaseRow`. - var allColumns: Set { get } - - /// Get the `DatabaseField` of a column from this row. - /// - /// - Parameter column: The column to get the value for. - /// - Throws: A `DatabaseError` if the column does not exist on - /// this row. - /// - Returns: The field at `column`. - func getField(column: String) throws -> DatabaseField - - /// Decode a `Model` type `D` from this row. - /// - /// The default implementation of this function populates the - /// properties of `D` with the value of the column named the - /// same as the property. - /// - /// - Parameter type: The type to decode from this row. - func decode(_ type: D.Type) throws -> D -} - -extension DatabaseRow { - public func decode(_ type: M.Type) throws -> M { - // For each stored coding key, pull out the column name. Will - // need to write a custom decoder that pulls out of a database - // row. - try M(from: DatabaseRowDecoder(row: self)) - } -} diff --git a/Sources/Alchemy/SQL/Database/Abstract/DatabaseValue.swift b/Sources/Alchemy/SQL/Database/Abstract/DatabaseValue.swift deleted file mode 100644 index 9e24f93f..00000000 --- a/Sources/Alchemy/SQL/Database/Abstract/DatabaseValue.swift +++ /dev/null @@ -1,47 +0,0 @@ -import Foundation - -/// Represents the type / value combo of an SQL database field. These -/// don't necessarily correspond to a specific SQL database's types; -/// they just represent the types that Alchemy current supports. -/// -/// All fields are optional by default, it's up to the end user to -/// decide if a nil value in that field is appropriate and -/// potentially throw an error. -public enum DatabaseValue: Equatable, Hashable { - /// An `Int` value. - case int(Int?) - /// A `Double` value. - case double(Double?) - /// A `Bool` value. - case bool(Bool?) - /// A `String` value. - case string(String?) - /// A `Date` value. - case date(Date?) - /// A JSON value, given as `Data`. - case json(Data?) - /// A `UUID` value. - case uuid(UUID?) -} - -extension DatabaseValue { - /// Indicates if the associated value inside this enum is nil. - public var isNil: Bool { - switch self { - case .int(let value): - return value == nil - case .double(let value): - return value == nil - case .bool(let value): - return value == nil - case .string(let value): - return value == nil - case .date(let value): - return value == nil - case .json(let value): - return value == nil - case .uuid(let value): - return value == nil - } - } -} diff --git a/Sources/Alchemy/SQL/Database/Abstract/DatabaseCodingError.swift b/Sources/Alchemy/SQL/Database/Core/DatabaseCodingError.swift similarity index 73% rename from Sources/Alchemy/SQL/Database/Abstract/DatabaseCodingError.swift rename to Sources/Alchemy/SQL/Database/Core/DatabaseCodingError.swift index 87b9ceba..a08a4317 100644 --- a/Sources/Alchemy/SQL/Database/Abstract/DatabaseCodingError.swift +++ b/Sources/Alchemy/SQL/Database/Core/DatabaseCodingError.swift @@ -1,5 +1,4 @@ -/// An error encountered when decoding a `Model` from a `DatabaseRow` -/// or encoding it to a `[DatabaseField]`. +/// An error encountered when decoding or encoding a `Model`. struct DatabaseCodingError: Error { /// What went wrong. let message: String diff --git a/Sources/Alchemy/SQL/Database/Abstract/DatabaseError.swift b/Sources/Alchemy/SQL/Database/Core/DatabaseError.swift similarity index 100% rename from Sources/Alchemy/SQL/Database/Abstract/DatabaseError.swift rename to Sources/Alchemy/SQL/Database/Core/DatabaseError.swift diff --git a/Sources/Alchemy/SQL/Database/Abstract/DatabaseKeyMapping.swift b/Sources/Alchemy/SQL/Database/Core/DatabaseKeyMapping.swift similarity index 100% rename from Sources/Alchemy/SQL/Database/Abstract/DatabaseKeyMapping.swift rename to Sources/Alchemy/SQL/Database/Core/DatabaseKeyMapping.swift diff --git a/Sources/Alchemy/SQL/Database/Core/SQL.swift b/Sources/Alchemy/SQL/Database/Core/SQL.swift new file mode 100644 index 00000000..ed7b3caa --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Core/SQL.swift @@ -0,0 +1,26 @@ +public struct SQL: Equatable { + let statement: String + let bindings: [SQLValue] + + public init(_ statement: String = "", bindings: [SQLValue] = []) { + self.statement = statement + self.bindings = bindings + } +} + +extension SQL: ExpressibleByStringLiteral { + public init(stringLiteral value: StringLiteralType) { + self.statement = value + self.bindings = [] + } +} + +extension SQL: SQLConvertible { + public var sql: SQL { self } +} + +extension SQL: SQLValueConvertible { + public var value: SQLValue { + .string(statement) + } +} diff --git a/Sources/Alchemy/SQL/QueryBuilder/Sequelizable.swift b/Sources/Alchemy/SQL/Database/Core/SQLConvertible.swift similarity index 56% rename from Sources/Alchemy/SQL/QueryBuilder/Sequelizable.swift rename to Sources/Alchemy/SQL/Database/Core/SQLConvertible.swift index 481c58fb..13f68459 100644 --- a/Sources/Alchemy/SQL/QueryBuilder/Sequelizable.swift +++ b/Sources/Alchemy/SQL/Database/Core/SQLConvertible.swift @@ -1,7 +1,5 @@ -import Foundation - /// Something that can be turned into SQL. -public protocol Sequelizable { +public protocol SQLConvertible { /// Returns an SQL representation of this type. - func toSQL() -> SQL + var sql: SQL { get } } diff --git a/Sources/Alchemy/SQL/Database/Core/SQLRow.swift b/Sources/Alchemy/SQL/Database/Core/SQLRow.swift new file mode 100644 index 00000000..aec3e706 --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Core/SQLRow.swift @@ -0,0 +1,44 @@ +import Foundation + +/// A row of data returned from a database. Various database packages +/// can use this as an abstraction around their internal row types. +public protocol SQLRow { + /// The `String` names of all columns that have values in this row. + var columns: Set { get } + + /// Get the `SQLValue` of a column from this row. + /// + /// - Parameter column: The column to get the value for. + /// - Throws: A `DatabaseError` if the column does not exist on + /// this row. + /// - Returns: The value at `column`. + func get(_ column: String) throws -> SQLValue + + /// Decode a `Model` type `D` from this row. + /// + /// The default implementation of this function populates the + /// properties of `D` with the value of the column named the + /// same as the property. + /// + /// - Parameter type: The type to decode from this row. + func decode(_ type: D.Type) throws -> D +} + +extension SQLRow { + public func decode( + _ type: D.Type, + keyMapping: DatabaseKeyMapping = .useDefaultKeys, + jsonDecoder: JSONDecoder = JSONDecoder() + ) throws -> D { + try D(from: SQLRowDecoder(row: self, keyMapping: keyMapping, jsonDecoder: jsonDecoder)) + } + + public func decode(_ type: M.Type) throws -> M { + try M(from: SQLRowDecoder(row: self, keyMapping: M.keyMapping, jsonDecoder: M.jsonDecoder)) + } + + /// Subscript for convenience access. + public subscript(column: String) -> SQLValue? { + columns.contains(column) ? try? get(column) : nil + } +} diff --git a/Sources/Alchemy/SQL/Database/Core/SQLValue.swift b/Sources/Alchemy/SQL/Database/Core/SQLValue.swift new file mode 100644 index 00000000..5fb71761 --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Core/SQLValue.swift @@ -0,0 +1,234 @@ +import Foundation + +/// Represents the type / value combo of an SQL database field. These +/// don't necessarily correspond to a specific SQL database's types; +/// they just represent the types that Alchemy current supports. +/// +/// All fields are optional by default, it's up to the end user to +/// decide if a nil value in that field is appropriate and +/// potentially throw an error. +public enum SQLValue: Equatable, Hashable, CustomStringConvertible { + /// An `Int` value. + case int(Int) + /// A `Double` value. + case double(Double) + /// A `Bool` value. + case bool(Bool) + /// A `String` value. + case string(String) + /// A `Date` value. + case date(Date) + /// A JSON value, given as `Data`. + case json(Data) + /// A `UUID` value. + case uuid(UUID) + /// A null value of any type. + case null + + public var description: String { + switch self { + case .int(let int): + return "SQLValue.int(\(int))" + case .double(let double): + return "SQLValue.double(\(double))" + case .bool(let bool): + return "SQLValue.bool(\(bool))" + case .string(let string): + return "SQLValue.string(`\(string)`)" + case .date(let date): + return "SQLValue.date(\(date))" + case .json(let data): + return "SQLValue.json(\(String(data: data, encoding: .utf8) ?? "\(data)"))" + case .uuid(let uuid): + return "SQLValue.uuid(\(uuid.uuidString))" + case .null: + return "SQLValue.null" + } + } +} + +/// Extension for easily accessing the unwrapped contents of an `SQLValue`. +extension SQLValue { + static let iso8601DateFormatter = ISO8601DateFormatter() + static let simpleFormatter: DateFormatter = { + let formatter = DateFormatter() + formatter.dateFormat = "yyyy-MM-dd HH:mm:ss" + return formatter + }() + + /// Unwrap and return an `Int` value from this `SQLValue`. + /// This throws if the underlying `value` isn't an `.int` or + /// the `.int` has a `nil` associated value. + /// + /// - Throws: A `DatabaseError` if this field's `value` isn't a + /// `SQLValue.int` or its contents is nil. + /// - Returns: The unwrapped `Int` of this field's value, if it + /// was indeed a non-null `.int`. + public func int(_ columnName: String? = nil) throws -> Int { + try ensureNotNull(columnName) + + switch self { + case .int(let value): + return value + default: + throw typeError("Int", columnName: columnName) + } + } + + /// Unwrap and return a `String` value from this `SQLValue`. + /// This throws if the underlying `value` isn't a `.string` or + /// the `.string` has a nil associated value. + /// + /// - Throws: A `DatabaseError` if this field's `value` isn't a + /// `SQLValue.string` or its contents is nil. + /// - Returns: The unwrapped `String` of this field's value, if + /// it was indeed a non-null `.string`. + public func string(_ columnName: String? = nil) throws -> String { + try ensureNotNull(columnName) + + switch self { + case .string(let value): + return value + default: + throw typeError("String", columnName: columnName) + } + } + + /// Unwrap and return a `Double` value from this `SQLValue`. + /// This throws if the underlying `value` isn't a `.double` or + /// the `.double` has a nil associated value. + /// + /// - Throws: A `DatabaseError` if this field's `value` isn't a + /// `SQLValue.double` or its contents is nil. + /// - Returns: The unwrapped `Double` of this field's value, if it + /// was indeed a non-null `.double`. + public func double(_ columnName: String? = nil) throws -> Double { + try ensureNotNull(columnName) + + switch self { + case .double(let value): + return value + default: + throw typeError("Double", columnName: columnName) + } + } + + /// Unwrap and return a `Bool` value from this `SQLValue`. + /// This throws if the underlying `value` isn't a `.bool` or + /// the `.bool` has a nil associated value. + /// + /// - Throws: A `DatabaseError` if this field's `value` isn't a + /// `SQLValue.bool` or its contents is nil. + /// - Returns: The unwrapped `Bool` of this field's value, if it + /// was indeed a non-null `.bool`. + public func bool(_ columnName: String? = nil) throws -> Bool { + try ensureNotNull(columnName) + + switch self { + case .bool(let value): + return value + case .int(let value): + return value != 0 + default: + throw typeError("Bool", columnName: columnName) + } + } + + /// Unwrap and return a `Date` value from this `SQLValue`. + /// This throws if the underlying `value` isn't a `.date` or + /// the `.date` has a nil associated value. + /// + /// - Throws: A `DatabaseError` if this field's `value` isn't a + /// `SQLValue.date` or its contents is nil. + /// - Returns: The unwrapped `Date` of this field's value, if it + /// was indeed a non-null `.date`. + public func date(_ columnName: String? = nil) throws -> Date { + try ensureNotNull(columnName) + + switch self { + case .date(let value): + return value + case .string(let value): + guard + let date = SQLValue.iso8601DateFormatter.date(from: value) + ?? SQLValue.simpleFormatter.date(from: value) + else { + throw typeError("Date", columnName: columnName) + } + + return date + default: + throw typeError("Date", columnName: columnName) + } + } + + /// Unwrap and return a JSON `Data` value from this + /// `SQLValue`. This throws if the underlying `value` isn't + /// a `.json` or the `.json` has a nil associated value. + /// + /// - Throws: A `DatabaseError` if this field's `value` isn't a + /// `SQLValue.json` or its contents is nil. + /// - Returns: The `Data` of this field's unwrapped json value, if + /// it was indeed a non-null `.json`. + public func json(_ columnName: String? = nil) throws -> Data { + try ensureNotNull(columnName) + + switch self { + case .json(let value): + return value + case .string(let string): + guard let data = string.data(using: .utf8) else { + throw typeError("JSON", columnName: columnName) + } + + return data + default: + throw typeError("JSON", columnName: columnName) + } + } + + /// Unwrap and return a `UUID` value from this `SQLValue`. + /// This throws if the underlying `value` isn't a `.uuid` or + /// the `.uuid` has a nil associated value. + /// + /// - Throws: A `DatabaseError` if this field's `value` isn't a + /// `SQLValue.uuid` or its contents is nil. + /// - Returns: The unwrapped `UUID` of this field's value, if it + /// was indeed a non-null `.uuid`. + public func uuid(_ columnName: String? = nil) throws -> UUID { + try ensureNotNull(columnName) + + switch self { + case .uuid(let value): + return value + case .string(let string): + guard let uuid = UUID(string) else { + throw typeError("UUID", columnName: columnName) + } + + return uuid + default: + throw typeError("UUID", columnName: columnName) + } + } + + /// Generates an error appropriate to throw if the user tries to get a type + /// that isn't compatible with this value. + /// + /// - Parameter typeName: The name of the type the user tried to get. + /// - Returns: A `DatabaseError` with a message describing the predicament. + private func typeError(_ typeName: String, columnName: String? = nil) -> Error { + if let columnName = columnName { + return DatabaseError("Unable to coerce \(self) at column `\(columnName)` to \(typeName)") + } + + return DatabaseError("Unable to coerce \(self) to \(typeName).") + } + + private func ensureNotNull(_ columnName: String? = nil) throws { + if case .null = self { + let desc = columnName.map { "column `\($0)`" } ?? "SQLValue" + throw DatabaseError("Expected \(desc) to have a value but it was `nil`.") + } + } +} diff --git a/Sources/Alchemy/SQL/Database/Core/SQLValueConvertible.swift b/Sources/Alchemy/SQL/Database/Core/SQLValueConvertible.swift new file mode 100644 index 00000000..495e3d6f --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Core/SQLValueConvertible.swift @@ -0,0 +1,114 @@ +import Foundation + +public protocol SQLValueConvertible: SQLConvertible { + var value: SQLValue { get } +} + +extension SQLValueConvertible { + public var sql: SQL { + (self as? SQL) ?? SQL(sqlLiteral) + } + + /// A string appropriate for representing this value in a non-parameterized + /// query. + public var sqlLiteral: String { + switch self.value { + case .int(let value): + return "\(value)" + case .double(let value): + return "\(value)" + case .bool(let value): + return "\(value)" + case .string(let value): + // ' -> '' is escape for MySQL & Postgres... not sure if this will break elsewhere. + return "'\(value.replacingOccurrences(of: "'", with: "''"))'" + case .date(let value): + return "'\(value)'" + case .json(let value): + let rawString = String(data: value, encoding: .utf8) ?? "" + return "'\(rawString)'" + case .uuid(let value): + return "'\(value.uuidString)'" + case .null: + return "NULL" + } + } +} + +extension SQLValue: SQLValueConvertible { + public var value: SQLValue { self } +} + +extension String: SQLValueConvertible { + public var value: SQLValue { .string(self) } +} + +extension Int: SQLValueConvertible { + public var value: SQLValue { .int(self) } +} + +extension Int8: SQLValueConvertible { + public var value: SQLValue { .int(Int(self)) } +} + +extension Int16: SQLValueConvertible { + public var value: SQLValue { .int(Int(self)) } +} + +extension Int32: SQLValueConvertible { + public var value: SQLValue { .int(Int(self)) } +} + +extension Int64: SQLValueConvertible { + public var value: SQLValue { .int(Int(self)) } +} + +extension UInt: SQLValueConvertible { + public var value: SQLValue { .int(Int(self)) } +} + +extension UInt8: SQLValueConvertible { + public var value: SQLValue { .int(Int(self)) } +} + +extension UInt16: SQLValueConvertible { + public var value: SQLValue { .int(Int(self)) } +} + +extension UInt32: SQLValueConvertible { + public var value: SQLValue { .int(Int(self)) } +} + +extension UInt64: SQLValueConvertible { + public var value: SQLValue { .int(Int(self)) } +} + +extension Bool: SQLValueConvertible { + public var value: SQLValue { .bool(self) } +} + +extension Double: SQLValueConvertible { + public var value: SQLValue { .double(self) } +} + +extension Float: SQLValueConvertible { + public var value: SQLValue { .double(Double(self)) } +} + +extension Date: SQLValueConvertible { + public var value: SQLValue { .date(self) } +} + +extension UUID: SQLValueConvertible { + public var value: SQLValue { .uuid(self) } +} + +extension Optional: SQLConvertible where Wrapped: SQLValueConvertible {} + +extension Optional: SQLValueConvertible where Wrapped: SQLValueConvertible { + public var value: SQLValue { self?.value ?? .null } +} + +extension RawRepresentable where RawValue: SQLValueConvertible { + public var value: SQLValue { rawValue.value } +} diff --git a/Sources/Alchemy/SQL/Database/Database+Config.swift b/Sources/Alchemy/SQL/Database/Database+Config.swift new file mode 100644 index 00000000..a86904d1 --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Database+Config.swift @@ -0,0 +1,25 @@ +extension Database { + public struct Config { + public let databases: [Identifier: Database] + public let migrations: [Migration] + public let seeders: [Seeder] + public let redis: [RedisClient.Identifier: RedisClient] + + public init(databases: [Database.Identifier: Database], migrations: [Migration], seeders: [Seeder], redis: [RedisClient.Identifier: RedisClient]) { + self.databases = databases + self.migrations = migrations + self.seeders = seeders + self.redis = redis + } + } + + public static func configure(with config: Config) { + config.databases.forEach { id, db in + db.migrations = config.migrations + db.seeders = config.seeders + Database.bind(id, db) + } + + config.redis.forEach { RedisClient.bind($0, $1) } + } +} diff --git a/Sources/Alchemy/SQL/Database/Database.swift b/Sources/Alchemy/SQL/Database/Database.swift index 1cf8f909..b83aeb49 100644 --- a/Sources/Alchemy/SQL/Database/Database.swift +++ b/Sources/Alchemy/SQL/Database/Database.swift @@ -1,45 +1,32 @@ import Foundation -import PostgresKit /// Used for interacting with an SQL database. This class is an /// injectable `Service` so you can register the default one /// via `Database.config(default: .postgres())`. public final class Database: Service { - /// The driver of this database. - let driver: DatabaseDriver + public struct Identifier: ServiceIdentifier { + private let hashable: AnyHashable + public init(hashable: AnyHashable) { self.hashable = hashable } + } /// Any migrations associated with this database, whether applied /// yet or not. public var migrations: [Migration] = [] - /// Create a database backed by the given driver. - /// - /// - Parameter driver: The driver. - public init(driver: DatabaseDriver) { - self.driver = driver - } + /// Any seeders associated with this database. + public var seeders: [Seeder] = [] - /// Start a QueryBuilder query on this database. See `Query` or - /// QueryBuilder guides. - /// - /// Usage: - /// ```swift - /// database.query() - /// .from(table: "users") - /// .where("id" == 1) - /// .first() - /// .whenSuccess { row in - /// guard let row = row else { - /// return print("No row found :(") - /// } - /// - /// print("Got a row with fields: \(row.allColumns)") - /// } - /// ``` + /// The provider of this database. + let provider: DatabaseProvider + + /// Indicates whether migrations were run on this database, by this process. + var didRunMigrations: Bool = false + + /// Create a database backed by the given provider. /// - /// - Returns: The start of a QueryBuilder `Query`. - public func query() -> Query { - Query(database: driver) + /// - Parameter provider: The provider. + public init(provider: DatabaseProvider) { + self.provider = provider } /// Run a parameterized query on the database. Parameterization @@ -48,107 +35,47 @@ public final class Database: Service { /// Usage: /// ```swift /// // No bindings - /// db.rawQuery("SELECT * FROM users where id = 1") - /// .whenSuccess { rows - /// guard let first = rows.first else { - /// return print("No rows found :(") - /// } - /// - /// print("Got a user row with columns \(rows.allColumns)!") - /// } + /// let rows = try await db.rawQuery("SELECT * FROM users where id = 1") + /// print("Got \(rows.count) users.") /// /// // Bindings, to protect against SQL injection. - /// db.rawQuery("SELECT * FROM users where id = ?", values = [.int(1)]) - /// .whenSuccess { rows - /// ... - /// } + /// let rows = db.rawQuery("SELECT * FROM users where id = ?", values = [.int(1)]) + /// print("Got \(rows.count) users.") /// ``` /// /// - Parameters: /// - sql: The SQL string with '?'s denoting variables that /// should be parameterized. - /// - values: An array, `[DatabaseValue]`, that will replace the - /// '?'s in `sql`. Ensure there are the same amnount of values + /// - values: An array, `[SQLValue]`, that will replace the + /// '?'s in `sql`. Ensure there are the same amount of values /// as there are '?'s in `sql`. - /// - Returns: A future containing the rows returned by the query. - public func rawQuery(_ sql: String, values: [DatabaseValue] = []) -> EventLoopFuture<[DatabaseRow]> { - driver.runRawQuery(sql, values: values) + /// - Returns: The database rows returned by the query. + public func query(_ sql: String, values: [SQLValue] = []) async throws -> [SQLRow] { + try await provider.query(sql, values: values) + } + + /// Run a raw, not parametrized SQL string. + /// + /// - Returns: The rows returned by the query. + public func raw(_ sql: String) async throws -> [SQLRow] { + try await provider.raw(sql) } /// Runs a transaction on the database, using the given closure. /// All database queries in the closure are executed atomically. /// - /// Uses START TRANSACTION; and COMMIT; under the hood. + /// Uses START TRANSACTION; and COMMIT; or similar under the hood. /// /// - Parameter action: The action to run atomically. - /// - Returns: A future that completes when the transaction is - /// finished. - public func transaction(_ action: @escaping (Database) -> EventLoopFuture) -> EventLoopFuture { - driver.transaction { action(Database(driver: $0)) } + /// - Returns: The return value of the transaction. + public func transaction(_ action: @escaping (Database) async throws -> T) async throws -> T { + try await provider.transaction { try await action(Database(provider: $0)) } } /// Called when the database connection will shut down. /// /// - Throws: Any error that occurred when shutting down. public func shutdown() throws { - try driver.shutdown() - } - - /// Returns a `Query` for the default database. - public static func query() -> Query { - Query(database: Database.default.driver) + try provider.shutdown() } } - -/// A generic type to represent any database you might be interacting -/// with. Currently, the only two implementations are -/// `PostgresDatabase` and `MySQLDatabase`. The QueryBuilder and Rune -/// ORM are built on top of this abstraction. -public protocol DatabaseDriver { - /// Functions around compiling SQL statments for this database's - /// SQL dialect when using the QueryBuilder or Rune. - var grammar: Grammar { get } - - /// Run a parameterized query on the database. Parameterization - /// helps protect against SQL injection. - /// - /// Usage: - /// ```swift - /// // No bindings - /// db.runRawQuery("SELECT * FROM users where id = 1") - /// .whenSuccess { rows - /// guard let first = rows.first else { - /// return print("No rows found :(") - /// } - /// - /// print("Got a user row with columns \(rows.allColumns)!") - /// } - /// - /// // Bindings, to protect against SQL injection. - /// db.runRawQuery("SELECT * FROM users where id = ?", values = [.int(1)]) - /// .whenSuccess { rows - /// ... - /// } - /// ``` - /// - /// - Parameters: - /// - sql: The SQL string with '?'s denoting variables that - /// should be parameterized. - /// - values: An array, `[DatabaseValue]`, that will replace the - /// '?'s in `sql`. Ensure there are the same amnount of values - /// as there are '?'s in `sql`. - /// - Returns: An `EventLoopFuture` of the rows returned by the - /// query. - func runRawQuery(_ sql: String, values: [DatabaseValue]) -> EventLoopFuture<[DatabaseRow]> - - /// Runs a transaction on the database, using the given closure. - /// All database queries in the closure are executed atomically. - /// - /// Uses START TRANSACTION; and COMMIT; under the hood. - func transaction(_ action: @escaping (DatabaseDriver) -> EventLoopFuture) -> EventLoopFuture - - /// Called when the database connection will shut down. - /// - /// - Throws: Any error that occurred when shutting down. - func shutdown() throws -} diff --git a/Sources/Alchemy/SQL/Database/DatabaseProvider.swift b/Sources/Alchemy/SQL/Database/DatabaseProvider.swift new file mode 100644 index 00000000..8e04b870 --- /dev/null +++ b/Sources/Alchemy/SQL/Database/DatabaseProvider.swift @@ -0,0 +1,48 @@ +/// A generic type to represent any database you might be interacting +/// with. Currently, the only two implementations are +/// `PostgresDatabase` and `MySQLDatabase`. The QueryBuilder and Rune +/// ORM are built on top of this abstraction. +public protocol DatabaseProvider { + /// Functions around compiling SQL statments for this database's + /// SQL dialect when using the QueryBuilder or Rune. + var grammar: Grammar { get } + + /// Run a parameterized query on the database. Parameterization + /// helps protect against SQL injection. + /// + /// Usage: + /// + /// // No bindings + /// let rows = try await db.query("SELECT * FROM users where id = 1") + /// print("Got \(rows.count) users.") + /// + /// // Bindings, to protect against SQL injection. + /// let rows = try await db.query("SELECT * FROM users where id = ?", values = [.int(1)]) + /// print("Got \(rows.count) users.") + /// + /// - Parameters: + /// - sql: The SQL string with '?'s denoting variables that + /// should be parameterized. + /// - values: An array, `[SQLValue]`, that will replace the + /// '?'s in `sql`. Ensure there are the same amnount of values + /// as there are '?'s in `sql`. + /// - Returns: The database rows returned by the query. + func query(_ sql: String, values: [SQLValue]) async throws -> [SQLRow] + + /// Run a raw, not parametrized SQL string. + /// + /// - Returns: The rows returned by the query. + func raw(_ sql: String) async throws -> [SQLRow] + + /// Runs a transaction on the database, using the given closure. + /// All database queries in the closure are executed atomically. + /// + /// Uses START TRANSACTION; and COMMIT; under the hood. + /// + /// - Parameter action: The action to run atomically. + /// - Returns: The return value of the transaction. + func transaction(_ action: @escaping (DatabaseProvider) async throws -> T) async throws -> T + + /// Called when the database connection will shut down. + func shutdown() throws +} diff --git a/Sources/Alchemy/SQL/Database/Drivers/MySQL/Database+MySQL.swift b/Sources/Alchemy/SQL/Database/Drivers/MySQL/Database+MySQL.swift new file mode 100644 index 00000000..c817630f --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Drivers/MySQL/Database+MySQL.swift @@ -0,0 +1,24 @@ +import NIOSSL + +extension Database { + /// Creates a PostgreSQL database configuration. + /// + /// - Parameters: + /// - host: The host the database is running on. + /// - port: The port the database is running on. + /// - database: The name of the database to connect to. + /// - username: The username to authorize with. + /// - password: The password to authorize with. + /// - enableSSL: Should the connection use SSL. + /// - Returns: The configuration for connecting to this database. + public static func mysql(host: String, port: Int = 3306, database: String, username: String, password: String, enableSSL: Bool = false) -> Database { + var tlsConfig = enableSSL ? TLSConfiguration.makeClientConfiguration() : nil + tlsConfig?.certificateVerification = .none + return mysql(socket: .ip(host: host, port: port), database: database, username: username, password: password, tlsConfiguration: tlsConfig) + } + + /// Create a PostgreSQL database configuration. + public static func mysql(socket: Socket, database: String, username: String, password: String, tlsConfiguration: TLSConfiguration? = nil) -> Database { + Database(provider: MySQLDatabase(socket: socket, database: database, username: username, password: password, tlsConfiguration: tlsConfiguration)) + } +} diff --git a/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQL+Database.swift b/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQL+Database.swift deleted file mode 100644 index cacd9e6f..00000000 --- a/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQL+Database.swift +++ /dev/null @@ -1,144 +0,0 @@ -import MySQLKit -import NIO - -final class MySQLDatabase: DatabaseDriver { - /// The connection pool from which to make connections to the - /// database with. - private let pool: EventLoopGroupConnectionPool - - var grammar: Grammar = MySQLGrammar() - - /// Initialize with the given configuration. The configuration - /// will be connected to when a query is run. - /// - /// - Parameter config: The info needed to connect to the - /// database. - init(config: DatabaseConfig) { - self.pool = EventLoopGroupConnectionPool( - source: MySQLConnectionSource(configuration: { - switch config.socket { - case .ip(let host, let port): - var tlsConfig = config.enableSSL ? TLSConfiguration.makeClientConfiguration() : nil - tlsConfig?.certificateVerification = .none - return MySQLConfiguration( - hostname: host, - port: port, - username: config.username, - password: config.password, - database: config.database, - tlsConfiguration: tlsConfig - ) - case .unix(let name): - return MySQLConfiguration( - unixDomainSocketPath: name, - username: config.username, - password: config.password, - database: config.database - ) - } - }()), - on: Loop.group - ) - } - - // MARK: Database - - func runRawQuery(_ sql: String, values: [DatabaseValue]) -> EventLoopFuture<[DatabaseRow]> { - withConnection { $0.runRawQuery(sql, values: values) } - } - - /// MySQL doesn't have a way to return a row after inserting. This - /// runs a query and if MySQL metadata contains a `lastInsertID`, - /// fetches the row with that id from the given table. - /// - /// - Parameters: - /// - sql: The SQL to run. - /// - table: The table from which `lastInsertID` should be - /// fetched. - /// - values: Any bindings for the query. - /// - Returns: A future containing the result of fetching the last - /// inserted id, or the result of the original query. - func runAndReturnLastInsertedItem(_ sql: String, table: String, values: [DatabaseValue]) -> EventLoopFuture<[DatabaseRow]> { - pool.withConnection(logger: Log.logger, on: Loop.current) { conn in - var lastInsertId: Int? - return conn - .query(sql, values.map(MySQLData.init), onMetadata: { lastInsertId = $0.lastInsertID.map(Int.init) }) - .flatMap { rows -> EventLoopFuture<[MySQLRow]> in - if let lastInsertId = lastInsertId { - return conn.query("select * from \(table) where id = ?;", [MySQLData(.int(lastInsertId))]) - } else { - return .new(rows) - } - } - .map { $0.map(MySQLDatabaseRow.init) } - } - } - - func transaction(_ action: @escaping (DatabaseDriver) -> EventLoopFuture) -> EventLoopFuture { - withConnection { database in - let conn = database.conn - // SimpleQuery since MySQL can't handle START TRANSACTION in prepared statements. - return conn.simpleQuery("START TRANSACTION;") - .flatMap { _ in action(database) } - .flatMap { conn.simpleQuery("COMMIT;").transform(to: $0) } - } - } - - private func withConnection(_ action: @escaping (MySQLConnectionDatabase) -> EventLoopFuture) -> EventLoopFuture { - return pool.withConnection(logger: Log.logger, on: Loop.current) { - action(MySQLConnectionDatabase(conn: $0, grammar: self.grammar)) - } - } - - func shutdown() throws { - try self.pool.syncShutdownGracefully() - } -} - -public extension Database { - /// Creates a MySQL database configuration. - /// - /// - Parameters: - /// - host: The host the database is running on. - /// - port: The port the database is running on. - /// - database: The name of the database to connect to. - /// - username: The username to authorize with. - /// - password: The password to authorize with. - /// - Returns: The configuration for connecting to this database. - static func mysql(host: String, port: Int = 3306, database: String, username: String, password: String) -> Database { - return mysql(config: DatabaseConfig( - socket: .ip(host: host, port: port), - database: database, - username: username, - password: password - )) - } - - /// Create a MySQL database configuration. - /// - /// - Parameter config: The raw configuration to connect with. - /// - Returns: The configured database. - static func mysql(config: DatabaseConfig) -> Database { - Database(driver: MySQLDatabase(config: config)) - } -} - - -/// A database to send through on transactions. -private struct MySQLConnectionDatabase: DatabaseDriver { - let conn: MySQLConnection - let grammar: Grammar - - func runRawQuery(_ sql: String, values: [DatabaseValue]) -> EventLoopFuture<[DatabaseRow]> { - return conn.query(sql, values.map(MySQLData.init)) - .map { $0.map(MySQLDatabaseRow.init) } - } - - func transaction(_ action: @escaping (DatabaseDriver) -> EventLoopFuture) -> EventLoopFuture { - action(self) - } - - func shutdown() throws { - _ = conn.close() - } -} diff --git a/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQL+DatabaseRow.swift b/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQL+DatabaseRow.swift deleted file mode 100644 index bed3f736..00000000 --- a/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQL+DatabaseRow.swift +++ /dev/null @@ -1,104 +0,0 @@ -import MySQLNIO -import MySQLKit -import NIO - -public final class MySQLDatabaseRow: DatabaseRow { - public let allColumns: Set - private let row: MySQLRow - - init(_ row: MySQLRow) { - self.row = row - self.allColumns = Set(self.row.columnDefinitions.map(\.name)) - } - - public func getField(column: String) throws -> DatabaseField { - try self.row.column(column) - .unwrap(or: DatabaseError("No column named `\(column)` was found.")) - .toDatabaseField(from: column) - } -} - -extension MySQLData { - /// Initialize from an Alchemy `DatabaseValue`. - /// - /// - Parameter value: The value with which to initialize. Given - /// the type of the value, the `MySQLData` will be initialized - /// with the best corresponding type. - init(_ value: DatabaseValue) { - switch value { - case .bool(let value): - self = value.map(MySQLData.init(bool:)) ?? .null - case .date(let value): - self = value.map(MySQLData.init(date:)) ?? .null - case .double(let value): - self = value.map(MySQLData.init(double:)) ?? .null - case .int(let value): - self = value.map(MySQLData.init(int:)) ?? .null - case .json(let value): - guard let data = value else { - self = .null - return - } - - // `MySQLData` doesn't support initializing from - // `Foundation.Data`. - var buffer = ByteBufferAllocator().buffer(capacity: data.count) - buffer.writeBytes(data) - self = MySQLData(type: .string, format: .text, buffer: buffer, isUnsigned: true) - case .string(let value): - self = value.map(MySQLData.init(string:)) ?? .null - case .uuid(let value): - self = value.map(MySQLData.init(uuid:)) ?? .null - } - } - - /// Converts a `MySQLData` to the Alchemy `DatabaseField` type. - /// - /// - Parameter column: The name of the column this data is at. - /// - Throws: A `DatabaseError` if there is an issue converting - /// the `MySQLData` to its expected type. - /// - Returns: A `DatabaseField` with the column, type and value, - /// best representing this `MySQLData`. - func toDatabaseField(from column: String) throws -> DatabaseField { - func validateNil(_ value: T?) throws -> T? { - if self.buffer == nil { - return nil - } else { - let errorMessage = "Unable to unwrap expected type " - + "`\(Swift.type(of: T.self))` from column '\(column)'." - return try value.unwrap(or: DatabaseError(errorMessage)) - } - } - - switch self.type { - case .int24, .short, .long, .longlong: - let value = DatabaseValue.int(try validateNil(self.int)) - return DatabaseField(column: column, value: value) - case .tiny: - let value = DatabaseValue.bool(try validateNil(self.bool)) - return DatabaseField(column: column, value: value) - case .varchar, .string, .varString, .blob, .tinyBlob, .mediumBlob, .longBlob: - let value = DatabaseValue.string(try validateNil(self.string)) - return DatabaseField(column: column, value: value) - case .date, .timestamp, .timestamp2, .datetime, .datetime2: - let value = DatabaseValue.date(try validateNil(self.time?.date)) - return DatabaseField(column: column, value: value) - case .time: - throw DatabaseError("Times aren't supported yet.") - case .float, .decimal, .double: - let value = DatabaseValue.double(try validateNil(self.double)) - return DatabaseField(column: column, value: value) - case .json: - guard var buffer = self.buffer else { - return DatabaseField(column: column, value: .json(nil)) - } - - let data = buffer.readData(length: buffer.writerIndex) - return DatabaseField(column: column, value: .json(data)) - default: - let errorMessage = "Couldn't parse a `\(self.type)` from column " - + "'\(column)'. That MySQL datatype isn't supported, yet." - throw DatabaseError(errorMessage) - } - } -} diff --git a/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQL+Grammar.swift b/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQL+Grammar.swift deleted file mode 100644 index 0d3d76ac..00000000 --- a/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQL+Grammar.swift +++ /dev/null @@ -1,67 +0,0 @@ -import NIO - -/// A MySQL specific Grammar for compiling QueryBuilder statements -/// into SQL strings. -final class MySQLGrammar: Grammar { - override func compileDropIndex(table: String, indexName: String) -> SQL { - SQL("DROP INDEX \(indexName) ON \(table)") - } - - override func typeString(for type: ColumnType) -> String { - switch type { - case .bool: - return "boolean" - case .date: - return "datetime" - case .double: - return "double" - case .increments: - return "serial" - case .int: - return "int" - case .bigInt: - return "bigint" - case .json: - return "json" - case .string(let length): - switch length { - case .unlimited: - return "text" - case .limit(let characters): - return "varchar(\(characters))" - } - case .uuid: - // There isn't a MySQL UUID type; store UUIDs as a 36 - // length varchar. - return "varchar(36)" - } - } - - override func jsonLiteral(from jsonString: String) -> String { - "('\(jsonString)')" - } - - override func allowsUnsigned() -> Bool { - true - } - - // MySQL needs custom insert behavior, since bulk inserting and - // returning is not supported. - override func insert(_ values: [OrderedDictionary], query: Query, returnItems: Bool) -> EventLoopFuture<[DatabaseRow]> { - catchError { - guard - returnItems, - let table = query.from, - let database = query.database as? MySQLDatabase - else { - return super.insert(values, query: query, returnItems: returnItems) - } - - return try values - .map { try self.compileInsert(query, values: [$0]) } - .map { database.runAndReturnLastInsertedItem($0.query, table: table, values: $0.bindings) } - .flatten(on: Loop.current) - .map { $0.flatMap { $0 } } - } - } -} diff --git a/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQLDatabase.swift b/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQLDatabase.swift new file mode 100644 index 00000000..54d4506a --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQLDatabase.swift @@ -0,0 +1,87 @@ +import MySQLKit +import NIO + +final class MySQLDatabase: DatabaseProvider { + /// The connection pool from which to make connections to the + /// database with. + let pool: EventLoopGroupConnectionPool + + var grammar: Grammar = MySQLGrammar() + + init(socket: Socket, database: String, username: String, password: String, tlsConfiguration: TLSConfiguration? = nil) { + pool = EventLoopGroupConnectionPool( + source: MySQLConnectionSource(configuration: { + switch socket { + case .ip(let host, let port): + return MySQLConfiguration( + hostname: host, + port: port, + username: username, + password: password, + database: database, + tlsConfiguration: tlsConfiguration + ) + case .unix(let name): + return MySQLConfiguration( + unixDomainSocketPath: name, + username: username, + password: password, + database: database + ) + } + }()), + on: Loop.group + ) + } + + // MARK: Database + + func query(_ sql: String, values: [SQLValue]) async throws -> [SQLRow] { + try await withConnection { try await $0.query(sql, values: values) } + } + + func raw(_ sql: String) async throws -> [SQLRow] { + try await withConnection { try await $0.raw(sql) } + } + + func transaction(_ action: @escaping (DatabaseProvider) async throws -> T) async throws -> T { + try await withConnection { + _ = try await $0.raw("START TRANSACTION;") + let val = try await action($0) + _ = try await $0.raw("COMMIT;") + return val + } + } + + private func withConnection(_ action: @escaping (MySQLConnectionDatabase) async throws -> T) async throws -> T { + try await pool.withConnection(logger: Log.logger, on: Loop.current) { + try await action(MySQLConnectionDatabase(conn: $0, grammar: self.grammar)) + } + } + + func shutdown() throws { + try self.pool.syncShutdownGracefully() + } +} + +/// A database to send through on transactions. +private struct MySQLConnectionDatabase: DatabaseProvider { + let conn: MySQLConnection + let grammar: Grammar + + func query(_ sql: String, values: [SQLValue]) async throws -> [SQLRow] { + try await conn.query(sql, values.map(MySQLData.init)).get().map(MySQLDatabaseRow.init) + } + + func raw(_ sql: String) async throws -> [SQLRow] { + try await conn.simpleQuery(sql).get().map(MySQLDatabaseRow.init) + } + + func transaction(_ action: @escaping (DatabaseProvider) async throws -> T) async throws -> T { + try await action(self) + } + + func shutdown() throws { + _ = conn.close() + } +} diff --git a/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQLDatabaseRow.swift b/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQLDatabaseRow.swift new file mode 100644 index 00000000..006113d2 --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQLDatabaseRow.swift @@ -0,0 +1,84 @@ +import MySQLNIO +import MySQLKit +import NIO + +final class MySQLDatabaseRow: SQLRow { + let columns: Set + private let row: MySQLRow + + init(_ row: MySQLRow) { + self.row = row + self.columns = Set(self.row.columnDefinitions.map(\.name)) + } + + func get(_ column: String) throws -> SQLValue { + try row.column(column) + .unwrap(or: DatabaseError("No column named `\(column)` was found.")) + .toSQLValue(column) + } +} + +extension MySQLData { + /// Initialize from an Alchemy `SQLValue`. + /// + /// - Parameter value: The value with which to initialize. Given + /// the type of the value, the `MySQLData` will be initialized + /// with the best corresponding type. + init(_ value: SQLValue) { + switch value { + case .bool(let value): + self = MySQLData(bool: value) + case .date(let value): + self = MySQLData(date: value) + case .double(let value): + self = MySQLData(double: value) + case .int(let value): + self = MySQLData(int: value) + case .json(let value): + self = MySQLData(type: .json, format: .text, buffer: ByteBuffer(data: value)) + case .string(let value): + self = MySQLData(string: value) + case .uuid(let value): + self = MySQLData(string: value.uuidString) + case .null: + self = .null + } + } + + /// Converts a `MySQLData` to the Alchemy `SQLValue` type. + /// + /// - Parameter column: The name of the column this data is at. + /// - Throws: A `DatabaseError` if there is an issue converting + /// the `MySQLData` to its expected type. + /// - Returns: An `SQLValue` with the column, type and value, + /// best representing this `MySQLData`. + func toSQLValue(_ column: String? = nil) throws -> SQLValue { + switch self.type { + case .int24, .short, .long, .longlong: + return int.map { .int($0) } ?? .null + case .tiny: + return bool.map { .bool($0) } ?? .null + case .varchar, .string, .varString, .blob, .tinyBlob, .mediumBlob, .longBlob: + return string.map { .string($0) } ?? .null + case .date, .timestamp, .timestamp2, .datetime, .datetime2: + guard let date = time?.date else { + return .null + } + + return .date(date) + case .float, .decimal, .double: + return double.map { .double($0) } ?? .null + case .json: + guard let data = self.buffer?.data else { + return .null + } + + return .json(data) + case .null: + return .null + default: + let desc = column.map { "from column `\($0)`" } ?? "from MySQL column" + throw DatabaseError("Couldn't parse a `\(type)` from \(desc). That MySQL datatype isn't supported, yet.") + } + } +} diff --git a/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQLGrammar.swift b/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQLGrammar.swift new file mode 100644 index 00000000..8bab86a8 --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Drivers/MySQL/MySQLGrammar.swift @@ -0,0 +1,61 @@ +import NIO + +/// A MySQL specific Grammar for compiling QueryBuilder statements +/// into SQL strings. +final class MySQLGrammar: Grammar { + override func compileInsertReturn(_ table: String, values: [[String : SQLValueConvertible]]) -> [SQL] { + return values.flatMap { + return [ + compileInsert(table, values: [$0]), + SQL("select * from \(table) where id = LAST_INSERT_ID()") + ] + } + } + + override func compileDropIndex(on table: String, indexName: String) -> SQL { + SQL("DROP INDEX \(indexName) ON \(table)") + } + + override func columnTypeString(for type: ColumnType) -> String { + switch type { + case .bool: + return "boolean" + case .date: + return "datetime" + case .double: + return "double" + case .increments: + return "serial" + case .int: + return "int" + case .bigInt: + return "bigint" + case .json: + return "json" + case .string(let length): + switch length { + case .unlimited: + return "text" + case .limit(let characters): + return "varchar(\(characters))" + } + case .uuid: + // There isn't a MySQL UUID type; store UUIDs as a 36 + // length varchar. + return "varchar(36)" + } + } + + override func columnConstraintString(for constraint: ColumnConstraint, on column: String, of type: ColumnType) -> String? { + switch constraint { + case .unsigned: + return "UNSIGNED" + default: + return super.columnConstraintString(for: constraint, on: column, of: type) + } + } + + override func jsonLiteral(for jsonString: String) -> String { + "('\(jsonString)')" + } +} diff --git a/Sources/Alchemy/SQL/Database/Drivers/Postgres/Database+Postgres.swift b/Sources/Alchemy/SQL/Database/Drivers/Postgres/Database+Postgres.swift new file mode 100644 index 00000000..6ff64a4f --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Drivers/Postgres/Database+Postgres.swift @@ -0,0 +1,24 @@ +import NIOSSL + +extension Database { + /// Creates a PostgreSQL database configuration. + /// + /// - Parameters: + /// - host: The host the database is running on. + /// - port: The port the database is running on. + /// - database: The name of the database to connect to. + /// - username: The username to authorize with. + /// - password: The password to authorize with. + /// - enableSSL: Should the connection use SSL. + /// - Returns: The configuration for connecting to this database. + public static func postgres(host: String, port: Int = 5432, database: String, username: String, password: String, enableSSL: Bool = false) -> Database { + var tlsConfig = enableSSL ? TLSConfiguration.makeClientConfiguration() : nil + tlsConfig?.certificateVerification = .none + return postgres(socket: .ip(host: host, port: port), database: database, username: username, password: password, tlsConfiguration: tlsConfig) + } + + /// Create a PostgreSQL database configuration. + public static func postgres(socket: Socket, database: String, username: String, password: String, tlsConfiguration: TLSConfiguration? = nil) -> Database { + Database(provider: PostgresDatabase(socket: socket, database: database, username: username, password: password, tlsConfiguration: tlsConfiguration)) + } +} diff --git a/Sources/Alchemy/SQL/Database/Drivers/Postgres/Postgres+Database.swift b/Sources/Alchemy/SQL/Database/Drivers/Postgres/Postgres+Database.swift deleted file mode 100644 index dd996cd1..00000000 --- a/Sources/Alchemy/SQL/Database/Drivers/Postgres/Postgres+Database.swift +++ /dev/null @@ -1,158 +0,0 @@ -import Fusion -import Foundation -import PostgresKit -import NIO - -/// A concrete `Database` for connecting to and querying a PostgreSQL -/// database. -final class PostgresDatabase: DatabaseDriver { - /// The connection pool from which to make connections to the - /// database with. - private let pool: EventLoopGroupConnectionPool - - let grammar: Grammar = PostgresGrammar() - - /// Initialize with the given configuration. The configuration - /// will be connected to when a query is run. - /// - /// - Parameter config: the info needed to connect to the - /// database. - init(config: DatabaseConfig) { - self.pool = EventLoopGroupConnectionPool( - source: PostgresConnectionSource(configuration: { - switch config.socket { - case .ip(let host, let port): - var tlsConfig = config.enableSSL ? TLSConfiguration.makeClientConfiguration() : nil - tlsConfig?.certificateVerification = .none - return PostgresConfiguration( - hostname: host, - port: port, - username: config.username, - password: config.password, - database: config.database, - tlsConfiguration: tlsConfig - ) - case .unix(let name): - return PostgresConfiguration( - unixDomainSocketPath: name, - username: config.username, - password: config.password, - database: config.database - ) - } - }()), - on: Loop.group - ) - } - - // MARK: Database - - func runRawQuery(_ sql: String, values: [DatabaseValue]) -> EventLoopFuture<[DatabaseRow]> { - withConnection { $0.runRawQuery(sql, values: values) } - } - - func transaction(_ action: @escaping (DatabaseDriver) -> EventLoopFuture) -> EventLoopFuture { - withConnection { conn in - conn.runRawQuery("START TRANSACTION;", values: []) - .flatMap { _ in action(conn) } - .flatMap { conn.runRawQuery("COMMIT;", values: []).transform(to: $0) } - } - } - - func shutdown() throws { - try pool.syncShutdownGracefully() - } - - private func withConnection(_ action: @escaping (DatabaseDriver) -> EventLoopFuture) -> EventLoopFuture { - return pool.withConnection(logger: Log.logger, on: Loop.current) { - action(PostgresConnectionDatabase(conn: $0, grammar: self.grammar)) - } - } -} - -public extension Database { - /// Creates a PostgreSQL database configuration. - /// - /// - Parameters: - /// - host: The host the database is running on. - /// - port: The port the database is running on. - /// - database: The name of the database to connect to. - /// - username: The username to authorize with. - /// - password: The password to authorize with. - /// - Returns: The configuration for connecting to this database. - static func postgres(host: String, port: Int = 5432, database: String, username: String, password: String) -> Database { - return postgres(config: DatabaseConfig( - socket: .ip(host: host, port: port), - database: database, - username: username, - password: password - )) - } - - /// Create a PostgreSQL database configuration. - /// - /// - Parameter config: The raw configuration to connect with. - /// - Returns: The configured database. - static func postgres(config: DatabaseConfig) -> Database { - Database(driver: PostgresDatabase(config: config)) - } -} - -/// A database driver that is wrapped around a single connection to -/// with which to send transactions. -private struct PostgresConnectionDatabase: DatabaseDriver { - let conn: PostgresConnection - let grammar: Grammar - - func runRawQuery(_ sql: String, values: [DatabaseValue]) -> EventLoopFuture<[DatabaseRow]> { - conn.query(sql.positionPostgresBindings(), values.map(PostgresData.init)) - .map { $0.rows.map(PostgresDatabaseRow.init) } - } - - func transaction(_ action: @escaping (DatabaseDriver) -> EventLoopFuture) -> EventLoopFuture { - action(self) - } - - func shutdown() throws { - _ = conn.close() - } -} - -private extension String { - /// The Alchemy query builder constructs bindings with question - /// marks ('?') in the SQL string. PostgreSQL requires bindings - /// to be denoted by $1, $2, etc. This function converts all - /// '?'s to strings appropriate for Postgres bindings. - /// - /// - Parameter sql: The SQL string to replace bindings with. - /// - Returns: An SQL string appropriate for running in Postgres. - func positionPostgresBindings() -> String { - // TODO: Ensure a user can enter ? into their content? - replaceAll(matching: "(\\?)") { (index, _) in "$\(index + 1)" } - } - - /// Replace all instances of a regex pattern with a string, - /// determined by a closure. - /// - /// - Parameters: - /// - pattern: The pattern to replace. - /// - callback: The closure used to define replacements for the - /// pattern. Takes an index and a string that is the token to - /// replace. - /// - Returns: The string with replaced patterns. - func replaceAll(matching pattern: String, callback: (Int, String) -> String?) -> String { - let expression = try! NSRegularExpression(pattern: pattern, options: []) - let matches = expression - .matches(in: self, options: [], range: NSRange(startIndex.. - - private let row: PostgresRow - - init(_ row: PostgresRow) { - self.row = row - self.allColumns = Set(self.row.rowDescription.fields.map(\.name)) - } - - public func getField(column: String) throws -> DatabaseField { - try self.row.column(column) - .unwrap(or: DatabaseError("No column named `\(column)` was found \(allColumns).")) - .toDatabaseField(from: column) - } -} - -extension PostgresData { - /// Initialize from an Alchemy `DatabaseValue`. - /// - /// - Parameter value: the value with which to initialize. Given - /// the type of the value, the `PostgresData` will be - /// initialized with the best corresponding type. - init(_ value: DatabaseValue) { - switch value { - case .bool(let value): - self = value.map(PostgresData.init(bool:)) ?? PostgresData(type: .bool) - case .date(let value): - self = value.map(PostgresData.init(date:)) ?? PostgresData(type: .date) - case .double(let value): - self = value.map(PostgresData.init(double:)) ?? PostgresData(type: .float8) - case .int(let value): - self = value.map(PostgresData.init(int:)) ?? PostgresData(type: .int4) - case .json(let value): - self = value.map(PostgresData.init(json:)) ?? PostgresData(type: .json) - case .string(let value): - self = value.map(PostgresData.init(string:)) ?? PostgresData(type: .text) - case .uuid(let value): - self = value.map(PostgresData.init(uuid:)) ?? PostgresData(type: .uuid) - } - } - - /// Converts a `PostgresData` to the Alchemy `DatabaseField` type. - /// - /// - Parameter column: The name of the column this data is at. - /// - Throws: A `DatabaseError` if there is an issue converting - /// the `PostgresData` to its expected type. - /// - Returns: A `DatabaseField` with the column, type and value, - /// best representing this `PostgresData`. - fileprivate func toDatabaseField(from column: String) throws -> DatabaseField { - // Ensures that if value is nil, it's because the database - // column is actually nil and not because we are attempting - // to pull out the wrong type. - func validateNil(_ value: T?) throws -> T? { - if self.value == nil { - return nil - } else { - let errorMessage = "Unable to unwrap expected type" - + " `\(name(of: T.self))` from column '\(column)'." - return try value.unwrap(or: DatabaseError(errorMessage)) - } - } - - switch self.type { - case .int2, .int4, .int8: - let value = DatabaseValue.int(try validateNil(self.int)) - return DatabaseField(column: column, value: value) - case .bool: - let value = DatabaseValue.bool(try validateNil(self.bool)) - return DatabaseField(column: column, value: value) - case .varchar, .text: - let value = DatabaseValue.string(try validateNil(self.string)) - return DatabaseField(column: column, value: value) - case .date: - let value = DatabaseValue.date(try validateNil(self.date)) - return DatabaseField(column: column, value: value) - case .timestamptz, .timestamp: - let value = DatabaseValue.date(try validateNil(self.date)) - return DatabaseField(column: column, value: value) - case .time, .timetz: - throw DatabaseError("Times aren't supported yet.") - case .float4, .float8: - let value = DatabaseValue.double(try validateNil(self.double)) - return DatabaseField(column: column, value: value) - case .uuid: - // The `PostgresNIO` `UUID` parser doesn't seem to work - // properly `self.uuid` returns nil. - let string = try validateNil(self.string) - let uuid = try string.map { string -> UUID in - guard let uuid = UUID(uuidString: string) else { - throw DatabaseError( - "Invalid UUID '\(string)' at column '\(column)'" - ) - } - - return uuid - } - return DatabaseField(column: column, value: .uuid(uuid)) - case .json, .jsonb: - let value = DatabaseValue.json(try validateNil(self.json)) - return DatabaseField(column: column, value: value) - default: - throw DatabaseError("Couldn't parse a `\(self.type)` from column " - + "'\(column)'. That Postgres datatype " - + "isn't supported, yet.") - } - } -} diff --git a/Sources/Alchemy/SQL/Database/Drivers/Postgres/Postgres+Grammar.swift b/Sources/Alchemy/SQL/Database/Drivers/Postgres/Postgres+Grammar.swift deleted file mode 100644 index 36d0ec87..00000000 --- a/Sources/Alchemy/SQL/Database/Drivers/Postgres/Postgres+Grammar.swift +++ /dev/null @@ -1,9 +0,0 @@ -/// A Postgres specific Grammar for compiling QueryBuilder statements -/// into SQL strings. -final class PostgresGrammar: Grammar { - override func compileInsert(_ query: Query, values: [OrderedDictionary]) throws -> SQL { - var initial = try super.compileInsert(query, values: values) - initial.query.append(" returning *") - return initial - } -} diff --git a/Sources/Alchemy/SQL/Database/Drivers/Postgres/PostgresDatabase.swift b/Sources/Alchemy/SQL/Database/Drivers/Postgres/PostgresDatabase.swift new file mode 100644 index 00000000..83bc7e87 --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Drivers/Postgres/PostgresDatabase.swift @@ -0,0 +1,139 @@ +import Fusion +import Foundation +import PostgresKit +import NIO +import MySQLKit + +/// A concrete `Database` for connecting to and querying a PostgreSQL +/// database. +final class PostgresDatabase: DatabaseProvider { + /// The connection pool from which to make connections to the + /// database with. + let pool: EventLoopGroupConnectionPool + + let grammar: Grammar = PostgresGrammar() + + init(socket: Socket, database: String, username: String, password: String, tlsConfiguration: TLSConfiguration? = nil) { + pool = EventLoopGroupConnectionPool( + source: PostgresConnectionSource(configuration: { + switch socket { + case .ip(let host, let port): + return PostgresConfiguration( + hostname: host, + port: port, + username: username, + password: password, + database: database, + tlsConfiguration: tlsConfiguration + ) + case .unix(let name): + return PostgresConfiguration( + unixDomainSocketPath: name, + username: username, + password: password, + database: database + ) + } + }()), + on: Loop.group + ) + } + + // MARK: Database + + func query(_ sql: String, values: [SQLValue]) async throws -> [SQLRow] { + try await withConnection { try await $0.query(sql, values: values) } + } + + func raw(_ sql: String) async throws -> [SQLRow] { + try await withConnection { try await $0.raw(sql) } + } + + func transaction(_ action: @escaping (DatabaseProvider) async throws -> T) async throws -> T { + try await withConnection { conn in + _ = try await conn.raw("START TRANSACTION;") + do { + let val = try await action(conn) + _ = try await conn.raw("COMMIT;") + return val + } catch { + Log.error("[Database] postgres transaction failed with error \(error). Rolling back.") + _ = try await conn.raw("ROLLBACK;") + _ = try await conn.raw("COMMIT;") + throw error + } + } + } + + func shutdown() throws { + try pool.syncShutdownGracefully() + } + + private func withConnection(_ action: @escaping (DatabaseProvider) async throws -> T) async throws -> T { + try await pool.withConnection(logger: Log.logger, on: Loop.current) { + try await action($0) + } + } +} + +/// A database provider that is wrapped around a single connection to with which +/// to send transactions. +extension PostgresConnection: DatabaseProvider { + public var grammar: Grammar { PostgresGrammar() } + + public func query(_ sql: String, values: [SQLValue]) async throws -> [SQLRow] { + try await query(sql.positionPostgresBindings(), values.map(PostgresData.init)) + .get().rows.map(PostgresDatabaseRow.init) + } + + public func raw(_ sql: String) async throws -> [SQLRow] { + try await simpleQuery(sql).get().map(PostgresDatabaseRow.init) + } + + public func transaction(_ action: @escaping (DatabaseProvider) async throws -> T) async throws -> T { + try await action(self) + } + + public func shutdown() throws { + _ = close() + } +} + +extension String { + /// The Alchemy query builder constructs bindings with question + /// marks ('?') in the SQL string. PostgreSQL requires bindings + /// to be denoted by $1, $2, etc. This function converts all + /// '?'s to strings appropriate for Postgres bindings. + /// + /// - Parameter sql: The SQL string to replace bindings with. + /// - Returns: An SQL string appropriate for running in Postgres. + func positionPostgresBindings() -> String { + // TODO: Ensure a user can enter ? into their content? + replaceAll(matching: "(\\?)") { (index, _) in "$\(index + 1)" } + } + + /// Replace all instances of a regex pattern with a string, + /// determined by a closure. + /// + /// - Parameters: + /// - pattern: The pattern to replace. + /// - callback: The closure used to define replacements for the + /// pattern. Takes an index and a string that is the token to + /// replace. + /// - Returns: The string with replaced patterns. + func replaceAll(matching pattern: String, callback: (Int, String) -> String) -> String { + let expression = try! NSRegularExpression(pattern: pattern, options: []) + let matches = expression + .matches(in: self, options: [], range: NSRange(startIndex.. + private let row: PostgresRow + + init(_ row: PostgresRow) { + self.row = row + self.columns = Set(self.row.rowDescription.fields.map(\.name)) + } + + func get(_ column: String) throws -> SQLValue { + try row.column(column) + .unwrap(or: DatabaseError("No column named `\(column)` was found \(columns).")) + .toSQLValue(column) + } +} + +extension PostgresData { + /// Initialize from an Alchemy `SQLValue`. + /// + /// - Parameter value: the value with which to initialize. Given + /// the type of the value, the `PostgresData` will be + /// initialized with the best corresponding type. + init(_ value: SQLValue) { + switch value { + case .bool(let value): + self = PostgresData(bool: value) + case .date(let value): + self = PostgresData(date: value) + case .double(let value): + self = PostgresData(double: value) + case .int(let value): + self = PostgresData(int: value) + case .json(let value): + self = PostgresData(json: value) + case .string(let value): + self = PostgresData(string: value) + case .uuid(let value): + self = PostgresData(uuid: value) + case .null: + self = .null + } + } + + /// Converts a `PostgresData` to the Alchemy `SQLValue` type. + /// + /// - Parameter column: The name of the column this data is at. + /// - Throws: A `DatabaseError` if there is an issue converting + /// the `PostgresData` to its expected type. + /// - Returns: An `SQLValue` with the column, type and value, + /// best representing this `PostgresData`. + func toSQLValue(_ column: String? = nil) throws -> SQLValue { + switch self.type { + case .int2, .int4, .int8: + return int.map { .int($0) } ?? .null + case .bool: + return bool.map { .bool($0) } ?? .null + case .varchar, .text: + return string.map { .string($0) } ?? .null + case .date, .timestamptz, .timestamp: + return date.map { .date($0) } ?? .null + case .float4, .float8: + return double.map { .double($0) } ?? .null + case .uuid: + return uuid.map { .uuid($0) } ?? .null + case .json, .jsonb: + return json.map { .json($0) } ?? .null + case .null: + return .null + default: + let desc = column.map { "from column `\($0)`" } ?? "from PostgreSQL column" + throw DatabaseError("Couldn't parse a `\(type)` from \(desc). That PostgreSQL datatype isn't supported, yet.") + } + } +} diff --git a/Sources/Alchemy/SQL/Database/Drivers/Postgres/PostgresGrammar.swift b/Sources/Alchemy/SQL/Database/Drivers/Postgres/PostgresGrammar.swift new file mode 100644 index 00000000..372b5954 --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Drivers/Postgres/PostgresGrammar.swift @@ -0,0 +1,4 @@ +/// A Postgres specific Grammar for compiling QueryBuilder statements into SQL +/// strings. The base Grammar class is made for Postgres, so there isn't +/// anything to override at the moment. +final class PostgresGrammar: Grammar {} diff --git a/Sources/Alchemy/SQL/Database/Drivers/SQLite/Database+SQLite.swift b/Sources/Alchemy/SQL/Database/Drivers/SQLite/Database+SQLite.swift new file mode 100644 index 00000000..55666aaa --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Drivers/SQLite/Database+SQLite.swift @@ -0,0 +1,20 @@ +extension Database { + /// A file based SQLite database configuration. + /// + /// - Parameter path: The path of the SQLite database file. + /// - Returns: The configuration for connecting to this database. + public static func sqlite(path: String) -> Database { + Database(provider: SQLiteDatabase(config: .file(path))) + } + + /// An in memory SQLite database configuration with the given identifier. + public static func sqlite(identifier: String) -> Database { + Database(provider: SQLiteDatabase(config: .memory(identifier: identifier))) + } + + /// An in memory SQLite database configuration. + public static var sqlite: Database { .memory } + + /// An in memory SQLite database configuration. + public static var memory: Database { Database(provider: SQLiteDatabase(config: .memory)) } +} diff --git a/Sources/Alchemy/SQL/Database/Drivers/SQLite/SQLiteDatabase.swift b/Sources/Alchemy/SQL/Database/Drivers/SQLite/SQLiteDatabase.swift new file mode 100644 index 00000000..f0bcac7e --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Drivers/SQLite/SQLiteDatabase.swift @@ -0,0 +1,86 @@ +import SQLiteKit + +final class SQLiteDatabase: DatabaseProvider { + /// The connection pool from which to make connections to the + /// database with. + let pool: EventLoopGroupConnectionPool + let config: Config + let grammar: Grammar = SQLiteGrammar() + + enum Config: Equatable { + case memory(identifier: String = UUID().uuidString) + case file(String) + + static var memory: Config { memory() } + } + + /// Initialize with the given configuration. The configuration + /// will be connected to when a query is run. + /// + /// - Parameter config: the info needed to connect to the + /// database. + init(config: Config) { + self.config = config + self.pool = EventLoopGroupConnectionPool( + source: SQLiteConnectionSource(configuration: { + switch config { + case .memory(let id): + return SQLiteConfiguration(storage: .memory(identifier: id), enableForeignKeys: true) + case .file(let path): + return SQLiteConfiguration(storage: .file(path: path), enableForeignKeys: true) + } + }(), threadPool: Thread.pool), + on: Loop.group + ) + } + + // MARK: Database + + func query(_ sql: String, values: [SQLValue]) async throws -> [SQLRow] { + try await withConnection { try await $0.query(sql, values: values) } + } + + func raw(_ sql: String) async throws -> [SQLRow] { + try await withConnection { try await $0.raw(sql) } + } + + func transaction(_ action: @escaping (DatabaseProvider) async throws -> T) async throws -> T { + try await withConnection { conn in + _ = try await conn.raw("BEGIN;") + let val = try await action(conn) + _ = try await conn.raw("COMMIT;") + return val + } + } + + func shutdown() throws { + try pool.syncShutdownGracefully() + } + + private func withConnection(_ action: @escaping (DatabaseProvider) async throws -> T) async throws -> T { + try await pool.withConnection(logger: Log.logger, on: Loop.current) { + try await action(SQLiteConnectionDatabase(conn: $0, grammar: self.grammar)) + } + } +} + +private struct SQLiteConnectionDatabase: DatabaseProvider { + let conn: SQLiteConnection + let grammar: Grammar + + func query(_ sql: String, values: [SQLValue]) async throws -> [SQLRow] { + try await conn.query(sql, values.map(SQLiteData.init)).get().map(SQLiteDatabaseRow.init) + } + + func raw(_ sql: String) async throws -> [SQLRow] { + try await conn.query(sql).get().map(SQLiteDatabaseRow.init) + } + + func transaction(_ action: @escaping (DatabaseProvider) async throws -> T) async throws -> T { + try await action(self) + } + + func shutdown() throws { + _ = conn.close() + } +} diff --git a/Sources/Alchemy/SQL/Database/Drivers/SQLite/SQLiteDatabaseRow.swift b/Sources/Alchemy/SQL/Database/Drivers/SQLite/SQLiteDatabaseRow.swift new file mode 100644 index 00000000..a07b9d91 --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Drivers/SQLite/SQLiteDatabaseRow.swift @@ -0,0 +1,71 @@ +import SQLiteNIO + +struct SQLiteDatabaseRow: SQLRow { + let columns: Set + private let row: SQLiteRow + + init(_ row: SQLiteRow) { + self.row = row + self.columns = Set(row.columns.map(\.name)) + } + + func get(_ column: String) throws -> SQLValue { + try row.column(column) + .unwrap(or: DatabaseError("No column named `\(column)` was found \(columns).")) + .toSQLValue() + } +} + +extension SQLiteData { + /// Initialize from an Alchemy `SQLValue`. + /// + /// - Parameter value: the value with which to initialize. Given + /// the type of the value, the `SQLiteData` will be + /// initialized with the best corresponding type. + init(_ value: SQLValue) { + switch value { + case .bool(let value): + self = value ? .integer(1) : .integer(0) + case .date(let value): + self = .text(SQLValue.iso8601DateFormatter.string(from: value)) + case .double(let value): + self = .float(value) + case .int(let value): + self = .integer(value) + case .json(let value): + guard let jsonString = String(data: value, encoding: .utf8) else { + self = .null + return + } + + self = .text(jsonString) + case .string(let value): + self = .text(value) + case .uuid(let value): + self = .text(value.uuidString) + case .null: + self = .null + } + } + + /// Converts a `SQLiteData` to the Alchemy `SQLValue` type. + /// + /// - Throws: A `DatabaseError` if there is an issue converting + /// the `SQLiteData` to its expected type. + /// - Returns: A `SQLValue` with the column, type and value, + /// best representing this `SQLiteData`. + func toSQLValue() throws -> SQLValue { + switch self { + case .integer(let int): + return .int(int) + case .float(let double): + return .double(double) + case .text(let string): + return .string(string) + case .blob: + throw DatabaseError("SQLite blob isn't supported yet") + case .null: + return .null + } + } +} diff --git a/Sources/Alchemy/SQL/Database/Drivers/SQLite/SQLiteGrammar.swift b/Sources/Alchemy/SQL/Database/Drivers/SQLite/SQLiteGrammar.swift new file mode 100644 index 00000000..73af6691 --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Drivers/SQLite/SQLiteGrammar.swift @@ -0,0 +1,56 @@ +final class SQLiteGrammar: Grammar { + override func compileInsertReturn(_ table: String, values: [[String : SQLValueConvertible]]) -> [SQL] { + return values.flatMap { fields -> [SQL] in + // If the id is already set, search the database for that. Otherwise + // assume id is autoincrementing and search for the last rowid. + let id = fields["id"] + let idString = id == nil ? "last_insert_rowid()" : "?" + return [ + compileInsert(table, values: [fields]), + SQL("select * from \(table) where id = \(idString)", bindings: [id?.value].compactMap { $0 }) + ] + } + } + + // No locks are supported with SQLite; the entire database is locked on + // write anyways. + override func compileLock(_ lock: Query.Lock?) -> SQL? { + return nil + } + + override func columnTypeString(for type: ColumnType) -> String { + switch type { + case .bool: + return "integer" + case .date: + return "datetime" + case .double: + return "double" + case .increments: + return "integer PRIMARY KEY AUTOINCREMENT" + case .int: + return "integer" + case .bigInt: + return "integer" + case .json: + return "text" + case .string: + return "text" + case .uuid: + return "text" + } + } + + override func columnConstraintString(for constraint: ColumnConstraint, on column: String, of type: ColumnType) -> String? { + switch constraint { + case .primaryKey where type == .increments: + return nil + default: + return super.columnConstraintString(for: constraint, on: column, of: type) + } + } + + override func jsonLiteral(for jsonString: String) -> String { + "'\(jsonString)'" + } +} diff --git a/Sources/Alchemy/SQL/Database/Seeding/Database+Seeder.swift b/Sources/Alchemy/SQL/Database/Seeding/Database+Seeder.swift new file mode 100644 index 00000000..c91664e0 --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Seeding/Database+Seeder.swift @@ -0,0 +1,34 @@ +extension Database { + /// Seeds the database by running each seeder in `seeders` + /// consecutively. + public func seed() async throws { + for seeder in seeders { + try await seeder.run() + } + } + + public func seed(with seeder: Seeder) async throws { + try await seeder.run() + } + + func seed(names seederNames: [String]) async throws { + let toRun = try seederNames.map { name in + return try seeders + .first(where: { + $0.name.lowercased() == name.lowercased() || + $0.name.lowercased().droppingSuffix("seeder") == name.lowercased() + }) + .unwrap(or: DatabaseError("Unable to find a seeder on this database named \(name) or \(name)Seeder.")) + } + + for seeder in toRun { + try await seeder.run() + } + } +} + +extension Seeder { + fileprivate var name: String { + Alchemy.name(of: Self.self) + } +} diff --git a/Sources/Alchemy/SQL/Database/Seeding/Seeder.swift b/Sources/Alchemy/SQL/Database/Seeding/Seeder.swift new file mode 100644 index 00000000..da7411be --- /dev/null +++ b/Sources/Alchemy/SQL/Database/Seeding/Seeder.swift @@ -0,0 +1,34 @@ +import Fakery + +public protocol Seeder { + func run() async throws +} + +public protocol Seedable { + static func generate() async throws -> Self +} + +extension Seedable where Self: Model { + @discardableResult + public static func seed() async throws -> Self { + try await generate().save() + } + + @discardableResult + public static func seed(_ count: Int) async throws -> [Self] { + var rows: [Self] = [] + for _ in 1...count { + rows.append(try await generate()) + } + + return try await rows.insertReturnAll() + } +} + +extension Faker { + static let `default` = Faker() +} + +extension Model { + public static var faker: Faker { .default } +} diff --git a/Sources/Alchemy/SQL/Migrations/Builders/AlterTableBuilder.swift b/Sources/Alchemy/SQL/Migrations/Builders/AlterTableBuilder.swift index 6aa6c18a..9e95e6f5 100644 --- a/Sources/Alchemy/SQL/Migrations/Builders/AlterTableBuilder.swift +++ b/Sources/Alchemy/SQL/Migrations/Builders/AlterTableBuilder.swift @@ -17,12 +17,12 @@ extension AlterTableBuilder { /// /// - Parameter column: The name of the column to drop. public func drop(column: String) { - self.dropColumns.append(column) + dropColumns.append(column) } /// Drop the `created_at` and `updated_at` columns. public func dropTimestamps() { - self.dropColumns.append(contentsOf: ["created_at", "updated_at"]) + dropColumns.append(contentsOf: ["created_at", "updated_at"]) } /// Rename a column. @@ -31,13 +31,13 @@ extension AlterTableBuilder { /// - column: The name of the column to rename. /// - to: The new name for the column. public func rename(column: String, to: String) { - self.renameColumns.append((from: column, to: to)) + renameColumns.append((from: column, to: to)) } /// Drop an index. /// /// - Parameter index: The name of the index to drop. public func drop(index: String) { - self.dropIndexes.append(index) + dropIndexes.append(index) } } diff --git a/Sources/Alchemy/SQL/Migrations/Builders/CreateColumnBuilder.swift b/Sources/Alchemy/SQL/Migrations/Builders/CreateColumnBuilder.swift index 34e15f12..e2586654 100644 --- a/Sources/Alchemy/SQL/Migrations/Builders/CreateColumnBuilder.swift +++ b/Sources/Alchemy/SQL/Migrations/Builders/CreateColumnBuilder.swift @@ -4,46 +4,11 @@ protocol ColumnBuilderErased { func toCreate() -> CreateColumn } -/// Options for an `onDelete` or `onUpdate` reference constraint. -public enum ReferenceOption: String { - /// RESTRICT - case restrict = "RESTRICT" - /// CASCADE - case cascade = "CASCADE" - /// SET NULL - case setNull = "SET NULL" - /// NO ACTION - case noAction = "NO ACTION" - /// SET DEFAULT - case setDefault = "SET DEFAULT" -} - -/// Various constraints for columns. -enum ColumnConstraint { - /// This column shouldn't be null. - case notNull - /// The default value for this column. - case `default`(String) - /// This column is the primary key of it's table. - case primaryKey - /// This column is unique on this table. - case unique - /// This column references a `column` on another `table`. - case foreignKey( - column: String, - table: String, - onDelete: ReferenceOption? = nil, - onUpdate: ReferenceOption? = nil - ) - /// This int column is unsigned. - case unsigned -} - /// A builder for creating columns on a table in a relational database. /// /// `Default` is a Swift type that can be used to add a default value /// to this column. -public final class CreateColumnBuilder: ColumnBuilderErased { +public final class CreateColumnBuilder: ColumnBuilderErased { /// The grammar of this builder. private let grammar: Grammar @@ -71,6 +36,14 @@ public final class CreateColumnBuilder: ColumnBuilderEras self.constraints = constraints } + // MARK: ColumnBuilderErased + + func toCreate() -> CreateColumn { + CreateColumn(name: self.name, type: self.type, constraints: self.constraints) + } +} + +extension CreateColumnBuilder { /// Adds an expression as the default value of this column. /// /// - Parameter expression: An expression for generating the @@ -88,10 +61,10 @@ public final class CreateColumnBuilder: ColumnBuilderEras // Janky, but MySQL requires parentheses around text (but not // varchar...) literals. if case .string(.unlimited) = self.type, self.grammar is MySQLGrammar { - return self.adding(constraint: .default("(\(val.toSQL().query))")) + return self.adding(constraint: .default("(\(val.sqlLiteral))")) } - return self.adding(constraint: .default(val.toSQL().query)) + return self.adding(constraint: .default(val.sqlLiteral)) } /// Define this column as not nullable. @@ -115,8 +88,8 @@ public final class CreateColumnBuilder: ColumnBuilderEras @discardableResult public func references( _ column: String, on table: String, - onDelete: ReferenceOption? = nil, - onUpdate: ReferenceOption? = nil + onDelete: ColumnConstraint.ReferenceOption? = nil, + onUpdate: ColumnConstraint.ReferenceOption? = nil ) -> Self { self.adding(constraint: .foreignKey(column: column, table: table, onDelete: onDelete, onUpdate: onUpdate)) } @@ -143,12 +116,6 @@ public final class CreateColumnBuilder: ColumnBuilderEras self.constraints.append(constraint) return self } - - // MARK: ColumnBuilderErased - - func toCreate() -> CreateColumn { - CreateColumn(column: self.name, type: self.type, constraints: self.constraints) - } } extension CreateColumnBuilder where Default == Int { @@ -167,7 +134,7 @@ extension CreateColumnBuilder where Default == Date { /// /// - Returns: This column builder. @discardableResult public func defaultNow() -> Self { - self.default(expression: "NOW()") + self.default(expression: "CURRENT_TIMESTAMP") } } @@ -179,7 +146,7 @@ extension CreateColumnBuilder where Default == SQLJSON { /// for this column. /// - Returns: This column builder. @discardableResult public func `default`(jsonString: String) -> Self { - self.adding(constraint: .default(self.grammar.jsonLiteral(from: jsonString))) + self.adding(constraint: .default(self.grammar.jsonLiteral(for: jsonString))) } /// Adds an `Encodable` as the default for this column. @@ -199,44 +166,10 @@ extension CreateColumnBuilder where Default == SQLJSON { } let jsonString = String(decoding: jsonData, as: UTF8.self) - return self.adding(constraint: .default(self.grammar.jsonLiteral(from: jsonString))) + return self.adding(constraint: .default(self.grammar.jsonLiteral(for: jsonString))) } } -extension Bool: Sequelizable { - public func toSQL() -> SQL { SQL("\(self)") } -} - -extension UUID: Sequelizable { - public func toSQL() -> SQL { SQL("'\(self.uuidString)'") } -} - -extension String: Sequelizable { - public func toSQL() -> SQL { SQL("'\(self)'") } -} - -extension Int: Sequelizable { - public func toSQL() -> SQL { SQL("\(self)") } -} - -extension Double: Sequelizable { - public func toSQL() -> SQL { SQL("\(self)") } -} - -extension Date: Sequelizable { - /// The date formatter for turning this `Date` into an SQL string. - private static let sqlFormatter: DateFormatter = { - let df = DateFormatter() - df.timeZone = TimeZone(abbreviation: "GMT") - df.dateFormat = "yyyy-MM-dd'T'HH:mm:ss" - return df - }() - - // MARK: Sequelizable - - public func toSQL() -> SQL { SQL("'\(Date.sqlFormatter.string(from: self))'") } -} - /// A type used to signify that a column on a database has a JSON /// type. /// @@ -244,11 +177,11 @@ extension Date: Sequelizable { /// generic `default` function on `CreateColumnBuilder`. Instead, /// opt to use `.default(jsonString:)` or `.default(encodable:)` /// to set a default value for a JSON column. -public struct SQLJSON: Sequelizable { +public struct SQLJSON: SQLValueConvertible { /// `init()` is kept private to this from ever being instantiated. private init() {} - // MARK: Sequelizable + // MARK: SQLConvertible - public func toSQL() -> SQL { SQL() } + public var value: SQLValue { .null } } diff --git a/Sources/Alchemy/SQL/Migrations/Builders/CreateTableBuilder.swift b/Sources/Alchemy/SQL/Migrations/Builders/CreateTableBuilder.swift index f021e86d..8c926617 100644 --- a/Sources/Alchemy/SQL/Migrations/Builders/CreateTableBuilder.swift +++ b/Sources/Alchemy/SQL/Migrations/Builders/CreateTableBuilder.swift @@ -11,9 +11,14 @@ public class CreateTableBuilder { /// All the columns to create on this table. var createColumns: [CreateColumn] { - self.columnBuilders.map { $0.toCreate() } + columnBuilders.map { $0.toCreate() } } + /// References to the builders for all the columns on this table. + /// Need to store these since they may be modified via column + /// builder functions. + private var columnBuilders: [ColumnBuilderErased] = [] + /// Create a table builder with the given grammar. /// /// - Parameter grammar: The grammar with which this builder will @@ -21,12 +26,9 @@ public class CreateTableBuilder { init(grammar: Grammar) { self.grammar = grammar } - - /// References to the builders for all the columns on this table. - /// Need to store these since they may be modified via column - /// builder functions. - private var columnBuilders: [ColumnBuilderErased] = [] - +} + +extension CreateTableBuilder { /// Add an index. /// /// It's name will be `__...` @@ -79,7 +81,7 @@ public class CreateTableBuilder { /// - Returns: A builder for adding modifiers to the column. @discardableResult public func string( _ column: String, - length: StringLength = .limit(255) + length: ColumnType.StringLength = .limit(255) ) -> CreateColumnBuilder { self.appendAndReturn(builder: CreateColumnBuilder(grammar: self.grammar, name: column, type: .string(length))) } @@ -134,66 +136,8 @@ public class CreateTableBuilder { /// - Parameter builder: The column builder to add to this table /// builder. /// - Returns: The passed in `builder`. - private func appendAndReturn( builder: CreateColumnBuilder) -> CreateColumnBuilder { + private func appendAndReturn( builder: CreateColumnBuilder) -> CreateColumnBuilder { self.columnBuilders.append(builder) return builder } } - -/// A type for keeping track of data associated with creating an -/// index. -public struct CreateIndex { - /// The columns that make up this index. - let columns: [String] - - /// Whether this index is unique or not. - let isUnique: Bool - - /// Generate an SQL string for creating this index on a given - /// table. - /// - /// - Parameter table: The name of the table this index will be - /// created on. - /// - Returns: An SQL string for creating this index on the given - /// table. - func toSQL(table: String) -> String { - let indexType = self.isUnique ? "UNIQUE INDEX" : "INDEX" - let indexName = self.name(table: table) - let indexColumns = "(\(self.columns.map(\.sqlEscaped).joined(separator: ", ")))" - return "CREATE \(indexType) \(indexName) ON \(table) \(indexColumns)" - } - - /// Generate the name of this index given the table it will be - /// created on. - /// - /// - Parameter table: The table this index will be created on. - /// - Returns: The name of this index. - private func name(table: String) -> String { - ([table] + self.columns + [self.nameSuffix]).joined(separator: "_") - } - - /// The suffix of the index name. "key" if it's a unique index, - /// "idx" if not. - private var nameSuffix: String { - self.isUnique ? "key" : "idx" - } -} - -/// A type for keeping track of data associated with creating an -/// column. -public struct CreateColumn { - /// The name. - let column: String - - /// The type string. - let type: ColumnType - - /// Any constraints. - let constraints: [ColumnConstraint] -} - -extension String { - var sqlEscaped: String { - "\"\(self)\"" - } -} diff --git a/Sources/Alchemy/SQL/Migrations/Schema.swift b/Sources/Alchemy/SQL/Migrations/Builders/Schema.swift similarity index 53% rename from Sources/Alchemy/SQL/Migrations/Schema.swift rename to Sources/Alchemy/SQL/Migrations/Builders/Schema.swift index f1210c01..94a50678 100644 --- a/Sources/Alchemy/SQL/Migrations/Schema.swift +++ b/Sources/Alchemy/SQL/Migrations/Builders/Schema.swift @@ -22,22 +22,12 @@ public class Schema { /// - ifNotExists: If the query should silently not be run if /// the table already exists. Defaults to `false`. /// - builder: A closure for building the new table. - public func create( - table: String, - ifNotExists: Bool = false, - builder: (inout CreateTableBuilder) -> Void - ) { - var createBuilder = CreateTableBuilder(grammar: self.grammar) + public func create(table: String, ifNotExists: Bool = false, builder: (inout CreateTableBuilder) -> Void) { + var createBuilder = CreateTableBuilder(grammar: grammar) builder(&createBuilder) - - let createColumns = self.grammar.compileCreate( - table: table, - ifNotExists: ifNotExists, - columns: createBuilder.createColumns - ) - let createIndexes = self.grammar - .compileCreateIndexes(table: table, indexes: createBuilder.createIndexes) - self.statements.append(contentsOf: [createColumns] + createIndexes) + let createColumns = grammar.compileCreateTable(table, ifNotExists: ifNotExists, columns: createBuilder.createColumns) + let createIndexes = grammar.compileCreateIndexes(on: table, indexes: createBuilder.createIndexes) + statements.append(contentsOf: [createColumns] + createIndexes) } /// Alter an existing table with the supplied builder. @@ -47,28 +37,20 @@ public class Schema { /// - builder: A closure passing a builder for defining what /// should be altered. public func alter(table: String, builder: (inout AlterTableBuilder) -> Void) { - var alterBuilder = AlterTableBuilder(grammar: self.grammar) + var alterBuilder = AlterTableBuilder(grammar: grammar) builder(&alterBuilder) - - let changes = self.grammar.compileAlter( - table: table, - dropColumns: alterBuilder.dropColumns, - addColumns: alterBuilder.createColumns - ) - let renames = alterBuilder.renameColumns - .map { self.grammar.compileRenameColumn(table: table, column: $0.from, to: $0.to) } - let dropIndexes = alterBuilder.dropIndexes - .map { self.grammar.compileDropIndex(table: table, indexName: $0) } - let createIndexes = self.grammar - .compileCreateIndexes(table: table, indexes: alterBuilder.createIndexes) - self.statements.append(contentsOf: changes + renames + dropIndexes + createIndexes) + let changes = grammar.compileAlterTable(table, dropColumns: alterBuilder.dropColumns, addColumns: alterBuilder.createColumns) + let renames = alterBuilder.renameColumns.map { grammar.compileRenameColumn(on: table, column: $0.from, to: $0.to) } + let dropIndexes = alterBuilder.dropIndexes.map { grammar.compileDropIndex(on: table, indexName: $0) } + let createIndexes = grammar.compileCreateIndexes(on: table, indexes: alterBuilder.createIndexes) + statements.append(contentsOf: changes + renames + dropIndexes + createIndexes) } /// Drop a table. /// /// - Parameter table: The table to drop. public func drop(table: String) { - self.statements.append(self.grammar.compileDrop(table: table)) + statements.append(grammar.compileDropTable(table)) } /// Rename a table. @@ -77,7 +59,7 @@ public class Schema { /// - table: The table to rename. /// - to: The new name for the table. public func rename(table: String, to: String) { - self.statements.append(self.grammar.compileRename(table: table, to: to)) + statements.append(grammar.compileRenameTable(table, to: to)) } /// Execute a raw SQL statement when running this migration @@ -85,6 +67,6 @@ public class Schema { /// /// - Parameter sql: The raw SQL string to execute. public func raw(sql: String) { - self.statements.append(SQL(sql, bindings: [])) + statements.append(SQL(sql, bindings: [])) } } diff --git a/Sources/Alchemy/SQL/Migrations/CreateColumn.swift b/Sources/Alchemy/SQL/Migrations/CreateColumn.swift new file mode 100644 index 00000000..1791e429 --- /dev/null +++ b/Sources/Alchemy/SQL/Migrations/CreateColumn.swift @@ -0,0 +1,79 @@ +/// A type for keeping track of data associated with creating an +/// column. +public struct CreateColumn { + /// The name for this column. + let name: String + + /// The type string. + let type: ColumnType + + /// Any constraints. + let constraints: [ColumnConstraint] +} + +/// An abstraction around various supported SQL column types. +/// `Grammar`s will map the `ColumnType` to the backing +/// dialect type string. +public enum ColumnType: Equatable { + /// The length of an SQL string column in characters. + public enum StringLength: Equatable { + /// This value of this column can be any number of characters. + case unlimited + /// This value of this column must be at most the provided number + /// of characters. + case limit(Int) + } + + /// Self incrementing integer. + case increments + /// Integer. + case int + /// Big integer. + case bigInt + /// Double. + case double + /// String, with a given max length. + case string(StringLength) + /// UUID. + case uuid + /// Boolean. + case bool + /// Date. + case date + /// JSON. + case json +} + +/// Various constraints for columns. +public enum ColumnConstraint { + /// Options for an `onDelete` or `onUpdate` reference constraint. + public enum ReferenceOption: String { + /// RESTRICT + case restrict = "RESTRICT" + /// CASCADE + case cascade = "CASCADE" + /// SET NULL + case setNull = "SET NULL" + /// NO ACTION + case noAction = "NO ACTION" + /// SET DEFAULT + case setDefault = "SET DEFAULT" + } + + /// This column shouldn't be null. + case notNull + /// The default value for this column. + case `default`(String) + /// This column is the primary key of it's table. + case primaryKey + /// This column is unique on this table. + case unique + /// This column references a `column` on another `table`. + case foreignKey( + column: String, + table: String, + onDelete: ReferenceOption? = nil, + onUpdate: ReferenceOption? = nil) + /// This int column is unsigned. + case unsigned +} diff --git a/Sources/Alchemy/SQL/Migrations/CreateIndex.swift b/Sources/Alchemy/SQL/Migrations/CreateIndex.swift new file mode 100644 index 00000000..7f31cadf --- /dev/null +++ b/Sources/Alchemy/SQL/Migrations/CreateIndex.swift @@ -0,0 +1,20 @@ +/// A type for keeping track of data associated with creating an +/// index. +public struct CreateIndex { + /// The columns that make up this index. + let columns: [String] + + /// Whether this index is unique or not. + let isUnique: Bool + + /// Generate the name of this index given the table it will be created on. + /// The name will be suffixed with "key" if it's a unique index or "idx" + /// if not. + /// + /// - Parameter table: The table this index will be created on. + /// - Returns: The name of this index. + func name(table: String) -> String { + let suffix = isUnique ? "key" : "idx" + return ([table] + columns + [suffix]).joined(separator: "_") + } +} diff --git a/Sources/Alchemy/SQL/Migrations/Database+Migration.swift b/Sources/Alchemy/SQL/Migrations/Database+Migration.swift index 93e14a92..3f289d09 100644 --- a/Sources/Alchemy/SQL/Migrations/Database+Migration.swift +++ b/Sources/Alchemy/SQL/Migrations/Database+Migration.swift @@ -4,96 +4,77 @@ import NIO extension Database { /// Applies all outstanding migrations to the database in a single /// batch. Migrations are read from `database.migrations`. - /// - /// - Returns: A future that completes when all migrations have - /// been applied. - public func migrate() -> EventLoopFuture { - // 1. Get all already migrated migrations - self.getMigrations() - // 2. Figure out which database migrations should be - // migrated - .map { alreadyMigrated in - let currentBatch = alreadyMigrated.map(\.batch).max() ?? 0 - let migrationsToRun = self.migrations.filter { pendingMigration in - !alreadyMigrated.contains(where: { $0.name == pendingMigration.name }) - } - - if migrationsToRun.isEmpty { - Log.info("[Migration] no new migrations to apply.") - } else { - Log.info("[Migration] applying \(migrationsToRun.count) migrations.") - } - - return (migrationsToRun, currentBatch + 1) - } - // 3. Run migrations & record in migration table - .flatMap(self.upMigrations) + public func migrate() async throws { + let alreadyMigrated = try await getMigrations() + + let currentBatch = alreadyMigrated.map(\.batch).max() ?? 0 + let migrationsToRun = migrations.filter { pendingMigration in + !alreadyMigrated.contains(where: { $0.name == pendingMigration.name }) + } + + if migrationsToRun.isEmpty { + Log.info("[Migration] no new migrations to apply.") + } else { + Log.info("[Migration] applying \(migrationsToRun.count) migrations.") + } + + try await upMigrations(migrationsToRun, batch: currentBatch + 1) + didRunMigrations = true } /// Rolls back the latest migration batch. - /// - /// - Returns: A future that completes when the rollback is - /// complete. - public func rollbackMigrations() -> EventLoopFuture { - Log.info("[Migration] rolling back last batch of migrations.") - return self.getMigrations() - .map { alreadyMigrated -> [Migration] in - guard let latestBatch = alreadyMigrated.map({ $0.batch }).max() else { - return [] - } - - let namesToRollback = alreadyMigrated.filter { $0.batch == latestBatch }.map(\.name) - let migrationsToRollback = self.migrations.filter { namesToRollback.contains($0.name) } - - return migrationsToRollback - } - .flatMap(self.downMigrations) + public func rollbackMigrations() async throws { + let alreadyMigrated = try await getMigrations() + guard let latestBatch = alreadyMigrated.map({ $0.batch }).max() else { + return + } + + let namesToRollback = alreadyMigrated.filter { $0.batch == latestBatch }.map(\.name) + let migrationsToRollback = migrations.filter { namesToRollback.contains($0.name) } + + if migrationsToRollback.isEmpty { + Log.info("[Migration] no migrations roll back.") + } else { + Log.info("[Migration] rolling back the \(migrationsToRollback.count) migrations from the last batch.") + } + + try await downMigrations(migrationsToRollback) } /// Gets any existing migrations. Creates the migration table if /// it doesn't already exist. /// - /// - Returns: A future containing an array of all the migrations - /// that have been applied to this database. - private func getMigrations() -> EventLoopFuture<[AlchemyMigration]> { - query() - .from(table: "information_schema.tables") - .where("table_name" == AlchemyMigration.tableName) - .count() - .flatMap { value in - guard value != 0 else { - Log.info("[Migration] creating '\(AlchemyMigration.tableName)' table.") - let statements = AlchemyMigration.Migration().upStatements(for: self.driver.grammar) - return self.rawQuery(statements.first!.query).voided() - } - - return .new() - } - .flatMap { - AlchemyMigration.query(database: self).allModels() - } + /// - Returns: The migrations that are applied to this database. + private func getMigrations() async throws -> [AlchemyMigration] { + let count: Int + if provider is PostgresDatabase || provider is MySQLDatabase { + count = try await table("information_schema.tables").where("table_name" == AlchemyMigration.tableName).count() + } else { + count = try await table("sqlite_master") + .where("type" == "table") + .where(Query.Where(type: .value(key: "name", op: .notLike, value: .string("sqlite_%")), boolean: .and)) + .count() + } + + if count == 0 { + Log.info("[Migration] creating '\(AlchemyMigration.tableName)' table.") + let statements = AlchemyMigration.Migration().upStatements(for: provider.grammar) + try await runStatements(statements: statements) + } + + return try await AlchemyMigration.query(database: self).get() } /// Run the `.down` functions of an array of migrations, in order. /// /// - Parameter migrations: The migrations to rollback on this /// database. - /// - Returns: A future that completes when the rollback is - /// finished. - private func downMigrations(_ migrations: [Migration]) -> EventLoopFuture { - var elf = Loop.current.future() + private func downMigrations(_ migrations: [Migration]) async throws { for m in migrations.sorted(by: { $0.name > $1.name }) { - let statements = m.downStatements(for: self.driver.grammar) - elf = elf.flatMap { self.runStatements(statements: statements) } - .flatMap { - AlchemyMigration.query() - .where("name" == m.name) - .delete() - .voided() - } + let statements = m.downStatements(for: provider.grammar) + try await runStatements(statements: statements) + try await AlchemyMigration.query(database: self).where("name" == m.name).delete() } - - return elf } /// Run the `.up` functions of an array of migrations in order. @@ -103,37 +84,20 @@ extension Database { /// - batch: The migration batch of these migrations. Based on /// any existing batches that have been applied on the /// database. - /// - Returns: A future that completes when the migration is - /// applied. - private func upMigrations(_ migrations: [Migration], batch: Int) -> EventLoopFuture { - var elf = Loop.current.future() + private func upMigrations(_ migrations: [Migration], batch: Int) async throws { for m in migrations { - let statements = m.upStatements(for: self.driver.grammar) - elf = elf.flatMap { self.runStatements(statements: statements) } - .flatMap { - AlchemyMigration(name: m.name, batch: batch, runAt: Date()) - .save(db: self) - .voided() - } + let statements = m.upStatements(for: provider.grammar) + try await runStatements(statements: statements) + _ = try await AlchemyMigration(name: m.name, batch: batch, runAt: Date()).save(db: self) } - - return elf } /// Consecutively run a list of SQL statements on this database. /// /// - Parameter statements: The statements to consecutively run. - /// - Returns: A future that completes when all statements have - /// been run. - private func runStatements(statements: [SQL]) -> EventLoopFuture { - var elf = Loop.current.future() + private func runStatements(statements: [SQL]) async throws { for statement in statements { - elf = elf.flatMap { _ in - self.rawQuery(statement.query, values: statement.bindings) - .voided() - } + _ = try await query(statement.statement, values: statement.bindings) } - - return elf.voided() } } diff --git a/Sources/Alchemy/SQL/Query/Builder/Query+CRUD.swift b/Sources/Alchemy/SQL/Query/Builder/Query+CRUD.swift new file mode 100644 index 00000000..a7768f29 --- /dev/null +++ b/Sources/Alchemy/SQL/Query/Builder/Query+CRUD.swift @@ -0,0 +1,131 @@ +extension Query { + /// Run a select query and return the database rows. + /// + /// - Note: Optional columns can be provided that override the + /// original select columns. + /// - Parameter columns: The columns you would like returned. + /// Defaults to `nil`. + /// - Returns: The rows returned by the database. + public func getRows(_ columns: [String]? = nil) async throws -> [SQLRow] { + if let columns = columns { + self.columns = columns + } + + let sql = try database.grammar.compileSelect( + table: table, + isDistinct: isDistinct, + columns: self.columns, + joins: joins, + wheres: wheres, + groups: groups, + havings: havings, + orders: orders, + limit: limit, + offset: offset, + lock: lock) + return try await database.query(sql.statement, values: sql.bindings) + } + + /// Run a select query and return the first database row only row. + /// + /// - Note: Optional columns can be provided that override the + /// original select columns. + /// - Parameter columns: The columns you would like returned. + /// Defaults to `nil`. + /// - Returns: The first row in the database, if it exists. + public func firstRow(_ columns: [String]? = nil) async throws -> SQLRow? { + try await limit(1).getRows(columns).first + } + + /// Run a select query that looks for a single row matching the + /// given database column and value. + /// + /// - Note: Optional columns can be provided that override the + /// original select columns. + /// - Parameter columns: The columns you would like returned. + /// Defaults to `nil`. + /// - Returns: The row from the database, if it exists. + public func findRow(_ column: String, equals value: SQLValue, columns: [String]? = nil) async throws -> SQLRow? { + wheres.append(column == value) + return try await limit(1).getRows(columns).first + } + + /// Find the total count of the rows that match the given query. + /// + /// - Parameter column: What column to count. Defaults to `*`. + /// - Returns: The count returned by the database. + public func count(column: String = "*") async throws -> Int { + let row = try await select(["COUNT(\(column))"]).firstRow() + .unwrap(or: DatabaseError("a COUNT query didn't return any rows")) + let column = try row.columns.first + .unwrap(or: DatabaseError("a COUNT query didn't return any columns")) + return try row.get(column).value.int() + } + + /// Perform an insert and create a database row from the provided + /// data. + /// + /// - Parameter value: A dictionary containing the values to be + /// inserted. + public func insert(_ value: [String: SQLValueConvertible]) async throws { + try await insert([value]) + } + + /// Perform an insert and create database rows from the provided data. + /// + /// - Parameter values: An array of dictionaries containing the values to be + /// inserted. + public func insert(_ values: [[String: SQLValueConvertible]]) async throws { + let sql = database.grammar.compileInsert(table, values: values) + _ = try await database.query(sql.statement, values: sql.bindings) + return + } + + public func insertReturn(_ values: [String: SQLValueConvertible]) async throws -> [SQLRow] { + try await insertReturn([values]) + } + + /// Perform an insert and return the inserted records. + /// + /// - Parameter values: An array of dictionaries containing the values to be + /// inserted. + /// - Returns: The inserted rows. + public func insertReturn(_ values: [[String: SQLValueConvertible]]) async throws -> [SQLRow] { + let statements = database.grammar.compileInsertReturn(table, values: values) + return try await database.transaction { conn in + var toReturn: [SQLRow] = [] + for sql in statements { + toReturn.append(contentsOf: try await conn.query(sql.statement, values: sql.bindings)) + } + + return toReturn + } + } + + /// Perform an update on all data matching the query in the + /// builder with the values provided. + /// + /// For example, if you wanted to update the first name of a user + /// whose ID equals 10, you could do so as follows: + /// ```swift + /// database + /// .table("users") + /// .where("id" == 10) + /// .update(values: [ + /// "first_name": "Ashley" + /// ]) + /// ``` + /// + /// - Parameter values: An dictionary containing the values to be + /// updated. + public func update(values: [String: SQLValueConvertible]) async throws { + let sql = try database.grammar.compileUpdate(table, joins: joins, wheres: wheres, values: values) + _ = try await database.query(sql.statement, values: sql.bindings) + } + + /// Perform a deletion on all data matching the given query. + public func delete() async throws { + let sql = try database.grammar.compileDelete(table, wheres: wheres) + _ = try await database.query(sql.statement, values: sql.bindings) + } +} diff --git a/Sources/Alchemy/SQL/Query/Builder/Query+Grouping.swift b/Sources/Alchemy/SQL/Query/Builder/Query+Grouping.swift new file mode 100644 index 00000000..234b3305 --- /dev/null +++ b/Sources/Alchemy/SQL/Query/Builder/Query+Grouping.swift @@ -0,0 +1,49 @@ +extension Query { + /// Group returned data by a given column. + /// + /// - Parameter group: The table column to group data on. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func groupBy(_ group: String) -> Self { + groups.append(group) + return self + } + + /// Add a having clause to filter results from aggregate + /// functions. + /// + /// - Parameter clause: A `WhereValue` clause matching a column to a + /// value. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func having(_ clause: Where) -> Self { + havings.append(clause) + return self + } + + /// An alias for `having(_ clause:) ` that appends an or clause + /// instead of an and clause. + /// + /// - Parameter clause: A `WhereValue` clause matching a column to a + /// value. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func orHaving(_ clause: Where) -> Self { + having(Where(type: clause.type, boolean: .or)) + } + + /// Add a having clause to filter results from aggregate functions + /// that matches a given key to a provided value. + /// + /// - Parameters: + /// - key: The column to match against. + /// - op: The `Operator` to be used in the comparison. + /// - value: The value that the column should match. + /// - boolean: How the clause should be appended (`.and` or + /// `.or`). + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func having(key: String, op: Operator, value: SQLValueConvertible, boolean: WhereBoolean = .and) -> Self { + having(Where(type: .value(key: key, op: op, value: value.value), boolean: boolean)) + } +} diff --git a/Sources/Alchemy/SQL/Query/Builder/Query+Join.swift b/Sources/Alchemy/SQL/Query/Builder/Query+Join.swift new file mode 100644 index 00000000..dd3a7592 --- /dev/null +++ b/Sources/Alchemy/SQL/Query/Builder/Query+Join.swift @@ -0,0 +1,135 @@ +extension Query { + /// The type of the join clause. + public enum JoinType: String { + /// INNER JOIN. + case inner + /// OUTER JOIN. + case outer + /// LEFT JOIN. + case left + /// RIGHT JOIN. + case right + /// CROSS JOIN. + case cross + } + + /// A JOIN query builder. + public final class Join: Query { + /// The type of the join to perform. + var type: JoinType + /// The table to join to. + let joinTable: String + /// The join conditions + var joinWheres: [Query.Where] = [] + + /// Create a join builder with a query, type, and table. + /// + /// - Parameters: + /// - database: The database the join table is on. + /// - type: The type of join this is. + /// - joinTable: The name of the table to join to. + init(database: DatabaseProvider, table: String, type: JoinType, joinTable: String) { + self.type = type + self.joinTable = joinTable + super.init(database: database, table: table) + } + + func on(first: String, op: Operator, second: String, boolean: WhereBoolean = .and) -> Join { + joinWheres.append(Where(type: .column(first: first, op: op, second: second), boolean: boolean)) + return self + } + + func orOn(first: String, op: Operator, second: String) -> Join { + on(first: first, op: op, second: second, boolean: .or) + } + + override func isEqual(to other: Query) -> Bool { + guard let other = other as? Join else { + return false + } + + return super.isEqual(to: other) && + type == other.type && + joinTable == other.joinTable && + joinWheres == other.joinWheres + } + } + + /// Join data from a separate table into the current query data. + /// + /// - Parameters: + /// - table: The table to be joined. + /// - first: The column from the current query to be matched. + /// - op: The `Operator` to be used in the comparison. Defaults + /// to `.equals`. + /// - second: The column from the joining table to be matched. + /// - type: The `JoinType` of the sql join. Defaults to + /// `.inner`. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func join(table: String, first: String, op: Operator = .equals, second: String, type: JoinType = .inner) -> Self { + joins.append( + Join(database: database, table: self.table, type: type, joinTable: table) + .on(first: first, op: op, second: second) + ) + return self + } + + /// Joins data from a separate table into the current query, using the given + /// conditions closure. + /// + /// - Parameters: + /// - table: The table to join with. + /// - type: The type of join. Defaults to `.inner` + /// - conditions: A closure that sets the conditions on the join using. + /// - Returns: This query builder. + public func join(table: String, type: JoinType = .inner, conditions: (Join) -> Join) -> Self { + joins.append(conditions(Join(database: database, table: self.table, type: type, joinTable: table))) + return self + } + + /// Left join data from a separate table into the current query + /// data. + /// + /// - Parameters: + /// - table: The table to be joined. + /// - first: The column from the current query to be matched. + /// - op: The `Operator` to be used in the comparison. Defaults + /// to `.equals`. + /// - second: The column from the joining table to be matched. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func leftJoin(table: String, first: String, op: Operator = .equals, second: String) -> Self { + join(table: table, first: first, op: op, second: second, type: .left) + } + + /// Right join data from a separate table into the current query + /// data. + /// + /// - Parameters: + /// - table: The table to be joined. + /// - first: The column from the current query to be matched. + /// - op: The `Operator` to be used in the comparison. Defaults + /// to `.equals`. + /// - second: The column from the joining table to be matched. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func rightJoin(table: String, first: String, op: Operator = .equals, second: String) -> Self { + join(table: table, first: first, op: op, second: second, type: .right) + } + + /// Cross join data from a separate table into the current query + /// data. + /// + /// - Parameters: + /// - table: The table to be joined. + /// - first: The column from the current query to be matched. + /// - op: The `Operator` to be used in the comparison. Defaults + /// to `.equals`. + /// - second: The column from the joining table to be matched. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func crossJoin(table: String, first: String, op: Operator = .equals, second: String) -> Self { + join(table: table, first: first, op: op, second: second, type: .cross) + } +} diff --git a/Sources/Alchemy/SQL/Query/Builder/Query+Lock.swift b/Sources/Alchemy/SQL/Query/Builder/Query+Lock.swift new file mode 100644 index 00000000..be96d785 --- /dev/null +++ b/Sources/Alchemy/SQL/Query/Builder/Query+Lock.swift @@ -0,0 +1,22 @@ +extension Query { + public struct Lock: Equatable { + public enum Strength: String { + case update + case share + } + + public enum Option: String { + case noWait + case skipLocked + } + + let strength: Strength + let option: Option? + } + + /// Adds custom locking SQL to the end of a SELECT query. + public func lock(for strength: Lock.Strength, option: Lock.Option? = nil) -> Self { + self.lock = Lock(strength: strength, option: option) + return self + } +} diff --git a/Sources/Alchemy/SQL/Query/Builder/Query+Operator.swift b/Sources/Alchemy/SQL/Query/Builder/Query+Operator.swift new file mode 100644 index 00000000..7cace1a4 --- /dev/null +++ b/Sources/Alchemy/SQL/Query/Builder/Query+Operator.swift @@ -0,0 +1,27 @@ +extension Query { + public enum Operator: CustomStringConvertible, Equatable { + case equals + case lessThan + case greaterThan + case lessThanOrEqualTo + case greaterThanOrEqualTo + case notEqualTo + case like + case notLike + case raw(String) + + public var description: String { + switch self { + case .equals: return "=" + case .lessThan: return "<" + case .greaterThan: return ">" + case .lessThanOrEqualTo: return "<=" + case .greaterThanOrEqualTo: return ">=" + case .notEqualTo: return "!=" + case .like: return "LIKE" + case .notLike: return "NOT LIKE" + case .raw(let value): return value + } + } + } +} diff --git a/Sources/Alchemy/SQL/Query/Builder/Query+Order.swift b/Sources/Alchemy/SQL/Query/Builder/Query+Order.swift new file mode 100644 index 00000000..c108c6a2 --- /dev/null +++ b/Sources/Alchemy/SQL/Query/Builder/Query+Order.swift @@ -0,0 +1,40 @@ +extension Query { + /// A clause for ordering rows by a certain column. + public struct Order: Equatable { + /// A sorting direction. + public enum Direction: String { + /// Sort elements in ascending order. + case asc + /// Sort elements in descending order. + case desc + } + + /// The column to order by. + let column: String + /// The direction to order by. + let direction: Direction + } + + /// Order the data from the query based on given clause. + /// + /// - Parameter order: The `OrderClause` that defines the + /// ordering. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func orderBy(_ order: Order) -> Self { + orders.append(order) + return self + } + + /// Order the data from the query based on a column and direction. + /// + /// - Parameters: + /// - column: The column to order data by. + /// - direction: The `OrderClause.Sort` direction (either `.asc` + /// or `.desc`). Defaults to `.asc`. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func orderBy(_ column: String, direction: Order.Direction = .asc) -> Self { + orderBy(Order(column: column, direction: direction)) + } +} diff --git a/Sources/Alchemy/SQL/Query/Builder/Query+Paging.swift b/Sources/Alchemy/SQL/Query/Builder/Query+Paging.swift new file mode 100644 index 00000000..6aa92d83 --- /dev/null +++ b/Sources/Alchemy/SQL/Query/Builder/Query+Paging.swift @@ -0,0 +1,37 @@ +extension Query { + /// Limit the returned results to a given amount. + /// + /// - Parameter value: An amount to cap the total result at. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func limit(_ value: Int) -> Self { + self.limit = max(0, value) + return self + } + + /// Offset the returned results by a given amount. + /// + /// - Parameter value: An amount representing the offset. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func offset(_ value: Int) -> Self { + self.offset = max(0, value) + return self + } + + /// A helper method to be used when needing to page returned + /// results. Internally this uses the `limit` and `offset` + /// methods. + /// + /// - Note: Paging starts at index 1, not 0. + /// + /// - Parameters: + /// - page: What `page` of results to offset by. + /// - perPage: How many results to show on each page. Defaults + /// to `25`. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func forPage(_ page: Int, perPage: Int = 25) -> Self { + offset((page - 1) * perPage).limit(perPage) + } +} diff --git a/Sources/Alchemy/SQL/Query/Builder/Query+Select.swift b/Sources/Alchemy/SQL/Query/Builder/Query+Select.swift new file mode 100644 index 00000000..7cf277c6 --- /dev/null +++ b/Sources/Alchemy/SQL/Query/Builder/Query+Select.swift @@ -0,0 +1,37 @@ +extension Query { + /// Set the columns that should be returned by the query. + /// + /// - Parameters: + /// - columns: An array of columns to be returned by the query. + /// Defaults to `[*]`. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func select(_ columns: String...) -> Self { + self.columns = columns + return self + } + + /// Set the columns that should be returned by the query. + /// + /// - Parameters: + /// - columns: An array of columns to be returned by the query. + /// Defaults to `[*]`. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func select(_ columns: [String] = ["*"]) -> Self { + self.columns = columns + return self + } + + /// Set query to only return distinct entries. + /// + /// - Parameter columns: An array of columns to be returned by the query. + /// Defaults to `[*]`. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func distinct(_ columns: [String] = ["*"]) -> Self { + self.columns = columns + self.isDistinct = true + return self + } +} diff --git a/Sources/Alchemy/SQL/Query/Builder/Query+Where.swift b/Sources/Alchemy/SQL/Query/Builder/Query+Where.swift new file mode 100644 index 00000000..5a0008b6 --- /dev/null +++ b/Sources/Alchemy/SQL/Query/Builder/Query+Where.swift @@ -0,0 +1,278 @@ +protocol WhereClause: SQLConvertible {} + +extension Query { + public indirect enum WhereType: Equatable { + case value(key: String, op: Operator, value: SQLValue) + case column(first: String, op: Operator, second: String) + case nested(wheres: [Where]) + case `in`(key: String, values: [SQLValue], type: WhereInType) + case raw(SQL) + } + + public enum WhereBoolean: String { + case and + case or + } + + public enum WhereInType: String { + case `in` + case notIn + } + + public struct Where: Equatable { + public let type: WhereType + public let boolean: WhereBoolean + } + + /// Add a basic where clause to the query to filter down results. + /// + /// - Parameters: + /// - clause: A `WhereValue` clause matching a column to a given + /// value. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func `where`(_ clause: Where) -> Self { + wheres.append(clause) + return self + } + + /// An alias for `where(_ clause: WhereValue) ` that appends an or + /// clause instead of an and clause. + /// + /// - Parameters: + /// - clause: A `WhereValue` clause matching a column to a given + /// value. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func orWhere(_ clause: Where) -> Self { + `where`(Where(type: clause.type, boolean: .or)) + } + + /// Add a nested where clause that is a group of combined clauses. + /// This can be used for logically grouping where clauses like + /// you would inside of an if statement. Each clause is + /// wrapped in parenthesis. + /// + /// For example if you want to logically ensure a user is under 30 + /// and named Paul, or over the age of 50 having any name, you + /// could use a nested where clause along with a separate + /// where value clause: + /// ```swift + /// Query + /// .from("users") + /// .where { + /// $0.where("age" < 30) + /// .orWhere("first_name" == "Paul") + /// } + /// .where("age" > 50) + /// ``` + /// + /// - Parameters: + /// - closure: A `WhereNestedClosure` that provides a nested + /// clause to attach nested where clauses to. + /// - boolean: How the clause should be appended(`.and` or + /// `.or`). Defaults to `.and`. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func `where`(_ closure: @escaping (Query) -> Query, boolean: WhereBoolean = .and) -> Self { + let query = closure(Query(database: database, table: table)) + wheres.append(Where(type: .nested(wheres: query.wheres), boolean: boolean)) + return self + } + + /// A helper for adding an **or** `where` nested closure clause. + /// + /// - Parameters: + /// - closure: A `WhereNestedClosure` that provides a nested + /// query to attach nested where clauses to. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func orWhere(_ closure: @escaping (Query) -> Query) -> Self { + `where`(closure, boolean: .or) + } + + /// Add a clause requiring that a column match any values in a + /// given array. + /// + /// - Parameters: + /// - key: The column to match against. + /// - values: The values that the column should not match. + /// - type: How the match should happen (*in* or *notIn*). + /// Defaults to `.in`. + /// - boolean: How the clause should be appended (`.and` or + /// `.or`). Defaults to `.and`. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func `where`(key: String, in values: [SQLValueConvertible], type: WhereInType = .in, boolean: WhereBoolean = .and) -> Self { + wheres.append(Where(type: .in(key: key, values: values.map { $0.value }, type: type), boolean: boolean)) + return self + } + + /// A helper for adding an **or** variant of the `where(key:in:)` clause. + /// + /// - Parameters: + /// - key: The column to match against. + /// - values: The values that the column should not match. + /// - type: How the match should happen (`.in` or `.notIn`). + /// Defaults to `.in`. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func orWhere(key: String, in values: [SQLValueConvertible], type: WhereInType = .in) -> Self { + `where`(key: key, in: values, type: type, boolean: .or) + } + + /// Add a clause requiring that a column not match any values in a + /// given array. This is a helper method for the where in method. + /// + /// - Parameters: + /// - key: The column to match against. + /// - values: The values that the column should not match. + /// - boolean: How the clause should be appended (`.and` or + /// `.or`). Defaults to `.and`. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func whereNot(key: String, in values: [SQLValueConvertible], boolean: WhereBoolean = .and) -> Self { + `where`(key: key, in: values, type: .notIn, boolean: boolean) + } + + /// A helper for adding an **or** `whereNot` clause. + /// + /// - Parameters: + /// - key: The column to match against. + /// - values: The values that the column should not match. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func orWhereNot(key: String, in values: [SQLValueConvertible]) -> Self { + `where`(key: key, in: values, type: .notIn, boolean: .or) + } + + /// Add a raw SQL where clause to your query. + /// + /// - Parameters: + /// - sql: A string representing the SQL where clause to be run. + /// - bindings: Any variables for binding in the SQL. + /// - boolean: How the clause should be appended (`.and` or + /// `.or`). Defaults to `.and`. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func whereRaw(sql: String, bindings: [SQLValueConvertible], boolean: WhereBoolean = .and) -> Self { + wheres.append(Where(type: .raw(SQL(sql, bindings: bindings.map(\.value))), boolean: boolean)) + return self + } + + /// A helper for adding an **or** `whereRaw` clause. + /// + /// - Parameters: + /// - sql: A string representing the SQL where clause to be run. + /// - bindings: Any variables for binding in the SQL. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func orWhereRaw(sql: String, bindings: [SQLValueConvertible]) -> Self { + whereRaw(sql: sql, bindings: bindings, boolean: .or) + } + + /// Add a where clause requiring that two columns match each other + /// + /// - Parameters: + /// - first: The first column to match against. + /// - op: The `Operator` to be used in the comparison. + /// - second: The second column to match against. + /// - boolean: How the clause should be appended (`.and` + /// or `.or`). + /// - Returns: The current query builder `Query` to chain future + /// queries to. + @discardableResult + public func whereColumn(first: String, op: Operator, second: String, boolean: WhereBoolean = .and) -> Self { + wheres.append(Where(type: .column(first: first, op: op, second: second), boolean: boolean)) + return self + } + + /// A helper for adding an **or** `whereColumn` clause. + /// + /// - Parameters: + /// - first: The first column to match against. + /// - op: The `Operator` to be used in the comparison. + /// - second: The second column to match against. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func orWhereColumn(first: String, op: Operator, second: String) -> Self { + whereColumn(first: first, op: op, second: second, boolean: .or) + } + + /// Add a where clause requiring that a column be null. + /// + /// - Parameters: + /// - key: The column to match against. + /// - boolean: How the clause should be appended (`.and` or + /// `.or`). + /// - not: Should the value be null or not null. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func whereNull(key: String, boolean: WhereBoolean = .and, not: Bool = false) -> Self { + let action = not ? "IS NOT" : "IS" + wheres.append(Where(type: .raw(SQL("\(key) \(action) NULL")), boolean: boolean)) + return self + } + + /// A helper for adding an **or** `whereNull` clause. + /// + /// - Parameter key: The column to match against. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func orWhereNull(key: String) -> Self { + whereNull(key: key, boolean: .or) + } + + /// Add a where clause requiring that a column not be null. + /// + /// - Parameters: + /// - key: The column to match against. + /// - boolean: How the clause should be appended (`.and` or + /// `.or`). + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func whereNotNull(key: String, boolean: WhereBoolean = .and) -> Self { + whereNull(key: key, boolean: boolean, not: true) + } + + /// A helper for adding an **or** `whereNotNull` clause. + /// + /// - Parameter key: The column to match against. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func orWhereNotNull(key: String) -> Self { + whereNotNull(key: key, boolean: .or) + } +} + +extension String { + // MARK: Custom Swift Operators + + public static func == (lhs: String, rhs: SQLValueConvertible) -> Query.Where { + Query.Where(type: .value(key: lhs, op: .equals, value: rhs.value), boolean: .and) + } + + public static func != (lhs: String, rhs: SQLValueConvertible) -> Query.Where { + Query.Where(type: .value(key: lhs, op: .notEqualTo, value: rhs.value), boolean: .and) + } + + public static func < (lhs: String, rhs: SQLValueConvertible) -> Query.Where { + Query.Where(type: .value(key: lhs, op: .lessThan, value: rhs.value), boolean: .and) + } + + public static func > (lhs: String, rhs: SQLValueConvertible) -> Query.Where { + Query.Where(type: .value(key: lhs, op: .greaterThan, value: rhs.value), boolean: .and) + } + + public static func <= (lhs: String, rhs: SQLValueConvertible) -> Query.Where { + Query.Where(type: .value(key: lhs, op: .lessThanOrEqualTo, value: rhs.value), boolean: .and) + } + + public static func >= (lhs: String, rhs: SQLValueConvertible) -> Query.Where { + Query.Where(type: .value(key: lhs, op: .greaterThanOrEqualTo, value: rhs.value), boolean: .and) + } + + public static func ~= (lhs: String, rhs: SQLValueConvertible) -> Query.Where { + Query.Where(type: .value(key: lhs, op: .like, value: rhs.value), boolean: .and) + } +} diff --git a/Sources/Alchemy/SQL/Query/Database+Query.swift b/Sources/Alchemy/SQL/Query/Database+Query.swift new file mode 100644 index 00000000..b6848e46 --- /dev/null +++ b/Sources/Alchemy/SQL/Query/Database+Query.swift @@ -0,0 +1,37 @@ +extension Database { + /// Start a QueryBuilder query on this database. See `Query` or + /// QueryBuilder guides. + /// + /// Usage: + /// ```swift + /// if let row = try await database.table("users").where("id" == 1).first() { + /// print("Got a row with fields: \(row.allColumns)") + /// } + /// ``` + /// + /// - Parameters: + /// - table: The table to run the query on. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func table(_ table: String, as alias: String? = nil) -> Query { + guard let alias = alias else { + return Query(database: provider, table: table) + } + + return Query(database: provider, table: "\(table) as \(alias)") + } + + /// An alias for `table(_ table: String)` to be used when running. + /// a `select` query that also lets you alias the table name. + /// + /// - Parameters: + /// - table: The table to select data from. + /// - alias: An alias to use in place of table name. Defaults to + /// `nil`. + /// - Returns: The current query builder `Query` to chain future + /// queries to. + public func from(_ table: String, as alias: String? = nil) -> Query { + self.table(table, as: alias) + } +} + diff --git a/Sources/Alchemy/SQL/Query/Grammar/Grammar.swift b/Sources/Alchemy/SQL/Query/Grammar/Grammar.swift new file mode 100644 index 00000000..00b9a267 --- /dev/null +++ b/Sources/Alchemy/SQL/Query/Grammar/Grammar.swift @@ -0,0 +1,388 @@ +import Foundation + +/// Used for compiling query builders into raw SQL statements. +open class Grammar { + public init() {} + + // MARK: Compiling Query Builder + + open func compileSelect( + table: String, + isDistinct: Bool, + columns: [String], + joins: [Query.Join], + wheres: [Query.Where], + groups: [String], + havings: [Query.Where], + orders: [Query.Order], + limit: Int?, + offset: Int?, + lock: Query.Lock? + ) throws -> SQL { + let select = isDistinct ? "select distinct" : "select" + return [ + SQL("\(select) \(columns.joined(separator: ", "))"), + SQL("from \(table)"), + compileJoins(joins), + compileWheres(wheres), + compileGroups(groups), + compileHavings(havings), + compileOrders(orders), + compileLimit(limit), + compileOffset(offset), + compileLock(lock) + ].compactMap { $0 }.joinedSQL() + } + + open func compileJoins(_ joins: [Query.Join]) -> SQL? { + guard !joins.isEmpty else { + return nil + } + + var bindings: [SQLValue] = [] + let query = joins.compactMap { join -> String? in + guard let whereSQL = compileWheres(join.joinWheres, isJoin: true) else { + return nil + } + + bindings += whereSQL.bindings + if let nestedSQL = compileJoins(join.joins) { + bindings += nestedSQL.bindings + return "\(join.type) join (\(join.joinTable)\(nestedSQL.statement)) \(whereSQL.statement)" + .trimmingCharacters(in: .whitespacesAndNewlines) + } + + return "\(join.type) join \(join.joinTable) \(whereSQL.statement)" + .trimmingCharacters(in: .whitespacesAndNewlines) + }.joined(separator: " ") + + return SQL(query, bindings: bindings) + } + + open func compileWheres(_ wheres: [Query.Where], isJoin: Bool = false) -> SQL? { + guard wheres.count > 0 else { + return nil + } + + let conjunction = isJoin ? "on" : "where" + let sql = wheres.joinedSQL().droppingLeadingBoolean() + return SQL("\(conjunction) \(sql.statement)", bindings: sql.bindings) + } + + open func compileGroups(_ groups: [String]) -> SQL? { + guard !groups.isEmpty else { + return nil + } + + return SQL("group by \(groups.joined(separator: ", "))") + } + + open func compileHavings(_ havings: [Query.Where]) -> SQL? { + guard havings.count > 0 else { + return nil + } + + let sql = havings.joinedSQL().droppingLeadingBoolean() + return SQL("having \(sql.statement)", bindings: sql.bindings) + } + + open func compileOrders(_ orders: [Query.Order]) -> SQL? { + guard !orders.isEmpty else { + return nil + } + + let ordersSQL = orders + .map { "\($0.column) \($0.direction)" } + .joined(separator: ", ") + return SQL("order by \(ordersSQL)") + } + + open func compileLimit(_ limit: Int?) -> SQL? { + limit.map { SQL("limit \($0)") } + } + + open func compileOffset(_ offset: Int?) -> SQL? { + offset.map { SQL("offset \($0)") } + } + + open func compileInsert(_ table: String, values: [[String: SQLValueConvertible]]) -> SQL { + guard !values.isEmpty else { + return SQL("insert into \(table) default values") + } + + let columns = values[0].map { $0.key } + var parameters: [SQLValue] = [] + var placeholders: [String] = [] + + for value in values { + let orderedValues = columns.compactMap { value[$0]?.value } + parameters.append(contentsOf: orderedValues) + placeholders.append("(\(parameterize(orderedValues)))") + } + + let columnsJoined = columns.joined(separator: ", ") + return SQL("insert into \(table) (\(columnsJoined)) values \(placeholders.joined(separator: ", "))", bindings: parameters) + } + + open func compileInsertReturn(_ table: String, values: [[String: SQLValueConvertible]]) -> [SQL] { + let insert = compileInsert(table, values: values) + return [SQL("\(insert.statement) returning *", bindings: insert.bindings)] + } + + open func compileUpdate(_ table: String, joins: [Query.Join], wheres: [Query.Where], values: [String: SQLValueConvertible]) throws -> SQL { + var bindings: [SQLValue] = [] + let columnStatements: [SQL] = values.map { key, val in + if let expression = val as? SQL { + return SQL("\(key) = \(expression.statement)") + } else { + return SQL("\(key) = ?", bindings: [val.value.value]) + } + } + + let columnSQL = SQL(columnStatements.map(\.statement).joined(separator: ", "), bindings: columnStatements.flatMap(\.bindings)) + + var base = "update \(table)" + if let joinSQL = compileJoins(joins) { + bindings += joinSQL.bindings + base += " \(joinSQL)" + } + + bindings += columnSQL.bindings + base += " set \(columnSQL.statement)" + + if let whereSQL = compileWheres(wheres) { + bindings += whereSQL.bindings + base += " \(whereSQL.statement)" + } + + return SQL(base, bindings: bindings) + } + + open func compileDelete(_ table: String, wheres: [Query.Where]) throws -> SQL { + if let whereSQL = compileWheres(wheres) { + return SQL("delete from \(table) \(whereSQL.statement)", bindings: whereSQL.bindings) + } else { + return SQL("delete from \(table)") + } + } + + open func compileLock(_ lock: Query.Lock?) -> SQL? { + guard let lock = lock else { + return nil + } + + var string = "" + switch lock.strength { + case .update: + string = "FOR UPDATE" + case .share: + string = "FOR SHARE" + } + + switch lock.option { + case .noWait: + string.append(" NO WAIT") + case .skipLocked: + string.append(" SKIP LOCKED") + case .none: + break + } + + return SQL(string) + } + + // MARK: - Compiling Migrations + + open func compileCreateTable(_ table: String, ifNotExists: Bool, columns: [CreateColumn]) -> SQL { + var columnStrings: [String] = [] + var constraintStrings: [String] = [] + for (column, constraints) in columns.map({ createColumnString(for: $0) }) { + columnStrings.append(column) + constraintStrings.append(contentsOf: constraints) + } + + return SQL( + """ + CREATE TABLE\(ifNotExists ? " IF NOT EXISTS" : "") \(table) ( + \((columnStrings + constraintStrings).joined(separator: ",\n ")) + ) + """ + ) + } + + open func compileRenameTable(_ table: String, to: String) -> SQL { + SQL("ALTER TABLE \(table) RENAME TO \(to)") + } + + open func compileDropTable(_ table: String) -> SQL { + SQL("DROP TABLE \(table)") + } + + open func compileAlterTable(_ table: String, dropColumns: [String], addColumns: [CreateColumn]) -> [SQL] { + guard !dropColumns.isEmpty || !addColumns.isEmpty else { + return [] + } + + var adds: [String] = [] + var constraints: [String] = [] + for (sql, tableConstraints) in addColumns.map({ createColumnString(for: $0) }) { + adds.append("ADD COLUMN \(sql)") + constraints.append(contentsOf: tableConstraints.map { "ADD \($0)" }) + } + + let drops = dropColumns.map { "DROP COLUMN \($0.escapedColumn)" } + return [ + SQL(""" + ALTER TABLE \(table) + \((adds + drops + constraints).joined(separator: ",\n ")) + """)] + } + + open func compileRenameColumn(on table: String, column: String, to: String) -> SQL { + SQL("ALTER TABLE \(table) RENAME COLUMN \(column.escapedColumn) TO \(to.escapedColumn)") + } + + /// Compile the given create indexes into SQL. + /// + /// - Parameter table: The name of the table this index will be + /// created on. + /// - Returns: SQL objects for creating these indexes on the given table. + open func compileCreateIndexes(on table: String, indexes: [CreateIndex]) -> [SQL] { + indexes.map { index in + let indexType = index.isUnique ? "UNIQUE INDEX" : "INDEX" + let indexName = index.name(table: table) + let indexColumns = "(\(index.columns.map(\.escapedColumn).joined(separator: ", ")))" + return SQL("CREATE \(indexType) \(indexName) ON \(table) \(indexColumns)") + } + } + + open func compileDropIndex(on table: String, indexName: String) -> SQL { + SQL("DROP INDEX \(indexName)") + } + + // MARK: - Misc + + open func columnTypeString(for type: ColumnType) -> String { + switch type { + case .bool: + return "bool" + case .date: + return "timestamptz" + case .double: + return "float8" + case .increments: + return "serial" + case .int: + return "int" + case .bigInt: + return "bigint" + case .json: + return "json" + case .string(let length): + switch length { + case .unlimited: + return "text" + case .limit(let characters): + return "varchar(\(characters))" + } + case .uuid: + return "uuid" + } + } + + /// Convert a `CreateColumn` to a `String` for inserting into an SQL + /// statement. + /// + /// - Returns: The SQL `String` describing the column and any table level + /// constraints to add. + open func createColumnString(for column: CreateColumn) -> (String, [String]) { + let columnEscaped = column.name.escapedColumn + var baseSQL = "\(columnEscaped) \(columnTypeString(for: column.type))" + var tableConstraints: [String] = [] + for constraint in column.constraints { + guard let constraintString = columnConstraintString(for: constraint, on: column.name.escapedColumn, of: column.type) else { + continue + } + + switch constraint { + case .notNull: + baseSQL.append(" \(constraintString)") + case .default: + baseSQL.append(" \(constraintString)") + case .unsigned: + baseSQL.append(" \(constraintString)") + case .primaryKey: + tableConstraints.append(constraintString) + case .unique: + tableConstraints.append(constraintString) + case .foreignKey: + tableConstraints.append(constraintString) + } + } + + return (baseSQL, tableConstraints) + } + + open func columnConstraintString(for constraint: ColumnConstraint, on column: String, of type: ColumnType) -> String? { + switch constraint { + case .notNull: + return "NOT NULL" + case .default(let string): + return "DEFAULT \(string)" + case .primaryKey: + return "PRIMARY KEY (\(column))" + case .unique: + return "UNIQUE (\(column))" + case .foreignKey(let fkColumn, let table, let onDelete, let onUpdate): + var fkBase = "FOREIGN KEY (\(column)) REFERENCES \(table) (\(fkColumn.escapedColumn))" + if let delete = onDelete { fkBase.append(" ON DELETE \(delete.rawValue)") } + if let update = onUpdate { fkBase.append(" ON UPDATE \(update.rawValue)") } + return fkBase + case .unsigned: + return nil + } + } + + open func jsonLiteral(for jsonString: String) -> String { + "'\(jsonString)'::jsonb" + } + + private func parameterize(_ values: [SQLValueConvertible]) -> String { + values.map { ($0 as? SQL)?.statement ?? "?" }.joined(separator: ", ") + } +} + +extension String { + fileprivate var escapedColumn: String { + "\"\(self)\"" + } +} + +extension Query.Where: SQLConvertible { + public var sql: SQL { + switch type { + case .value(let key, let op, let value): + if value == .null { + if op == .notEqualTo { + return SQL("\(boolean) \(key) IS NOT NULL") + } else if op == .equals { + return SQL("\(boolean) \(key) IS NULL") + } else { + fatalError("Can't use any where operators other than .notEqualTo or .equals if the value is NULL.") + } + } else { + return SQL("\(boolean) \(key) \(op) ?", bindings: [value]) + } + case .column(let first, let op, let second): + return SQL("\(boolean) \(first) \(op) \(second)") + case .nested(let wheres): + let nestedSQL = wheres.joinedSQL().droppingLeadingBoolean() + return SQL("\(boolean) (\(nestedSQL.statement))", bindings: nestedSQL.bindings) + case .in(let key, let values, let type): + let placeholders = Array(repeating: "?", count: values.count).joined(separator: ", ") + return SQL("\(boolean) \(key) \(type)(\(placeholders))", bindings: values) + case .raw(let sql): + return SQL("\(boolean) \(sql.statement)", bindings: sql.bindings) + } + } +} diff --git a/Sources/Alchemy/SQL/Query/Query.swift b/Sources/Alchemy/SQL/Query/Query.swift new file mode 100644 index 00000000..680de668 --- /dev/null +++ b/Sources/Alchemy/SQL/Query/Query.swift @@ -0,0 +1,42 @@ +import Foundation +import NIO + +public class Query: Equatable { + let database: DatabaseProvider + var table: String + + var columns: [String] = ["*"] + var isDistinct = false + var limit: Int? = nil + var offset: Int? = nil + var lock: Lock? = nil + + var joins: [Join] = [] + var wheres: [Where] = [] + var groups: [String] = [] + var havings: [Where] = [] + var orders: [Order] = [] + + public init(database: DatabaseProvider, table: String) { + self.database = database + self.table = table + } + + func isEqual(to other: Query) -> Bool { + return table == other.table && + columns == other.columns && + isDistinct == other.isDistinct && + limit == other.limit && + offset == other.offset && + lock == other.lock && + joins == other.joins && + wheres == other.wheres && + groups == other.groups && + havings == other.havings && + orders == other.orders + } + + public static func == (lhs: Query, rhs: Query) -> Bool { + lhs.isEqual(to: rhs) + } +} diff --git a/Sources/Alchemy/SQL/Query/SQL+Utilities.swift b/Sources/Alchemy/SQL/Query/SQL+Utilities.swift new file mode 100644 index 00000000..023a1c8a --- /dev/null +++ b/Sources/Alchemy/SQL/Query/SQL+Utilities.swift @@ -0,0 +1,12 @@ +extension Array where Element: SQLConvertible { + public func joinedSQL() -> SQL { + let statements = map(\.sql) + return SQL(statements.map(\.statement).joined(separator: " "), bindings: statements.flatMap(\.bindings)) + } +} + +extension SQL { + func droppingLeadingBoolean() -> SQL { + SQL(statement.droppingPrefix("and ").droppingPrefix("or "), bindings: bindings) + } +} diff --git a/Sources/Alchemy/SQL/QueryBuilder/Clauses/JoinClause.swift b/Sources/Alchemy/SQL/QueryBuilder/Clauses/JoinClause.swift deleted file mode 100644 index 55702565..00000000 --- a/Sources/Alchemy/SQL/QueryBuilder/Clauses/JoinClause.swift +++ /dev/null @@ -1,44 +0,0 @@ -import Foundation - -/// The type of the join clause. -public enum JoinType: String { - /// INNER JOIN. - case inner - /// OUTER JOIN. - case outer - /// LEFT JOIN. - case left - /// RIGHT JOIN. - case right - /// CROSS JOIN. - case cross -} - -/// A JOIN query builder. -public final class JoinClause: Query { - /// The type of the join to perform. - public let type: JoinType - /// The table to join to. - public let table: String - - /// Create a join builder with a query, type, and table. - /// - /// - Parameters: - /// - database: The database the join table is on. - /// - type: The type of join this is. - /// - table: The name of the table to join to. - init(database: DatabaseDriver, type: JoinType, table: String) { - self.type = type - self.table = table - super.init(database: database) - } - - func on(first: String, op: Operator, second: String, boolean: WhereBoolean = .and) -> JoinClause { - self.whereColumn(first: first, op: op, second: second, boolean: boolean) - return self - } - - func orOn(first: String, op: Operator, second: String) -> JoinClause { - return self.on(first: first, op: op, second: second, boolean: .or) - } -} diff --git a/Sources/Alchemy/SQL/QueryBuilder/Clauses/OrderClause.swift b/Sources/Alchemy/SQL/QueryBuilder/Clauses/OrderClause.swift deleted file mode 100644 index 86f61ca0..00000000 --- a/Sources/Alchemy/SQL/QueryBuilder/Clauses/OrderClause.swift +++ /dev/null @@ -1,26 +0,0 @@ -import Foundation - -/// A clause for ordering rows by a certain column. -public struct OrderClause: Sequelizable { - /// A sorting direction. - public enum Sort: String { - /// Sort elements in ascending order. - case asc - /// Sort elements in descending order. - case desc - } - - /// The column to order by. - let column: Column - /// The direction to order by. - let direction: Sort - - // MARK: - Sequelizable - - public func toSQL() -> SQL { - if let raw = column as? SQL { - return raw - } - return SQL("\(column) \(direction)") - } -} diff --git a/Sources/Alchemy/SQL/QueryBuilder/Clauses/WhereClause.swift b/Sources/Alchemy/SQL/QueryBuilder/Clauses/WhereClause.swift deleted file mode 100644 index 6c95e853..00000000 --- a/Sources/Alchemy/SQL/QueryBuilder/Clauses/WhereClause.swift +++ /dev/null @@ -1,93 +0,0 @@ -import Foundation - -protocol WhereClause: Sequelizable {} - -public enum WhereBoolean: String { - case and - case or -} - -public struct WhereValue: WhereClause { - let key: String - let op: Operator - let value: DatabaseValue - var boolean: WhereBoolean = .and - - // MARK: - Sequelizable - - public func toSQL() -> SQL { - if self.value.isNil { - if self.op == .notEqualTo { - return SQL("\(boolean) \(key) IS NOT NULL") - } else if self.op == .equals { - return SQL("\(boolean) \(key) IS NULL") - } else { - fatalError("Can't use any where operators other than .notEqualTo or .equals if the value is NULL.") - } - } else { - return SQL("\(boolean) \(key) \(op) ?", binding: value) - } - } -} - -public struct WhereColumn: WhereClause { - let first: String - let op: Operator - let second: Expression - var boolean: WhereBoolean = .and - - // MARK: - Sequelizable - - public func toSQL() -> SQL { - return SQL("\(boolean) \(first) \(op) \(second.description)") - } -} - -public typealias WhereNestedClosure = (Query) -> Query -public struct WhereNested: WhereClause { - let database: DatabaseDriver - let closure: WhereNestedClosure - var boolean: WhereBoolean = .and - - // MARK: - Sequelizable - - public func toSQL() -> SQL { - let query = self.closure(Query(database: self.database)) - let (sql, bindings) = QueryHelpers.groupSQL(values: query.wheres) - let clauses = QueryHelpers.removeLeadingBoolean( - sql.joined(separator: " ") - ) - return SQL("\(boolean) (\(clauses))", bindings: bindings) - } -} - -public struct WhereIn: WhereClause { - public enum InType: String { - case `in` - case notIn - } - - let key: String - let values: [DatabaseValue] - let type: InType - var boolean: WhereBoolean = .and - - // MARK: - Sequelizable - - public func toSQL() -> SQL { - let placeholders = Array(repeating: "?", count: values.count).joined(separator: ", ") - return SQL("\(boolean) \(key) \(type)(\(placeholders))", bindings: values) - } -} - -public struct WhereRaw: WhereClause { - let query: String - var values: [DatabaseValue] = [] - var boolean: WhereBoolean = .and - - // MARK: - Sequelizable - - public func toSQL() -> SQL { - return SQL("\(boolean) \(self.query)", bindings: values) - } -} diff --git a/Sources/Alchemy/SQL/QueryBuilder/Grammar.swift b/Sources/Alchemy/SQL/QueryBuilder/Grammar.swift deleted file mode 100644 index d2fb70c5..00000000 --- a/Sources/Alchemy/SQL/QueryBuilder/Grammar.swift +++ /dev/null @@ -1,366 +0,0 @@ -import Foundation - -/// Used for compiling query builders into raw SQL statements. -open class Grammar { - struct GrammarError: Error { - let message: String - static let missingTable = GrammarError(message: "Missing a table to run the query on.") - } - - // MARK: Compiling Query Builders - - open func compileSelect(query: Query) throws -> SQL { - let parts: [SQL?] = [ - self.compileColumns(query, columns: query.columns), - try self.compileFrom(query, table: query.from), - self.compileJoins(query, joins: query.joins), - self.compileWheres(query), - self.compileGroups(query, groups: query.groups), - self.compileHavings(query), - self.compileOrders(query, orders: query.orders), - self.compileLimit(query, limit: query.limit), - self.compileOffset(query, offset: query.offset), - query.lock.map { SQL($0) } - ] - - let (sql, bindings) = QueryHelpers.groupSQL(values: parts) - return SQL(sql.joined(separator: " "), bindings: bindings) - } - - open func compileJoins(_ query: Query, joins: [JoinClause]?) -> SQL? { - guard let joins = joins else { return nil } - var bindings: [DatabaseValue] = [] - let query = joins.compactMap { join -> String? in - guard let whereSQL = compileWheres(join) else { - return nil - } - bindings += whereSQL.bindings - if let nestedJoins = join.joins, - let nestedSQL = compileJoins(query, joins: nestedJoins) { - bindings += nestedSQL.bindings - return self.trim("\(join.type) join (\(join.table)\(nestedSQL.query)) \(whereSQL.query)") - } - return self.trim("\(join.type) join \(join.table) \(whereSQL.query)") - }.joined(separator: " ") - return SQL(query, bindings: bindings) - } - - open func compileGroups(_ query: Query, groups: [String]) -> SQL? { - if groups.isEmpty { return nil } - return SQL("group by \(groups.joined(separator: ", "))") - } - - open func compileHavings(_ query: Query) -> SQL? { - let (sql, bindings) = QueryHelpers.groupSQL(values: query.havings) - if (sql.count > 0) { - let clauses = QueryHelpers.removeLeadingBoolean( - sql.joined(separator: " ") - ) - return SQL("having \(clauses)", bindings: bindings) - } - return nil - } - - open func compileOrders(_ query: Query, orders: [OrderClause]) -> SQL? { - if orders.isEmpty { return nil } - let ordersSQL = orders.map { $0.toSQL().query }.joined(separator: ", ") - return SQL("order by \(ordersSQL)") - } - - open func compileLimit(_ query: Query, limit: Int?) -> SQL? { - guard let limit = limit else { return nil } - return SQL("limit \(limit)") - } - - open func compileOffset(_ query: Query, offset: Int?) -> SQL? { - guard let offset = offset else { return nil } - return SQL("offset \(offset)") - } - - open func compileInsert(_ query: Query, values: [OrderedDictionary]) throws -> SQL { - - guard let table = query.from else { throw GrammarError.missingTable } - - if values.isEmpty { - return SQL("insert into \(table) default values") - } - - let columns = values[0].map { $0.key }.joined(separator: ", ") - var parameters: [DatabaseValue] = [] - var placeholders: [String] = [] - - for value in values { - parameters.append(contentsOf: value.map { $0.value.value }) - placeholders.append("(\(parameterize(value.map { $0.value })))") - } - return SQL( - "insert into \(table) (\(columns)) values \(placeholders.joined(separator: ", "))", - bindings: parameters - ) - } - - open func insert(_ values: [OrderedDictionary], query: Query, returnItems: Bool) - -> EventLoopFuture<[DatabaseRow]> - { - catchError { - let sql = try self.compileInsert(query, values: values) - return query.database.runRawQuery(sql.query, values: sql.bindings) - } - } - - open func compileUpdate(_ query: Query, values: [String: Parameter]) throws -> SQL { - guard let table = query.from else { throw GrammarError.missingTable } - var bindings: [DatabaseValue] = [] - let columnSQL = compileUpdateColumns(query, values: values) - - var base = "update \(table)" - if let clauses = query.joins, - let joinSQL = compileJoins(query, joins: clauses) { - bindings += joinSQL.bindings - base += " \(joinSQL)" - } - - bindings += columnSQL.bindings - base += " set \(columnSQL.query)" - - if let whereSQL = compileWheres(query) { - bindings += whereSQL.bindings - base += " \(whereSQL.query)" - } - return SQL(base, bindings: bindings) - } - - open func compileUpdateColumns(_ query: Query, values: [String: Parameter]) -> SQL { - var bindings: [DatabaseValue] = [] - var parts: [String] = [] - for value in values { - if let expression = value.value as? Expression { - parts.append("\(value.key) = \(expression.description)") - } - else { - bindings.append(value.value.value) - parts.append("\(value.key) = ?") - } - } - - return SQL(parts.joined(separator: ", "), bindings: bindings) - } - - open func compileDelete(_ query: Query) throws -> SQL { - guard let table = query.from else { throw GrammarError.missingTable } - if let whereSQL = compileWheres(query) { - return SQL("delete from \(table) \(whereSQL.query)", bindings: whereSQL.bindings) - } - else { - return SQL("delete from \(table)") - } - } - - // MARK: - Compiling Migrations - - open func compileCreate(table: String, ifNotExists: Bool, columns: [CreateColumn]) -> SQL { - var columnStrings: [String] = [] - var constraintStrings: [String] = [] - for (column, constraints) in columns.map({ $0.sqlString(with: self) }) { - columnStrings.append(column) - constraintStrings.append(contentsOf: constraints) - } - return SQL( - """ - CREATE TABLE\(ifNotExists ? " IF NOT EXISTS" : "") \(table) ( - \((columnStrings + constraintStrings).joined(separator: ",\n ")) - ) - """ - ) - } - - open func compileRename(table: String, to: String) -> SQL { - SQL("ALTER TABLE \(table) RENAME TO \(to)") - } - - open func compileDrop(table: String) -> SQL { - SQL("DROP TABLE \(table)") - } - - open func compileAlter(table: String, dropColumns: [String], addColumns: [CreateColumn]) -> [SQL] { - guard !dropColumns.isEmpty || !addColumns.isEmpty else { - return [] - } - - var adds: [String] = [] - var constraints: [String] = [] - for (sql, tableConstraints) in addColumns.map({ $0.sqlString(with: self) }) { - adds.append("ADD COLUMN \(sql)") - constraints.append(contentsOf: tableConstraints.map { "ADD \($0)" }) - } - - let drops = dropColumns.map { "DROP COLUMN \($0.sqlEscaped)" } - return [ - SQL(""" - ALTER TABLE \(table) - \((adds + drops + constraints).joined(separator: ",\n ")) - """)] - } - - open func compileRenameColumn(table: String, column: String, to: String) -> SQL { - SQL("ALTER TABLE \(table) RENAME COLUMN \(column.sqlEscaped) TO \(to.sqlEscaped)") - } - - open func compileCreateIndexes(table: String, indexes: [CreateIndex]) -> [SQL] { - indexes.map { SQL($0.toSQL(table: table)) } - } - - open func compileDropIndex(table: String, indexName: String) -> SQL { - SQL("DROP INDEX \(indexName)") - } - - open func typeString(for type: ColumnType) -> String { - switch type { - case .bool: - return "bool" - case .date: - return "timestamptz" - case .double: - return "float8" - case .increments: - return "serial" - case .int: - return "int" - case .bigInt: - return "bigint" - case .json: - return "json" - case .string(let length): - switch length { - case .unlimited: - return "text" - case .limit(let characters): - return "varchar(\(characters))" - } - case .uuid: - return "uuid" - } - } - - open func jsonLiteral(from jsonString: String) -> String { - "'\(jsonString)'::jsonb" - } - - open func allowsUnsigned() -> Bool { - false - } - - private func parameterize(_ values: [Parameter]) -> String { - return values.map { parameter($0) }.joined(separator: ", ") - } - - private func parameter(_ value: Parameter) -> String { - if let value = value as? Expression { - return value.description - } - return "?" - } - - private func trim(_ value: String) -> String { - return value.trimmingCharacters(in: .whitespacesAndNewlines) - } - - private func compileWheres(_ query: Query) -> SQL? { - // If we actually have some where clauses, we will strip off - // the first boolean operator, which is added by the query - // builders for convenience so we can avoid checking for - // the first clauses in each of the compilers methods. - - // Need to handle nested stuff somehow - - let (sql, bindings) = QueryHelpers.groupSQL(values: query.wheres) - if (sql.count > 0) { - let conjunction = query is JoinClause ? "on" : "where" - let clauses = QueryHelpers.removeLeadingBoolean( - sql.joined(separator: " ") - ) - return SQL("\(conjunction) \(clauses)", bindings: bindings) - } - return nil - } - - private func compileColumns(_ query: Query, columns: [SQL]) -> SQL { - let select = query.isDistinct ? "select distinct" : "select" - let (sql, bindings) = QueryHelpers.groupSQL(values: columns) - return SQL("\(select) \(sql.joined(separator: ", "))", bindings: bindings) - } - - private func compileFrom(_ query: Query, table: String?) throws -> SQL { - guard let table = table else { throw GrammarError.missingTable } - return SQL("from \(table)") - } -} - -/// An abstraction around various supported SQL column types. -/// `Grammar`s will map the `ColumnType` to the backing -/// dialect type string. -public enum ColumnType { - /// Self incrementing integer. - case increments - /// Integer. - case int - /// Big integer. - case bigInt - /// Double. - case double - /// String, with a given max length. - case string(StringLength) - /// UUID. - case uuid - /// Boolean. - case bool - /// Date. - case date - /// JSON. - case json -} - -/// The length of an SQL string column in characters. -public enum StringLength { - /// This value of this column can be any number of characters. - case unlimited - /// This value of this column must be at most the provided number - /// of characters. - case limit(Int) -} - -extension CreateColumn { - /// Convert this `CreateColumn` to a `String` for inserting into - /// an SQL statement. - /// - /// - Returns: The SQL `String` describing this column and any - /// table level constraints to add. - func sqlString(with grammar: Grammar) -> (String, [String]) { - let columnEscaped = self.column.sqlEscaped - var baseSQL = "\(columnEscaped) \(grammar.typeString(for: self.type))" - var tableConstraints: [String] = [] - for constraint in self.constraints { - switch constraint { - case .notNull: - baseSQL.append(" NOT NULL") - case .primaryKey: - tableConstraints.append("PRIMARY KEY (\(columnEscaped))") - case .unique: - tableConstraints.append("UNIQUE (\(columnEscaped))") - case let .default(val): - baseSQL.append(" DEFAULT \(val)") - case let .foreignKey(column, table, onDelete, onUpdate): - var fkBase = "FOREIGN KEY (\(columnEscaped)) REFERENCES \(table) (\(column.sqlEscaped))" - if let delete = onDelete { fkBase.append(" ON DELETE \(delete.rawValue)") } - if let update = onUpdate { fkBase.append(" ON UPDATE \(update.rawValue)") } - tableConstraints.append(fkBase) - case .unsigned: - if grammar.allowsUnsigned() { - baseSQL.append(" UNSIGNED") - } - } - } - - return (baseSQL, tableConstraints) - } -} diff --git a/Sources/Alchemy/SQL/QueryBuilder/Query.swift b/Sources/Alchemy/SQL/QueryBuilder/Query.swift deleted file mode 100644 index f7eaa9a4..00000000 --- a/Sources/Alchemy/SQL/QueryBuilder/Query.swift +++ /dev/null @@ -1,785 +0,0 @@ -import Foundation -import NIO - -public class Query: Sequelizable { - public enum LockStrength: String { - case update = "FOR UPDATE", share = "FOR SHARE" - } - - public enum LockOption: String { - case noWait = "NO WAIT", skipLocked = "SKIP LOCKED" - } - - let database: DatabaseDriver - - private(set) var columns: [SQL] = [SQL("*")] - private(set) var from: String? - private(set) var joins: [JoinClause]? = nil - private(set) var wheres: [WhereClause] = [] - private(set) var groups: [String] = [] - private(set) var havings: [WhereClause] = [] - private(set) var orders: [OrderClause] = [] - private(set) var limit: Int? = nil - private(set) var offset: Int? = nil - private(set) var isDistinct = false - private(set) var lock: String? = nil - - public init(database: DatabaseDriver) { - self.database = database - } - - /// Get the raw `SQL` for a given query. - /// - /// - Returns: A `SQL` value wrapper containing the executable - /// query and bindings. - public func toSQL() -> SQL { - return (try? self.database.grammar.compileSelect(query: self)) - ?? SQL() - } - - /// Set the columns that should be returned by the query. - /// - /// - Parameters: - /// - columns: An array of columns to be returned by the query. - /// Defaults to `[*]`. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - @discardableResult - public func select(_ columns: [Column] = ["*"]) -> Self { - self.columns = columns.map(\.columnSQL) - return self - } - - /// Set the table to perform a query from. - /// - /// - Parameters: - /// - table: The table to run the query on. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func table(_ table: String) -> Self { - self.from = table - return self - } - - /// An alias for `table(_ table: String)` to be used when running. - /// a `select` query that also lets you alias the table name. - /// - /// - Parameters: - /// - table: The table to select data from. - /// - alias: An alias to use in place of table name. Defaults to - /// `nil`. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func from(table: String, as alias: String? = nil) -> Self { - guard let alias = alias else { - return self.table(table) - } - return self.table("\(table) as \(alias)") - } - - /// Join data from a separate table into the current query data. - /// - /// - Parameters: - /// - table: The table to be joined. - /// - first: The column from the current query to be matched. - /// - op: The `Operator` to be used in the comparison. Defaults - /// to `.equals`. - /// - second: The column from the joining table to be matched. - /// - type: The `JoinType` of the sql join. Defaults to - /// `.inner`. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func join( - table: String, - first: String, - op: Operator = .equals, - second: String, - type: JoinType = .inner - ) -> Self { - let join = JoinClause(database: self.database, type: type, table: table) - .on(first: first, op: op, second: second) - if joins == nil { - joins = [join] - } - else { - joins?.append(join) - } - return self - } - - /// Left join data from a separate table into the current query - /// data. - /// - /// - Parameters: - /// - table: The table to be joined. - /// - first: The column from the current query to be matched. - /// - op: The `Operator` to be used in the comparison. Defaults - /// to `.equals`. - /// - second: The column from the joining table to be matched. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func leftJoin( - table: String, - first: String, - op: Operator = .equals, - second: String - ) -> Self { - self.join( - table: table, - first: first, - op: op, - second: second, - type: .left - ) - } - - /// Right join data from a separate table into the current query - /// data. - /// - /// - Parameters: - /// - table: The table to be joined. - /// - first: The column from the current query to be matched. - /// - op: The `Operator` to be used in the comparison. Defaults - /// to `.equals`. - /// - second: The column from the joining table to be matched. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func rightJoin( - table: String, - first: String, - op: Operator = .equals, - second: String - ) -> Self { - self.join( - table: table, - first: first, - op: op, - second: second, - type: .right - ) - } - - /// Cross join data from a separate table into the current query - /// data. - /// - /// - Parameters: - /// - table: The table to be joined. - /// - first: The column from the current query to be matched. - /// - op: The `Operator` to be used in the comparison. Defaults - /// to `.equals`. - /// - second: The column from the joining table to be matched. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func crossJoin( - table: String, - first: String, - op: Operator = .equals, - second: String - ) -> Self { - self.join( - table: table, - first: first, - op: op, - second: second, - type: .cross - ) - } - - /// Add a basic where clause to the query to filter down results. - /// - /// - Parameters: - /// - clause: A `WhereValue` clause matching a column to a given - /// value. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func `where`(_ clause: WhereValue) -> Self { - self.wheres.append(clause) - return self - } - - /// An alias for `where(_ clause: WhereValue) ` that appends an or - /// clause instead of an and clause. - /// - /// - Parameters: - /// - clause: A `WhereValue` clause matching a column to a given - /// value. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func orWhere(_ clause: WhereValue) -> Self { - var clause = clause - clause.boolean = .or - return self.where(clause) - } - - /// Add a nested where clause that is a group of combined clauses. - /// This can be used for logically grouping where clauses like - /// you would inside of an if statement. Each clause is - /// wrapped in parenthesis. - /// - /// For example if you want to logically ensure a user is under 30 - /// and named Paul, or over the age of 50 having any name, you - /// could use a nested where clause along with a separate - /// where value clause: - /// ```swift - /// Query - /// .from("users") - /// .where { - /// $0.where("age" < 30) - /// .orWhere("first_name" == "Paul") - /// } - /// .where("age" > 50) - /// ``` - /// - /// - Parameters: - /// - closure: A `WhereNestedClosure` that provides a nested - /// clause to attach nested where clauses to. - /// - boolean: How the clause should be appended(`.and` or - /// `.or`). Defaults to `.and`. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func `where`(_ closure: @escaping WhereNestedClosure, boolean: WhereBoolean = .and) -> Self { - self.wheres.append( - WhereNested( - database: database, - closure: closure, - boolean: boolean - ) - ) - return self - } - - /// A helper for adding an **or** `where` nested closure clause. - /// - /// - Parameters: - /// - closure: A `WhereNestedClosure` that provides a nested - /// query to attach nested where clauses to. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func orWhere(_ closure: @escaping WhereNestedClosure) -> Self { - self.where(closure, boolean: .or) - } - - /// Add a clause requiring that a column match any values in a - /// given array. - /// - /// - Parameters: - /// - key: The column to match against. - /// - values: The values that the column should not match. - /// - type: How the match should happen (*in* or *notIn*). - /// Defaults to `.in`. - /// - boolean: How the clause should be appended (`.and` or - /// `.or`). Defaults to `.and`. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func `where`( - key: String, - in values: [Parameter], - type: WhereIn.InType = .in, - boolean: WhereBoolean = .and - ) -> Self { - self.wheres.append(WhereIn( - key: key, - values: values.map { $0.value }, - type: type, - boolean: boolean) - ) - return self - } - - /// A helper for adding an **or** variant of the `where(key:in:)` clause. - /// - /// - Parameters: - /// - key: The column to match against. - /// - values: The values that the column should not match. - /// - type: How the match should happen (`.in` or `.notIn`). - /// Defaults to `.in`. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func orWhere(key: String, in values: [Parameter], type: WhereIn.InType = .in) -> Self { - return self.where( - key: key, - in: values, - type: type, - boolean: .or - ) - } - - /// Add a clause requiring that a column not match any values in a - /// given array. This is a helper method for the where in method. - /// - /// - Parameters: - /// - key: The column to match against. - /// - values: The values that the column should not match. - /// - boolean: How the clause should be appended (`.and` or - /// `.or`). Defaults to `.and`. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func whereNot(key: String, in values: [Parameter], boolean: WhereBoolean = .and) -> Self { - return self.where(key: key, in: values, type: .notIn, boolean: boolean) - } - - /// A helper for adding an **or** `whereNot` clause. - /// - /// - Parameters: - /// - key: The column to match against. - /// - values: The values that the column should not match. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func orWhereNot(key: String, in values: [Parameter]) -> Self { - self.where(key: key, in: values, type: .notIn, boolean: .or) - } - - /// Add a raw SQL where clause to your query. - /// - /// - Parameters: - /// - sql: A string representing the SQL where clause to be run. - /// - bindings: Any variables for binding in the SQL. - /// - boolean: How the clause should be appended (`.and` or - /// `.or`). Defaults to `.and`. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func whereRaw(sql: String, bindings: [Parameter], boolean: WhereBoolean = .and) -> Self { - self.wheres.append(WhereRaw( - query: sql, - values: bindings.map { $0.value }, - boolean: boolean) - ) - return self - } - - /// A helper for adding an **or** `whereRaw` clause. - /// - /// - Parameters: - /// - sql: A string representing the SQL where clause to be run. - /// - bindings: Any variables for binding in the SQL. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func orWhereRaw(sql: String, bindings: [Parameter]) -> Self { - self.whereRaw(sql: sql, bindings: bindings, boolean: .or) - } - - /// Add a where clause requiring that two columns match each other - /// - /// - Parameters: - /// - first: The first column to match against. - /// - op: The `Operator` to be used in the comparison. - /// - second: The second column to match against. - /// - boolean: How the clause should be appended (`.and` - /// or `.or`). - /// - Returns: The current query builder `Query` to chain future - /// queries to. - @discardableResult - public func whereColumn(first: String, op: Operator, second: String, boolean: WhereBoolean = .and) -> Self { - self.wheres.append(WhereColumn(first: first, op: op, second: Expression(second), boolean: boolean)) - return self - } - - /// A helper for adding an **or** `whereColumn` clause. - /// - /// - Parameters: - /// - first: The first column to match against. - /// - op: The `Operator` to be used in the comparison. - /// - second: The second column to match against. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func orWhereColumn(first: String, op: Operator, second: String) -> Self { - self.whereColumn(first: first, op: op, second: second, boolean: .or) - } - - /// Add a where clause requiring that a column be null. - /// - /// - Parameters: - /// - key: The column to match against. - /// - boolean: How the clause should be appended (`.and` or - /// `.or`). - /// - not: Should the value be null or not null. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func whereNull( - key: String, - boolean: WhereBoolean = .and, - not: Bool = false - ) -> Self { - let action = not ? "IS NOT" : "IS" - self.wheres.append(WhereRaw( - query: "\(key) \(action) NULL", - boolean: boolean) - ) - return self - } - - /// A helper for adding an **or** `whereNull` clause. - /// - /// - Parameter key: The column to match against. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func orWhereNull(key: String) -> Self { - self.whereNull(key: key, boolean: .or) - } - - /// Add a where clause requiring that a column not be null. - /// - /// - Parameters: - /// - key: The column to match against. - /// - boolean: How the clause should be appended (`.and` or - /// `.or`). - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func whereNotNull(key: String, boolean: WhereBoolean = .and) -> Self { - self.whereNull(key: key, boolean: boolean, not: true) - } - - /// A helper for adding an **or** `whereNotNull` clause. - /// - /// - Parameter key: The column to match against. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func orWhereNotNull(key: String) -> Self { - self.whereNotNull(key: key, boolean: .or) - } - - /// Add a having clause to filter results from aggregate - /// functions. - /// - /// - Parameter clause: A `WhereValue` clause matching a column to a - /// value. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func having(_ clause: WhereValue) -> Self { - self.havings.append(clause) - return self - } - - /// An alias for `having(_ clause:) ` that appends an or clause - /// instead of an and clause. - /// - /// - Parameter clause: A `WhereValue` clause matching a column to a - /// value. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func orHaving(_ clause: WhereValue) -> Self { - var clause = clause - clause.boolean = .or - return self.having(clause) - } - - /// Add a having clause to filter results from aggregate functions - /// that matches a given key to a provided value. - /// - /// - Parameters: - /// - key: The column to match against. - /// - op: The `Operator` to be used in the comparison. - /// - value: The value that the column should match. - /// - boolean: How the clause should be appended (`.and` or - /// `.or`). - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func having(key: String, op: Operator, value: Parameter, boolean: WhereBoolean = .and) -> Self { - return self.having(WhereValue( - key: key, - op: op, - value: value.value, - boolean: boolean) - ) - } - - /// Group returned data by a given column. - /// - /// - Parameter group: The table column to group data on. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func groupBy(_ group: String) -> Self { - self.groups.append(group) - return self - } - - /// Order the data from the query based on given clause. - /// - /// - Parameter order: The `OrderClause` that defines the - /// ordering. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func orderBy(_ order: OrderClause) -> Self { - self.orders.append(order) - return self - } - - /// Order the data from the query based on a column and direction. - /// - /// - Parameters: - /// - column: The column to order data by. - /// - direction: The `OrderClause.Sort` direction (either `.asc` - /// or `.desc`). Defaults to `.asc`. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func orderBy(column: Column, direction: OrderClause.Sort = .asc) -> Self { - self.orderBy(OrderClause(column: column, direction: direction)) - } - - /// Set query to only return distinct entries. - /// - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func distinct() -> Self { - self.isDistinct = true - return self - } - - /// Offset the returned results by a given amount. - /// - /// - Parameter value: An amount representing the offset. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func offset(_ value: Int) -> Self { - self.offset = max(0, value) - return self - } - - /// Limit the returned results to a given amount. - /// - /// - Parameter value: An amount to cap the total result at. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func limit(_ value: Int) -> Self { - if (value >= 0) { - self.limit = value - } else { - fatalError("No negative limits allowed!") - } - return self - } - - /// A helper method to be used when needing to page returned - /// results. Internally this uses the `limit` and `offset` - /// methods. - /// - /// - Note: Paging starts at index 1, not 0. - /// - /// - Parameters: - /// - page: What `page` of results to offset by. - /// - perPage: How many results to show on each page. Defaults - /// to `25`. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public func forPage(_ page: Int, perPage: Int = 25) -> Self { - offset((page - 1) * perPage).limit(perPage) - } - - /// Adds custom SQL to the end of a SELECT query. - public func forLock(_ lock: LockStrength, option: LockOption? = nil) -> Self { - let lockOptionString = option.map { " \($0.rawValue)" } ?? "" - self.lock = lock.rawValue + lockOptionString - return self - } - - /// Run a select query and return the database rows. - /// - /// - Note: Optional columns can be provided that override the - /// original select columns. - /// - Parameter columns: The columns you would like returned. - /// Defaults to `nil`. - /// - Returns: An `EventLoopFuture` to be run that contains the - /// returned rows from the database. - public func get(_ columns: [Column]? = nil) -> EventLoopFuture<[DatabaseRow]> { - if let columns = columns { - self.select(columns) - } - do { - let sql = try self.database.grammar.compileSelect(query: self) - return self.database.runRawQuery(sql.query, values: sql.bindings) - } - catch let error { - return .new(error: error) - } - } - - /// Run a select query and return the first database row only row. - /// - /// - Note: Optional columns can be provided that override the - /// original select columns. - /// - Parameter columns: The columns you would like returned. - /// Defaults to `nil`. - /// - Returns: An `EventLoopFuture` to be run that contains the - /// returned row from the database. - public func first(_ columns: [Column]? = nil) -> EventLoopFuture { - return self.limit(1) - .get(columns) - .map { $0.first } - } - - /// Run a select query that looks for a single row matching the - /// given database column and value. - /// - /// - Note: Optional columns can be provided that override the - /// original select columns. - /// - Parameter columns: The columns you would like returned. - /// Defaults to `nil`. - /// - Returns: An `EventLoopFuture` to be run that contains the - /// returned row from the database. - public func find(field: DatabaseField, columns: [Column]? = nil) -> EventLoopFuture { - self.wheres.append(WhereValue(key: field.column, op: .equals, value: field.value)) - return self.limit(1) - .get(columns) - .map { $0.first } - } - - /// Find the total count of the rows that match the given query. - /// - /// - Parameters: - /// - column: What column to count. Defaults to `*`. - /// - name: The alias that can be used for renaming the returned - /// count. - /// - Returns: An `EventLoopFuture` to be run that contains the - /// returned count value. - public func count(column: Column = "*", as name: String? = nil) -> EventLoopFuture { - var query = "COUNT(\(column))" - if let name = name { - query += " as \(name)" - } - return self.select([query]) - .first() - .unwrap(orError: DatabaseError("a COUNT query didn't return any rows")) - .flatMapThrowing { - guard let column = $0.allColumns.first else { - throw DatabaseError("a COUNT query didn't return any columns") - } - - return try $0.getField(column: column).int() - } - } - - /// Perform an insert and create a database row from the provided - /// data. - /// - /// - Parameter value: A dictionary containing the values to be - /// inserted. - /// - Parameter returnItems: Indicates whether the inserted items - /// should be returned with any fields updated/set by the - /// insert. Defaults to `true`. This flag doesn't affect - /// Postgres which always returns inserted items, but on MySQL - /// it means this will run two queries; one to insert and one to - /// fetch. - /// - Returns: An `EventLoopFuture` to be run that contains the - /// inserted rows. - public func insert(_ value: OrderedDictionary, returnItems: Bool = true) -> EventLoopFuture<[DatabaseRow]> { - return insert([value], returnItems: returnItems) - } - - /// Perform an insert and create database rows from the provided - /// data. - /// - /// - Parameter values: An array of dictionaries containing the - /// values to be inserted. - /// - Parameter returnItems: Indicates whether the inserted items - /// should be returned with any fields updated/set by the - /// insert. Defaults to `true`. This flag doesn't affect - /// Postgres which always runs a single query and returns - /// inserted items. On MySQL it means this will run two queries - /// _per value_; one to insert and one to fetch. If this is - /// `false`, MySQL will run a single query inserting all values. - /// - Returns: An `EventLoopFuture` to be run that contains the - /// inserted rows. - public func insert(_ values: [OrderedDictionary], returnItems: Bool = true) -> EventLoopFuture<[DatabaseRow]> { - self.database.grammar.insert(values, query: self, returnItems: returnItems) - } - - /// Perform an update on all data matching the query in the - /// builder with the values provided. - /// - /// For example, if you wanted to update the first name of a user - /// whose ID equals 10, you could do so as follows: - /// ```swift - /// Query - /// .table("users") - /// .where("id" == 10) - /// .update(values: [ - /// "first_name": "Ashley" - /// ]) - /// ``` - /// - /// - Parameter values: An dictionary containing the values to be - /// updated. - /// - Returns: An `EventLoopFuture` to be run that will update all - /// matched rows. - public func update(values: [String: Parameter]) -> EventLoopFuture<[DatabaseRow]> { - catchError { - let sql = try self.database.grammar.compileUpdate(self, values: values) - return self.database.runRawQuery(sql.query, values: sql.bindings) - } - } - - /// Perform a deletion on all data matching the given query. - /// - /// - Returns: An `EventLoopFuture` to be run that will delete all - /// matched rows. - public func delete() -> EventLoopFuture<[DatabaseRow]> { - do { - let sql = try self.database.grammar.compileDelete(self) - return self.database.runRawQuery(sql.query, values: sql.bindings) - } - catch let error { - return .new(error: error) - } - } -} - -extension Query { - /// Shortcut for running a query with the given table on - /// `Database.default`. - /// - /// - Parameter table: The table to run the query on. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public static func table(_ table: String) -> Query { - Database.default.query().table(table) - } - - /// Shortcut for running a query with the given table on - /// `Database.default`. - /// - /// An alias for `table(_ table: String)` to be used when running - /// a `select` query that also lets you alias the table name. - /// - /// - Parameters: - /// - table: The table to select data from. - /// - alias: An alias to use in place of table name. Defaults to - /// `nil`. - /// - Returns: The current query builder `Query` to chain future - /// queries to. - public static func from(table: String, as alias: String? = nil) -> Query { - guard let alias = alias else { - return Query.table(table) - } - return Query.table("\(table) as \(alias)") - } -} - -extension String { - public static func ==(lhs: String, rhs: Parameter) -> WhereValue { - return WhereValue(key: lhs, op: .equals, value: rhs.value) - } - - public static func !=(lhs: String, rhs: Parameter) -> WhereValue { - return WhereValue(key: lhs, op: .notEqualTo, value: rhs.value) - } - - public static func <(lhs: String, rhs: Parameter) -> WhereValue { - return WhereValue(key: lhs, op: .lessThan, value: rhs.value) - } - - public static func >(lhs: String, rhs: Parameter) -> WhereValue { - return WhereValue(key: lhs, op: .greaterThan, value: rhs.value) - } - - public static func <=(lhs: String, rhs: Parameter) -> WhereValue { - return WhereValue(key: lhs, op: .lessThanOrEqualTo, value: rhs.value) - } - - public static func >=(lhs: String, rhs: Parameter) -> WhereValue { - return WhereValue(key: lhs, op: .greaterThanOrEqualTo, value: rhs.value) - } - - public static func ~=(lhs: String, rhs: Parameter) -> WhereValue { - return WhereValue(key: lhs, op: .like, value: rhs.value) - } -} diff --git a/Sources/Alchemy/SQL/QueryBuilder/QueryHelpers.swift b/Sources/Alchemy/SQL/QueryBuilder/QueryHelpers.swift deleted file mode 100644 index c5502cd5..00000000 --- a/Sources/Alchemy/SQL/QueryBuilder/QueryHelpers.swift +++ /dev/null @@ -1,27 +0,0 @@ -import Foundation - -enum QueryHelpers { - static func removeLeadingBoolean(_ value: String) -> String { - if value.hasPrefix("and ") { - return String(value.dropFirst(4)) - } - else if value.hasPrefix("or ") { - return String(value.dropFirst(3)) - } - return value - } - - static func groupSQL(values: [Sequelizable]) -> ([String], [DatabaseValue]) { - self.groupSQL(values: values.map { $0.toSQL() }) - } - - static func groupSQL(values: [SQL?]) -> ([String], [DatabaseValue]) { - return values.reduce(([String](), [DatabaseValue]())) { - var parts = $0 - guard let sql = $1 else { return parts } - parts.0.append(sql.query) - parts.1.append(contentsOf: sql.bindings) - return parts - } - } -} diff --git a/Sources/Alchemy/SQL/QueryBuilder/Types/Column.swift b/Sources/Alchemy/SQL/QueryBuilder/Types/Column.swift deleted file mode 100644 index 78f30e69..00000000 --- a/Sources/Alchemy/SQL/QueryBuilder/Types/Column.swift +++ /dev/null @@ -1,18 +0,0 @@ -import Foundation - -/// Something convertible to a table column in an SQL database. -public protocol Column { - var columnSQL: SQL { get } -} - -extension String: Column { - public var columnSQL: SQL { - SQL(self) - } -} - -extension SQL: Column { - public var columnSQL: SQL { - self - } -} diff --git a/Sources/Alchemy/SQL/QueryBuilder/Types/Expression.swift b/Sources/Alchemy/SQL/QueryBuilder/Types/Expression.swift deleted file mode 100644 index 6a2ec766..00000000 --- a/Sources/Alchemy/SQL/QueryBuilder/Types/Expression.swift +++ /dev/null @@ -1,14 +0,0 @@ -import Foundation - -struct Expression: Parameter { - private var _value: String - public var value: DatabaseValue { .string(_value) } - - init(_ value: String) { - self._value = value - } -} - -extension Expression: CustomStringConvertible { - var description: String { return self._value } -} diff --git a/Sources/Alchemy/SQL/QueryBuilder/Types/Operator.swift b/Sources/Alchemy/SQL/QueryBuilder/Types/Operator.swift deleted file mode 100644 index 051645af..00000000 --- a/Sources/Alchemy/SQL/QueryBuilder/Types/Operator.swift +++ /dev/null @@ -1,27 +0,0 @@ -import Foundation - -public enum Operator: CustomStringConvertible, Equatable { - case equals - case lessThan - case greaterThan - case lessThanOrEqualTo - case greaterThanOrEqualTo - case notEqualTo - case like - case notLike - case raw(operator: String) - - public var description: String { - switch self { - case .equals: return "=" - case .lessThan: return "<" - case .greaterThan: return ">" - case .lessThanOrEqualTo: return "<=" - case .greaterThanOrEqualTo: return ">=" - case .notEqualTo: return "!=" - case .like: return "LIKE" - case .notLike: return "NOT LIKE" - case .raw(let value): return value - } - } -} diff --git a/Sources/Alchemy/SQL/QueryBuilder/Types/Parameter.swift b/Sources/Alchemy/SQL/QueryBuilder/Types/Parameter.swift deleted file mode 100644 index 156eb5dc..00000000 --- a/Sources/Alchemy/SQL/QueryBuilder/Types/Parameter.swift +++ /dev/null @@ -1,41 +0,0 @@ -import Foundation - -public protocol Parameter { - var value: DatabaseValue { get } -} - -extension DatabaseValue: Parameter { - public var value: DatabaseValue { self } -} - -extension String: Parameter { - public var value: DatabaseValue { .string(self) } -} - -extension Int: Parameter { - public var value: DatabaseValue { .int(self) } -} - -extension Bool: Parameter { - public var value: DatabaseValue { .bool(self) } -} - -extension Double: Parameter { - public var value: DatabaseValue { .double(self) } -} - -extension Date: Parameter { - public var value: DatabaseValue { .date(self) } -} - -extension UUID: Parameter { - public var value: DatabaseValue { .uuid(self) } -} - -extension Optional: Parameter where Wrapped: Parameter { - public var value: DatabaseValue { self?.value ?? .string(nil) } -} - -extension RawRepresentable where RawValue: Parameter { - public var value: DatabaseValue { self.rawValue.value } -} diff --git a/Sources/Alchemy/SQL/QueryBuilder/Types/SQL.swift b/Sources/Alchemy/SQL/QueryBuilder/Types/SQL.swift deleted file mode 100644 index 726cef5d..00000000 --- a/Sources/Alchemy/SQL/QueryBuilder/Types/SQL.swift +++ /dev/null @@ -1,34 +0,0 @@ -import Foundation - -public struct SQL { - var query: String - let bindings: [DatabaseValue] - - public init(_ query: String = "", bindings: [DatabaseValue] = []) { - self.query = query - self.bindings = bindings - } - - public init(_ query: String, binding: DatabaseValue) { - self.init(query, bindings: [binding]) - } - - @discardableResult - func bind(_ bindings: inout [DatabaseValue]) -> SQL { - bindings.append(contentsOf: self.bindings) - return self - } - - @discardableResult - func bind(queries: inout [String], bindings: inout [DatabaseValue]) -> SQL { - queries.append(self.query) - bindings.append(contentsOf: self.bindings) - return self - } -} - -extension SQL: Equatable { - public static func == (lhs: SQL, rhs: SQL) -> Bool { - lhs.query == rhs.query && lhs.bindings == rhs.bindings - } -} diff --git a/Sources/Alchemy/SQL/Rune/Model/Decoding/SQLDecodable.swift b/Sources/Alchemy/SQL/Rune/Model/Decoding/SQLDecodable.swift new file mode 100644 index 00000000..32a23f08 --- /dev/null +++ b/Sources/Alchemy/SQL/Rune/Model/Decoding/SQLDecodable.swift @@ -0,0 +1,3 @@ +protocol SQLDecodable { + init(from sql: SQLValue, at column: String) throws +} diff --git a/Sources/Alchemy/SQL/Rune/Model/Decoding/SQLDecoder.swift b/Sources/Alchemy/SQL/Rune/Model/Decoding/SQLDecoder.swift new file mode 100644 index 00000000..9854e5ec --- /dev/null +++ b/Sources/Alchemy/SQL/Rune/Model/Decoding/SQLDecoder.swift @@ -0,0 +1,3 @@ +/// Used so `Relationship` types can know not to decode themselves properly from +/// an `SQLDecoder`. +protocol SQLDecoder: Decoder {} diff --git a/Sources/Alchemy/Rune/Model/Decoding/DatabaseRowDecoder.swift b/Sources/Alchemy/SQL/Rune/Model/Decoding/SQLRowDecoder.swift similarity index 54% rename from Sources/Alchemy/Rune/Model/Decoding/DatabaseRowDecoder.swift rename to Sources/Alchemy/SQL/Rune/Model/Decoding/SQLRowDecoder.swift index 44922519..7a5489b1 100644 --- a/Sources/Alchemy/Rune/Model/Decoding/DatabaseRowDecoder.swift +++ b/Sources/Alchemy/SQL/Rune/Model/Decoding/SQLRowDecoder.swift @@ -1,142 +1,150 @@ import Foundation -/// Used so `Relationship` types can know not to encode themselves to -/// a `ModelEncoder`. -protocol ModelDecoder: Decoder {} - -/// Decoder for decoding `Model` types from a `DatabaseRow`. +/// Decoder for decoding `Model` types from an `SQLRow`. /// Properties of the `Decodable` type are matched to /// columns with matching names (either the same /// name or a specific name mapping based on /// the supplied `keyMapping`). -struct DatabaseRowDecoder: ModelDecoder { +struct SQLRowDecoder: SQLDecoder { /// The row that will be decoded out of. - let row: DatabaseRow + let row: SQLRow + let keyMapping: DatabaseKeyMapping + let jsonDecoder: JSONDecoder // MARK: Decoder var codingPath: [CodingKey] = [] var userInfo: [CodingUserInfoKey : Any] = [:] - func container( - keyedBy type: Key.Type - ) throws -> KeyedDecodingContainer where Key: CodingKey { - KeyedDecodingContainer( - KeyedContainer(row: self.row) - ) + func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer where Key: CodingKey { + KeyedDecodingContainer(KeyedContainer(row: row, decoder: self, keyMapping: keyMapping, jsonDecoder: jsonDecoder)) } - + func unkeyedContainer() throws -> UnkeyedDecodingContainer { - /// This is for arrays, which we don't support. throw DatabaseCodingError("This shouldn't be called; top level is keyed.") } - + func singleValueContainer() throws -> SingleValueDecodingContainer { - /// This is for non-primitives that encode to a single value - /// and should be handled by `DatabaseFieldDecoder`. throw DatabaseCodingError("This shouldn't be called; top level is keyed.") } } /// A `KeyedDecodingContainerProtocol` used to decode keys from a -/// `DatabaseRow`. -private struct KeyedContainer: KeyedDecodingContainerProtocol { +/// `SQLRow`. +private struct KeyedContainer: KeyedDecodingContainerProtocol { /// The row to decode from. - let row: DatabaseRow + let row: SQLRow + let decoder: SQLRowDecoder + let keyMapping: DatabaseKeyMapping + let jsonDecoder: JSONDecoder // MARK: KeyedDecodingContainerProtocol var codingPath: [CodingKey] = [] - var allKeys: [Key] { [] } + var allKeys: [Key] = [] func contains(_ key: Key) -> Bool { - self.row.allColumns.contains(self.string(for: key)) + row.columns.contains(string(for: key)) } func decodeNil(forKey key: Key) throws -> Bool { - try self.row.getField(column: self.string(for: key)).value.isNil + let column = string(for: key) + return try row.get(column) == .null } func decode(_ type: Bool.Type, forKey key: Key) throws -> Bool { - try self.row.getField(column: self.string(for: key)).bool() + let column = string(for: key) + return try row.get(column).bool(column) } func decode(_ type: String.Type, forKey key: Key) throws -> String { - try self.row.getField(column: self.string(for: key)).string() + let column = string(for: key) + return try row.get(column).string(column) } func decode(_ type: Double.Type, forKey key: Key) throws -> Double { - try self.row.getField(column: self.string(for: key)).double() + let column = string(for: key) + return try row.get(column).double(column) } func decode(_ type: Float.Type, forKey key: Key) throws -> Float { - Float(try self.row.getField(column: self.string(for: key)).double()) + let column = string(for: key) + return Float(try row.get(column).double(column)) } func decode(_ type: Int.Type, forKey key: Key) throws -> Int { - try self.row.getField(column: self.string(for: key)).int() + let column = string(for: key) + return try row.get(column).int(column) } func decode(_ type: Int8.Type, forKey key: Key) throws -> Int8 { - Int8(try self.row.getField(column: self.string(for: key)).int()) + let column = string(for: key) + return Int8(try row.get(column).int(column)) } func decode(_ type: Int16.Type, forKey key: Key) throws -> Int16 { - Int16(try self.row.getField(column: self.string(for: key)).int()) + let column = string(for: key) + return Int16(try row.get(column).int(column)) } func decode(_ type: Int32.Type, forKey key: Key) throws -> Int32 { - Int32(try self.row.getField(column: self.string(for: key)).int()) + let column = string(for: key) + return Int32(try row.get(column).int(column)) } func decode(_ type: Int64.Type, forKey key: Key) throws -> Int64 { - Int64(try self.row.getField(column: self.string(for: key)).int()) + let column = string(for: key) + return Int64(try row.get(column).int(column)) } func decode(_ type: UInt.Type, forKey key: Key) throws -> UInt { - UInt(try self.row.getField(column: self.string(for: key)).int()) + let column = string(for: key) + return UInt(try row.get(column).int(column)) } func decode(_ type: UInt8.Type, forKey key: Key) throws -> UInt8 { - UInt8(try self.row.getField(column: self.string(for: key)).int()) + let column = string(for: key) + return UInt8(try row.get(column).int(column)) } func decode(_ type: UInt16.Type, forKey key: Key) throws -> UInt16 { - UInt16(try self.row.getField(column: self.string(for: key)).int()) + let column = string(for: key) + return UInt16(try row.get(column).int(column)) } func decode(_ type: UInt32.Type, forKey key: Key) throws -> UInt32 { - UInt32(try self.row.getField(column: self.string(for: key)).int()) + let column = string(for: key) + return UInt32(try row.get(column).int(column)) } func decode(_ type: UInt64.Type, forKey key: Key) throws -> UInt64 { - UInt64(try self.row.getField(column: self.string(for: key)).int()) + let column = string(for: key) + return UInt64(try row.get(column).int(column)) } func decode(_ type: T.Type, forKey key: Key) throws -> T where T : Decodable { + let column = string(for: key) if type == UUID.self { - return try self.row.getField(column: self.string(for: key)).uuid() as! T + return try row.get(column).uuid(column) as! T } else if type == Date.self { - return try self.row.getField(column: self.string(for: key)).date() as! T + return try row.get(column).date(column) as! T } else if type is AnyBelongsTo.Type { - let field = try self.row.getField(column: self.string(for: key, includeIdSuffix: true)) - return try T(from: DatabaseFieldDecoder(field: field)) + // need relationship mapping + let belongsToColumn = string(for: key, includeIdSuffix: true) + let value = row.columns.contains(belongsToColumn) ? try row.get(belongsToColumn) : nil + return try (type as! AnyBelongsTo.Type).init(from: value) as! T } else if type is AnyHas.Type { - // Special case the `AnyHas` to decode dummy data. - let field = DatabaseField(column: "key", value: .string(key.stringValue)) - return try T(from: DatabaseFieldDecoder(field: field)) + return try T(from: decoder) } else if type is AnyModelEnum.Type { - let field = try self.row.getField(column: self.string(for: key)) - return try T(from: DatabaseFieldDecoder(field: field)) - } else { - let field = try self.row.getField(column: self.string(for: key)) - return try M.jsonDecoder.decode(T.self, from: field.json()) + let field = try row.get(column) + return try (type as! AnyModelEnum.Type).init(from: field) as! T } + + let field = try row.get(column) + return try jsonDecoder.decode(T.self, from: field.json(column)) } - func nestedContainer( - keyedBy type: NestedKey.Type, forKey key: Key - ) throws -> KeyedDecodingContainer where NestedKey : CodingKey { + func nestedContainer(keyedBy type: NestedKey.Type, forKey key: Key) throws -> KeyedDecodingContainer where NestedKey : CodingKey { throw DatabaseCodingError("Nested decoding isn't supported.") } @@ -165,6 +173,6 @@ private struct KeyedContainer: KeyedDecodingContainerP /// - Returns: The column name that `key` is mapped to. private func string(for key: Key, includeIdSuffix: Bool = false) -> String { let value = key.stringValue + (includeIdSuffix ? "Id" : "") - return M.keyMapping.map(input: value) + return keyMapping.map(input: value) } } diff --git a/Sources/Alchemy/SQL/Rune/Model/Fields/Model+Fields.swift b/Sources/Alchemy/SQL/Rune/Model/Fields/Model+Fields.swift new file mode 100644 index 00000000..45f26c62 --- /dev/null +++ b/Sources/Alchemy/SQL/Rune/Model/Fields/Model+Fields.swift @@ -0,0 +1,12 @@ +extension Model { + /// Returns an ordered dictionary of column names to `Parameter` + /// values, appropriate for working with the QueryBuilder. + /// + /// - Throws: A `DatabaseCodingError` if there is an error + /// creating any of the fields of this instance. + /// - Returns: An ordered dictionary mapping column names to + /// parameters for use in a QueryBuilder `Query`. + public func fields() throws -> [String: SQLValue] { + try ModelFieldReader(Self.keyMapping).getFields(of: self) + } +} diff --git a/Sources/Alchemy/SQL/Rune/Model/Fields/ModelFieldReader.swift b/Sources/Alchemy/SQL/Rune/Model/Fields/ModelFieldReader.swift new file mode 100644 index 00000000..0c10c31b --- /dev/null +++ b/Sources/Alchemy/SQL/Rune/Model/Fields/ModelFieldReader.swift @@ -0,0 +1,113 @@ +import Foundation + +/// Used so `Relationship` types can know not to encode themselves to +/// a `SQLEncoder`. +protocol SQLEncoder: Encoder {} + +/// Used for turning any `Model` into an ordered dictionary of columns to +/// `SQLValue`s based on its stored properties. +final class ModelFieldReader: SQLEncoder { + /// Used for keeping track of the database fields pulled off the + /// object encoded to this encoder. + fileprivate var readFields: [(column: String, value: SQLValue)] = [] + + /// The mapping strategy for associating `CodingKey`s on an object + /// with column names in a database. + fileprivate let mappingStrategy: DatabaseKeyMapping + + // MARK: Encoder + + var codingPath = [CodingKey]() + var userInfo: [CodingUserInfoKey: Any] = [:] + + /// Create with an associated `DatabasekeyMapping`. + /// + /// - Parameter mappingStrategy: The strategy for mapping `CodingKey` string + /// values to SQL columns. + init(_ mappingStrategy: DatabaseKeyMapping) { + self.mappingStrategy = mappingStrategy + } + + /// Read and return the stored properties of an `Model` object. + /// + /// - Parameter value: The `Model` instance to read from. + /// - Throws: A `DatabaseCodingError` if there is an error reading + /// fields from `value`. + /// - Returns: An ordered dictionary of the model's columns and values. + func getFields(of model: M) throws -> [String: SQLValue] { + try model.encode(to: self) + let toReturn = Dictionary(uniqueKeysWithValues: readFields) + readFields = [] + return toReturn + } + + func container(keyedBy: Key.Type) -> KeyedEncodingContainer { + KeyedEncodingContainer(_KeyedEncodingContainer(encoder: self, codingPath: codingPath)) + } + + func unkeyedContainer() -> UnkeyedEncodingContainer { + fatalError("`Model`s should never encode to an unkeyed container.") + } + + func singleValueContainer() -> SingleValueEncodingContainer { + fatalError("`Model`s should never encode to a single value container.") + } +} + +private struct _KeyedEncodingContainer: KeyedEncodingContainerProtocol, ModelValueReader { + var encoder: ModelFieldReader + + // MARK: KeyedEncodingContainerProtocol + + var codingPath = [CodingKey]() + + mutating func encodeNil(forKey key: Key) throws { + let keyString = encoder.mappingStrategy.map(input: key.stringValue) + encoder.readFields.append((keyString, SQLValue.null)) + } + + mutating func encode(_ value: T, forKey key: Key) throws { + guard !(value is AnyBelongsTo) else { + let keyString = encoder.mappingStrategy.map(input: key.stringValue + "Id") + if let idValue = (value as? AnyBelongsTo)?.idValue { + encoder.readFields.append((keyString, idValue)) + } + + return + } + + let keyString = encoder.mappingStrategy.map(input: key.stringValue) + guard let convertible = value as? SQLValueConvertible else { + // Assume anything else is JSON. + let jsonData = try M.jsonEncoder.encode(value) + encoder.readFields.append((column: keyString, value: .json(jsonData))) + return + } + + encoder.readFields.append((column: keyString, value: convertible.value)) + } + + mutating func nestedContainer(keyedBy keyType: NestedKey.Type, forKey key: Key) -> KeyedEncodingContainer where NestedKey: CodingKey { + fatalError("Nested coding of `Model` not supported.") + } + + mutating func nestedUnkeyedContainer(forKey key: Key) -> UnkeyedEncodingContainer { + fatalError("Nested coding of `Model` not supported.") + } + + mutating func superEncoder() -> Encoder { + fatalError("Superclass encoding of `Model` not supported.") + } + + mutating func superEncoder(forKey key: Key) -> Encoder { + fatalError("Superclass encoding of `Model` not supported.") + } +} + +/// Used for passing along the type of the `Model` various containers +/// are working with so that the `Model`'s custom encoders can be +/// used. +private protocol ModelValueReader { + /// The `Model` type this field reader is reading from. + associatedtype M: Model +} diff --git a/Sources/Alchemy/SQL/Rune/Model/Model+CRUD.swift b/Sources/Alchemy/SQL/Rune/Model/Model+CRUD.swift new file mode 100644 index 00000000..42cf0ae5 --- /dev/null +++ b/Sources/Alchemy/SQL/Rune/Model/Model+CRUD.swift @@ -0,0 +1,315 @@ +import NIO + +/// Useful extensions for various CRUD operations of a `Model`. +extension Model { + + // MARK: - Fetch + + /// Load all models of this type from a database. + /// + /// - Parameter db: The database to load models from. Defaults to + /// `Database.default`. + /// - Returns: An array of this model, loaded from the database. + public static func all(db: Database = DB) async throws -> [Self] { + try await Self.query(database: db).get() + } + + /// Fetch the first model with the given id. + /// + /// - Parameters: + /// - db: The database to fetch the model from. Defaults to + /// `Database.default`. + /// - id: The id of the model to find. + /// - Returns: A matching model, if one exists. + public static func find(_ id: Self.Identifier, db: Database = DB) async throws -> Self? { + try await Self.firstWhere("id" == id, db: db) + } + + /// Fetch the first model that matches the given where clause. + /// + /// - Parameters: + /// - where: A where clause for filtering models. + /// - db: The database to fetch the model from. Defaults to + /// `Database.default`. + /// - Returns: A matching model, if one exists. + public static func find(_ where: Query.Where, db: Database = DB) async throws -> Self? { + try await Self.firstWhere(`where`, db: db) + } + + /// Fetch the first model with the given id, throwing the given + /// error if it doesn't exist. + /// + /// - Parameters: + /// - db: The database to delete the model from. Defaults to + /// `Database.default`. + /// - id: The id of the model to delete. + /// - error: An error to throw if the model doesn't exist. + /// - Returns: A matching model. + public static func find(db: Database = DB, _ id: Self.Identifier, or error: Error) async throws -> Self { + try await Self.firstWhere("id" == id, db: db).unwrap(or: error) + } + + /// Fetch the first model of this type. + /// + /// - Parameters: db: The database to search the model for. + /// Defaults to `Database.default`. + /// - Returns: The first model, if one exists. + public static func first(db: Database = DB) async throws -> Self? { + try await Self.query().first() + } + + /// Returns a random model of this type, if one exists. + public static func random() async throws -> Self? { + // Note; MySQL should be `RAND()` + try await Self.query().select().orderBy("RANDOM()").limit(1).first() + } + + /// Gets the first element that meets the given where value. + /// + /// - Parameters: + /// - where: The table will be queried for a row matching this + /// clause. + /// - db: The database to query. Defaults to `Database.default`. + /// - Returns: The first result matching the `where` clause, if + /// one exists. + public static func firstWhere(_ where: Query.Where, db: Database = DB) async throws -> Self? { + try await Self.query(database: db).where(`where`).first() + } + + /// Gets all elements that meets the given where value. + /// + /// - Parameters: + /// - where: The table will be queried for a row matching this + /// clause. + /// - db: The database to query. Defaults to `Database.default`. + /// - Returns: All the models matching the `where` clause. + public static func allWhere(_ where: Query.Where, db: Database = DB) async throws -> [Self] { + try await Self.where(`where`, db: db).get() + } + + /// Gets the first element that meets the given where value. + /// Throws an error if no results match. The opposite of + /// `ensureNotExists(...)`. + /// + /// - Parameters: + /// - where: The table will be queried for a row matching this + /// clause. + /// - error: The error to throw if there are no results. + /// - db: The database to query. Defaults to `Database.default`. + /// - Returns: The first result matching the `where` clause. + public static func unwrapFirstWhere(_ where: Query.Where, or error: Error, db: Database = DB) async throws -> Self { + try await Self.where(`where`, db: db).unwrapFirst(or: error) + } + + /// Creates a query on the given model with the given where + /// clause. + /// + /// - Parameters: + /// - where: A clause to match. + /// - db: The database to query. Defaults to `Database.default`. + /// - Returns: A query on the `Model`'s table that matches the + /// given where clause. + public static func `where`(_ where: Query.Where, db: Database = DB) -> ModelQuery { + Self.query(database: db).where(`where`) + } + + // MARK: - Insert + + /// Inserts this model to a database. + /// + /// - Parameter db: The database to insert this model to. Defaults + /// to `Database.default`. + public func insert(db: Database = DB) async throws { + try await Self.query(database: db).insert(fields()) + } + + /// Inserts this model to a database. Return the newly created model. + /// + /// - Parameter db: The database to insert this model to. Defaults + /// to `Database.default`. + /// - Returns: An updated version of this model, reflecting any + /// changes that may have occurred saving this object to the + /// database. (an `id` being populated, for example). + public func insertReturn(db: Database = DB) async throws -> Self { + try await Self.query(database: db) + .insertReturn(try fields()) + .first + .unwrap(or: RuneError.notFound) + .decode(Self.self) + } + + // MARK: - Update + + /// Update this model in a database. + /// + /// - Parameter db: The database to update this model to. Defaults + /// to `Database.default`. + /// - Returns: An updated version of this model, reflecting any + /// changes that may have occurred saving this object to the + /// database. + @discardableResult + public func update(db: Database = DB) async throws -> Self { + let id = try getID() + let fields = try fields() + try await Self.query(database: db).where("id" == id).update(values: fields) + return self + } + + @discardableResult + public func update(db: Database = DB, updateClosure: (inout Self) -> Void) async throws -> Self { + let id = try self.getID() + var copy = self + updateClosure(©) + let fields = try copy.fields() + try await Self.query(database: db).where("id" == id).update(values: fields) + return copy + } + + @discardableResult + public static func update(db: Database = DB, _ id: Identifier, with dict: [String: Any]) async throws -> Self? { + try await Self.find(id)?.update(with: dict) + } + + @discardableResult + public func update(db: Database = DB, with dict: [String: Any]) async throws -> Self { + let updateValues = dict.compactMapValues { $0 as? SQLValueConvertible } + try await Self.query().where("id" == id).update(values: updateValues) + return try await sync() + } + + // MARK: - Save + + /// Saves this model to a database. If this model's `id` is nil, + /// it inserts it. If the `id` is not nil, it updates. + /// + /// - Parameter db: The database to save this model to. Defaults + /// to `Database.default`. + /// - Returns: An updated version of this model, reflecting any + /// changes that may have occurred saving this object to the + /// database (an `id` being populated, for example). + @discardableResult + public func save(db: Database = DB) async throws -> Self { + guard id != nil else { + return try await insertReturn(db: db) + } + + return try await update(db: db) + } + + // MARK: - Delete + + /// Delete all models that match the given where clause. + /// + /// - Parameters: + /// - db: The database to fetch the model from. Defaults to + /// `Database.default`. + /// - where: A where clause to filter models. + public static func delete(_ where: Query.Where, db: Database = DB) async throws { + try await query().where(`where`).delete() + } + + /// Delete the first model with the given id. + /// + /// - Parameters: + /// - db: The database to delete the model from. Defaults to + /// `Database.default`. + /// - id: The id of the model to delete. + public static func delete(db: Database = DB, _ id: Self.Identifier) async throws { + try await query().where("id" == id).delete() + } + + /// Delete all models of this type from a database. + /// + /// - Parameter + /// - db: The database to delete models from. Defaults + /// to `Database.default`. + /// - where: An optional where clause to specify the elements + /// to delete. + public static func deleteAll(db: Database = DB, where: Query.Where? = nil) async throws { + var query = Self.query(database: db) + if let clause = `where` { query = query.where(clause) } + try await query.delete() + } + + /// Deletes this model from a database. This will fail if the + /// model has a nil `id` field. + /// + /// - Parameter db: The database to remove this model from. + /// Defaults to `Database.default`. + public func delete(db: Database = DB) async throws { + try await Self.query(database: db).where("id" == id).delete() + } + + // MARK: - Sync + + /// Fetches an copy of this model from a database, with any + /// updates that may have been made since it was last + /// fetched. + /// + /// - Parameter db: The database to load from. Defaults to + /// `Database.default`. + /// - Returns: A freshly synced copy of this model. + public func sync(db: Database = DB, query: ((ModelQuery) -> ModelQuery) = { $0 }) async throws -> Self { + try await query(Self.query(database: db).where("id" == id)) + .first() + .unwrap(or: RuneError.syncErrorNoMatch(table: Self.tableName, id: id)) + } + + // MARK: - Misc + + /// Throws an error if a query with the specified where clause + /// returns a value. The opposite of `unwrapFirstWhere(...)`. + /// + /// Useful for detecting if a value with a key that may conflict + /// (such as a unique email) already exists on a table. + /// + /// - Parameters: + /// - where: The where clause to attempt to match. + /// - error: The error that will be thrown, should a query with + /// the where clause find a result. + /// - db: The database to query. Defaults to `Database.default`. + public static func ensureNotExists(_ where: Query.Where, else error: Error, db: Database = DB) async throws { + try await Self.query(database: db).where(`where`).firstRow() + .map { _ in throw error } + } +} + +// MARK: - Array Extensions + +/// Usefuly extensions for CRUD operations on an array of `Model`s. +extension Array where Element: Model { + /// Inserts each element in this array to a database. + /// + /// - Parameter db: The database to insert the models into. + /// Defaults to `Database.default`. + /// - Returns: All models in array, updated to reflect any changes + /// in the model caused by inserting. + public func insertAll(db: Database = DB) async throws { + try await Element.query(database: db) + .insert(try self.map { try $0.fields().mapValues { $0 } }) + } + + /// Inserts and returns each element in this array to a database. + /// + /// - Parameter db: The database to insert the models into. + /// Defaults to `Database.default`. + /// - Returns: All models in array, updated to reflect any changes + /// in the model caused by inserting. + public func insertReturnAll(db: Database = DB) async throws -> Self { + try await Element.query(database: db) + .insertReturn(try self.map { try $0.fields().mapValues { $0 } }) + .map { try $0.decode(Element.self) } + } + + /// Deletes all objects in this array from a database. If an + /// object in this array isn't actually in the database, it + /// will be ignored. + /// + /// - Parameter db: The database to delete from. Defaults to + /// `Database.default`. + public func deleteAll(db: Database = DB) async throws { + _ = try await Element.query(database: db) + .where(key: "id", in: self.compactMap { $0.id }) + .delete() + } +} diff --git a/Sources/Alchemy/Rune/Relationships/Model+PrimaryKey.swift b/Sources/Alchemy/SQL/Rune/Model/Model+PrimaryKey.swift similarity index 50% rename from Sources/Alchemy/Rune/Relationships/Model+PrimaryKey.swift rename to Sources/Alchemy/SQL/Rune/Model/Model+PrimaryKey.swift index 716193c6..8b0eacf8 100644 --- a/Sources/Alchemy/Rune/Relationships/Model+PrimaryKey.swift +++ b/Sources/Alchemy/SQL/Rune/Model/Model+PrimaryKey.swift @@ -1,5 +1,36 @@ import Foundation +/// Represents a type that may be a primary key in a database. Out of +/// the box `UUID`, `String` and `Int` are supported but you can +/// easily support your own by conforming to this protocol. +public protocol PrimaryKey: Hashable, SQLValueConvertible, Codable { + /// Initialize this value from an `SQLValue`. + /// + /// - Throws: If there is an error decoding this type from the + /// given database value. + /// - Parameter field: The field with which this type should be + /// initialzed from. + init(value: SQLValue) throws +} + +extension UUID: PrimaryKey { + public init(value: SQLValue) throws { + self = try value.uuid() + } +} + +extension Int: PrimaryKey { + public init(value: SQLValue) throws { + self = try value.int() + } +} + +extension String: PrimaryKey { + public init(value: SQLValue) throws { + self = try value.string() + } +} + extension Model { /// Initialize this model from a primary key. All other fields /// will be populated with dummy data. Useful for setting a @@ -14,7 +45,7 @@ extension Model { } } -private struct DummyDecoder: Decoder { +struct DummyDecoder: Decoder { var codingPath: [CodingKey] = [] var userInfo: [CodingUserInfoKey : Any] = [:] @@ -24,165 +55,11 @@ private struct DummyDecoder: Decoder { } func unkeyedContainer() throws -> UnkeyedDecodingContainer { - Unkeyed() + throw RuneError("Unkeyed containers aren't supported yet.") } func singleValueContainer() throws -> SingleValueDecodingContainer { - Single() - } -} - -private struct Single: SingleValueDecodingContainer { - var codingPath: [CodingKey] = [] - - func decodeNil() -> Bool { - false - } - - func decode(_ type: Bool.Type) throws -> Bool { - true - } - - func decode(_ type: String.Type) throws -> String { - "foo" - } - - func decode(_ type: Double.Type) throws -> Double { - 0 - } - - func decode(_ type: Float.Type) throws -> Float { - 0 - } - - func decode(_ type: Int.Type) throws -> Int { - 0 - } - - func decode(_ type: Int8.Type) throws -> Int8 { - 0 - } - - func decode(_ type: Int16.Type) throws -> Int16 { - 0 - } - - func decode(_ type: Int32.Type) throws -> Int32 { - 0 - } - - func decode(_ type: Int64.Type) throws -> Int64 { - 0 - } - - func decode(_ type: UInt.Type) throws -> UInt { - 0 - } - - func decode(_ type: UInt8.Type) throws -> UInt8 { - 0 - } - - func decode(_ type: UInt16.Type) throws -> UInt16 { - 0 - } - - func decode(_ type: UInt32.Type) throws -> UInt32 { - 0 - } - - func decode(_ type: UInt64.Type) throws -> UInt64 { - 0 - } - - func decode(_ type: T.Type) throws -> T where T : Decodable { - try T(from: DummyDecoder()) - } -} - -private struct Unkeyed: UnkeyedDecodingContainer { - var codingPath: [CodingKey] = [] - - var count: Int? = nil - - var isAtEnd: Bool = false - - var currentIndex: Int = 0 - - mutating func decodeNil() throws -> Bool { - false - } - - mutating func decode(_ type: Bool.Type) throws -> Bool { - true - } - - mutating func decode(_ type: String.Type) throws -> String { - "foo" - } - - mutating func decode(_ type: Double.Type) throws -> Double { - 0 - } - - mutating func decode(_ type: Float.Type) throws -> Float { - 0 - } - - mutating func decode(_ type: Int.Type) throws -> Int { - 0 - } - - mutating func decode(_ type: Int8.Type) throws -> Int8 { - 0 - } - - mutating func decode(_ type: Int16.Type) throws -> Int16 { - 0 - } - - mutating func decode(_ type: Int32.Type) throws -> Int32 { - 0 - } - - mutating func decode(_ type: Int64.Type) throws -> Int64 { - 0 - } - - mutating func decode(_ type: UInt.Type) throws -> UInt { - 0 - } - - mutating func decode(_ type: UInt8.Type) throws -> UInt8 { - 0 - } - - mutating func decode(_ type: UInt16.Type) throws -> UInt16 { - 0 - } - - mutating func decode(_ type: UInt32.Type) throws -> UInt32 { - 0 - } - - mutating func decode(_ type: UInt64.Type) throws -> UInt64 { - 0 - } - - mutating func decode(_ type: T.Type) throws -> T where T : Decodable { - try T(from: DummyDecoder()) - } - - mutating func nestedContainer(keyedBy type: NestedKey.Type) throws -> KeyedDecodingContainer where NestedKey : CodingKey { - throw RuneError("`DummyDecoder` doesn't support nested keyed containers yet.") - } - - mutating func nestedUnkeyedContainer() throws -> UnkeyedDecodingContainer { - throw RuneError("`DummyDecoder` doesn't support nested unkeyed containers yet.") - } - - mutating func superDecoder() throws -> Decoder { - throw RuneError("`DummyDecoder` doesn't support super decoders yet.") + throw RuneError("Single value containers aren't supported yet, if you're using an enum, please conform it to `ModelEnum`.") } } @@ -261,9 +138,15 @@ private struct Keyed: KeyedDecodingContainerProtocol { return (type as! AnyModelEnum.Type).defaultCase as! T } else if type is AnyArray.Type { return [] as! T - } else { - return try T(from: DummyDecoder()) + } else if type is AnyBelongsTo.Type { + return try (type as! AnyBelongsTo.Type).init(from: nil) as! T + } else if type is UUID.Type { + return UUID() as! T + } else if type is Date.Type { + return Date() as! T } + + return try T(from: DummyDecoder()) } func nestedContainer(keyedBy type: NestedKey.Type, forKey key: K) throws -> KeyedDecodingContainer where NestedKey : CodingKey { diff --git a/Sources/Alchemy/Rune/Model/Model.swift b/Sources/Alchemy/SQL/Rune/Model/Model.swift similarity index 78% rename from Sources/Alchemy/Rune/Model/Model.swift rename to Sources/Alchemy/SQL/Rune/Model/Model.swift index 9c6e9f73..c8346ff3 100644 --- a/Sources/Alchemy/Rune/Model/Model.swift +++ b/Sources/Alchemy/SQL/Rune/Model/Model.swift @@ -88,34 +88,3 @@ extension Model { try self.id.unwrap(or: DatabaseError("Object of type \(type(of: self)) had a nil id.")) } } - -/// Represents a type that may be a primary key in a database. Out of -/// the box `UUID`, `String` and `Int` are supported but you can -/// easily support your own by conforming to this protocol. -public protocol PrimaryKey: Hashable, Parameter, Codable { - /// Initialize this value from a `DatabaseField`. - /// - /// - Throws: If there is an error decoding this type from the - /// given database value. - /// - Parameter field: The field with which this type should be - /// initialzed from. - init(field: DatabaseField) throws -} - -extension UUID: PrimaryKey { - public init(field: DatabaseField) throws { - self = try field.uuid() - } -} - -extension Int: PrimaryKey { - public init(field: DatabaseField) throws { - self = try field.int() - } -} - -extension String: PrimaryKey { - public init(field: DatabaseField) throws { - self = try field.string() - } -} diff --git a/Sources/Alchemy/SQL/Rune/Model/ModelEnum.swift b/Sources/Alchemy/SQL/Rune/Model/ModelEnum.swift new file mode 100644 index 00000000..a0d85a5b --- /dev/null +++ b/Sources/Alchemy/SQL/Rune/Model/ModelEnum.swift @@ -0,0 +1,54 @@ +/// A protocol to which enums on `Model`s should conform to. The enum +/// will be modeled in the backing table by it's raw value. +/// +/// Usage: +/// ```swift +/// enum TaskPriority: Int, ModelEnum { +/// case low, medium, high +/// } +/// +/// struct Todo: Model { +/// var id: Int? +/// let name: String +/// let isDone: Bool +/// let priority: TaskPriority // Stored as `Int` in the database. +/// } +/// ``` +public protocol ModelEnum: AnyModelEnum, CaseIterable {} + +/// A type erased `ModelEnum`. +public protocol AnyModelEnum: Codable, SQLValueConvertible { + init(from sqlValue: SQLValue) throws + + /// The default case of this enum. Defaults to the first of + /// `Self.allCases`. + static var defaultCase: Self { get } +} + +extension ModelEnum { + public static var defaultCase: Self { Self.allCases.first! } +} + +extension AnyModelEnum where Self: RawRepresentable, RawValue == String { + public init(from sqlValue: SQLValue) throws { + let string = try sqlValue.string() + self = try Self(rawValue: string) + .unwrap(or: DatabaseCodingError("Error decoding \(name(of: Self.self)) from \(string)")) + } +} + +extension AnyModelEnum where Self: RawRepresentable, RawValue == Int { + public init(from sqlValue: SQLValue) throws { + let int = try sqlValue.int() + self = try Self(rawValue: int) + .unwrap(or: DatabaseCodingError("Error decoding \(name(of: Self.self)) from \(int)")) + } +} + +extension AnyModelEnum where Self: RawRepresentable, RawValue == Double { + public init(from sqlValue: SQLValue) throws { + let double = try sqlValue.double() + self = try Self(rawValue: double) + .unwrap(or: DatabaseCodingError("Error decoding \(name(of: Self.self)) from \(double)")) + } +} diff --git a/Sources/Alchemy/SQL/Rune/Model/ModelQuery.swift b/Sources/Alchemy/SQL/Rune/Model/ModelQuery.swift new file mode 100644 index 00000000..fb82946a --- /dev/null +++ b/Sources/Alchemy/SQL/Rune/Model/ModelQuery.swift @@ -0,0 +1,211 @@ +import Foundation +import NIO + +public extension Model { + /// Begin a `ModelQuery` from a given database. + /// + /// - Parameter database: The database to run the query on. + /// Defaults to `Database.default`. + /// - Returns: A builder for building your query. + static func query(database: Database = DB) -> ModelQuery { + ModelQuery(database: database.provider, table: Self.tableName) + } +} + +/// A `ModelQuery` is just a subclass of `Query` with some added +/// typing and convenience functions for querying the table of +/// a specific `Model`. +public class ModelQuery: Query { + /// A closure for defining any nested eager loading when loading a + /// relationship on this `Model`. + /// + /// "Eager loading" refers to loading a model at the other end of + /// a relationship of this queried model. Nested eager loads + /// refers to loading a model from a relationship on that + /// _other_ model. + public typealias NestedEagerLoads = (ModelQuery) -> ModelQuery + + private typealias ModelRow = (model: M, row: SQLRow) + + /// The closures of any eager loads to run. To be run after the + /// initial models of type `Self` are fetched. + /// + /// - Warning: Right now these only run when the query is + /// finished with `allModels` or `firstModel`. If the user + /// finishes a query with a `get()` we don't know if/when the + /// decode will happen and how to handle it. A potential ways + /// of doing this could be to call eager loading @ the + /// `.decode` level of a `SQLRow`, but that's too + /// complicated for now). + private var eagerLoadQueries: [([ModelRow]) async throws -> [ModelRow]] = [] + + /// Gets all models matching this query from the database. + /// + /// - Returns: All models matching this query. + public func get() async throws -> [M] { + try await _get().map(\.model) + } + + private func _get(columns: [String]? = ["\(M.tableName).*"]) async throws -> [ModelRow] { + let initialResults = try await getRows(columns).map { (try $0.decode(M.self), $0) } + return try await evaluateEagerLoads(for: initialResults) + } + + /// Get the first model matching this query from the database. + /// + /// - Returns: The first model matching this query if one exists. + public func first() async throws -> M? { + guard let result = try await firstRow() else { + return nil + } + + return try await evaluateEagerLoads(for: [(result.decode(M.self), result)]).first?.0 + } + + /// Similar to `firstModel`. Gets the first result of a query, but + /// unwraps the element, throwing an error if it doesn't exist. + /// + /// - Parameter error: The error to throw should no element be + /// found. Defaults to `RuneError.notFound`. + /// - Returns: The unwrapped first result of this query, or the + /// supplied error if no result was found. + public func unwrapFirst(or error: Error = RuneError.notFound) async throws -> M { + try await first().unwrap(or: error) + } + + /// Eager loads (loads a related `Model`) a `Relationship` on this + /// model. + /// + /// Eager loads are evaluated in a single query per eager load + /// after the initial model query has completed. + /// + /// Usage: + /// ```swift + /// // Consider three types, `Pet`, `Person`, and `Plant`. They + /// // have the following relationships: + /// struct Pet: Model { + /// ... + /// @BelongsTo var owner: Person + /// } + /// + /// struct Person: Model { + /// ... + /// @BelongsTo var favoritePlant: Plant + /// } + /// + /// struct Plant: Model { ... } + /// + /// // A `Pet` query that loads each pet's related owner _as well_ + /// // as those owners' favorite plants would look like this: + /// Pet.query() + /// // An eager load + /// .with(\.$owner) { ownerQuery in + /// // `ownerQuery` is the query that will be run when + /// // fetching owner objects; we can give it its own + /// // eager loads (aka nested eager loading) + /// ownerQuery.with(\.$favoritePlant) + /// } + /// .getAll() + /// ``` + /// - Parameters: + /// - relationshipKeyPath: The `KeyPath` of the relationship to + /// load. Please note that this is a `KeyPath` to a + /// `Relationship`, not a `Model`, so it will likely + /// start with a '$', such as `\.$user`. + /// - nested: A closure for any nested loading to do. See + /// example above. Defaults to an empty closure. + /// - Returns: A query builder for extending the query. + public func with( + _ relationshipKeyPath: KeyPath, + nested: @escaping NestedEagerLoads = { $0 } + ) -> ModelQuery where R.From == M { + eagerLoadQueries.append { fromResults in + let mapper = RelationshipMapper() + M.mapRelations(mapper) + let config = mapper.getConfig(for: relationshipKeyPath) + + // If there are no results, don't need to eager load. + guard !fromResults.isEmpty else { + return [] + } + + // Alias whatever key we'll join the relationship on + let toJoinKeyAlias = "_to_join_key" + let toJoinKey: String = { + let table = config.through?.table ?? config.toTable + let key = config.through?.fromKey ?? config.toKey + return "\(table).\(key) as \(toJoinKeyAlias)" + }() + + // Load the matching `To` rows + let allRows = fromResults.map(\.1) + let query = try nested(config.load(allRows, database: Database(provider: self.database))) + let toResults = try await query + ._get(columns: ["\(R.To.Value.tableName).*", toJoinKey]) + .map { (try R.To.from($0), $1) } + + // Key the results by the join key value + let toResultsKeyedByJoinKey = try Dictionary(grouping: toResults) { _, row in + try row.get(toJoinKeyAlias).value + } + + // For each `from` populate it's relationship + return try fromResults.map { model, row in + let pk = try row.get(config.fromKey).value + let models = toResultsKeyedByJoinKey[pk]?.map(\.0) ?? [] + try model[keyPath: relationshipKeyPath].set(values: models) + return (model, row) + } + } + + return self + } + + /// Evaluate all eager loads in this `ModelQuery` sequentially. + /// This occurs after the inital `M` query has completed. + /// + /// - Parameter models: The models that were loaded by the initial + /// query. + /// - Returns: The loaded models that will have all specified + /// relationships loaded. + private func evaluateEagerLoads(for models: [ModelRow]) async throws -> [ModelRow] { + var results: [ModelRow] = models + for query in eagerLoadQueries { + results = try await query(results) + } + + return results + } +} + +private extension RelationshipMapping { + func load(_ values: [SQLRow], database: Database) throws -> ModelQuery { + var query = M.query(database: database) + query.table = toTable + var whereKey = "\(toTable).\(toKey)" + if let through = through { + whereKey = "\(through.table).\(through.fromKey)" + query = query.leftJoin(table: through.table, first: "\(through.table).\(through.toKey)", second: "\(toTable).\(toKey)") + } + + let ids = try values.map { try $0.get(fromKey).value } + query = query.where(key: "\(whereKey)", in: ids.uniques) + return query + } +} + +private extension Array where Element: Hashable { + /// Removes any duplicates from the array while maintaining the + /// original order. + var uniques: Array { + var buffer = Array() + var added = Set() + for elem in self { + if !added.contains(elem) { + buffer.append(elem) + added.insert(elem) + } + } + return buffer + } +} diff --git a/Sources/Alchemy/Rune/Relationships/Model+Relationships.swift b/Sources/Alchemy/SQL/Rune/Relationships/Model+Relationships.swift similarity index 100% rename from Sources/Alchemy/Rune/Relationships/Model+Relationships.swift rename to Sources/Alchemy/SQL/Rune/Relationships/Model+Relationships.swift diff --git a/Sources/Alchemy/Rune/Relationships/PropertyWrappers/AnyRelationships.swift b/Sources/Alchemy/SQL/Rune/Relationships/PropertyWrappers/AnyRelationships.swift similarity index 68% rename from Sources/Alchemy/Rune/Relationships/PropertyWrappers/AnyRelationships.swift rename to Sources/Alchemy/SQL/Rune/Relationships/PropertyWrappers/AnyRelationships.swift index 160354eb..51f0b549 100644 --- a/Sources/Alchemy/Rune/Relationships/PropertyWrappers/AnyRelationships.swift +++ b/Sources/Alchemy/SQL/Rune/Relationships/PropertyWrappers/AnyRelationships.swift @@ -4,4 +4,8 @@ protocol AnyHas {} /// A type erased `BelongsToRelationship`. Used for special casing /// decoding behavior for `BelongsTo`s. -protocol AnyBelongsTo {} +protocol AnyBelongsTo { + var idValue: SQLValue? { get } + + init(from sqlValue: SQLValue?) throws +} diff --git a/Sources/Alchemy/Rune/Relationships/PropertyWrappers/BelongsToRelationship.swift b/Sources/Alchemy/SQL/Rune/Relationships/PropertyWrappers/BelongsToRelationship.swift similarity index 71% rename from Sources/Alchemy/Rune/Relationships/PropertyWrappers/BelongsToRelationship.swift rename to Sources/Alchemy/SQL/Rune/Relationships/PropertyWrappers/BelongsToRelationship.swift index aa454193..b85610e8 100644 --- a/Sources/Alchemy/Rune/Relationships/PropertyWrappers/BelongsToRelationship.swift +++ b/Sources/Alchemy/SQL/Rune/Relationships/PropertyWrappers/BelongsToRelationship.swift @@ -18,20 +18,19 @@ import NIO /// } /// ``` @propertyWrapper -public final class BelongsToRelationship< - Child: Model, - Parent: ModelMaybeOptional ->: AnyBelongsTo, Codable, Relationship { +public final class BelongsToRelationship: AnyBelongsTo, Relationship, Codable { public typealias From = Child public typealias To = Parent /// The identifier of this relationship's parent. - public var id: Parent.Value.Identifier! { + public var id: Parent.Value.Identifier? { didSet { - self.value = nil + value = nil } } + var idValue: SQLValue? { id.value } + /// The underlying relationship object, if there is one. Populated /// by eager loading. private var value: Parent? @@ -48,8 +47,8 @@ public final class BelongsToRelationship< } } set { - self.id = newValue.id - self.value = newValue + id = newValue.id + value = newValue } } @@ -66,8 +65,8 @@ public final class BelongsToRelationship< /// belongs. public init(wrappedValue: Parent) { do { - self.value = try Parent.from(wrappedValue) - self.id = value?.id + value = try Parent.from(wrappedValue) + id = value?.id } catch { fatalError("Error initializing `BelongsTo`; expected a value but got nil. Perhaps this relationship should be optional?") } @@ -86,7 +85,7 @@ public final class BelongsToRelationship< // MARK: Codable public func encode(to encoder: Encoder) throws { - if !(encoder is ModelEncoder) { + if !(encoder is SQLEncoder) { try value.encode(to: encoder) } else { // When encoding to the database, just encode the Parent's ID. @@ -96,23 +95,28 @@ public final class BelongsToRelationship< } public init(from decoder: Decoder) throws { - if !(decoder is ModelDecoder) { - let container = try decoder.singleValueContainer() - if container.decodeNil() { - id = nil - } else { - let parent = try Parent(from: decoder) - id = parent.id - value = parent - } + let container = try decoder.singleValueContainer() + if container.decodeNil() { + id = nil } else { - let container = try decoder.singleValueContainer() - if container.decodeNil() { - id = nil - } else { - // When decode from a database, just decode the Parent's ID. - id = try container.decode(Parent.Value.Identifier.self) - } + let parent = try Parent(from: decoder) + id = parent.id + value = parent } } + + init(from sqlValue: SQLValue?) throws { + guard sqlValue != .null else { + id = nil + return + } + + id = try sqlValue.map { try Parent.Value.Identifier.init(value: $0) } + } +} + +extension BelongsToRelationship: Equatable { + public static func == (lhs: BelongsToRelationship, rhs: BelongsToRelationship) -> Bool { + lhs.id == rhs.id + } } diff --git a/Sources/Alchemy/Rune/Relationships/PropertyWrappers/HasManyRelationship.swift b/Sources/Alchemy/SQL/Rune/Relationships/PropertyWrappers/HasManyRelationship.swift similarity index 75% rename from Sources/Alchemy/Rune/Relationships/PropertyWrappers/HasManyRelationship.swift rename to Sources/Alchemy/SQL/Rune/Relationships/PropertyWrappers/HasManyRelationship.swift index 43edce30..c0b661c6 100644 --- a/Sources/Alchemy/Rune/Relationships/PropertyWrappers/HasManyRelationship.swift +++ b/Sources/Alchemy/SQL/Rune/Relationships/PropertyWrappers/HasManyRelationship.swift @@ -4,10 +4,7 @@ import NIO /// relationship. The details of this relationship are defined /// in the initializers inherited from `HasRelationship`. @propertyWrapper -public final class HasManyRelationship< - From: Model, - To: ModelMaybeOptional ->: AnyHas, Codable, Relationship { +public final class HasManyRelationship: AnyHas, Relationship, Codable { /// Internal value for storing the `To` objects of this /// relationship, when they are loaded. fileprivate var value: [To]? @@ -17,12 +14,13 @@ public final class HasManyRelationship< /// or set manually. public var wrappedValue: [To] { get { - guard let value = self.value else { + guard let value = value else { fatalError("Relationship of type `\(name(of: To.self))` was not loaded!") } + return value } - set { self.value = newValue } + set { value = newValue } } /// The projected value of this property wrapper is itself. Used @@ -41,7 +39,7 @@ public final class HasManyRelationship< } public func set(values: [To]) throws { - self.wrappedValue = try values.map { try To.from($0) } + wrappedValue = try values.map { try To.from($0) } } // MARK: Codable @@ -49,12 +47,18 @@ public final class HasManyRelationship< public init(from decoder: Decoder) throws {} public func encode(to encoder: Encoder) throws { - if !(encoder is ModelEncoder) { - try self.value.encode(to: encoder) + if !(encoder is SQLEncoder) { + try value.encode(to: encoder) } } } +extension HasManyRelationship: Equatable where To: Equatable { + public static func == (lhs: HasManyRelationship, rhs: HasManyRelationship) -> Bool { + lhs.value == rhs.value + } +} + public extension KeyedEncodingContainer { // Only encode the underlying value if it exists. mutating func encode(_ value: HasManyRelationship, forKey key: Key) throws { diff --git a/Sources/Alchemy/Rune/Relationships/PropertyWrappers/HasOneRelationship.swift b/Sources/Alchemy/SQL/Rune/Relationships/PropertyWrappers/HasOneRelationship.swift similarity index 77% rename from Sources/Alchemy/Rune/Relationships/PropertyWrappers/HasOneRelationship.swift rename to Sources/Alchemy/SQL/Rune/Relationships/PropertyWrappers/HasOneRelationship.swift index afbd3f09..db14f256 100644 --- a/Sources/Alchemy/Rune/Relationships/PropertyWrappers/HasOneRelationship.swift +++ b/Sources/Alchemy/SQL/Rune/Relationships/PropertyWrappers/HasOneRelationship.swift @@ -4,10 +4,7 @@ import NIO /// relationship are defined in the initializers inherited from /// `HasRelationship`. @propertyWrapper -public final class HasOneRelationship< - From: Model, - To: ModelMaybeOptional ->: AnyHas, Codable, Relationship { +public final class HasOneRelationship: AnyHas, Codable, Relationship { /// Internal value for storing the `To` object of this /// relationship, when it is loaded. fileprivate var value: To? @@ -28,7 +25,7 @@ public final class HasOneRelationship< fatalError("Relationship of type `\(name(of: To.self))` was not loaded!") } } - set { self.value = newValue } + set { value = newValue } } // MARK: Overrides @@ -42,7 +39,7 @@ public final class HasOneRelationship< } public func set(values: [To]) throws { - self.wrappedValue = try To.from(values.first) + wrappedValue = try To.from(values.first) } // MARK: Codable @@ -50,12 +47,18 @@ public final class HasOneRelationship< public init(from decoder: Decoder) throws {} public func encode(to encoder: Encoder) throws { - if !(encoder is ModelEncoder) { - try self.value.encode(to: encoder) + if !(encoder is SQLEncoder) { + try value.encode(to: encoder) } } } +extension HasOneRelationship: Equatable where To: Equatable { + public static func == (lhs: HasOneRelationship, rhs: HasOneRelationship) -> Bool { + lhs.value == rhs.value + } +} + public extension KeyedEncodingContainer { // Only encode the underlying value if it exists. mutating func encode(_ value: HasOneRelationship, forKey key: Key) throws { diff --git a/Sources/Alchemy/Rune/Relationships/Relationship.swift b/Sources/Alchemy/SQL/Rune/Relationships/Relationship.swift similarity index 100% rename from Sources/Alchemy/Rune/Relationships/Relationship.swift rename to Sources/Alchemy/SQL/Rune/Relationships/Relationship.swift diff --git a/Sources/Alchemy/Rune/Relationships/RelationshipMapper.swift b/Sources/Alchemy/SQL/Rune/Relationships/RelationshipMapper.swift similarity index 80% rename from Sources/Alchemy/Rune/Relationships/RelationshipMapper.swift rename to Sources/Alchemy/SQL/Rune/Relationships/RelationshipMapper.swift index 43bdb43e..88e0fe1a 100644 --- a/Sources/Alchemy/Rune/Relationships/RelationshipMapper.swift +++ b/Sources/Alchemy/SQL/Rune/Relationships/RelationshipMapper.swift @@ -11,23 +11,23 @@ public final class RelationshipMapper { } func getConfig(for relation: KeyPath) -> RelationshipMapping { - if let rel = configs[relation] { - return rel as! RelationshipMapping - } else { + guard let rel = configs[relation] else { return R.defaultConfig() } + + return rel as! RelationshipMapping } } protocol AnyRelation {} /// Defines how a `Relationship` is mapped from it's `From` to `To`. -public final class RelationshipMapping: AnyRelation { +public final class RelationshipMapping: AnyRelation, Equatable { enum Kind { case has, belongs } - struct Through { + struct Through: Equatable { var table: String var fromKey: String var toKey: String @@ -43,26 +43,22 @@ public final class RelationshipMapping: AnyRelation { var toKey: String { toKeyOverride ?? toKeyAssumed } var type: Kind - var through: Through? { - didSet { - if oldValue != nil && through != nil { - fatalError("For now, only one through is allowed per relation.") - } - } - } + var through: Through? init( _ type: Kind, fromTable: String = From.tableName, fromKey: String = To.referenceKey, toTable: String = To.tableName, - toKey: String = From.referenceKey + toKey: String = From.referenceKey, + through: Through? = nil ) { self.type = type self.fromTable = fromTable self.fromKeyAssumed = fromKey self.toTable = toTable self.toKeyAssumed = toKey + self.through = through } @discardableResult @@ -103,6 +99,17 @@ public final class RelationshipMapping: AnyRelation { through = Through(table: table, fromKey: _from, toKey: _to) return self } + + public static func == (lhs: RelationshipMapping, rhs: RelationshipMapping) -> Bool { + lhs.fromTable == rhs.fromTable && + lhs.fromKeyAssumed == rhs.fromKeyAssumed && + lhs.fromKeyOverride == rhs.fromKeyOverride && + lhs.toTable == rhs.toTable && + lhs.toKeyAssumed == rhs.toKeyAssumed && + lhs.toKeyOverride == rhs.toKeyOverride && + lhs.type == rhs.type && + lhs.through == rhs.through + } } extension RelationshipMapping { diff --git a/Sources/Alchemy/Rune/RuneError.swift b/Sources/Alchemy/SQL/Rune/RuneError.swift similarity index 84% rename from Sources/Alchemy/Rune/RuneError.swift rename to Sources/Alchemy/SQL/Rune/RuneError.swift index 96f0c1db..1bcdea42 100644 --- a/Sources/Alchemy/Rune/RuneError.swift +++ b/Sources/Alchemy/SQL/Rune/RuneError.swift @@ -20,7 +20,8 @@ public struct RuneError: Error { public static let syncErrorNoId = RuneError("Can't .sync() an object with a nil `id`.") /// Failed to sync a model; it didn't exist in the database. - public static func syncErrorNoMatch(table: String, id: P) -> RuneError { - RuneError("Error syncing Model, didn't find a row with id '\(id)' on table '\(table)'.") + public static func syncErrorNoMatch(table: String, id: P?) -> RuneError { + let id = id.map { "\($0)" } ?? "nil" + return RuneError("Error syncing Model, didn't find a row with id '\(id)' on table '\(table)'.") } } diff --git a/Sources/Alchemy/Scheduler/DayOfWeek.swift b/Sources/Alchemy/Scheduler/DayOfWeek.swift new file mode 100644 index 00000000..b9e3fcfd --- /dev/null +++ b/Sources/Alchemy/Scheduler/DayOfWeek.swift @@ -0,0 +1,21 @@ +/// A day of the week. +public enum DayOfWeek: Int, ExpressibleByIntegerLiteral { + /// Sunday + case sun = 0 + /// Monday + case mon = 1 + /// Tuesday + case tue = 2 + /// Wednesday + case wed = 3 + /// Thursday + case thu = 4 + /// Friday + case fri = 5 + /// Saturday + case sat = 6 + + public init(integerLiteral value: Int) { + self = DayOfWeek(rawValue: value) ?? .sun + } +} diff --git a/Sources/Alchemy/Scheduler/Frequency.swift b/Sources/Alchemy/Scheduler/Frequency.swift deleted file mode 100644 index 8f6f7a15..00000000 --- a/Sources/Alchemy/Scheduler/Frequency.swift +++ /dev/null @@ -1,104 +0,0 @@ -import Foundation - -/// Represents a frequency that occurs at a `rate` and may have -/// specific requirements for when it should start running, -/// such as "every day at 9:30 am". -protocol Frequency { - /// A cron expression representing this frequency. - var cronExpression: String { get } -} - -// MARK: - TimeUnits - -/// A week of time. -public struct WeekUnit {} - -/// A day of time. -public struct DayUnit {} - -/// An hour of time. -public struct HourUnit {} - -/// A minute of time. -public struct MinuteUnit {} - -/// A second of time. -public struct SecondUnit {} - -// MARK: - Frequencies - -/// A generic frequency for handling amounts of time. -public struct FrequencyTyped: Frequency { - /// The frequency at which this work should be repeated. - let value: Int - - public var cronExpression: String - - fileprivate init(value: Int, cronExpression: String) { - self.value = value - self.cronExpression = cronExpression - } -} - -/// A frequency measured in a number of seconds. -public typealias Seconds = FrequencyTyped - -/// A frequency measured in a number of minutes. -public typealias Minutes = FrequencyTyped -extension Minutes { - /// When this frequency should first take place. - /// - /// - Parameter sec: A second of a minute (0-59). - /// - Returns: A minutely frequency that first takes place at the - /// given component. - public func at(sec: Int = 0) -> Minutes { - Minutes(value: self.value, cronExpression: "\(sec) */\(self.value) * * * *") - } -} - -/// A frequency measured in a number of hours. -public typealias Hours = FrequencyTyped -extension Hours { - /// When this frequency should first take place. - /// - /// - Parameters: - /// - min: A minute of an hour (0-59). - /// - sec: A second of a minute (0-59). - /// - Returns: An hourly frequency that first takes place at the - /// given components. - public func at(min: Int = 0, sec: Int = 0) -> Hours { - Hours(value: self.value, cronExpression: "\(sec) \(min) */\(self.value) * * * *") - } -} - -/// A frequency measured in a number of days. -public typealias Days = FrequencyTyped -extension Days { - /// When this frequency should first take place. - /// - /// - Parameters: - /// - hr: An hour of the day (0-23). - /// - min: A minute of an hour (0-59). - /// - sec: A second of a minute (0-59). - /// - Returns: A daily frequency that first takes place at the - /// given components. - public func at(hr: Int = 0, min: Int = 0, sec: Int = 0) -> Days { - Days(value: self.value, cronExpression: "\(sec) \(min) \(hr) */\(self.value) * * *") - } -} - -/// A frequency measured in a number of weeks. -public typealias Weeks = FrequencyTyped -extension Weeks { - /// When this frequency should first take place. - /// - /// - Parameters: - /// - hr: An hour of the day (0-23). - /// - min: A minute of an hour (0-59). - /// - sec: A second of a minute (0-59). - /// - Returns: A weekly frequency that first takes place at the - /// given components. - public func at(hr: Int = 0, min: Int = 0, sec: Int = 0) -> Weeks { - Weeks(value: self.value, cronExpression: "\(sec) \(min) \(hr) */\(self.value * 7) * * *") - } -} diff --git a/Sources/Alchemy/Scheduler/Month.swift b/Sources/Alchemy/Scheduler/Month.swift new file mode 100644 index 00000000..299dd61b --- /dev/null +++ b/Sources/Alchemy/Scheduler/Month.swift @@ -0,0 +1,31 @@ +/// A month of the year. +public enum Month: Int, ExpressibleByIntegerLiteral { + /// January + case jan = 1 + /// February + case feb = 2 + /// March + case mar = 3 + /// April + case apr = 4 + /// May + case may = 5 + /// June + case jun = 6 + /// July + case jul = 7 + /// August + case aug = 8 + /// September + case sep = 9 + /// October + case oct = 10 + /// November + case nov = 11 + /// December + case dec = 12 + + public init(integerLiteral value: Int) { + self = Month(rawValue: value) ?? .jan + } +} diff --git a/Sources/Alchemy/Scheduler/ScheduleBuilder.swift b/Sources/Alchemy/Scheduler/Schedule.swift similarity index 51% rename from Sources/Alchemy/Scheduler/ScheduleBuilder.swift rename to Sources/Alchemy/Scheduler/Schedule.swift index 3fa10185..c5eece27 100644 --- a/Sources/Alchemy/Scheduler/ScheduleBuilder.swift +++ b/Sources/Alchemy/Scheduler/Schedule.swift @@ -1,8 +1,21 @@ import Cron +import NIOCore /// Used to help build schedule frequencies for scheduled tasks. -public struct ScheduleBuilder { +public final class Schedule { private let buildingFinished: (Schedule) -> Void + private var pattern: DatePattern? = nil { + didSet { + if pattern != nil { + buildingFinished(self) + } + } + } + + /// {seconds} {minutes} {hour} {day of month} {month} {day of week} {year} + var cronExpression: String? { + pattern?.string + } init(_ buildingFinished: @escaping (Schedule) -> Void) { self.buildingFinished = buildingFinished @@ -17,8 +30,7 @@ public struct ScheduleBuilder { /// - min: The minute to run. /// - sec: The second to run. public func yearly(month: Month = .jan, day: Int = 1, hr: Int = 0, min: Int = 0, sec: Int = 0) { - let schedule = Schedule(second: "\(sec)", minute: "\(min)", hour: "\(hr)", dayOfMonth: "\(day)", month: "\(month.rawValue)") - self.buildingFinished(schedule) + pattern = DatePattern(second: "\(sec)", minute: "\(min)", hour: "\(hr)", dayOfMonth: "\(day)", month: "\(month.rawValue)") } /// Run this task monthly. @@ -29,8 +41,7 @@ public struct ScheduleBuilder { /// - min: The minute to run. /// - sec: The second to run. public func monthly(day: Int = 1, hr: Int = 0, min: Int = 0, sec: Int = 0) { - let schedule = Schedule(second: "\(sec)", minute: "\(min)", hour: "\(hr)", dayOfMonth: "\(day)") - self.buildingFinished(schedule) + pattern = DatePattern(second: "\(sec)", minute: "\(min)", hour: "\(hr)", dayOfMonth: "\(day)") } /// Run this task weekly. @@ -41,8 +52,7 @@ public struct ScheduleBuilder { /// - min: The minute to run. /// - sec: The second to run. public func weekly(day: DayOfWeek = .sun, hr: Int = 0, min: Int = 0, sec: Int = 0) { - let schedule = Schedule(second: "\(sec)", minute: "\(min)", hour: "\(hr)", dayOfWeek: "\(day.rawValue)") - self.buildingFinished(schedule) + pattern = DatePattern(second: "\(sec)", minute: "\(min)", hour: "\(hr)", dayOfWeek: "\(day.rawValue)") } /// Run this task daily. @@ -52,8 +62,7 @@ public struct ScheduleBuilder { /// - min: The minute to run. /// - sec: The second to run. public func daily(hr: Int = 0, min: Int = 0, sec: Int = 0) { - let schedule = Schedule(second: "\(sec)", minute: "\(min)", hour: "\(hr)") - self.buildingFinished(schedule) + pattern = DatePattern(second: "\(sec)", minute: "\(min)", hour: "\(hr)") } /// Run this task every hour. @@ -62,8 +71,7 @@ public struct ScheduleBuilder { /// - min: The minute to run. /// - sec: The second to run. public func hourly(min: Int = 0, sec: Int = 0) { - let schedule = Schedule(second: "\(sec)", minute: "\(min)", hour: "*/1") - self.buildingFinished(schedule) + pattern = DatePattern(second: "\(sec)", minute: "\(min)", hour: "*") } /// Run this task every minute. @@ -71,14 +79,12 @@ public struct ScheduleBuilder { /// - Parameters: /// - sec: The second to run. public func minutely(sec: Int = 0) { - let schedule = Schedule(second: "\(sec)") - self.buildingFinished(schedule) + pattern = DatePattern(second: "\(sec)") } /// Run this task every second. public func secondly() { - let schedule = Schedule() - self.buildingFinished(schedule) + pattern = DatePattern() } @@ -86,18 +92,33 @@ public struct ScheduleBuilder { /// and year fields are acceptable. /// /// - Parameter expression: A cron expression. - public func cron(_ expression: String) { - let schedule = Schedule(validate: expression) - self.buildingFinished(schedule) + public func expression(_ cronExpression: String) { + pattern = DatePattern(validate: cronExpression) } -} + + /// The delay after which this schedule will be run, if it will ever be run. + func next() -> TimeAmount? { + guard let next = pattern?.next(), let nextDate = next.date else { + return nil + } -typealias Schedule = DatePattern + var delay = Int64(nextDate.timeIntervalSinceNow * 1000) + // Occasionally Cron library returns the `next()` as fractions of a + // millisecond before or after now. If the delay is 0, get the next + // date and use that instead. + if delay == 0 { + let newDate = pattern?.next(next)?.date ?? Date().addingTimeInterval(1) + delay = Int64(newDate.timeIntervalSinceNow * 1000) + } + + return .milliseconds(delay) + } +} -extension Schedule { +extension DatePattern { /// Initialize with a cron expression. This will crash if the /// expression is invalid. - init(validate cronExpression: String) { + fileprivate init(validate cronExpression: String) { do { self = try DatePattern(cronExpression) } catch { @@ -108,7 +129,7 @@ extension Schedule { /// Initialize with pieces of a cron expression. Each piece /// defaults to `*`. This will fatal if a piece of the /// expression is invalid. - init( + fileprivate init( second: String = "*", minute: String = "*", hour: String = "*", @@ -117,7 +138,7 @@ extension Schedule { dayOfWeek: String = "*", year: String = "*" ) { - let string = [second, minute, hour, dayOfWeek, month, dayOfWeek, year].joined(separator: " ") + let string = [second, minute, hour, dayOfMonth, month, dayOfWeek, year].joined(separator: " ") do { self = try DatePattern(string) } catch { @@ -125,80 +146,3 @@ extension Schedule { } } } - -/// A day of the week. -public enum DayOfWeek: Int, ExpressibleByIntegerLiteral { - /// Sunday - case sun = 0 - /// Monday - case mon = 1 - /// Tuesday - case tue = 2 - /// Wednesday - case wed = 3 - /// Thursday - case thu = 4 - /// Friday - case fri = 5 - /// Saturday - case sat = 6 - - public init(integerLiteral value: Int) { - switch value { - case 0: self = .sun - case 1: self = .mon - case 2: self = .tue - case 3: self = .wed - case 4: self = .thu - case 5: self = .fri - case 6: self = .sat - default: fatalError("\(value) isn't a valid day of the week.") - } - } -} - -/// A month of the year. -public enum Month: Int, ExpressibleByIntegerLiteral { - /// January - case jan = 0 - /// February - case feb = 1 - /// March - case mar = 2 - /// April - case apr = 3 - /// May - case may = 4 - /// June - case jun = 5 - /// July - case jul = 6 - /// August - case aug = 7 - /// September - case sep = 8 - /// October - case oct = 9 - /// November - case nov = 10 - /// December - case dec = 11 - - public init(integerLiteral value: Int) { - switch value { - case 0: self = .jan - case 1: self = .feb - case 2: self = .mar - case 3: self = .apr - case 4: self = .may - case 5: self = .jun - case 6: self = .jul - case 7: self = .aug - case 8: self = .sep - case 9: self = .oct - case 10: self = .nov - case 11: self = .dec - default: fatalError("\(value) isn't a valid month.") - } - } -} diff --git a/Sources/Alchemy/Scheduler/Scheduler+Scheduling.swift b/Sources/Alchemy/Scheduler/Scheduler+Scheduling.swift new file mode 100644 index 00000000..dc6529d4 --- /dev/null +++ b/Sources/Alchemy/Scheduler/Scheduler+Scheduling.swift @@ -0,0 +1,35 @@ +import NIO + +extension Scheduler { + /// Schedule a recurring `Job`. + /// + /// - Parameters: + /// - job: The job to schedule. + /// - queue: The queue to schedule it on. + /// - channel: The queue channel to schedule it on. + /// - Returns: A builder for customizing the scheduling frequency. + public func job(_ job: @escaping @autoclosure () -> Job, queue: Queue = Q, channel: String = Queue.defaultChannel) -> Schedule { + Schedule { [weak self] schedule in + self?.addWork(schedule: schedule) { + do { + try await job().dispatch(on: queue, channel: channel) + } catch { + Log.error("[Scheduler] error scheduling Job: \(error)") + throw error + } + } + } + } + + /// Schedule a recurring task. + /// + /// - Parameter task: The task to run. + /// - Returns: A builder for customizing the scheduling frequency. + public func run(_ task: @escaping () async throws -> Void) -> Schedule { + Schedule { [weak self] schedule in + self?.addWork(schedule: schedule) { + try await task() + } + } + } +} diff --git a/Sources/Alchemy/Scheduler/Scheduler.swift b/Sources/Alchemy/Scheduler/Scheduler.swift index 176b1591..2335404d 100644 --- a/Sources/Alchemy/Scheduler/Scheduler.swift +++ b/Sources/Alchemy/Scheduler/Scheduler.swift @@ -1,13 +1,23 @@ +import NIOCore + /// A service for scheduling recurring work, in lieu of a separate /// cron task running apart from your server. -public final class Scheduler: Service { +public final class Scheduler { private struct WorkItem { let schedule: Schedule - let work: (EventLoop) throws -> Void + let work: () async throws -> Void } - + + public private(set) var isStarted: Bool = false private var workItems: [WorkItem] = [] - private var isStarted: Bool = false + private let isTesting: Bool + + /// Initialize this Scheduler, potentially flagging it for testing. If + /// testing is enabled, work items will only be run once, and not + /// rescheduled. + init(isTesting: Bool = false) { + self.isTesting = isTesting + } /// Start scheduling with the given loop. /// @@ -31,31 +41,24 @@ public final class Scheduler: Service { /// - Parameters: /// - schedule: The schedule to run this work. /// - work: The work to run. - func addWork(schedule: Schedule, work: @escaping (EventLoop) throws -> Void) { + func addWork(schedule: Schedule, work: @escaping () async throws -> Void) { workItems.append(WorkItem(schedule: schedule, work: work)) } - private func schedule(schedule: Schedule, task: @escaping (EventLoop) throws -> Void, on loop: EventLoop) { - guard - let next = schedule.next(), - let nextDate = next.date - else { - return Log.error("[Scheduler] schedule doesn't have a future date to run.") + private func schedule(schedule: Schedule, task: @escaping () async throws -> Void, on loop: EventLoop) { + guard let delay = schedule.next() else { + return Log.info("[Scheduler] scheduling finished; there's no future date to run.") } - - func scheduleNextAndRun() throws -> Void { - self.schedule(schedule: schedule, task: task, on: loop) - try task(loop) - } - - var delay = Int64(nextDate.timeIntervalSinceNow * 1000) - // Occasionally Cron library returns the `next()` as fractions of a - // millisecond before or after now. If the delay is 0, get the next - // date and use that instead. - if delay == 0 { - let newDate = schedule.next(next)?.date ?? Date().addingTimeInterval(1) - delay = Int64(newDate.timeIntervalSinceNow * 1000) + + loop.flatScheduleTask(in: delay) { + loop.asyncSubmit { + // Schedule next and run + if !self.isTesting { + self.schedule(schedule: schedule, task: task, on: loop) + } + + try await task() + } } - loop.scheduleTask(in: .milliseconds(delay), scheduleNextAndRun) } } diff --git a/Sources/Alchemy/Utilities/Aliases.swift b/Sources/Alchemy/Utilities/Aliases.swift new file mode 100644 index 00000000..c17da62c --- /dev/null +++ b/Sources/Alchemy/Utilities/Aliases.swift @@ -0,0 +1,23 @@ +// The default configured Client +public var Http: Client.Builder { Client.id(.default).builder() } +public func Http(_ id: Client.Identifier) -> Client.Builder { Client.id(id).builder() } + +// The default configured Database +public var DB: Database { .id(.default) } +public func DB(_ id: Database.Identifier) -> Database { .id(id) } + +// The default configured Filesystem +public var Storage: Filesystem { .id(.default) } +public func Storage(_ id: Filesystem.Identifier) -> Filesystem { .id(id) } + +// Your app's default Cache. +public var Stash: Cache { .id(.default) } +public func Stash(_ id: Cache.Identifier) -> Cache { .id(id) } + +// Your app's default Queue +public var Q: Queue { .id(.default) } +public func Q(_ id: Queue.Identifier) -> Queue { .id(id) } + +// Your app's default RedisClient +public var Redis: RedisClient { .id(.default) } +public func Redis(_ id: RedisClient.Identifier) -> RedisClient { .id(id) } diff --git a/Sources/Alchemy/Utilities/Vendor/BCrypt.swift b/Sources/Alchemy/Utilities/BCrypt.swift similarity index 89% rename from Sources/Alchemy/Utilities/Vendor/BCrypt.swift rename to Sources/Alchemy/Utilities/BCrypt.swift index e95e2eb9..039b5eb7 100644 --- a/Sources/Alchemy/Utilities/Vendor/BCrypt.swift +++ b/Sources/Alchemy/Utilities/BCrypt.swift @@ -52,15 +52,20 @@ public final class BCryptDigest { /// Creates a new `BCryptDigest`. Use the global `BCrypt` convenience variable. public init() { } - - public func hash(_ plaintext: String, cost: Int = 12) throws -> String { - guard cost >= BCRYPT_MINLOGROUNDS && cost <= 31 else { - throw BcryptError.invalidCost - } - return try self.hash(plaintext, salt: self.generateSalt(cost: cost)) + /// Asynchronously hashes a password on a separate thread. + /// + /// - Parameter password: The password to hash. + /// - Returns: The hashed password. + public func hash(_ password: String) async throws -> String { + try await Thread.run { try Bcrypt.hashSync(password) } + } + + public func hashSync(_ plaintext: String, cost: Int = 12) throws -> String { + guard cost >= BCRYPT_MINLOGROUNDS && cost <= 31 else { throw BcryptError.invalidCost } + return try self.hashSync(plaintext, salt: self.generateSalt(cost: cost)) } - public func hash(_ plaintext: String, salt: String) throws -> String { + public func hashSync(_ plaintext: String, salt: String) throws -> String { guard isSaltValid(salt) else { throw BcryptError.invalidSalt } @@ -104,6 +109,17 @@ public final class BCryptDigest { + String(cString: hashedBytes) .dropFirst(originalAlgorithm.revisionCount) } + + /// Asynchronously verifies a password & hash on a separate + /// thread. + /// + /// - Parameters: + /// - plaintext: The plaintext password. + /// - hashed: The hashed password to verify with. + /// - Returns: Whether the password and hash matched. + public func verify(plaintext: String, hashed: String) async throws -> Bool { + try await Thread.run { try Bcrypt.verifySync(plaintext, created: hashed) } + } /// Verifies an existing BCrypt hash matches the supplied plaintext value. Verification works by parsing the salt and version from /// the existing digest and using that information to hash the plaintext data. If hash digests match, this method returns `true`. @@ -117,7 +133,7 @@ public final class BCryptDigest { /// - hash: Existing BCrypt hash to parse version, salt, and existing digest from. /// - throws: `CryptoError` if hashing fails or if data conversion fails. /// - returns: `true` if the hash was created from the supplied plaintext data. - public func verify(_ plaintext: String, created hash: String) throws -> Bool { + public func verifySync(_ plaintext: String, created hash: String) throws -> Bool { guard let hashVersion = Algorithm(rawValue: String(hash.prefix(4))) else { throw BcryptError.invalidHash } @@ -132,7 +148,7 @@ public final class BCryptDigest { throw BcryptError.invalidHash } - let messageHash = try self.hash(plaintext, salt: hashSalt) + let messageHash = try self.hashSync(plaintext, salt: hashSalt) let messageHashChecksum = String(messageHash.suffix(hashVersion.checksumCount)) return messageHashChecksum.secureCompare(to: hashChecksum) } @@ -297,12 +313,6 @@ extension FixedWidthInteger { public static func random() -> Self { return Self.random(in: .min ... .max) } - - public static func random(using generator: inout T) -> Self - where T : RandomNumberGenerator - { - return Self.random(in: .min ... .max, using: &generator) - } } extension Array where Element: FixedWidthInteger { @@ -311,18 +321,4 @@ extension Array where Element: FixedWidthInteger { (0..(count: Int, using generator: inout T) -> [Element] - where T: RandomNumberGenerator - { - var array: [Element] = .init(repeating: 0, count: count) - (0.. Void) -> Self { + var _copy = self + build(&_copy) + return _copy + } +} diff --git a/Sources/Alchemy/Utilities/Codable/DecoderDelegate.swift b/Sources/Alchemy/Utilities/Codable/DecoderDelegate.swift new file mode 100644 index 00000000..4ad9c195 --- /dev/null +++ b/Sources/Alchemy/Utilities/Codable/DecoderDelegate.swift @@ -0,0 +1,63 @@ +protocol DecoderDelegate { + // Values + func decodeString(for key: CodingKey?) throws -> String + func decodeDouble(for key: CodingKey?) throws -> Double + func decodeInt(for key: CodingKey?) throws -> Int + func decodeBool(for key: CodingKey?) throws -> Bool + func decodeNil(for key: CodingKey?) -> Bool + + // Contains + func contains(key: CodingKey) -> Bool + var allKeys: [String] { get } + + // Array / Map + func map(for key: CodingKey) throws -> DecoderDelegate + func array(for key: CodingKey?) throws -> [DecoderDelegate] +} + +extension DecoderDelegate { + func _decode(_ type: T.Type = T.self, for key: CodingKey? = nil) throws -> T { + var value: Any? = nil + + if T.self is Int.Type { + value = try decodeInt(for: key) + } else if T.self is String.Type { + value = try decodeString(for: key) + } else if T.self is Bool.Type { + value = try decodeBool(for: key) + } else if T.self is Double.Type { + value = try decodeDouble(for: key) + } else if T.self is Float.Type { + value = Float(try decodeDouble(for: key)) + } else if T.self is Int8.Type { + value = Int8(try decodeInt(for: key)) + } else if T.self is Int16.Type { + value = Int16(try decodeInt(for: key)) + } else if T.self is Int32.Type { + value = Int32(try decodeInt(for: key)) + } else if T.self is Int64.Type { + value = Int64(try decodeInt(for: key)) + } else if T.self is UInt.Type { + value = UInt(try decodeInt(for: key)) + } else if T.self is UInt8.Type { + value = UInt8(try decodeInt(for: key)) + } else if T.self is UInt16.Type { + value = UInt16(try decodeInt(for: key)) + } else if T.self is UInt32.Type { + value = UInt32(try decodeInt(for: key)) + } else if T.self is UInt64.Type { + value = UInt64(try decodeInt(for: key)) + } else { + return try T(from: GenericDecoder(delegate: key.map { try map(for: $0) } ?? self)) + } + + guard let t = value as? T else { + throw DecodingError.dataCorrupted( + DecodingError.Context( + codingPath: [key].compactMap { $0 }, + debugDescription: "Unable to decode value of type \(T.self).")) + } + + return t + } +} diff --git a/Sources/Alchemy/Utilities/Codable/GenericDecoder.swift b/Sources/Alchemy/Utilities/Codable/GenericDecoder.swift new file mode 100644 index 00000000..e141a47c --- /dev/null +++ b/Sources/Alchemy/Utilities/Codable/GenericDecoder.swift @@ -0,0 +1,113 @@ +struct GenericDecoder: Decoder { + struct Keyed: KeyedDecodingContainerProtocol { + let delegate: DecoderDelegate + let codingPath: [CodingKey] = [] + var allKeys: [Key] { delegate.allKeys.compactMap { Key(stringValue: $0) } } + + func contains(_ key: Key) -> Bool { + delegate.contains(key: key) + } + + func decodeNil(forKey key: Key) throws -> Bool { + delegate.decodeNil(for: key) + } + + func decode(_ type: T.Type, forKey key: Key) throws -> T where T : Decodable { + try delegate._decode(type, for: key) + } + + func nestedContainer(keyedBy type: NestedKey.Type, forKey key: Key) throws -> KeyedDecodingContainer where NestedKey : CodingKey { + KeyedDecodingContainer(Keyed(delegate: try delegate.map(for: key))) + } + + func nestedUnkeyedContainer(forKey key: Key) throws -> UnkeyedDecodingContainer { + Unkeyed(delegate: try delegate.array(for: key)) + } + + func superDecoder() throws -> Decoder { + throw DecodingError.dataCorrupted(.init(codingPath: codingPath, debugDescription: "Super Decoder isn't supported.")) + } + + func superDecoder(forKey key: Key) throws -> Decoder { + throw DecodingError.dataCorrupted(.init(codingPath: codingPath, debugDescription: "Super Decoder isn't supported.")) + } + } + + struct Unkeyed: UnkeyedDecodingContainer { + let delegate: [DecoderDelegate] + let codingPath: [CodingKey] = [] + var count: Int? { delegate.count } + var isAtEnd: Bool { currentIndex == count } + var currentIndex: Int = 0 + + mutating func decodeNil() throws -> Bool { + defer { currentIndex += 1 } + return delegate[currentIndex].decodeNil(for: nil) + } + + mutating func decode(_ type: T.Type) throws -> T where T : Decodable { + defer { currentIndex += 1 } + return try delegate[currentIndex]._decode(type) + } + + mutating func nestedUnkeyedContainer() throws -> UnkeyedDecodingContainer { + defer { currentIndex += 1 } + return Unkeyed(delegate: try delegate[currentIndex].array(for: nil)) + } + + mutating func nestedContainer(keyedBy type: NestedKey.Type) throws -> KeyedDecodingContainer where NestedKey : CodingKey { + defer { currentIndex += 1 } + return KeyedDecodingContainer(Keyed(delegate: delegate[currentIndex])) + } + + func superDecoder() throws -> Decoder { + throw DecodingError.dataCorrupted(.init(codingPath: codingPath, debugDescription: "Super Decoder isn't supported.")) + } + } + + struct Single: SingleValueDecodingContainer { + let delegate: DecoderDelegate + let codingPath: [CodingKey] = [] + + func decodeNil() -> Bool { + delegate.decodeNil(for: nil) + } + + func decode(_ type: T.Type) throws -> T where T : Decodable { + try delegate._decode(type) + } + } + + // MARK: Decoder + + var delegate: DecoderDelegate + var codingPath: [CodingKey] = [] + var userInfo: [CodingUserInfoKey : Any] = [:] + + func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer where Key : CodingKey { + KeyedDecodingContainer(Keyed(delegate: delegate)) + } + + func unkeyedContainer() throws -> UnkeyedDecodingContainer { + Unkeyed(delegate: try delegate.array(for: nil)) + } + + func singleValueContainer() throws -> SingleValueDecodingContainer { + Single(delegate: delegate) + } +} + +struct GenericCodingKey: CodingKey { + var stringValue: String + var intValue: Int? + + init?(stringValue: String) { + self.stringValue = stringValue + self.intValue = Int(stringValue) + } + + init?(intValue: Int) { + self.stringValue = "\(intValue)" + self.intValue = intValue + } +} diff --git a/Sources/Alchemy/Utilities/Extendable.swift b/Sources/Alchemy/Utilities/Extendable.swift new file mode 100644 index 00000000..e623602b --- /dev/null +++ b/Sources/Alchemy/Utilities/Extendable.swift @@ -0,0 +1,38 @@ +public protocol Extendable { + var extensions: Extensions { get } +} + +public final class Extensions { + private var items: [PartialKeyPath: Any] + + /// Initialize extensions + public init() { + self.items = [:] + } + + /// Get optional extension from a `KeyPath` + public func get(_ key: KeyPath) -> Type? { + self.items[key] as? Type + } + + /// Get extension from a `KeyPath` + public func get(_ key: KeyPath, error: StaticString? = nil) -> Type { + guard let value = items[key] as? Type else { + preconditionFailure(error?.description ?? "Cannot get extension of type \(Type.self) without having set it") + } + return value + } + + /// Return if extension has been set + public func exists(_ key: KeyPath) -> Bool { + self.items[key] != nil + } + + /// Set extension for a `KeyPath` + /// - Parameters: + /// - key: KeyPath + /// - value: value to store in extension + public func set(_ key: KeyPath, value: Type) { + items[key] = value + } +} diff --git a/Sources/Alchemy/Utilities/Extensions/Bcrypt+Async.swift b/Sources/Alchemy/Utilities/Extensions/Bcrypt+Async.swift index 8aa0478f..acc4d857 100644 --- a/Sources/Alchemy/Utilities/Extensions/Bcrypt+Async.swift +++ b/Sources/Alchemy/Utilities/Extensions/Bcrypt+Async.swift @@ -1,26 +1,3 @@ import Foundation import NIO -extension BCryptDigest { - /// Asynchronously hashes a password on a separate thread. - /// - /// - Parameter password: The password to hash. - /// - Returns: A future containing the hashed password that will - /// resolve on the initiating `EventLoop`. - public func hashAsync(_ password: String) -> EventLoopFuture { - Thread.run { try Bcrypt.hash(password) } - } - - /// Asynchronously verifies a password & hash on a separate - /// thread. - /// - /// - Parameters: - /// - plaintext: The plaintext password. - /// - hashed: The hashed password to verify with. - /// - Returns: A future containing a `Bool` indicating whether the - /// password and hash matched. This will resolve on the - /// initiating `EventLoop`. - public func verifyAsync(plaintext: String, hashed: String) -> EventLoopFuture { - Thread.run { try Bcrypt.verify(plaintext, created: hashed) } - } -} diff --git a/Sources/Alchemy/Utilities/Extensions/ByteBuffer+Utilities.swift b/Sources/Alchemy/Utilities/Extensions/ByteBuffer+Utilities.swift new file mode 100644 index 00000000..836dd093 --- /dev/null +++ b/Sources/Alchemy/Utilities/Extensions/ByteBuffer+Utilities.swift @@ -0,0 +1,5 @@ +// Better way to do these? +extension ByteBuffer { + var data: Data { Data(buffer: self) } + var string: String { String(buffer: self) } +} diff --git a/Sources/Alchemy/Utilities/Extensions/EventLoop+Utilities.swift b/Sources/Alchemy/Utilities/Extensions/EventLoop+Utilities.swift new file mode 100644 index 00000000..06b13a56 --- /dev/null +++ b/Sources/Alchemy/Utilities/Extensions/EventLoop+Utilities.swift @@ -0,0 +1,9 @@ +import NIO + +extension EventLoop { + func asyncSubmit(_ action: @escaping () async throws -> T) -> EventLoopFuture { + let elp = makePromise(of: T.self) + elp.completeWithTask { try await action() } + return elp.futureResult + } +} diff --git a/Sources/Alchemy/Utilities/Extensions/EventLoopFuture+Utilities.swift b/Sources/Alchemy/Utilities/Extensions/EventLoopFuture+Utilities.swift deleted file mode 100644 index 76d47f76..00000000 --- a/Sources/Alchemy/Utilities/Extensions/EventLoopFuture+Utilities.swift +++ /dev/null @@ -1,57 +0,0 @@ -import NIO - -/// Convenient extensions for working with `EventLoopFuture`s. -extension EventLoopFuture { - /// Erases the type of the future to `Void` - /// - /// - Returns: An erased future of type `EventLoopFuture`. - public func voided() -> EventLoopFuture { - self.map { _ in () } - } - - /// Creates a new errored `EventLoopFuture` on the current - /// `EventLoop`. - /// - /// - Parameter error: The error to create the future with. - /// - Returns: A created future that will resolve to an error. - public static func new(error: Error) -> EventLoopFuture { - Loop.current.future(error: error) - } - - /// Creates a new successed `EventLoopFuture` on the current - /// `EventLoop`. - /// - /// - Parameter value: The value to create the future with. - /// - Returns: A created future that will resolve to the provided - /// value. - public static func new(_ value: T) -> EventLoopFuture { - Loop.current.future(value) - } -} - -extension EventLoopFuture where Value == Void { - /// Creates a new successed `EventLoopFuture` on the current - /// `EventLoop`. - /// - /// - Returns: A created future that will resolve immediately. - public static func new() -> EventLoopFuture { - .new(()) - } -} - -/// Takes a throwing block & returns either the `EventLoopFuture` -/// that block creates or an errored `EventLoopFuture` if the -/// closure threw an error. -/// -/// - Parameter closure: The throwing closure used to generate an -/// `EventLoopFuture`. -/// - Returns: A future with the given closure run with any errors -/// piped into the future. -public func catchError(_ closure: () throws -> EventLoopFuture) -> EventLoopFuture { - do { - return try closure() - } - catch { - return .new(error: error) - } -} diff --git a/Sources/Alchemy/Utilities/Extensions/EventLoopGroupConnectionPool+Async.swift b/Sources/Alchemy/Utilities/Extensions/EventLoopGroupConnectionPool+Async.swift new file mode 100644 index 00000000..bb23da44 --- /dev/null +++ b/Sources/Alchemy/Utilities/Extensions/EventLoopGroupConnectionPool+Async.swift @@ -0,0 +1,14 @@ +import AsyncKit + +extension EventLoopGroupConnectionPool { + /// Async wrapper around the future variant of `withConnection`. + func withConnection( + logger: Logger? = nil, + on eventLoop: EventLoop? = nil, + _ closure: @escaping (Source.Connection) async throws -> Result + ) async throws -> Result { + try await withConnection(logger: logger, on: eventLoop) { connection in + connection.eventLoop.asyncSubmit { try await closure(connection) } + }.get() + } +} diff --git a/Sources/Alchemy/Utilities/Extensions/Metatype+Utilities.swift b/Sources/Alchemy/Utilities/Extensions/Metatype+Utilities.swift index 6fef0f64..4c70e328 100644 --- a/Sources/Alchemy/Utilities/Extensions/Metatype+Utilities.swift +++ b/Sources/Alchemy/Utilities/Extensions/Metatype+Utilities.swift @@ -5,3 +5,11 @@ public func name(of metatype: T.Type) -> String { "\(metatype)" } + +/// Returns an id for the given type. +/// +/// - Parameter metatype: The type to identify. +/// - Returns: A unique id for the type. +public func id(of metatype: Any.Type) -> ObjectIdentifier { + ObjectIdentifier(metatype) +} diff --git a/Sources/Alchemy/Utilities/Extensions/String+Utilities.swift b/Sources/Alchemy/Utilities/Extensions/String+Utilities.swift new file mode 100644 index 00000000..76fd81ea --- /dev/null +++ b/Sources/Alchemy/Utilities/Extensions/String+Utilities.swift @@ -0,0 +1,19 @@ +extension String { + var trimmingQuotes: String { + trimmingCharacters(in: CharacterSet(charactersIn: "\"'")) + } + + var trimmingForwardSlash: String { + trimmingCharacters(in: CharacterSet(charactersIn: "/")) + } + + func droppingPrefix(_ prefix: String) -> String { + guard hasPrefix(prefix) else { return self } + return String(dropFirst(prefix.count)) + } + + func droppingSuffix(_ suffix: String) -> String { + guard hasSuffix(suffix) else { return self } + return String(dropLast(suffix.count)) + } +} diff --git a/Sources/Alchemy/Utilities/Extensions/TLSConfiguration+Utilities.swift b/Sources/Alchemy/Utilities/Extensions/TLSConfiguration+Utilities.swift new file mode 100644 index 00000000..a124473b --- /dev/null +++ b/Sources/Alchemy/Utilities/Extensions/TLSConfiguration+Utilities.swift @@ -0,0 +1,11 @@ +import NIOSSL + +extension TLSConfiguration { + static func makeServerConfiguration(key: String, cert: String) throws -> TLSConfiguration { + TLSConfiguration.makeServerConfiguration( + certificateChain: try NIOSSLCertificate + .fromPEMFile(cert) + .map { NIOSSLCertificateSource.certificate($0) }, + privateKey: .file(key)) + } +} diff --git a/Sources/Alchemy/Queue/TimeAmount+Utilities.swift b/Sources/Alchemy/Utilities/Extensions/TimeAmount+Utilities.swift similarity index 100% rename from Sources/Alchemy/Queue/TimeAmount+Utilities.swift rename to Sources/Alchemy/Utilities/Extensions/TimeAmount+Utilities.swift diff --git a/Sources/Alchemy/Utilities/Extensions/UUID+LosslessStringConvertible.swift b/Sources/Alchemy/Utilities/Extensions/UUID+LosslessStringConvertible.swift new file mode 100644 index 00000000..9aec1ff8 --- /dev/null +++ b/Sources/Alchemy/Utilities/Extensions/UUID+LosslessStringConvertible.swift @@ -0,0 +1,5 @@ +extension UUID: LosslessStringConvertible { + public init?(_ description: String) { + self.init(uuidString: description) + } +} diff --git a/Sources/Alchemy/Utilities/Headers/HTTPHeaders+ContentDisposition.swift b/Sources/Alchemy/Utilities/Headers/HTTPHeaders+ContentDisposition.swift new file mode 100644 index 00000000..7a55209a --- /dev/null +++ b/Sources/Alchemy/Utilities/Headers/HTTPHeaders+ContentDisposition.swift @@ -0,0 +1,64 @@ +extension HTTPHeaders { + public struct ContentDisposition { + public struct Value: ExpressibleByStringLiteral { + public let string: String + + public init(stringLiteral value: StringLiteralType) { + self.string = value + } + + public static let inline: Value = "inline" + public static let attachment: Value = "attachment" + public static let formData: Value = "form-data" + } + + public var value: Value + public var name: String? + public var filename: String? + } + + public var contentDisposition: ContentDisposition? { + get { + guard let disposition = self["Content-Disposition"].first else { + return nil + } + + let components = disposition.components(separatedBy: ";") + .map { $0.trimmingCharacters(in: .whitespaces) } + + guard let valueString = components.first else { + return nil + } + + var directives: [String: String] = [:] + components + .dropFirst() + .compactMap { pair -> (String, String)? in + let parts = pair.components(separatedBy: "=") + guard let key = parts[safe: 0], let value = parts[safe: 1] else { + return nil + } + + return (key.trimmingQuotes, value.trimmingQuotes) + } + .forEach { directives[$0] = $1 } + + let value = ContentDisposition.Value(stringLiteral: valueString) + return ContentDisposition(value: value, name: directives["name"], filename: directives["filename"]) + } + set { + if let disposition = newValue { + let value = [ + disposition.value.string, + disposition.name.map { "name=\($0)" }, + disposition.filename.map { "filename=\($0)" }, + ] + .compactMap { $0 } + .joined(separator: "; ") + replaceOrAdd(name: "Content-Disposition", value: value) + } else { + remove(name: "Content-Disposition") + } + } + } +} diff --git a/Sources/Alchemy/Utilities/Headers/HTTPHeaders+ContentInformation.swift b/Sources/Alchemy/Utilities/Headers/HTTPHeaders+ContentInformation.swift new file mode 100644 index 00000000..4260dbac --- /dev/null +++ b/Sources/Alchemy/Utilities/Headers/HTTPHeaders+ContentInformation.swift @@ -0,0 +1,25 @@ +extension HTTPHeaders { + public var contentType: ContentType? { + get { + first(name: "content-type").map(ContentType.init) + } + set { + if let contentType = newValue { + self.replaceOrAdd(name: "content-type", value: "\(contentType.string)") + } else { + self.remove(name: "content-type") + } + } + } + + public var contentLength: Int? { + get { first(name: "content-length").map { Int($0) } ?? nil } + set { + if let contentLength = newValue { + self.replaceOrAdd(name: "content-length", value: "\(contentLength)") + } else { + self.remove(name: "content-length") + } + } + } +} diff --git a/Sources/Alchemy/Utilities/IgnoreDecoding.swift b/Sources/Alchemy/Utilities/IgnoreDecoding.swift new file mode 100644 index 00000000..98d5bf2b --- /dev/null +++ b/Sources/Alchemy/Utilities/IgnoreDecoding.swift @@ -0,0 +1,12 @@ +@propertyWrapper +struct IgnoreDecoding: Decodable { + var wrappedValue: T? + + init(from decoder: Decoder) throws { + wrappedValue = nil + } + + init() { + wrappedValue = nil + } +} diff --git a/Sources/Alchemy/Utilities/Locked.swift b/Sources/Alchemy/Utilities/Locked.swift index 51c7608d..e844e85b 100644 --- a/Sources/Alchemy/Utilities/Locked.swift +++ b/Sources/Alchemy/Utilities/Locked.swift @@ -1,27 +1,20 @@ import Foundation +import NIOConcurrencyHelpers -/// Used for providing thread safe access to a property. +/// Used for providing thread safe access to a property. Doesn't work on +/// collections. @propertyWrapper public struct Locked { /// The threadsafe accessor for this property. public var wrappedValue: T { - get { - self.lock.lock() - defer { self.lock.unlock() } - return self.value - } - set { - self.lock.lock() - defer { self.lock.unlock() } - self.value = newValue - } + get { lock.withLock { value } } + set { lock.withLock { value = newValue } } } /// The underlying value of this property. private var value: T - /// The lock to protect this property. - private let lock = NSRecursiveLock() + private let lock = Lock() /// Initialize with the given value. /// diff --git a/Sources/Alchemy/Utilities/Loop.swift b/Sources/Alchemy/Utilities/Loop.swift index 89125c4d..464ebaa1 100644 --- a/Sources/Alchemy/Utilities/Loop.swift +++ b/Sources/Alchemy/Utilities/Loop.swift @@ -10,29 +10,36 @@ public struct Loop { /// The main `EventLoopGroup` of the Application. @Inject public static var group: EventLoopGroup - @Inject private static var lifecycle: ServiceLifecycle - /// Configure the Applications `EventLoopGroup` and `EventLoop`. static func config() { - Container.register(EventLoop.self) { _ in + Container.bind(to: EventLoop.self) { _ -> EventLoop in guard let current = MultiThreadedEventLoopGroup.currentEventLoop else { - fatalError("This code isn't running on an `EventLoop`!") + // With async/await there is no guarantee that you'll + // be running on an event loop. When one is needed, + // return a random one for now. + return Loop.group.next() } - + return current } - Container.register(singleton: EventLoopGroup.self) { _ in + Container.main.bind(.singleton, to: EventLoopGroup.self) { _ in MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount) } + @Inject var lifecycle: ServiceLifecycle lifecycle.registerShutdown(label: name(of: EventLoopGroup.self), .sync(group.syncShutdownGracefully)) } /// Register mocks of `EventLoop` and `EventLoop` to the /// application container. static func mock() { - Container.register(EventLoop.self) { _ in EmbeddedEventLoop() } - Container.register(singleton: EventLoopGroup.self) { _ in MultiThreadedEventLoopGroup(numberOfThreads: 1) } + Container.bind(.singleton, to: EventLoopGroup.self) { _ in + MultiThreadedEventLoopGroup(numberOfThreads: 1) + } + + Container.bind(to: EventLoop.self) { _ in + group.next() + } } } diff --git a/Sources/Alchemy/Utilities/Service.swift b/Sources/Alchemy/Utilities/Service.swift deleted file mode 100644 index db208fb6..00000000 --- a/Sources/Alchemy/Utilities/Service.swift +++ /dev/null @@ -1,64 +0,0 @@ -import Fusion -import Lifecycle - -/// A protocol for registering and resolving a type through Alchemy's -/// dependency injection system, Fusion. Conform a type to this -/// to make it simple to inject and resolve around your app. -public protocol Service { - // Shutdown this service. Will be called when the application your - // service is registered to shuts down. - func shutdown() throws - - /// The default instance of this service. - static var `default`: Self { get } - - /// A named instance of this service. - /// - /// - Parameter name: The name of the service to fetch. - static func named(_ name: String) -> Self - - /// Register the default driver for this service. - static func config(default: Self) - - /// Register a named driver driver for this service. - static func config(_ name: String, _ driver: Self) -} - -// Default implementations. -extension Service { - public func shutdown() throws {} - - public static var `default`: Self { - Container.resolve(Self.self) - } - - public static func named(_ name: String) -> Self { - Container.resolve(Self.self, identifier: name) - } - - public static func config(default configuration: Self) { - _config(nil, configuration) - } - - public static func config(_ name: String, _ configuration: Self) { - _config(name, configuration) - } - - private static func _config(_ name: String? = nil, _ configuration: Self) { - let label: String - if let name = name { - label = "\(Alchemy.name(of: Self.self)):\(name)" - Container.register(singleton: Self.self, identifier: name) { _ in configuration } - } else { - label = "\(Alchemy.name(of: Self.self))" - Container.register(singleton: Self.self) { _ in configuration } - } - - if - !(configuration is ServiceLifecycle), - let lifecycle = Container.resolveOptional(ServiceLifecycle.self) - { - lifecycle.registerShutdown(label: label, .sync(configuration.shutdown)) - } - } -} diff --git a/Sources/Alchemy/Utilities/Socket.swift b/Sources/Alchemy/Utilities/Socket.swift index 6f64643e..7baa0904 100644 --- a/Sources/Alchemy/Utilities/Socket.swift +++ b/Sources/Alchemy/Utilities/Socket.swift @@ -5,7 +5,7 @@ import NIO /// (i.e. this is where the server can be reached). Other network /// interfaces can also be reached via a socket, such as a database. /// Either an ip host & port or a unix socket path. -public enum Socket { +public enum Socket: Equatable { /// An ip address `host` at port `port`. case ip(host: String, port: Int) /// A unix domain socket (IPC socket) at path `path`. diff --git a/Sources/Alchemy/Utilities/Thread.swift b/Sources/Alchemy/Utilities/Thread.swift index a4901b03..2c7824a0 100644 --- a/Sources/Alchemy/Utilities/Thread.swift +++ b/Sources/Alchemy/Utilities/Thread.swift @@ -3,14 +3,17 @@ import NIO /// A utility for running expensive CPU work on threads so as not to /// block the current `EventLoop`. public struct Thread { + /// The apps main thread pool for running expensive work. + @Inject public static var pool: NIOThreadPool + /// Runs an expensive bit of work on a thread that isn't backing /// an `EventLoop`, returning any value generated by that work /// back on the current `EventLoop`. /// /// - Parameter task: The work to run. - /// - Returns: A future containing the result of the expensive - /// work that completes on the current `EventLoop`. - public static func run(_ task: @escaping () throws -> T) -> EventLoopFuture { - return NIOThreadPool.default.runIfActive(eventLoop: Loop.current, task) + /// - Returns: The result of the expensive work that completes on + /// the current `EventLoop`. + public static func run(_ task: @escaping () throws -> T) async throws -> T { + try await pool.runIfActive(eventLoop: Loop.current, task).get() } } diff --git a/Sources/Alchemy/Utilities/Vendor/OrderedDictionary.swift b/Sources/Alchemy/Utilities/Vendor/OrderedDictionary.swift deleted file mode 100644 index 6549b7e8..00000000 --- a/Sources/Alchemy/Utilities/Vendor/OrderedDictionary.swift +++ /dev/null @@ -1,759 +0,0 @@ -/// The MIT License (MIT) -/// -/// Copyright © 2015-2020 Lukas Kubanek -/// -/// Permission is hereby granted, free of charge, to any person obtaining a copy -/// of this software and associated documentation files (the "Software"), to -/// deal in the Software without restriction, including without limitation the -/// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or -/// sell copies of the Software, and to permit persons to whom the Software is -/// furnished to do so, subject to the following conditions: -/// -/// The above copyright notice and this permission notice shall be included in -/// all copies or substantial portions of the Software. -/// -/// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -/// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -/// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -/// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -/// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -/// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -/// -/// Courtesy of https://github.com/lukaskubanek/OrderedDictionary - -/// A generic collection for storing key-value pairs in an ordered manner. -/// -/// Same as in a dictionary all keys in the collection are unique and have an associated value. -/// Same as in an array, all key-value pairs (elements) are kept sorted and accessible by -/// a zero-based integer index. -public struct OrderedDictionary: BidirectionalCollection { - - // ======================================================= // - // MARK: - Type Aliases - // ======================================================= // - - /// The type of the key-value pair stored in the ordered dictionary. - public typealias Element = (key: Key, value: Value) - - /// The type of the index. - public typealias Index = Int - - /// The type of the indices collection. - public typealias Indices = CountableRange - - /// The type of the contiguous subrange of the ordered dictionary's elements. - /// - /// - SeeAlso: OrderedDictionarySlice - public typealias SubSequence = OrderedDictionarySlice - - // ======================================================= // - // MARK: - Initialization - // ======================================================= // - - /// Initializes an empty ordered dictionary. - public init() { - self._orderedKeys = [Key]() - self._keysToValues = [Key: Value]() - } - - /// Initializes an empty ordered dictionary with preallocated space for at least the specified - /// number of elements. - public init(minimumCapacity: Int) { - self.init() - self.reserveCapacity(minimumCapacity) - } - - /// Initializes an ordered dictionary from a regular unsorted dictionary by sorting it using - /// the given sort function. - /// - /// - Parameter unsorted: The unsorted dictionary. - /// - Parameter areInIncreasingOrder: The sort function which compares the key-value pairs. - public init( - unsorted: Dictionary, - areInIncreasingOrder: (Element, Element) throws -> Bool - ) rethrows { - let keysAndValues = try Array(unsorted).sorted(by: areInIncreasingOrder) - - self.init( - uniqueKeysWithValues: keysAndValues, - minimumCapacity: unsorted.count - ) - } - - /// Initializes an ordered dictionary from a sequence of values keyed by a unique key extracted - /// from the value using the given closure. - /// - /// - Parameter values: The sequence of values. - /// - Parameter extractKey: The closure which extracts a key from the value. The returned keys - /// must be unique for all values from the sequence. - public init( - values: S, - uniquelyKeyedBy extractKey: (Value) throws -> Key - ) rethrows where S.Element == Value { - self.init(uniqueKeysWithValues: try values.map { value in - return (try extractKey(value), value) - }) - } - - /// Initializes an ordered dictionary from a sequence of values keyed by a unique key extracted - /// from the value using the given key path. - /// - /// - Parameter values: The sequence of values. - /// - Parameter keyPath: The key path to use for extracting a key from the value. The extracted - /// keys must be unique for all values from the sequence. - public init( - values: S, - uniquelyKeyedBy keyPath: KeyPath - ) where S.Element == Value { - self.init(uniqueKeysWithValues: values.map { value in - return (value[keyPath: keyPath], value) - }) - } - - /// Initializes an ordered dictionary from a sequence of key-value pairs. - /// - /// - Parameter keysAndValues: A sequence of key-value pairs to use for the new ordered - /// dictionary. Every key in `keysAndValues` must be unique. - public init( - uniqueKeysWithValues keysAndValues: S - ) where S.Element == Element { - self.init( - uniqueKeysWithValues: keysAndValues, - minimumCapacity: keysAndValues.underestimatedCount - ) - } - - private init( - uniqueKeysWithValues keysAndValues: S, - minimumCapacity: Int - ) where S.Element == Element { - self.init(minimumCapacity: minimumCapacity) - - for (key, value) in keysAndValues { - precondition(!containsKey(key), "Sequence of key-value pairs contains duplicate keys") - self[key] = value - } - } - - // ======================================================= // - // MARK: - Ordered Keys & Values - // ======================================================= // - - /// A collection containing just the keys of the ordered dictionary in the correct order. - public var orderedKeys: OrderedDictionaryKeys { - return self.lazy.map { $0.key } - } - - /// A collection containing just the values of the ordered dictionary in the correct order. - public var orderedValues: OrderedDictionaryValues { - return self.lazy.map { $0.value } - } - - // ======================================================= // - // MARK: - Dictionary - // ======================================================= // - - /// Converts itself to a common unsorted dictionary. - public var unorderedDictionary: Dictionary { - return _keysToValues - } - - // ======================================================= // - // MARK: - Indices - // ======================================================= // - - /// The indices that are valid for subscripting the ordered dictionary. - public var indices: Indices { - return _orderedKeys.indices - } - - /// The position of the first key-value pair in a non-empty ordered dictionary. - public var startIndex: Index { - return _orderedKeys.startIndex - } - - /// The position which is one greater than the position of the last valid key-value pair in the - /// ordered dictionary. - public var endIndex: Index { - return _orderedKeys.endIndex - } - - /// Returns the position immediately after the given index. - public func index(after i: Index) -> Index { - return _orderedKeys.index(after: i) - } - - /// Returns the position immediately before the given index. - public func index(before i: Index) -> Index { - return _orderedKeys.index(before: i) - } - - // ======================================================= // - // MARK: - Slices - // ======================================================= // - - /// Accesses a contiguous subrange of the ordered dictionary. - /// - /// - Parameter bounds: A range of the ordered dictionary's indices. The bounds of the range - /// must be valid indices of the ordered dictionary. - /// - Returns: The slice view at the ordered dictionary in the specified subrange. - public subscript(bounds: Range) -> SubSequence { - return OrderedDictionarySlice(base: self, bounds: bounds) - } - - // ======================================================= // - // MARK: - Key-based Access - // ======================================================= // - - /// Accesses the value associated with the given key for reading and writing. - /// - /// This key-based subscript returns the value for the given key if the key is found in the - /// ordered dictionary, or `nil` if the key is not found. - /// - /// When you assign a value for a key and that key already exists, the ordered dictionary - /// overwrites the existing value and preservers the index of the key-value pair. If the ordered - /// dictionary does not contain the key, a new key-value pair is appended to the end of the - /// ordered dictionary. - /// - /// If you assign `nil` as the value for the given key, the ordered dictionary removes that key - /// and its associated value if it exists. - /// - /// - Parameter key: The key to find in the ordered dictionary. - /// - Returns: The value associated with `key` if `key` is in the ordered dictionary; otherwise, - /// `nil`. - public subscript(key: Key) -> Value? { - get { - return value(forKey: key) - } - set(newValue) { - if let newValue = newValue { - updateValue(newValue, forKey: key) - } else { - removeValue(forKey: key) - } - } - } - - /// Returns a Boolean value indicating whether the ordered dictionary contains the given key. - /// - /// - Parameter key: The key to be looked up. - /// - Returns: `true` if the ordered dictionary contains the given key; otherwise, `false`. - public func containsKey(_ key: Key) -> Bool { - return _keysToValues[key] != nil - } - - /// Returns the value associated with the given key if the key is found in the ordered - /// dictionary, or `nil` if the key is not found. - /// - /// - Parameter key: The key to find in the ordered dictionary. - /// - Returns: The value associated with `key` if `key` is in the ordered dictionary; otherwise, - /// `nil`. - public func value(forKey key: Key) -> Value? { - return _keysToValues[key] - } - - /// Updates the value stored in the ordered dictionary for the given key, or appends a new - /// key-value pair if the key does not exist. - /// - /// - Parameter value: The new value to add to the ordered dictionary. - /// - Parameter key: The key to associate with `value`. If `key` already exists in the ordered - /// dictionary, `value` replaces the existing associated value. If `key` is not already a key - /// of the ordered dictionary, the `(key, value)` pair is appended at the end of the ordered - /// dictionary. - @discardableResult - public mutating func updateValue(_ value: Value, forKey key: Key) -> Value? { - if containsKey(key) { - let currentValue = _unsafeValue(forKey: key) - - _keysToValues[key] = value - - return currentValue - } else { - _orderedKeys.append(key) - _keysToValues[key] = value - - return nil - } - } - - /// Removes the given key and its associated value from the ordered dictionary. - /// - /// If the key is found in the ordered dictionary, this method returns the key's associated - /// value. On removal, the indices of the ordered dictionary are invalidated. If the key is - /// not found in the ordered dictionary, this method returns `nil`. - /// - /// - Parameter key: The key to remove along with its associated value. - /// - Returns: The value that was removed, or `nil` if the key was not present in the - /// ordered dictionary. - /// - /// - SeeAlso: remove(at:) - @discardableResult - public mutating func removeValue(forKey key: Key) -> Value? { - guard let index = index(forKey: key) else { return nil } - - let currentValue = self[index].value - - _orderedKeys.remove(at: index) - _keysToValues[key] = nil - - return currentValue - } - - /// Removes all key-value pairs from the ordered dictionary and invalidates all indices. - /// - /// - Parameter keepCapacity: Whether the ordered dictionary should keep its underlying storage. - /// If you pass `true`, the operation preserves the storage capacity that the collection has, - /// otherwise the underlying storage is released. The default is `false`. - public mutating func removeAll(keepingCapacity keepCapacity: Bool = false) { - _orderedKeys.removeAll(keepingCapacity: keepCapacity) - _keysToValues.removeAll(keepingCapacity: keepCapacity) - } - - private func _unsafeValue(forKey key: Key) -> Value { - let value = _keysToValues[key] - precondition(value != nil, "Inconsistency error occurred in OrderedDictionary") - return value! - } - - // ======================================================= // - // MARK: - Index-based Access - // ======================================================= // - - /// Accesses the key-value pair at the specified position. - /// - /// The specified position has to be a valid index of the ordered dictionary. The index-base - /// subscript returns the key-value pair corresponding to the index. - /// - /// - Parameter position: The position of the key-value pair to access. `position` must be - /// a valid index of the ordered dictionary and not equal to `endIndex`. - /// - Returns: A tuple containing the key-value pair corresponding to `position`. - /// - /// - SeeAlso: update(:at:) - public subscript(position: Index) -> Element { - precondition(indices.contains(position), "OrderedDictionary index is out of range") - - let key = _orderedKeys[position] - let value = _unsafeValue(forKey: key) - - return (key, value) - } - - /// Returns the index for the given key. - /// - /// - Parameter key: The key to find in the ordered dictionary. - /// - Returns: The index for `key` and its associated value if `key` is in the ordered - /// dictionary; otherwise, `nil`. - public func index(forKey key: Key) -> Index? { - #if swift(>=5.0) - return _orderedKeys.firstIndex(of: key) - #else - return _orderedKeys.index(of: key) - #endif - } - - /// Returns the key-value pair at the specified index, or `nil` if there is no key-value pair - /// at that index. - /// - /// - Parameter index: The index of the key-value pair to be looked up. `index` does not have to - /// be a valid index. - /// - Returns: A tuple containing the key-value pair corresponding to `index` if the index is - /// valid; otherwise, `nil`. - public func elementAt(_ index: Index) -> Element? { - return indices.contains(index) ? self[index] : nil - } - - /// Checks whether the given key-value pair can be inserted into to ordered dictionary by - /// validating the presence of the key. - /// - /// - Parameter newElement: The key-value pair to be inserted into the ordered dictionary. - /// - Returns: `true` if the key-value pair can be safely inserted; otherwise, `false`. - /// - /// - SeeAlso: canInsert(key:) - /// - SeeAlso: canInsert(at:) - @available(*, deprecated, message: "Use canInsert(key:) with the element's key instead") - public func canInsert(_ newElement: Element) -> Bool { - return canInsert(key: newElement.key) - } - - /// Checks whether a key-value pair with the given key can be inserted into the ordered - /// dictionary by validating its presence. - /// - /// - Parameter key: The key to be inserted into the ordered dictionary. - /// - Returns: `true` if the key can safely be inserted; ortherwise, `false`. - /// - /// - SeeAlso: canInsert(at:) - public func canInsert(key: Key) -> Bool { - return !containsKey(key) - } - - /// Checks whether a new key-value pair can be inserted into the ordered dictionary at the - /// given index. - /// - /// - Parameter index: The index the new key-value pair should be inserted at. - /// - Returns: `true` if a new key-value pair can be inserted at the specified index; otherwise, - /// `false`. - /// - /// - SeeAlso: canInsert(key:) - public func canInsert(at index: Index) -> Bool { - return index >= startIndex && index <= endIndex - } - - /// Inserts a new key-value pair at the specified position. - /// - /// If the key of the inserted pair already exists in the ordered dictionary, a runtime error - /// is triggered. Use `canInsert(_:)` for performing a check first, so that this method can - /// be executed safely. - /// - /// - Parameter newElement: The new key-value pair to insert into the ordered dictionary. The - /// key contained in the pair must not be already present in the ordered dictionary. - /// - Parameter index: The position at which to insert the new key-value pair. `index` must be - /// a valid index of the ordered dictionary or equal to `endIndex` property. - /// - /// - SeeAlso: canInsert(key:) - /// - SeeAlso: canInsert(at:) - /// - SeeAlso: update(:at:) - public mutating func insert(_ newElement: Element, at index: Index) { - precondition(canInsert(key: newElement.key), "Cannot insert duplicate key in OrderedDictionary") - precondition(canInsert(at: index), "Cannot insert at invalid index in OrderedDictionary") - - let (key, value) = newElement - - _orderedKeys.insert(key, at: index) - _keysToValues[key] = value - } - - /// Checks whether the key-value pair at the given index can be updated with the given key-value - /// pair. This is not the case if the key of the updated element is already present in the - /// ordered dictionary and located at another index than the updated one. - /// - /// Although this is a checking method, a valid index has to be provided. - /// - /// - Parameter newElement: The key-value pair to be set at the specified position. - /// - Parameter index: The position at which to set the key-value pair. `index` must be a valid - /// index of the ordered dictionary. - public func canUpdate(_ newElement: Element, at index: Index) -> Bool { - var keyPresentAtIndex = false - return _canUpdate(newElement, at: index, keyPresentAtIndex: &keyPresentAtIndex) - } - - /// Updates the key-value pair located at the specified position. - /// - /// If the key of the updated pair already exists in the ordered dictionary *and* is located at - /// a different position than the specified one, a runtime error is triggered. Use - /// `canUpdate(_:at:)` for performing a check first, so that this method can be executed safely. - /// - /// - Parameter newElement: The key-value pair to be set at the specified position. - /// - Parameter index: The position at which to set the key-value pair. `index` must be a valid - /// index of the ordered dictionary. - /// - /// - SeeAlso: canUpdate(_:at:) - /// - SeeAlso: insert(:at:) - @discardableResult - public mutating func update(_ newElement: Element, at index: Index) -> Element? { - // Store the flag indicating whether the key of the inserted element - // is present at the updated index - var keyPresentAtIndex = false - - precondition( - _canUpdate(newElement, at: index, keyPresentAtIndex: &keyPresentAtIndex), - "OrderedDictionary update duplicates key" - ) - - // Decompose the element - let (key, value) = newElement - - // Load the current element at the index - let replacedElement = self[index] - - // If its key differs, remove its associated value - if (!keyPresentAtIndex) { - _keysToValues.removeValue(forKey: replacedElement.key) - } - - // Store the new position of the key and the new value associated with the key - _orderedKeys[index] = key - _keysToValues[key] = value - - return replacedElement - } - - /// Removes and returns the key-value pair at the specified position if there is any key-value - /// pair, or `nil` if there is none. - /// - /// - Parameter index: The position of the key-value pair to remove. - /// - Returns: The element at the specified index, or `nil` if the position is not taken. - /// - /// - SeeAlso: removeValue(forKey:) - @discardableResult - public mutating func remove(at index: Index) -> Element? { - guard let element = elementAt(index) else { return nil } - - _orderedKeys.remove(at: index) - _keysToValues.removeValue(forKey: element.key) - - return element - } - - private func _canUpdate( - _ newElement: Element, - at index: Index, - keyPresentAtIndex: inout Bool - ) -> Bool { - precondition(indices.contains(index), "OrderedDictionary index is out of range") - - let currentIndexOfKey = self.index(forKey: newElement.key) - - let keyNotPresent = currentIndexOfKey == nil - keyPresentAtIndex = currentIndexOfKey == index - - return keyNotPresent || keyPresentAtIndex - } - - // ======================================================= // - // MARK: - Removing First & Last Elements - // ======================================================= // - - /// Removes and returns the first key-value pair of the ordered dictionary if it is not empty. - public mutating func popFirst() -> Element? { - guard !isEmpty else { return nil } - return remove(at: startIndex) - } - - /// Removes and returns the last key-value pair of the ordered dictionary if it is not empty. - public mutating func popLast() -> Element? { - guard !isEmpty else { return nil } - return remove(at: index(before: endIndex)) - } - - /// Removes and returns the first key-value pair of the ordered dictionary. - public mutating func removeFirst() -> Element { - precondition(!isEmpty, "Cannot remove key-value pairs from empty OrderedDictionary") - return remove(at: startIndex)! - } - - /// Removes and returns the last key-value pair of the ordered dictionary. - public mutating func removeLast() -> Element { - precondition(!isEmpty, "Cannot remove key-value pairs from empty OrderedDictionary") - return remove(at: index(before: endIndex))! - } - - // ======================================================= // - // MARK: - Moving Elements - // ======================================================= // - - /// Moves an existing key-value pair specified by the given key to the new index by removing it - /// from its original index first and inserting it at the new index. If the movement is - /// actually performed, the previous index of the key-value pair is returned. Otherwise, `nil` - /// is returned. - /// - /// - Parameter key: The key specifying the key-value pair to move. - /// - Parameter newIndex: The new index the key-value pair should be moved to. - /// - Returns: The previous index of the key-value pair if it was sucessfully moved. - @discardableResult - public mutating func moveElement(forKey key: Key, to newIndex: Index) -> Index? { - // Load the previous index and return nil if the index is not found. - guard let previousIndex = index(forKey: key) else { return nil } - - // If the previous and new indices match, threat it as if the movement was already - // performed. - guard previousIndex != newIndex else { return previousIndex } - - // Remove the value for the key at its original index. - let value = removeValue(forKey: key)! - - // Validate the new index. - precondition(canInsert(at: newIndex), "Cannot move to invalid index in OrderedDictionary") - - // Insert the element at the new index. - insert((key: key, value: value), at: newIndex) - - return previousIndex - } - - // ======================================================= // - // MARK: - Sorting Elements - // ======================================================= // - - /// Sorts the ordered dictionary in place, using the given predicate as the comparison between - /// elements. - /// - /// The predicate must be a *strict weak ordering* over the elements. - /// - /// - Parameter areInIncreasingOrder: A predicate that returns `true` if its first argument - /// should be ordered before its second argument; otherwise, `false`. - /// - /// - SeeAlso: MutableCollection.sort(by:), sorted(by:) - public mutating func sort( - by areInIncreasingOrder: (Element, Element) throws -> Bool - ) rethrows { - _orderedKeys = try _sortedElements(by: areInIncreasingOrder).map { $0.key } - } - - /// Returns a new ordered dictionary, sorted using the given predicate as the comparison between - /// elements. - /// - /// The predicate must be a *strict weak ordering* over the elements. - /// - /// - Parameter areInIncreasingOrder: A predicate that returns `true` if its first argument - /// should be ordered before its second argument; otherwise, `false`. - /// - Returns: A new ordered dictionary sorted according to the predicate. - /// - /// - SeeAlso: MutableCollection.sorted(by:), sort(by:) - /// - MutatingVariant: sort - public func sorted( - by areInIncreasingOrder: (Element, Element) throws -> Bool - ) rethrows -> OrderedDictionary { - return OrderedDictionary(uniqueKeysWithValues: try _sortedElements(by: areInIncreasingOrder)) - } - - private func _sortedElements( - by areInIncreasingOrder: (Element, Element) throws -> Bool - ) rethrows -> [Element] { - return try sorted(by: areInIncreasingOrder) - } - - // ======================================================= // - // MARK: - Mapping Values - // ======================================================= // - - /// Returns a new ordered dictionary containing the keys of this ordered dictionary with the - /// values transformed by the given closure by preserving the original order. - public func mapValues( - _ transform: (Value) throws -> T - ) rethrows -> OrderedDictionary { - var result = OrderedDictionary() - - for (key, value) in self { - result[key] = try transform(value) - } - - return result - } - - /// Returns a new ordered dictionary containing only the key-value pairs that have non-nil - /// values as the result of transformation by the given closure by preserving the original - /// order. - public func compactMapValues( - _ transform: (Value) throws -> T? - ) rethrows -> OrderedDictionary { - var result = OrderedDictionary() - - for (key, value) in self { - if let transformedValue = try transform(value) { - result[key] = transformedValue - } - } - - return result - } - - // ======================================================= // - // MARK: - Capacity - // ======================================================= // - - /// The total number of elements that the ordered dictionary can contain without allocating - /// new storage. - public var capacity: Int { - return Swift.min(_orderedKeys.capacity, _keysToValues.capacity) - } - - /// Reserves enough space to store the specified number of elements, when appropriate - /// for the underlying types. - /// - /// If you are adding a known number of elements to an ordered dictionary, use this method - /// to avoid multiple reallocations. This method ensures that the underlying types of the - /// ordered dictionary have space allocated for at least the requested number of elements. - /// - /// - Parameter minimumCapacity: The requested number of elements to store. - public mutating func reserveCapacity(_ minimumCapacity: Int) { - _orderedKeys.reserveCapacity(minimumCapacity) - _keysToValues.reserveCapacity(minimumCapacity) - } - - // ======================================================= // - // MARK: - Internal Storage - // ======================================================= // - - /// The backing storage for the ordered keys. - fileprivate var _orderedKeys: [Key] - - /// The backing storage for the mapping of keys to values. - fileprivate var _keysToValues: [Key: Value] - -} - -// ======================================================= // -// MARK: - Subtypes -// ======================================================= // - -/// A view into an ordered dictionary whose indices are a subrange of the indices of the ordered -/// dictionary. -public typealias OrderedDictionarySlice = Slice> - -/// A collection containing the keys of the ordered dictionary. -/// -/// Under the hood this is a lazily evaluated bidirectional collection deriving the keys from -/// the base ordered dictionary on-the-fly. -public typealias OrderedDictionaryKeys = LazyMapCollection, Key> - -/// A collection containing the values of the ordered dictionary. -/// -/// Under the hood this is a lazily evaluated bidirectional collection deriving the values from -/// the base ordered dictionary on-the-fly. -public typealias OrderedDictionaryValues = LazyMapCollection, Value> - -// ======================================================= // -// MARK: - Literals -// ======================================================= // - -extension OrderedDictionary: ExpressibleByArrayLiteral { - - /// Initializes an ordered dictionary initialized from an array literal containing a list of - /// key-value pairs. Every key in `elements` must be unique. - public init(arrayLiteral elements: Element...) { - self.init(uniqueKeysWithValues: elements) - } - -} - -extension OrderedDictionary: ExpressibleByDictionaryLiteral { - - /// Initializes an ordered dictionary initialized from a dictionary literal. Every key in - /// `elements` must be unique. - public init(dictionaryLiteral elements: (Key, Value)...) { - self.init(uniqueKeysWithValues: elements.map { element in - let (key, value) = element - return (key: key, value: value) - }) - } - -} - -// ======================================================= // -// MARK: - Equatable Conformance -// ======================================================= // - -extension OrderedDictionary: Equatable where Value: Equatable {} - -// ======================================================= // -// MARK: - Dictionary Extension -// ======================================================= // - -extension Dictionary { - - /// Returns an ordered dictionary containing the key-value pairs from the dictionary, sorted - /// using the given sort function. - /// - /// - Parameter areInIncreasingOrder: The sort function which compares the key-value pairs. - /// - Returns: The ordered dictionary. - /// - SeeAlso: OrderedDictionary.init(unsorted:areInIncreasingOrder:) - public func sorted( - by areInIncreasingOrder: (Element, Element) throws -> Bool - ) rethrows -> OrderedDictionary { - return try OrderedDictionary( - unsorted: self, - areInIncreasingOrder: areInIncreasingOrder - ) - } - -} diff --git a/Sources/CAlchemy/bcrypt.c b/Sources/AlchemyC/bcrypt.c similarity index 100% rename from Sources/CAlchemy/bcrypt.c rename to Sources/AlchemyC/bcrypt.c diff --git a/Sources/CAlchemy/bcrypt.h b/Sources/AlchemyC/bcrypt.h similarity index 100% rename from Sources/CAlchemy/bcrypt.h rename to Sources/AlchemyC/bcrypt.h diff --git a/Sources/CAlchemy/blf.c b/Sources/AlchemyC/blf.c similarity index 100% rename from Sources/CAlchemy/blf.c rename to Sources/AlchemyC/blf.c diff --git a/Sources/CAlchemy/blf.h b/Sources/AlchemyC/blf.h similarity index 100% rename from Sources/CAlchemy/blf.h rename to Sources/AlchemyC/blf.h diff --git a/Sources/CAlchemy/include/module.modulemap b/Sources/AlchemyC/include/module.modulemap similarity index 100% rename from Sources/CAlchemy/include/module.modulemap rename to Sources/AlchemyC/include/module.modulemap diff --git a/Sources/AlchemyTest/Assertions/Client+Assertions.swift b/Sources/AlchemyTest/Assertions/Client+Assertions.swift new file mode 100644 index 00000000..de93b38c --- /dev/null +++ b/Sources/AlchemyTest/Assertions/Client+Assertions.swift @@ -0,0 +1,77 @@ +@testable import Alchemy +import AsyncHTTPClient +import XCTest + +extension Client.Builder { + public func assertNothingSent(file: StaticString = #filePath, line: UInt = #line) { + let stubbedRequests = client.stubs?.stubbedRequests ?? [] + XCTAssert(stubbedRequests.isEmpty, file: file, line: line) + } + + public func assertSent( + _ count: Int? = nil, + validate: ((Client.Request) throws -> Bool)? = nil, + file: StaticString = #filePath, + line: UInt = #line + ) { + let stubbedRequests = client.stubs?.stubbedRequests ?? [] + XCTAssertFalse(stubbedRequests.isEmpty, file: file, line: line) + if let count = count { + XCTAssertEqual(client.stubs?.stubbedRequests.count, count, file: file, line: line) + } + + if let validate = validate { + var foundMatch = false + for request in stubbedRequests where !foundMatch { + XCTAssertNoThrow(foundMatch = try validate(request)) + } + + AssertTrue(foundMatch, file: file, line: line) + } + } +} + +extension Client.Request { + public func hasHeader(_ name: String, value: String? = nil) -> Bool { + guard let header = headers.first(name: name) else { + return false + } + + if let value = value { + return header == value + } else { + return true + } + } + + public func hasQuery(_ name: String, value: L) -> Bool { + let components = URLComponents(string: url.absoluteString) + return components?.queryItems?.contains(where: { item in + guard + item.name == name, + let stringValue = item.value, + let itemValue = L(stringValue) + else { + return false + } + + return itemValue == value + }) ?? false + } + + public func hasPath(_ path: String) -> Bool { + URLComponents(string: url.absoluteString)?.path == path + } + + public func hasMethod(_ method: HTTPMethod) -> Bool { + self.method == method + } + + public func hasBody(string: String) -> Bool { + if let buffer = body?.buffer { + return buffer.string == string + } else { + return false + } + } +} diff --git a/Sources/AlchemyTest/Assertions/HTTP/ContentInspector+Assertions.swift b/Sources/AlchemyTest/Assertions/HTTP/ContentInspector+Assertions.swift new file mode 100644 index 00000000..e41936e5 --- /dev/null +++ b/Sources/AlchemyTest/Assertions/HTTP/ContentInspector+Assertions.swift @@ -0,0 +1,98 @@ +import Alchemy + +extension ContentInspector { + // MARK: Header Assertions + + @discardableResult + public func assertHeader(_ header: String, value: String, file: StaticString = #filePath, line: UInt = #line) -> Self { + let values = headers[header] + XCTAssertFalse(values.isEmpty, file: file, line: line) + for v in values { + XCTAssertEqual(v, value, file: file, line: line) + } + + return self + } + + @discardableResult + public func assertHeaderMissing(_ header: String, file: StaticString = #filePath, line: UInt = #line) -> Self { + XCTAssert(headers[header].isEmpty, file: file, line: line) + return self + } + + @discardableResult + public func assertLocation(_ uri: String, file: StaticString = #filePath, line: UInt = #line) -> Self { + assertHeader("Location", value: uri, file: file, line: line) + } + + // MARK: Body Assertions + + @discardableResult + public func assertBody(_ string: String, file: StaticString = #filePath, line: UInt = #line) -> Self { + guard let body = self.body else { + XCTFail("Request body was nil.", file: file, line: line) + return self + } + + guard let decoded = body.string() else { + XCTFail("Request body was not a String.", file: file, line: line) + return self + } + + XCTAssertEqual(decoded, string, file: file, line: line) + return self + } + + @discardableResult + public func assertStream(_ assertChunk: @escaping (ByteBuffer) -> Void, file: StaticString = #filePath, line: UInt = #line) async throws -> Self { + guard let body = self.body else { + XCTFail("Request body was nil.", file: file, line: line) + return self + } + + try await body.stream.readAll(chunkHandler: assertChunk) + return self + } + + @discardableResult + public func assertJson(_ value: D, file: StaticString = #filePath, line: UInt = #line) -> Self { + guard body != nil else { + XCTFail("Request body was nil.", file: file, line: line) + return self + } + + XCTAssertNoThrow(try decode(D.self), file: file, line: line) + guard let decoded = try? decode(D.self) else { + return self + } + + XCTAssertEqual(decoded, value, file: file, line: line) + return self + } + + // Convert to anything? String, Int, Bool, Double, Array, Object... + @discardableResult + public func assertJson(_ value: [String: Any], file: StaticString = #filePath, line: UInt = #line) -> Self { + guard let body = self.body else { + XCTFail("Request body was nil.", file: file, line: line) + return self + } + + guard let dict = try? body.decodeJSONDictionary() else { + XCTFail("Request body wasn't a json object.", file: file, line: line) + return self + } + + XCTAssertEqual(NSDictionary(dictionary: dict), NSDictionary(dictionary: value), file: file, line: line) + return self + } + + @discardableResult + public func assertEmpty(file: StaticString = #filePath, line: UInt = #line) -> Self { + if body != nil { + XCTFail("The response body was not empty \(body?.string() ?? "nil")", file: file, line: line) + } + + return self + } +} diff --git a/Sources/AlchemyTest/Assertions/HTTP/RequestInspector+Assertions.swift b/Sources/AlchemyTest/Assertions/HTTP/RequestInspector+Assertions.swift new file mode 100644 index 00000000..1ce03243 --- /dev/null +++ b/Sources/AlchemyTest/Assertions/HTTP/RequestInspector+Assertions.swift @@ -0,0 +1,6 @@ +import Alchemy + +extension Client.Request: RequestInspector {} +extension RequestInspector { + +} diff --git a/Sources/AlchemyTest/Assertions/HTTP/ResponseInspector+Assertions.swift b/Sources/AlchemyTest/Assertions/HTTP/ResponseInspector+Assertions.swift new file mode 100644 index 00000000..26818377 --- /dev/null +++ b/Sources/AlchemyTest/Assertions/HTTP/ResponseInspector+Assertions.swift @@ -0,0 +1,66 @@ +import Alchemy +import XCTest + +extension Response: ResponseInspector {} +extension ResponseInspector { + // MARK: Status Assertions + + @discardableResult + public func assertCreated(file: StaticString = #filePath, line: UInt = #line) -> Self { + XCTAssertEqual(status, .created, file: file, line: line) + return self + } + + @discardableResult + public func assertForbidden(file: StaticString = #filePath, line: UInt = #line) -> Self { + XCTAssertEqual(status, .forbidden, file: file, line: line) + return self + } + + @discardableResult + public func assertNotFound(file: StaticString = #filePath, line: UInt = #line) -> Self { + XCTAssertEqual(status, .notFound, file: file, line: line) + return self + } + + @discardableResult + public func assertOk(file: StaticString = #filePath, line: UInt = #line) -> Self { + XCTAssertEqual(status, .ok, file: file, line: line) + return self + } + + @discardableResult + public func assertRedirect(to uri: String? = nil, file: StaticString = #filePath, line: UInt = #line) -> Self { + XCTAssertTrue((300...399).contains(status.code), file: file, line: line) + + if let uri = uri { + assertLocation(uri, file: file, line: line) + } + + return self + } + + @discardableResult + public func assertStatus(_ status: HTTPResponseStatus, file: StaticString = #filePath, line: UInt = #line) -> Self { + XCTAssertEqual(self.status, status, file: file, line: line) + return self + } + + @discardableResult + public func assertStatus(_ code: UInt, file: StaticString = #filePath, line: UInt = #line) -> Self { + XCTAssertEqual(status.code, code, file: file, line: line) + return self + } + + @discardableResult + public func assertSuccessful(file: StaticString = #filePath, line: UInt = #line) -> Self { + XCTAssertTrue((200...299).contains(status.code), file: file, line: line) + return self + } + + @discardableResult + public func assertUnauthorized(file: StaticString = #filePath, line: UInt = #line) -> Self { + XCTAssertEqual(status, .unauthorized, file: file, line: line) + return self + } +} diff --git a/Sources/AlchemyTest/Assertions/MemoryCache+Assertions.swift b/Sources/AlchemyTest/Assertions/MemoryCache+Assertions.swift new file mode 100644 index 00000000..95c27374 --- /dev/null +++ b/Sources/AlchemyTest/Assertions/MemoryCache+Assertions.swift @@ -0,0 +1,21 @@ +@testable import Alchemy +import XCTest + +extension MemoryCache { + public func assertSet(_ key: String, _ val: L? = nil) { + XCTAssertTrue(has(key)) + if let val = val { + XCTAssertNoThrow(try { + XCTAssertEqual(try get(key), val) + }()) + } + } + + public func assertNotSet(_ key: String) { + XCTAssertFalse(has(key)) + } + + public func assertEmpty() { + XCTAssertTrue(data.isEmpty) + } +} diff --git a/Sources/AlchemyTest/Assertions/MemoryQueue+Assertions.swift b/Sources/AlchemyTest/Assertions/MemoryQueue+Assertions.swift new file mode 100644 index 00000000..87a0bc5a --- /dev/null +++ b/Sources/AlchemyTest/Assertions/MemoryQueue+Assertions.swift @@ -0,0 +1,64 @@ +@testable import Alchemy +import XCTest + +extension MemoryQueue { + public func assertNothingPushed() { + XCTAssertTrue(jobs.isEmpty) + } + + public func assertNotPushed(_ type: J.Type, file: StaticString = #filePath, line: UInt = #line) { + XCTAssertFalse(jobs.values.contains { $0.jobName == J.name }, file: file, line: line) + } + + public func assertPushed( + on channel: String? = nil, + _ type: J.Type, + _ count: Int = 1, + file: StaticString = #filePath, + line: UInt = #line + ) { + let matches = jobs.values.filter { $0.jobName == J.name && $0.channel == channel ?? $0.channel } + XCTAssertEqual(matches.count, count, file: file, line: line) + } + + public func assertPushed( + on channel: String? = nil, + _ type: J.Type, + assertion: (J) -> Bool, + file: StaticString = #filePath, + line: UInt = #line + ) { + XCTAssertNoThrow(try { + let matches = try jobs.values.filter { + guard $0.jobName == J.name, $0.channel == channel ?? $0.channel else { + return false + } + + let job = try (JobDecoding.decode($0) as? J).unwrap(or: JobError.unknownType) + return assertion(job) + } + + XCTAssertFalse(matches.isEmpty, file: file, line: line) + }(), file: file, line: line) + } + + public func assertPushed( + on channel: String? = nil, + _ instance: J, + file: StaticString = #filePath, + line: UInt = #line + ) { + XCTAssertNoThrow(try { + let matches = try jobs.values.filter { + guard $0.jobName == J.name, $0.channel == channel ?? $0.channel else { + return false + } + + let job = try (JobDecoding.decode($0) as? J).unwrap(or: JobError.unknownType) + return job == instance + } + + XCTAssertFalse(matches.isEmpty, file: file, line: line) + }(), file: file, line: line) + } +} diff --git a/Sources/AlchemyTest/Exports.swift b/Sources/AlchemyTest/Exports.swift new file mode 100644 index 00000000..f63f7d78 --- /dev/null +++ b/Sources/AlchemyTest/Exports.swift @@ -0,0 +1,2 @@ +@_exported import Alchemy +@_exported import XCTest diff --git a/Sources/AlchemyTest/Fakes/Database+Fake.swift b/Sources/AlchemyTest/Fakes/Database+Fake.swift new file mode 100644 index 00000000..ed1580c4 --- /dev/null +++ b/Sources/AlchemyTest/Fakes/Database+Fake.swift @@ -0,0 +1,66 @@ +extension Database { + /// Fake the database with an in memory SQLite database. + /// + /// - Parameters: + /// - id: The identifier of the database to fake, defaults to `default`. + /// - seeds: Any migrations to set on the database, they will be run + /// before this function returns. + /// - seeders: Any seeders to set on the database, they will be run before + /// this function returns. + @discardableResult + public static func fake(_ id: Identifier = .default, migrations: [Migration] = [], seeders: [Seeder] = []) -> Database { + let db = Database.sqlite + db.migrations = migrations + db.seeders = seeders + bind(id, db) + + let sem = DispatchSemaphore(value: 0) + Task { + do { + if !migrations.isEmpty { try await db.migrate() } + if !seeders.isEmpty { try await db.seed() } + } catch { + Log.error("Error initializing fake database: \(error)") + } + + sem.signal() + } + + sem.wait() + return db + } + + /// Synchronously migrates the database, useful for setting up the database + /// before test cases. + public func syncMigrate() { + let sem = DispatchSemaphore(value: 0) + Task { + do { + if !migrations.isEmpty { try await migrate() } + } catch { + Log.error("Error migrating test database: \(error)") + } + + sem.signal() + } + + sem.wait() + } + + /// Synchronously seeds the database, useful for setting up the database + /// before test cases. + public func syncSeed() { + let sem = DispatchSemaphore(value: 0) + Task { + do { + if !seeders.isEmpty { try await seed() } + } catch { + Log.error("Error seeding test database: \(error)") + } + + sem.signal() + } + + sem.wait() + } +} diff --git a/Sources/AlchemyTest/Fixtures/Request+Fixture.swift b/Sources/AlchemyTest/Fixtures/Request+Fixture.swift new file mode 100644 index 00000000..c4cad7d0 --- /dev/null +++ b/Sources/AlchemyTest/Fixtures/Request+Fixture.swift @@ -0,0 +1,27 @@ +@testable +import Alchemy +import Hummingbird +import NIOCore + +extension Request { + /// Initialize a request fixture with the given data. + public static func fixture( + remoteAddress: SocketAddress? = nil, + version: HTTPVersion = .http1_1, + method: HTTPMethod = .GET, + uri: String = "foo", + headers: HTTPHeaders = [:], + body: ByteContent? = nil + ) -> Request { + struct DummyContext: HBRequestContext { + let eventLoop: EventLoop = EmbeddedEventLoop() + let allocator: ByteBufferAllocator = .init() + let remoteAddress: SocketAddress? = nil + } + + let dummyApp = HBApplication() + let head = HTTPRequestHead(version: version, method: method, uri: uri, headers: headers) + let req = HBRequest(head: head, body: .byteBuffer(body?.buffer), application: dummyApp, context: DummyContext()) + return Request(hbRequest: req) + } +} diff --git a/Sources/AlchemyTest/Fixtures/TestApp.swift b/Sources/AlchemyTest/Fixtures/TestApp.swift new file mode 100644 index 00000000..4251772a --- /dev/null +++ b/Sources/AlchemyTest/Fixtures/TestApp.swift @@ -0,0 +1,7 @@ +import Alchemy + +/// An app that does nothing, for testing. +public struct TestApp: Application { + public init() {} + public func boot() throws {} +} diff --git a/Sources/AlchemyTest/Stubs/Database/Database+Stub.swift b/Sources/AlchemyTest/Stubs/Database/Database+Stub.swift new file mode 100644 index 00000000..ac852084 --- /dev/null +++ b/Sources/AlchemyTest/Stubs/Database/Database+Stub.swift @@ -0,0 +1,12 @@ +extension Database { + /// Mock the database with a database for stubbing specific queries. + /// + /// - Parameter id: The identifier of the database to stub, defaults to + /// `default`. + @discardableResult + public static func stub(_ id: Identifier = .default) -> StubDatabase { + let stub = StubDatabase() + bind(id, Database(provider: stub)) + return stub + } +} diff --git a/Sources/AlchemyTest/Stubs/Database/StubDatabase.swift b/Sources/AlchemyTest/Stubs/Database/StubDatabase.swift new file mode 100644 index 00000000..bb7fb5de --- /dev/null +++ b/Sources/AlchemyTest/Stubs/Database/StubDatabase.swift @@ -0,0 +1,64 @@ +public final class StubDatabase: DatabaseProvider { + private var isShutdown = false + private var stubs: [[SQLRow]] = [] + + public let grammar = Grammar() + + init() {} + + public func query(_ sql: String, values: [SQLValue]) async throws -> [SQLRow] { + guard !isShutdown else { + throw StubDatabaseError("This stubbed database has been shutdown.") + } + + guard let mockedRows = stubs.first else { + throw StubDatabaseError("Before running a query on a stubbed database, please stub it's resposne with `stub()`.") + } + + return mockedRows + } + + public func raw(_ sql: String) async throws -> [SQLRow] { + try await query(sql, values: []) + } + + public func transaction(_ action: @escaping (DatabaseProvider) async throws -> T) async throws -> T { + try await action(self) + } + + public func shutdown() throws { + isShutdown = true + } + + public func stub(_ rows: [StubDatabaseRow]) { + stubs.append(rows) + } +} + +public struct StubDatabaseRow: SQLRow { + public let data: [String: SQLValueConvertible] + public let columns: Set + + public init(data: [String: SQLValueConvertible] = [:]) { + self.data = data + self.columns = Set(data.keys) + } + + public func get(_ column: String) throws -> SQLValue { + try data[column].unwrap(or: StubDatabaseError("Stubbed database row had no column `\(column)`.")).value + } +} + +/// An error encountered when interacting with a `StubDatabase`. +public struct StubDatabaseError: Error { + /// What went wrong. + let message: String + + /// Initialize a `DatabaseError` with a message detailing what + /// went wrong. + /// + /// - Parameter message: Why this error was thrown. + init(_ message: String) { + self.message = message + } +} diff --git a/Sources/AlchemyTest/Stubs/Redis/Redis+Stub.swift b/Sources/AlchemyTest/Stubs/Redis/Redis+Stub.swift new file mode 100644 index 00000000..bc0ebe35 --- /dev/null +++ b/Sources/AlchemyTest/Stubs/Redis/Redis+Stub.swift @@ -0,0 +1,13 @@ +import NIO + +extension RedisClient { + /// Mock Redis with a provider for stubbing specific commands. + /// + /// - Parameter id: The id of the redis client to stub, defaults to + /// `default`. + public static func stub(_ id: Identifier = .default) -> StubRedis { + let provider = StubRedis() + bind(id, RedisClient(provider: provider)) + return provider + } +} diff --git a/Sources/AlchemyTest/Stubs/Redis/StubRedis.swift b/Sources/AlchemyTest/Stubs/Redis/StubRedis.swift new file mode 100644 index 00000000..13ca1a4d --- /dev/null +++ b/Sources/AlchemyTest/Stubs/Redis/StubRedis.swift @@ -0,0 +1,72 @@ +import NIOCore +import RediStack + +public final class StubRedis: RedisProvider { + private var isShutdown = false + + var stubs: [String: RESPValue] = [:] + + public func stub(_ command: String, response: RESPValue) { + stubs[command] = response + } + + // MARK: RedisProvider + + public func getClient() -> RediStack.RedisClient { + self + } + + public func transaction(_ transaction: @escaping (RedisProvider) async throws -> T) async throws -> T { + try await transaction(self) + } + + public func shutdown() throws { + isShutdown = true + } +} + +extension StubRedis: RediStack.RedisClient { + public var eventLoop: EventLoop { Loop.current } + + public func send(command: String, with arguments: [RESPValue]) -> EventLoopFuture { + guard !isShutdown else { + return eventLoop.future(error: RedisError(reason: "This stubbed redis client has been shutdown.")) + } + + guard let stub = stubs.removeValue(forKey: command) else { + return eventLoop.future(error: RedisError(reason: "No stub found for Redis command \(command). Please stub it's response with `stub()`.")) + } + + return eventLoop.future(stub) + } + + public func subscribe( + to channels: [RedisChannelName], + messageReceiver receiver: @escaping RedisSubscriptionMessageReceiver, + onSubscribe subscribeHandler: RedisSubscriptionChangeHandler?, + onUnsubscribe unsubscribeHandler: RedisSubscriptionChangeHandler? + ) -> EventLoopFuture { + eventLoop.future(error: RedisError(reason: "pub/sub stubbing isn't supported, yet")) + } + + public func psubscribe( + to patterns: [String], + messageReceiver receiver: @escaping RedisSubscriptionMessageReceiver, + onSubscribe subscribeHandler: RedisSubscriptionChangeHandler?, + onUnsubscribe unsubscribeHandler: RedisSubscriptionChangeHandler? + ) -> EventLoopFuture { + eventLoop.future(error: RedisError(reason: "pub/sub stubbing isn't supported, yet")) + } + + public func unsubscribe(from channels: [RedisChannelName]) -> EventLoopFuture { + eventLoop.future(error: RedisError(reason: "pub/sub stubbing isn't supported, yet")) + } + + public func punsubscribe(from patterns: [String]) -> EventLoopFuture { + eventLoop.future(error: RedisError(reason: "pub/sub stubbing isn't supported, yet")) + } + + public func logging(to logger: Logger) -> RediStack.RedisClient { + self + } +} diff --git a/Sources/AlchemyTest/TestCase/TestCase+FakeTLS.swift b/Sources/AlchemyTest/TestCase/TestCase+FakeTLS.swift new file mode 100644 index 00000000..c3e34dda --- /dev/null +++ b/Sources/AlchemyTest/TestCase/TestCase+FakeTLS.swift @@ -0,0 +1,82 @@ +extension TestCase { + /// Creates a fake certificate chain and private key in a temporary + /// directory. Useful for faking TLS configurations in tests. + /// + /// final class MyAppTests: TestCase { + /// func testConfigureTLS() { + /// XCTAssertNil(app.tlsConfig) + /// let (key, cert) = app.generateFakeTLSCertificate() + /// try app.useHTTPS(key: key, cert: cert) + /// XCTAssertNotNil(app.tlsConfig) + /// } + /// } + /// + /// - Returns: Paths to the fake key and certificate chain, respectively. + public func generateFakeTLSCertificate() -> (keyPath: String, certPath: String) { + return ( + createTempFile("fake_private_key.pem", contents: samplePKCS8PemPrivateKey), + createTempFile("fake_cert.pem", contents: samplePemCert) + ) + } + + public func createTempFile(_ name: String, contents: String) -> String { + let dirPath = NSTemporaryDirectory() + FileManager.default.createFile(atPath: dirPath + name, contents: contents.data(using: .utf8)) + return dirPath + name + } + + private var samplePemCert: String { + """ + -----BEGIN CERTIFICATE----- + MIIC+zCCAeOgAwIBAgIJANG6W1v704/aMA0GCSqGSIb3DQEBBQUAMBQxEjAQBgNV + BAMMCWxvY2FsaG9zdDAeFw0xOTA4MDExMDMzMjhaFw0yOTA3MjkxMDMzMjhaMBQx + EjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC + ggEBAMLw9InBMGKUNZKXFIpjUYt+Tby42GRQaRFmHfUrlYkvI9L7i9cLqltX/Pta + XL9zISJIEgIgOW1R3pQ4xRP3xC+C3lKpo5lnD9gaMnDIsXhXLQzvTo+tFgtShXsU + /xGl4U2Oc2BbPmydd+sfOPKXOYk/0TJsuSb1U5pA8FClyJUrUlykHkN120s5GhfA + P2KYP+RMZuaW7gNlDEhiInqYUxBpLE+qutAYIDdpKWgxmHKW1oLhZ70TT1Zs7tUI + 22ydjo81vbtB4214EDX0KRRBep+Kq9vTigss34CwhYvyhaCP6l305Z9Vjtu1q1vp + a3nfMeVtcg6JDn3eogv0CevZMc0CAwEAAaNQME4wHQYDVR0OBBYEFK6KIoQAlLog + bBT3snTQ22x5gmXQMB8GA1UdIwQYMBaAFK6KIoQAlLogbBT3snTQ22x5gmXQMAwG + A1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEBAEgoqcGDMriG4cCxWzuiXuV7 + 3TthA8TbdHQOeucNvXt9b3HUG1fQo7a0Tv4X3136SfCy3SsXXJr43snzVUK9SuNb + ntqhAOIudZNw8KSwe+qJwmSEO4y3Lwr5pFfUUkGkV4K86wv3LmBpo3jep5hbkpAc + kvbzTynFrOILV0TaDkF46KHIoyAb5vPneapdW7rXbX1Jba3jo9PyvHRMeoh/I8zt + 4g+Je2PpH1TJ/GT9dmYhYgJaIssVpv/fWkWphVXwMmpqiH9vEbln8piXHxvCj9XU + y7uc6N1VUvIvygzUsR+20wjODoGiXp0g0cj+38n3oG5S9rBd1iGEPMAA/2lQS/0= + -----END CERTIFICATE----- + """ + } + + private var samplePKCS8PemPrivateKey: String { + """ + -----BEGIN RSA PRIVATE KEY----- + MIIEowIBAAKCAQEAwvD0icEwYpQ1kpcUimNRi35NvLjYZFBpEWYd9SuViS8j0vuL + 1wuqW1f8+1pcv3MhIkgSAiA5bVHelDjFE/fEL4LeUqmjmWcP2BoycMixeFctDO9O + j60WC1KFexT/EaXhTY5zYFs+bJ136x848pc5iT/RMmy5JvVTmkDwUKXIlStSXKQe + Q3XbSzkaF8A/Ypg/5Exm5pbuA2UMSGIiephTEGksT6q60BggN2kpaDGYcpbWguFn + vRNPVmzu1QjbbJ2OjzW9u0HjbXgQNfQpFEF6n4qr29OKCyzfgLCFi/KFoI/qXfTl + n1WO27WrW+lred8x5W1yDokOfd6iC/QJ69kxzQIDAQABAoIBAQCX+KZ62cuxnh8h + l3wg4oqIt788l9HCallugfBq2D5sQv6nlQiQbfyx1ydWgDx71/IFuq+nTp3WVpOx + c4xYI7ii3WAaizsJ9SmJ6+pUuHB6A2QQiGLzaRkdXIjIyjaK+IlrH9lcTeWdYSlC + eAW6QSBOmhypNc8lyu0Q/P0bshJsDow5iuy3d8PeT3klxgRPWjgjLZj0eUA0Orfp + s6rC3t7wq8S8+YscKNS6dO0Vp2rF5ZHYYZ9kG5Y0PbAx24hDoYcgMJYJSw5LuR9D + TkNcstHI8aKM7t9TZN0eXeLmzKXAbkD0uyaK0ZwI2panFDBjkjnkwS7FjHDusk1S + Or36zCV1AoGBAOj8ALqa5y4HHl2QF8+dkH7eEFnKmExd1YX90eUuO1v7oTW4iQN+ + Z/me45exNDrG27+w8JqF66zH+WAfHv5Va0AUnTuFAyBmOEqit0m2vFzOLBgDGub1 + xOVYQQ5LetIbiXYU4H3IQDSO+UY27u1yYsgYMrO1qiyGgEkFSbK5xh6HAoGBANYy + 3rv9ULu3ZzeLqmkO+uYxBaEzAzubahgcDniKrsKfLVywYlF1bzplgT9OdGRkwMR8 + K7K5s+6ehrIu8pOadP1fZO7GC7w5lYypbrH74E7mBXSP53NOOebKYpojPhxjMrtI + HLOxGg742WY5MTtDZ81Va0TrhErb4PxccVQEIY4LAoGAc8TMw+y21Ps6jnlMK6D6 + rN/BNiziUogJ0qPWCVBYtJMrftssUe0c0z+tjbHC5zXq+ax9UfsbqWZQtv+f0fc1 + 7MiRfILSk+XXMNb7xogjvuW/qUrZskwLQ38ADI9a/04pluA20KmRpcwpd0dSn/BH + v2+uufeaELfgxOf4v/Npy78CgYBqmqzB8QQCOPg069znJp52fEVqAgKE4wd9clE9 + awApOqGP9PUpx4GRFb2qrTg+Uuqhn478B3Jmux0ch0MRdRjulVCdiZGDn0Ev3Y+L + I2lyuwZSCeDOQUuN8oH6Zrnd1P0FupEWWXk3pGBGgQZgkV6TEgUuKu0PeLlTwApj + Hx84GwKBgHWqSoiaBml/KX+GBUDu8Yp0v+7dkNaiU/RVaSEOFl2wHkJ+bq4V+DX1 + lgofMC2QvBrSinEjHrQPZILl+lOq/ppDcnxhY/3bljsutcgHhIT7PKYDOxFqflMi + ahwyQwRg2oQ2rBrBevgOKFEuIV62WfDYXi8SlT8QaZpTt2r4PYt4 + -----END RSA PRIVATE KEY----- + """ + } +} diff --git a/Sources/AlchemyTest/TestCase/TestCase.swift b/Sources/AlchemyTest/TestCase/TestCase.swift new file mode 100644 index 00000000..745149ac --- /dev/null +++ b/Sources/AlchemyTest/TestCase/TestCase.swift @@ -0,0 +1,56 @@ +@testable +import Alchemy +import NIOCore +import XCTest + +/// A test case class that makes it easy for you to test your app. By default +/// a new instance of your application will be setup before and shutdown +/// after each test. +/// +/// You may also use this class to build & send mock http requests to your app. +open class TestCase: XCTestCase { + public final class Builder: RequestBuilder { + public var urlComponents = URLComponents() + public var method: HTTPMethod = .GET + public var headers: HTTPHeaders = [:] + public var body: ByteContent? = nil + private var version: HTTPVersion = .http1_1 + private var remoteAddress: SocketAddress? = nil + + /// Set the http version of the mock request. + public func withHttpVersion(_ version: HTTPVersion) -> Self { + with { $0.version = version } + } + + /// Set the remote address of the mock request. + public func withRemoteAddress(_ address: SocketAddress) -> Self { + with { $0.remoteAddress = address } + } + + public func execute() async throws -> Response { + await A.current.router.handle( + request: .fixture( + remoteAddress: remoteAddress, + version: version, + method: method, + uri: urlComponents.path, + headers: headers, + body: body)) + } + } + + /// An instance of your app, reset and configured before each test. + public var app = A() + public var Test: Builder { Builder() } + + open override func setUpWithError() throws { + try super.setUpWithError() + app = A() + try app.setup() + } + + open override func tearDownWithError() throws { + try super.tearDownWithError() + try app.stop() + } +} diff --git a/Sources/AlchemyTest/Utilities/AsyncAsserts.swift b/Sources/AlchemyTest/Utilities/AsyncAsserts.swift new file mode 100644 index 00000000..c4038e07 --- /dev/null +++ b/Sources/AlchemyTest/Utilities/AsyncAsserts.swift @@ -0,0 +1,21 @@ +import XCTest + +public func AssertEqual(_ expression1: T, _ expression2: T, _ message: @autoclosure () -> String = "", file: StaticString = #filePath, line: UInt = #line) { + XCTAssertEqual(expression1, expression2, message(), file: file, line: line) +} + +public func AssertNotEqual(_ expression1: T, _ expression2: T, _ message: @autoclosure () -> String = "", file: StaticString = #filePath, line: UInt = #line) { + XCTAssertNotEqual(expression1, expression2, message(), file: file, line: line) +} + +public func AssertNil(_ expression: Any?, _ message: @autoclosure () -> String = "", file: StaticString = #filePath, line: UInt = #line) { + XCTAssertNil(expression, message(), file: file, line: line) +} + +public func AssertFalse(_ expression: Bool, _ message: @autoclosure () -> String = "", file: StaticString = #filePath, line: UInt = #line) { + XCTAssertFalse(expression, message(), file: file, line: line) +} + +public func AssertTrue(_ expression: Bool, _ message: @autoclosure () -> String = "", file: StaticString = #filePath, line: UInt = #line) { + XCTAssertTrue(expression, message(), file: file, line: line) +} diff --git a/Sources/AlchemyTest/Utilities/ByteBuffer+ExpressibleByStringLiteral.swift b/Sources/AlchemyTest/Utilities/ByteBuffer+ExpressibleByStringLiteral.swift new file mode 100644 index 00000000..db55e42d --- /dev/null +++ b/Sources/AlchemyTest/Utilities/ByteBuffer+ExpressibleByStringLiteral.swift @@ -0,0 +1,5 @@ +extension ByteBuffer: ExpressibleByStringLiteral { + public init(stringLiteral value: StringLiteralType) { + self.init(string: value) + } +} diff --git a/Tests/Alchemy/Alchemy+Papyrus/PapyrusRequestTests.swift b/Tests/Alchemy/Alchemy+Papyrus/PapyrusRequestTests.swift new file mode 100644 index 00000000..9b0559ea --- /dev/null +++ b/Tests/Alchemy/Alchemy+Papyrus/PapyrusRequestTests.swift @@ -0,0 +1,71 @@ +import AlchemyTest +import Papyrus + +final class PapyrusRequestTests: TestCase { + private let api = Provider(api: SampleAPI(baseURL: "http://localhost:3000")) + + func testRequest() async throws { + Http.stub() + _ = try await api.createTest.request(CreateTestReq(foo: "one", bar: "two", baz: "three")) + Http.assertSent { + $0.hasMethod(.POST) && + $0.hasPath("/create") && + $0.hasHeader("foo", value: "one") && + $0.hasHeader("bar", value: "two") && + $0.hasQuery("baz", value: "three") + } + } + + func testResponse() async throws { + Http.stub([ + "localhost:3000/get": .stub(body: "\"testing\"") + ]) + let response = try await api.getTest.request().response + XCTAssertEqual(response, "testing") + Http.assertSent(1) { + $0.hasMethod(.GET) && + $0.hasPath("/get") + } + } + + func testUrlEncode() async throws { + Http.stub() + _ = try await api.urlEncode.request(UrlEncodeReq()) + Http.assertSent(1) { + print($0.body?.string() ?? "N/A") + return $0.hasMethod(.PUT) && + $0.hasPath("/url")// && +// $0.hasBody(string: "foo=one") + } + } +} + +private struct SampleAPI: API { + let baseURL: String + + @POST("/create") + var createTest = Endpoint() + + @GET("/get") + var getTest = Endpoint() + + @URLForm + @PUT("/url") + var urlEncode = Endpoint() +} + +private struct CreateTestReq: EndpointRequest { + @Header var foo: String + @Header var bar: String + @RequestQuery var baz: String +} + +private struct UrlEncodeReq: EndpointRequest { + struct Content: Codable { + var foo = "one" + } + + @Body var body = Content() +} + +extension String: EndpointResponse {} diff --git a/Tests/Alchemy/Alchemy+Papyrus/PapyrusRoutingTests.swift b/Tests/Alchemy/Alchemy+Papyrus/PapyrusRoutingTests.swift new file mode 100644 index 00000000..81579026 --- /dev/null +++ b/Tests/Alchemy/Alchemy+Papyrus/PapyrusRoutingTests.swift @@ -0,0 +1,57 @@ +import AlchemyTest +import Papyrus + +final class PapyrusRoutingTests: TestCase { + private let api = TestAPI(baseURL: "https://localhost:3000") + + func testTypedReqTypedRes() async throws { + app.on(api.createTest) { request, content in + return "foo" + } + + let res = try await Test.post("/test") + res.assertSuccessful() + res.assertJson("foo") + } + + func testEmptyReqTypedRes() async throws { + app.on(api.getTest) { request in + return "foo" + } + + let res = try await Test.get("/test") + res.assertSuccessful() + res.assertJson("foo") + } + + func testTypedReqEmptyRes() async throws { + app.on(api.updateTests) { request, content in + return + } + + let res = try await Test.patch("/test") + res.assertSuccessful() + res.assertEmpty() + } + + func testEmptyReqEmptyRes() async throws { + app.on(api.deleteTests) { request in + return + } + + let res = try await Test.delete("/test") + res.assertSuccessful() + res.assertEmpty() + } +} + +private struct TestAPI: API { + let baseURL: String + @POST("/test") var createTest = Endpoint() + @GET("/test") var getTest = Endpoint() + @PATCH("/test") var updateTests = Endpoint() + @DELETE("/test") var deleteTests = Endpoint() +} + +private struct CreateTestReq: EndpointRequest {} +private struct UpdateTestsReq: EndpointRequest {} diff --git a/Tests/Alchemy/Alchemy+Plot/PlotTests.swift b/Tests/Alchemy/Alchemy+Plot/PlotTests.swift new file mode 100644 index 00000000..b753b5b8 --- /dev/null +++ b/Tests/Alchemy/Alchemy+Plot/PlotTests.swift @@ -0,0 +1,51 @@ +@testable import Alchemy +import Plot +import XCTest + +final class PlotTests: XCTestCase { + func testHTMLView() { + let home = HomeView(title: "Welcome", favoriteAnimals: ["Kiwi", "Dolphin"]) + let res = home.response() + XCTAssertEqual(res.status, .ok) + XCTAssertEqual(res.headers.contentType, .html) + XCTAssertEqual(res.body?.string(), home.content.render()) + } + + func testHTMLConversion() { + let html = HomeView(title: "Welcome", favoriteAnimals: ["Kiwi", "Dolphin"]).content + let res = html.response() + XCTAssertEqual(res.status, .ok) + XCTAssertEqual(res.headers.contentType, .html) + XCTAssertEqual(res.body?.string(), html.render()) + } + + func testXMLConversion() { + let xml = XML(.attribute(named: "attribute"), .element(named: "element")) + let res = xml.response() + XCTAssertEqual(res.status, .ok) + XCTAssertEqual(res.headers.contentType, .xml) + XCTAssertEqual(res.body?.string(), xml.render()) + } +} + +struct HomeView: HTMLView { + let title: String + let favoriteAnimals: [String] + + var content: HTML { + HTML( + .head( + .title(self.title), + .stylesheet("styles.css") + ), + .body( + .div( + .h1("My favorite animals are"), + .ul(.forEach(self.favoriteAnimals) { + .li(.class("name"), .text($0)) + }) + ) + ) + ) + } +} diff --git a/Tests/Alchemy/Application/ApplicationCommandTests.swift b/Tests/Alchemy/Application/ApplicationCommandTests.swift new file mode 100644 index 00000000..c1da9db6 --- /dev/null +++ b/Tests/Alchemy/Application/ApplicationCommandTests.swift @@ -0,0 +1,23 @@ +@testable +import Alchemy +import AlchemyTest + +final class ApplicationCommandTests: TestCase { + func testCommandRegistration() throws { + try app.start() + XCTAssertTrue(Launch.customCommands.contains { + id(of: $0) == id(of: TestCommand.self) + }) + } +} + +struct CommandApp: Application { + var commands: [Command.Type] = [TestCommand.self] + func boot() throws {} +} + +private struct TestCommand: Command { + static var configuration = CommandConfiguration(commandName: "command:test") + + func start() async throws {} +} diff --git a/Tests/Alchemy/Application/ApplicationControllerTests.swift b/Tests/Alchemy/Application/ApplicationControllerTests.swift new file mode 100644 index 00000000..4b249d2e --- /dev/null +++ b/Tests/Alchemy/Application/ApplicationControllerTests.swift @@ -0,0 +1,75 @@ +import AlchemyTest + +final class ApplicationControllerTests: TestCase { + func testController() async throws { + try await Test.get("/test").assertNotFound() + app.controller(TestController()) + try await Test.get("/test").assertOk() + } + + func testControllerMiddleware() async throws { + let expect = Expect() + let controller = MiddlewareController(middlewares: [ + ActionMiddleware { await expect.signalOne() }, + ActionMiddleware { await expect.signalTwo() }, + ActionMiddleware { await expect.signalThree() } + ]) + app.controller(controller) + try await Test.get("/middleware").assertOk() + + AssertTrue(await expect.one) + AssertTrue(await expect.two) + AssertTrue(await expect.three) + } + + func testControllerMiddlewareRemoved() async throws { + let expect = Expect() + let controller = MiddlewareController(middlewares: [ + ActionMiddleware { await expect.signalOne() }, + ActionMiddleware { await expect.signalTwo() }, + ActionMiddleware { await expect.signalThree() }, + ]) + + app + .controller(controller) + .get("/outside") { _ async -> String in + await expect.signalFour() + return "foo" + } + + try await Test.get("/outside").assertOk() + AssertFalse(await expect.one) + AssertFalse(await expect.two) + AssertFalse(await expect.three) + AssertTrue(await expect.four) + } +} + +struct ActionMiddleware: Middleware { + let action: () async -> Void + + func intercept(_ request: Request, next: (Request) async throws -> Response) async throws -> Response { + await action() + return try await next(request) + } +} + +struct MiddlewareController: Controller { + let middlewares: [Middleware] + + func route(_ app: Application) { + app + .use(middlewares) + .get("/middleware") { _ in + "Hello, world!" + } + } +} + +struct TestController: Controller { + func route(_ app: Application) { + app.get("/test") { req -> String in + return "Hello, world!" + } + } +} diff --git a/Tests/Alchemy/Application/ApplicationErrorRouteTests.swift b/Tests/Alchemy/Application/ApplicationErrorRouteTests.swift new file mode 100644 index 00000000..9921646b --- /dev/null +++ b/Tests/Alchemy/Application/ApplicationErrorRouteTests.swift @@ -0,0 +1,44 @@ +import AlchemyTest + +final class ApplicationErrorRouteTests: TestCase { + func testCustomNotFound() async throws { + try await Test.get("/not_found").assertBody(HTTPResponseStatus.notFound.reasonPhrase).assertNotFound() + app.notFound { _ in + "Hello, world!" + } + + try await Test.get("/not_found").assertBody("Hello, world!").assertOk() + } + + func testCustomInternalError() async throws { + struct TestError: Error {} + + app.get("/error") { _ -> String in + throw TestError() + } + + let status = HTTPResponseStatus.internalServerError + try await Test.get("/error").assertBody(status.reasonPhrase).assertStatus(status) + + app.internalError { _, _ in + "Nothing to see here." + } + + try await Test.get("/error").assertBody("Nothing to see here.").assertOk() + } + + func testThrowingCustomInternalError() async throws { + struct TestError: Error {} + + app.get("/error") { _ -> String in + throw TestError() + } + + app.internalError { _, _ in + throw TestError() + } + + let status = HTTPResponseStatus.internalServerError + try await Test.get("/error").assertBody(status.reasonPhrase).assertStatus(.internalServerError) + } +} diff --git a/Tests/Alchemy/Application/ApplicationHTTP2Tests.swift b/Tests/Alchemy/Application/ApplicationHTTP2Tests.swift new file mode 100644 index 00000000..e84dc71a --- /dev/null +++ b/Tests/Alchemy/Application/ApplicationHTTP2Tests.swift @@ -0,0 +1,7 @@ +import AlchemyTest + +final class ApplicationHTTP2Tests: TestCase { + func testConfigureHTTP2() throws { + throw XCTSkip() + } +} diff --git a/Tests/Alchemy/Application/ApplicationJobTests.swift b/Tests/Alchemy/Application/ApplicationJobTests.swift new file mode 100644 index 00000000..a7d5a0e2 --- /dev/null +++ b/Tests/Alchemy/Application/ApplicationJobTests.swift @@ -0,0 +1,21 @@ +@testable +import Alchemy +import AlchemyTest + +final class ApplicationJobTests: TestCase { + override func tearDown() { + super.tearDown() + JobDecoding.reset() + } + + func testRegisterJob() { + app.registerJob(TestJob.self) + XCTAssertTrue(app.registeredJobs.contains(where: { + id(of: $0) == id(of: TestJob.self) + })) + } +} + +private struct TestJob: Job { + func run() async throws {} +} diff --git a/Tests/Alchemy/Application/ApplicationTLSTests.swift b/Tests/Alchemy/Application/ApplicationTLSTests.swift new file mode 100644 index 00000000..167e897c --- /dev/null +++ b/Tests/Alchemy/Application/ApplicationTLSTests.swift @@ -0,0 +1,7 @@ +import AlchemyTest + +final class ApplicationTLSTests: TestCase { + func testConfigureTLS() throws { + throw XCTSkip() + } +} diff --git a/Tests/Alchemy/Auth/BasicAuthableTests.swift b/Tests/Alchemy/Auth/BasicAuthableTests.swift new file mode 100644 index 00000000..7c38220e --- /dev/null +++ b/Tests/Alchemy/Auth/BasicAuthableTests.swift @@ -0,0 +1,27 @@ +import AlchemyTest + +final class BasicAuthableTests: TestCase { + func testBasicAuthable() async throws { + Database.fake(migrations: [AuthModel.Migrate()]) + + app.use(AuthModel.basicAuthMiddleware()) + app.get("/user") { try $0.get(AuthModel.self) } + + try await AuthModel(email: "test@withapollo.com", password: Bcrypt.hash("password")).insert() + + try await Test.get("/user") + .assertUnauthorized() + + try await Test.withBasicAuth(username: "test@withapollo.com", password: "password") + .get("/user") + .assertOk() + + try await Test.withBasicAuth(username: "test@withapollo.com", password: "foo") + .get("/user") + .assertUnauthorized() + + try await Test.withBasicAuth(username: "josh@withapollo.com", password: "password") + .get("/user") + .assertUnauthorized() + } +} diff --git a/Tests/Alchemy/Auth/Fixtures/AuthableModel.swift b/Tests/Alchemy/Auth/Fixtures/AuthableModel.swift new file mode 100644 index 00000000..9c265ad9 --- /dev/null +++ b/Tests/Alchemy/Auth/Fixtures/AuthableModel.swift @@ -0,0 +1,53 @@ +import Alchemy + +struct AuthModel: BasicAuthable { + var id: Int? + let email: String + let password: String + + struct Migrate: Migration { + func up(schema: Schema) { + schema.create(table: AuthModel.tableName) { + $0.increments("id") + .primary() + $0.string("email") + .notNull() + .unique() + $0.string("password") + .notNull() + } + } + + func down(schema: Schema) { + schema.drop(table: AuthModel.tableName) + } + } +} + +struct TokenModel: Model, TokenAuthable { + static var userKey = \TokenModel.$authModel + + var id: Int? + var value = UUID() + + @BelongsTo + var authModel: AuthModel + + struct Migrate: Migration { + func up(schema: Schema) { + schema.create(table: TokenModel.tableName) { + $0.increments("id") + .primary() + $0.uuid("value") + .notNull() + $0.bigInt("auth_model_id") + .notNull() + .references("id", on: "auth_models") + } + } + + func down(schema: Schema) { + schema.drop(table: TokenModel.tableName) + } + } +} diff --git a/Tests/Alchemy/Auth/TokenAuthableTests.swift b/Tests/Alchemy/Auth/TokenAuthableTests.swift new file mode 100644 index 00000000..486d06a6 --- /dev/null +++ b/Tests/Alchemy/Auth/TokenAuthableTests.swift @@ -0,0 +1,28 @@ +import AlchemyTest + +final class TokenAuthableTests: TestCase { + func testTokenAuthable() async throws { + Database.fake(migrations: [AuthModel.Migrate(), TokenModel.Migrate()]) + + app.use(TokenModel.tokenAuthMiddleware()) + app.get("/user") { req -> UUID in + _ = try req.get(AuthModel.self) + return try req.get(TokenModel.self).value + } + + let auth = try await AuthModel(email: "test@withapollo.com", password: Bcrypt.hash("password")).insertReturn() + let token = try await TokenModel(authModel: auth).insertReturn() + + try await Test.get("/user") + .assertUnauthorized() + + try await Test.withToken(token.value.uuidString) + .get("/user") + .assertOk() + .assertJson(token.value) + + try await Test.withToken(UUID().uuidString) + .get("/user") + .assertUnauthorized() + } +} diff --git a/Tests/Alchemy/Cache/CacheTests.swift b/Tests/Alchemy/Cache/CacheTests.swift new file mode 100644 index 00000000..e20d5d38 --- /dev/null +++ b/Tests/Alchemy/Cache/CacheTests.swift @@ -0,0 +1,108 @@ +import AlchemyTest +import XCTest + +final class CacheTests: TestCase { + private lazy var allTests = [ + _testSet, + _testExpire, + _testHas, + _testRemove, + _testDelete, + _testIncrement, + _testWipe, + ] + + override func tearDownWithError() throws { + // Redis seems to throw on shutdown if it could never connect in the + // first place. While this shouldn't be necessary, it is a stopgap + // for throwing an error when shutting down unconnected redis. + try? app.stop() + } + + func testConfig() { + let config = Cache.Config(caches: [.default: .memory, 1: .memory, 2: .memory]) + Cache.configure(with: config) + XCTAssertNotNil(Container.resolve(Cache.self, identifier: Cache.Identifier.default)) + XCTAssertNotNil(Container.resolve(Cache.self, identifier: 1)) + XCTAssertNotNil(Container.resolve(Cache.self, identifier: 2)) + } + + func testDatabaseCache() async throws { + for test in allTests { + Database.fake(migrations: [Cache.AddCacheMigration()]) + Cache.bind(.database) + try await test() + } + } + + func testMemoryCache() async throws { + for test in allTests { + Cache.fake() + try await test() + } + } + + func testRedisCache() async throws { + for test in allTests { + RedisClient.bind(.testing) + Cache.bind(.redis) + + guard await Redis.checkAvailable() else { + throw XCTSkip() + } + + try await test() + try await Stash.wipe() + } + } + + private func _testSet() async throws { + AssertNil(try await Stash.get("foo", as: String.self)) + try await Stash.set("foo", value: "bar") + AssertEqual(try await Stash.get("foo"), "bar") + try await Stash.set("foo", value: "baz") + AssertEqual(try await Stash.get("foo"), "baz") + } + + private func _testExpire() async throws { + AssertNil(try await Stash.get("foo", as: String.self)) + try await Stash.set("foo", value: "bar", for: .zero) + AssertNil(try await Stash.get("foo", as: String.self)) + } + + private func _testHas() async throws { + AssertFalse(try await Stash.has("foo")) + try await Stash.set("foo", value: "bar") + AssertTrue(try await Stash.has("foo")) + } + + private func _testRemove() async throws { + try await Stash.set("foo", value: "bar") + AssertEqual(try await Stash.remove("foo"), "bar") + AssertFalse(try await Stash.has("foo")) + AssertEqual(try await Stash.remove("foo", as: String.self), nil) + } + + private func _testDelete() async throws { + try await Stash.set("foo", value: "bar") + try await Stash.delete("foo") + AssertFalse(try await Stash.has("foo")) + } + + private func _testIncrement() async throws { + AssertEqual(try await Stash.increment("foo"), 1) + AssertEqual(try await Stash.increment("foo", by: 10), 11) + AssertEqual(try await Stash.decrement("foo"), 10) + AssertEqual(try await Stash.decrement("foo", by: 19), -9) + } + + private func _testWipe() async throws { + try await Stash.set("foo", value: 1) + try await Stash.set("bar", value: 2) + try await Stash.set("baz", value: 3) + try await Stash.wipe() + AssertNil(try await Stash.get("foo", as: String.self)) + AssertNil(try await Stash.get("bar", as: String.self)) + AssertNil(try await Stash.get("baz", as: String.self)) + } +} diff --git a/Tests/Alchemy/Client/ClientErrorTests.swift b/Tests/Alchemy/Client/ClientErrorTests.swift new file mode 100644 index 00000000..263e6a53 --- /dev/null +++ b/Tests/Alchemy/Client/ClientErrorTests.swift @@ -0,0 +1,31 @@ +@testable +import Alchemy +import AlchemyTest +import AsyncHTTPClient + +final class ClientErrorTests: TestCase { + func testClientError() async throws { + let request = Client.Request(url: "http://localhost/foo", method: .POST, headers: ["foo": "bar"], body: .string("foo")) + let response = Client.Response(request: request, host: "alchemy", status: .conflict, version: .http1_1, headers: ["foo": "bar"], body: .string("bar")) + + let error = ClientError(message: "foo", request: request, response: response) + AssertEqual(error.description, """ + *** HTTP Client Error *** + foo + + *** Request *** + URL: POST http://localhost/foo + Headers: [ + foo + ] + Body: <3 bytes> + + *** Response *** + Status: 409 Conflict + Headers: [ + foo + ] + Body: <3 bytes> + """) + } +} diff --git a/Tests/Alchemy/Client/ClientResponseTests.swift b/Tests/Alchemy/Client/ClientResponseTests.swift new file mode 100644 index 00000000..ea66b1b2 --- /dev/null +++ b/Tests/Alchemy/Client/ClientResponseTests.swift @@ -0,0 +1,47 @@ +@testable +import Alchemy +import AlchemyTest +import AsyncHTTPClient + +final class ClientResponseTests: XCTestCase { + func testStatusCodes() { + XCTAssertTrue(Client.Response(.ok).isOk) + XCTAssertTrue(Client.Response(.created).isSuccessful) + XCTAssertTrue(Client.Response(.badRequest).isClientError) + XCTAssertTrue(Client.Response(.badGateway).isServerError) + XCTAssertTrue(Client.Response(.internalServerError).isFailed) + XCTAssertThrowsError(try Client.Response(.internalServerError).validateSuccessful()) + XCTAssertNoThrow(try Client.Response(.ok).validateSuccessful()) + } + + func testHeaders() { + let headers: HTTPHeaders = ["foo":"bar"] + XCTAssertEqual(Client.Response(headers: headers).headers, headers) + XCTAssertEqual(Client.Response(headers: headers).header("foo"), "bar") + XCTAssertEqual(Client.Response(headers: headers).header("baz"), nil) + } + + func testBody() { + struct SampleJson: Codable, Equatable { + var foo: String = "bar" + } + + let jsonString = """ + {"foo":"bar"} + """ + let jsonData = jsonString.data(using: .utf8) ?? Data() + let body = ByteContent.string(jsonString) + XCTAssertEqual(Client.Response(body: body).body?.buffer, body.buffer) + XCTAssertEqual(Client.Response(body: body).data, jsonData) + XCTAssertEqual(Client.Response(body: body).string, jsonString) + XCTAssertEqual(try Client.Response(body: body).decode(), SampleJson()) + XCTAssertThrowsError(try Client.Response().decode(SampleJson.self)) + XCTAssertThrowsError(try Client.Response(body: body).decode(String.self)) + } +} + +extension Client.Response { + fileprivate init(_ status: HTTPResponseStatus = .ok, headers: HTTPHeaders = [:], body: ByteContent? = nil) { + self.init(request: Client.Request(url: ""), host: "https://example.com", status: status, version: .http1_1, headers: headers, body: body) + } +} diff --git a/Tests/Alchemy/Client/ClientTests.swift b/Tests/Alchemy/Client/ClientTests.swift new file mode 100644 index 00000000..8a66861a --- /dev/null +++ b/Tests/Alchemy/Client/ClientTests.swift @@ -0,0 +1,24 @@ +@testable +import Alchemy +import AlchemyTest + +final class ClientTests: TestCase { + func testQueries() async throws { + Http.stub([ + "localhost/foo": .stub(.unauthorized), + "localhost/*": .stub(.ok), + "*": .stub(.ok), + ]) + try await Http.withQueries(["foo":"bar"]).get("https://localhost/baz") + .assertOk() + + try await Http.withQueries(["bar":"2"]).get("https://localhost/foo?baz=1") + .assertUnauthorized() + + try await Http.get("https://example.com") + .assertOk() + + Http.assertSent { $0.hasQuery("foo", value: "bar") } + Http.assertSent { $0.hasQuery("bar", value: 2) && $0.hasQuery("baz", value: 1) } + } +} diff --git a/Tests/Alchemy/Commands/CommandTests.swift b/Tests/Alchemy/Commands/CommandTests.swift new file mode 100644 index 00000000..73556652 --- /dev/null +++ b/Tests/Alchemy/Commands/CommandTests.swift @@ -0,0 +1,25 @@ +import AlchemyTest + +final class CommandTests: TestCase { + func testCommandRuns() async throws { + struct TestCommand: Command { + static var action: (() async -> Void)? = nil + + func start() async throws { + await TestCommand.action?() + } + } + + let expect = Expect() + TestCommand.action = { + await expect.signalOne() + } + + try TestCommand().run() + + @Inject var lifecycle: ServiceLifecycle + try lifecycle.startAndWait() + + AssertTrue(await expect.one) + } +} diff --git a/Tests/Alchemy/Commands/LaunchTests.swift b/Tests/Alchemy/Commands/LaunchTests.swift new file mode 100644 index 00000000..526e61c9 --- /dev/null +++ b/Tests/Alchemy/Commands/LaunchTests.swift @@ -0,0 +1,12 @@ +@testable +import Alchemy +import AlchemyTest + +final class LaunchTests: TestCase { + func testLaunch() async throws { + let fileName = UUID().uuidString + Launch.main(["make:job", fileName]) + try app.lifecycle.startAndWait() + XCTAssertTrue(FileCreator.shared.fileExists(at: "Jobs/\(fileName).swift")) + } +} diff --git a/Tests/Alchemy/Commands/Make/MakeCommandTests.swift b/Tests/Alchemy/Commands/Make/MakeCommandTests.swift new file mode 100644 index 00000000..d67de958 --- /dev/null +++ b/Tests/Alchemy/Commands/Make/MakeCommandTests.swift @@ -0,0 +1,70 @@ +@testable +import Alchemy +import AlchemyTest + +final class MakeCommandTests: TestCase { + var fileName: String = UUID().uuidString + + override func setUp() { + super.setUp() + fileName = UUID().uuidString + } + + func testColumnData() { + XCTAssertThrowsError(try ColumnData(from: "foo")) + XCTAssertThrowsError(try ColumnData(from: "foo:bar")) + XCTAssertEqual(try ColumnData(from: "foo:string:primary"), ColumnData(name: "foo", type: "string", modifiers: ["primary"])) + XCTAssertEqual(try ColumnData(from: "foo:bigint"), ColumnData(name: "foo", type: "bigInt", modifiers: [])) + } + + func testMakeController() throws { + try MakeController(name: fileName).start() + XCTAssertTrue(FileCreator.shared.fileExists(at: "Controllers/\(fileName).swift")) + + try MakeController(model: fileName).start() + XCTAssertTrue(FileCreator.shared.fileExists(at: "Controllers/\(fileName)Controller.swift")) + } + + func testMakeJob() throws { + try MakeJob(name: fileName).start() + XCTAssertTrue(FileCreator.shared.fileExists(at: "Jobs/\(fileName).swift")) + } + + func testMakeMiddleware() throws { + try MakeMiddleware(name: fileName).start() + XCTAssertTrue(FileCreator.shared.fileExists(at: "Middleware/\(fileName).swift")) + } + + func testMakeMigration() throws { + try MakeMigration(name: fileName, table: "users", columns: .testData).start() + XCTAssertTrue(FileCreator.shared.fileExists(at: "Database/Migrations/\(fileName).swift")) + XCTAssertThrowsError(try MakeMigration(name: fileName + ":", table: "users", columns: .testData).start()) + } + + func testMakeModel() throws { + try MakeModel(name: fileName, columns: .testData, migration: true, controller: true).start() + XCTAssertTrue(FileCreator.shared.fileExists(at: "Models/\(fileName).swift")) + XCTAssertTrue(FileCreator.shared.fileExists(at: "Database/Migrations/Create\(fileName)s.swift")) + XCTAssertTrue(FileCreator.shared.fileExists(at: "Controllers/\(fileName)Controller.swift")) + XCTAssertThrowsError(try MakeModel(name: fileName + ":").start()) + } + + func testMakeView() throws { + try MakeView(name: fileName).start() + XCTAssertTrue(FileCreator.shared.fileExists(at: "Views/\(fileName).swift")) + } +} + +extension Array where Element == ColumnData { + static let testData: [ColumnData] = [ + ColumnData(name: "id", type: "increments", modifiers: ["primary"]), + ColumnData(name: "email", type: "string", modifiers: ["notNull", "unique"]), + ColumnData(name: "password", type: "string", modifiers: ["notNull"]), + ColumnData(name: "parent_id", type: "bigint", modifiers: ["references.users.id"]), + ColumnData(name: "uuid", type: "uuid", modifiers: ["notNull"]), + ColumnData(name: "double", type: "double", modifiers: ["notNull"]), + ColumnData(name: "bool", type: "bool", modifiers: ["notNull"]), + ColumnData(name: "date", type: "date", modifiers: ["notNull"]), + ColumnData(name: "json", type: "json", modifiers: ["notNull"]), + ] +} diff --git a/Tests/Alchemy/Commands/Migrate/RunMigrateTests.swift b/Tests/Alchemy/Commands/Migrate/RunMigrateTests.swift new file mode 100644 index 00000000..e4654439 --- /dev/null +++ b/Tests/Alchemy/Commands/Migrate/RunMigrateTests.swift @@ -0,0 +1,28 @@ +@testable +import Alchemy +import AlchemyTest + +final class RunMigrateTests: TestCase { + func testRun() async throws { + let db = Database.fake() + db.migrations = [MigrationA()] + XCTAssertFalse(MigrationA.didUp) + XCTAssertFalse(MigrationA.didDown) + + try await RunMigrate(rollback: false).start() + XCTAssertTrue(MigrationA.didUp) + XCTAssertFalse(MigrationA.didDown) + + try app.start("migrate", "--rollback") + app.wait() + + XCTAssertTrue(MigrationA.didDown) + } +} + +private struct MigrationA: Migration { + static var didUp: Bool = false + static var didDown: Bool = false + func up(schema: Schema) { MigrationA.didUp = true } + func down(schema: Schema) { MigrationA.didDown = true } +} diff --git a/Tests/Alchemy/Commands/Queue/RunWorkerTests.swift b/Tests/Alchemy/Commands/Queue/RunWorkerTests.swift new file mode 100644 index 00000000..c958113a --- /dev/null +++ b/Tests/Alchemy/Commands/Queue/RunWorkerTests.swift @@ -0,0 +1,44 @@ +@testable +import Alchemy +import AlchemyTest + +final class RunWorkerTests: TestCase { + override func setUp() { + super.setUp() + Queue.fake() + } + + func testRun() throws { + let exp = expectation(description: "") + + try RunWorker(name: nil, workers: 5, schedule: false).run() + app.lifecycle.start { _ in + XCTAssertEqual(Q.workers.count, 5) + XCTAssertFalse(self.app.scheduler.isStarted) + exp.fulfill() + } + + waitForExpectations(timeout: kMinTimeout) + } + + func testRunName() throws { + let exp = expectation(description: "") + Queue.fake("a") + try RunWorker(name: "a", workers: 5, schedule: false).run() + + app.lifecycle.start { _ in + XCTAssertEqual(Q.workers.count, 0) + XCTAssertEqual(Q("a").workers.count, 5) + XCTAssertFalse(self.app.scheduler.isStarted) + exp.fulfill() + } + + waitForExpectations(timeout: kMinTimeout) + } + + func testRunCLI() async throws { + try app.start("worker", "--workers", "3", "--schedule") + XCTAssertEqual(Q.workers.count, 3) + XCTAssertTrue(app.scheduler.isStarted) + } +} diff --git a/Tests/Alchemy/Commands/Seed/SeedDatabaseTests.swift b/Tests/Alchemy/Commands/Seed/SeedDatabaseTests.swift new file mode 100644 index 00000000..8df52703 --- /dev/null +++ b/Tests/Alchemy/Commands/Seed/SeedDatabaseTests.swift @@ -0,0 +1,45 @@ +@testable +import Alchemy +import AlchemyTest + +final class SeedDatabaseTests: TestCase { + func testSeed() async throws { + let db = Database.fake(migrations: [SeedModel.Migrate()]) + db.seeders = [Seeder1(), Seeder2()] + try SeedDatabase(database: nil).run() + try app.lifecycle.startAndWait() + XCTAssertTrue(Seeder1.didRun) + XCTAssertTrue(Seeder2.didRun) + } + + func testNamedSeed() async throws { + let db = Database.fake("a", migrations: [SeedModel.Migrate()]) + db.seeders = [Seeder3(), Seeder4()] + + try app.start("db:seed", "seeder3", "--database", "a") + app.wait() + + XCTAssertTrue(Seeder3.didRun) + XCTAssertFalse(Seeder4.didRun) + } +} + +private struct Seeder1: Seeder { + static var didRun: Bool = false + func run() async throws { Seeder1.didRun = true } +} + +private struct Seeder2: Seeder { + static var didRun: Bool = false + func run() async throws { Seeder2.didRun = true } +} + +private struct Seeder3: Seeder { + static var didRun: Bool = false + func run() async throws { Seeder3.didRun = true } +} + +private struct Seeder4: Seeder { + static var didRun: Bool = false + func run() async throws { Seeder4.didRun = true } +} diff --git a/Tests/Alchemy/Commands/Serve/RunServeTests.swift b/Tests/Alchemy/Commands/Serve/RunServeTests.swift new file mode 100644 index 00000000..b8265a57 --- /dev/null +++ b/Tests/Alchemy/Commands/Serve/RunServeTests.swift @@ -0,0 +1,37 @@ +@testable +import Alchemy +import AlchemyTest + +final class RunServeTests: TestCase { + override func setUp() { + super.setUp() + Database.fake() + Queue.fake() + } + + func testServe() async throws { + app.get("/foo", use: { _ in "hello" }) + try RunServe(host: "127.0.0.1", port: 1234).run() + app.lifecycle.start { _ in } + + try await Http.get("http://127.0.0.1:1234/foo") + .assertBody("hello") + + XCTAssertEqual(Q.workers.count, 0) + XCTAssertFalse(app.scheduler.isStarted) + XCTAssertFalse(DB.didRunMigrations) + } + + func testServeWithSideEffects() async throws { + app.get("/foo", use: { _ in "hello" }) + try RunServe(host: "127.0.0.1", port: 1234, workers: 2, schedule: true, migrate: true).run() + app.lifecycle.start { _ in } + + try await Http.get("http://127.0.0.1:1234/foo") + .assertBody("hello") + + XCTAssertEqual(Q.workers.count, 2) + XCTAssertTrue(app.scheduler.isStarted) + XCTAssertTrue(DB.didRunMigrations) + } +} diff --git a/Tests/Alchemy/Config/ConfigurableTests.swift b/Tests/Alchemy/Config/ConfigurableTests.swift new file mode 100644 index 00000000..0cf1111c --- /dev/null +++ b/Tests/Alchemy/Config/ConfigurableTests.swift @@ -0,0 +1,9 @@ +import AlchemyTest + +final class ConfigurableTests: XCTestCase { + func testDefaults() { + XCTAssertEqual(TestService.foo, "bar") + TestService.configureDefaults() + XCTAssertEqual(TestService.foo, "baz") + } +} diff --git a/Tests/Alchemy/Config/Fixtures/TestService.swift b/Tests/Alchemy/Config/Fixtures/TestService.swift new file mode 100644 index 00000000..35f85d19 --- /dev/null +++ b/Tests/Alchemy/Config/Fixtures/TestService.swift @@ -0,0 +1,25 @@ +import Alchemy + +struct TestService: Service, Configurable { + public struct Identifier: ServiceIdentifier { + private let hashable: AnyHashable + public init(hashable: AnyHashable) { self.hashable = hashable } + } + + struct Config { + let foo: String + } + + static var config = Config(foo: "baz") + static var foo: String = "bar" + + let bar: String + + static func configure(with config: Config) { + foo = config.foo + } +} + +extension TestService.Identifier { + static var foo: Self { "foo" } +} diff --git a/Tests/Alchemy/Config/ServiceIdentifierTests.swift b/Tests/Alchemy/Config/ServiceIdentifierTests.swift new file mode 100644 index 00000000..0d1d8933 --- /dev/null +++ b/Tests/Alchemy/Config/ServiceIdentifierTests.swift @@ -0,0 +1,19 @@ +import AlchemyTest + +final class ServiceIdentifierTests: XCTestCase { + func testServiceIdentifier() { + struct TestIdentifier: ServiceIdentifier { + private let hashable: AnyHashable + init(hashable: AnyHashable) { self.hashable = hashable } + } + + let intId: TestIdentifier = 1 + let stringId: TestIdentifier = "one" + let nilId: TestIdentifier = .init(hashable: AnyHashable(nil as AnyHashable?)) + + XCTAssertNotEqual(intId, .default) + XCTAssertNotEqual(stringId, .default) + XCTAssertEqual(nilId, .default) + XCTAssertEqual(1.hashValue, TestIdentifier(hashable: 1).hashValue) + } +} diff --git a/Tests/Alchemy/Config/ServiceTests.swift b/Tests/Alchemy/Config/ServiceTests.swift new file mode 100644 index 00000000..f58b5848 --- /dev/null +++ b/Tests/Alchemy/Config/ServiceTests.swift @@ -0,0 +1,14 @@ +import AlchemyTest + +final class ServiceTests: TestCase { + func testAlchemyInject() { + TestService.bind(TestService(bar: "one")) + TestService.bind(.foo, TestService(bar: "two")) + + @Inject var one: TestService + @Inject(.foo) var two: TestService + + XCTAssertEqual(one.bar, "one") + XCTAssertEqual(two.bar, "two") + } +} diff --git a/Tests/Alchemy/Env/EnvTests.swift b/Tests/Alchemy/Env/EnvTests.swift new file mode 100644 index 00000000..5a736806 --- /dev/null +++ b/Tests/Alchemy/Env/EnvTests.swift @@ -0,0 +1,75 @@ +@testable +import Alchemy +import AlchemyTest + +final class EnvTests: TestCase { + private let sampleEnvFile = """ + #TEST=ignore + FOO=1 + BAR=two + + BAZ= + fake + QUOTES="three" + """ + + func testIsRunningTests() { + XCTAssertTrue(Env.isTest) + } + + func testEnvLookup() { + let env = Env(name: "test", dotEnvVariables: ["foo": "bar"]) + XCTAssertEqual(env.get("foo"), "bar") + } + + func testStaticLookup() { + Env.current = Env(name: "test", dotEnvVariables: [ + "foo": "one", + "bar": "two", + ]) + XCTAssertEqual(Env.get("foo"), "one") + XCTAssertEqual(Env.bar, "two") + let wrongCase: String? = Env.BAR + XCTAssertEqual(wrongCase, nil) + } + + func testEnvNameProcess() { + Env.boot(processEnv: ["APP_ENV": "foo"]) + XCTAssertEqual(Env.current.name, "foo") + } + + func testEnvNameArgs() { + Env.boot(args: ["-e", "foo"]) + XCTAssertEqual(Env.current.name, "foo") + Env.boot(args: ["--env", "bar"]) + XCTAssertEqual(Env.current.name, "bar") + Env.boot(args: ["--env", "baz"], processEnv: ["APP_ENV": "test"]) + XCTAssertEqual(Env.current.name, "baz") + } + + func testEnvArgsPrecedence() { + Env.boot(args: ["--env", "baz"], processEnv: ["APP_ENV": "test"]) + XCTAssertEqual(Env.current.name, "baz") + } + + func testLoadEnvFile() { + let path = createTempFile(".env-fake-\(UUID().uuidString)", contents: sampleEnvFile) + Env.loadDotEnv(path) + XCTAssertEqual(Env.FOO, "1") + XCTAssertEqual(Env.BAR, "two") + XCTAssertEqual(Env.get("TEST", as: String.self), nil) + XCTAssertEqual(Env.get("fake", as: String.self), nil) + XCTAssertEqual(Env.get("BAZ", as: String.self), nil) + XCTAssertEqual(Env.QUOTES, "three") + } + + func testProcessPrecedence() { + let path = createTempFile(".env-fake-\(UUID().uuidString)", contents: sampleEnvFile) + Env.boot(args: ["-e", path], processEnv: ["FOO": "2"]) + XCTAssertEqual(Env.FOO, "2") + } + + func testWarnDerivedData() { + Env.warnIfUsingDerivedData("/Xcode/DerivedData") + } +} diff --git a/Tests/Alchemy/Filesystem/FileTests.swift b/Tests/Alchemy/Filesystem/FileTests.swift new file mode 100644 index 00000000..1978b038 --- /dev/null +++ b/Tests/Alchemy/Filesystem/FileTests.swift @@ -0,0 +1,17 @@ +@testable +import Alchemy +import AlchemyTest + +final class FileTests: XCTestCase { + func testFile() { + let file = File(name: "foo.html", size: 10, content: .buffer("

foo

")) + XCTAssertEqual(file.extension, "html") + XCTAssertEqual(file.size, 10) + XCTAssertEqual(file.contentType, .html) + } + + func testInvalidURL() { + let file = File(name: "", size: 3, content: .buffer("foo")) + XCTAssertEqual(file.extension, "") + } +} diff --git a/Tests/Alchemy/Filesystem/FilesystemTests.swift b/Tests/Alchemy/Filesystem/FilesystemTests.swift new file mode 100644 index 00000000..a118361e --- /dev/null +++ b/Tests/Alchemy/Filesystem/FilesystemTests.swift @@ -0,0 +1,89 @@ +@testable +import Alchemy +import AlchemyTest + +final class FilesystemTests: TestCase { + private var filePath: String = "" + + private lazy var allTests = [ + _testCreate, + _testDelete, + _testPut, + _testPathing, + _testFileStore, + _testInvalidURL, + ] + + func testConfig() { + let config = Filesystem.Config(disks: [.default: .local, 1: .local, 2: .local]) + Filesystem.configure(with: config) + XCTAssertNotNil(Container.resolve(Filesystem.self, identifier: Filesystem.Identifier.default)) + XCTAssertNotNil(Container.resolve(Filesystem.self, identifier: 1)) + XCTAssertNotNil(Container.resolve(Filesystem.self, identifier: 2)) + } + + func testLocal() async throws { + let root = NSTemporaryDirectory() + UUID().uuidString + Filesystem.bind(.local(root: root)) + XCTAssertEqual(root, Storage.root) + for test in allTests { + filePath = UUID().uuidString + ".txt" + try await test() + } + } + + func _testCreate() async throws { + AssertFalse(try await Storage.exists(filePath)) + do { + _ = try await Storage.get(filePath) + XCTFail("Should throw an error") + } catch {} + try await Storage.create(filePath, content: "1;2;3") + AssertTrue(try await Storage.exists(filePath)) + let file = try await Storage.get(filePath) + AssertEqual(file.name, filePath) + AssertEqual(try await file.content.collect(), "1;2;3") + } + + func _testDelete() async throws { + do { + try await Storage.delete(filePath) + XCTFail("Should throw an error") + } catch {} + try await Storage.create(filePath, content: "123") + try await Storage.delete(filePath) + AssertFalse(try await Storage.exists(filePath)) + } + + func _testPut() async throws { + let file = File(name: filePath, size: 3, content: "foo") + try await Storage.put(file) + AssertTrue(try await Storage.exists(filePath)) + try await Storage.put(file, in: "foo/bar") + AssertTrue(try await Storage.exists("foo/bar/\(filePath)")) + } + + func _testPathing() async throws { + try await Storage.create("foo/bar/baz/\(filePath)", content: "foo") + AssertFalse(try await Storage.exists(filePath)) + AssertTrue(try await Storage.exists("foo/bar/baz/\(filePath)")) + let file = try await Storage.get("foo/bar/baz/\(filePath)") + AssertEqual(file.name, filePath) + AssertEqual(try await file.content.collect(), "foo") + try await Storage.delete("foo/bar/baz/\(filePath)") + AssertFalse(try await Storage.exists("foo/bar/baz/\(filePath)")) + } + + func _testFileStore() async throws { + try await File(name: filePath, size: 3, content: "bar").store() + AssertTrue(try await Storage.exists(filePath)) + } + + func _testInvalidURL() async throws { + do { + let store: Filesystem = .local(root: "\\") + _ = try await store.exists("foo") + XCTFail("Should throw an error") + } catch {} + } +} diff --git a/Tests/Alchemy/Fixtures.swift b/Tests/Alchemy/Fixtures.swift new file mode 100644 index 00000000..bda5d4cb --- /dev/null +++ b/Tests/Alchemy/Fixtures.swift @@ -0,0 +1,11 @@ +// Used for verifying expectations (XCTExpectation isn't as needed since things are async now). +actor Expect { + var one = false, two = false, three = false, four = false, five = false, six = false + + func signalOne() async { one = true } + func signalTwo() async { two = true } + func signalThree() async { three = true } + func signalFour() async { four = true } + func signalFive() async { five = true } + func signalSix() async { six = true } +} diff --git a/Tests/Alchemy/HTTP/Content/ContentTests.swift b/Tests/Alchemy/HTTP/Content/ContentTests.swift new file mode 100644 index 00000000..18eee545 --- /dev/null +++ b/Tests/Alchemy/HTTP/Content/ContentTests.swift @@ -0,0 +1,258 @@ +@testable +import Alchemy +import AlchemyTest +import HummingbirdFoundation +import MultipartKit + +final class ContentTests: XCTestCase { + private lazy var allTests = [ + _testAccess, + _testNestedAccess, + _testEnumAccess, + _testFlatten, + _testDecode, + ] + + func testDict() throws { + let content = Content(root: .any(Fixtures.dictContent)) + for test in allTests { + try test(content, true) + } + try _testNestedArray(content: content) + try _testNestedDecode(content: content) + } + + func testMultipart() throws { + let buffer = ByteBuffer(string: Fixtures.multipartContent) + let content = FormDataDecoder().content(from: buffer, contentType: .multipart(boundary: Fixtures.multipartBoundary)) + try _testAccess(content: content, allowsNull: false) + try _testMultipart(content: content) + } + + func testJson() throws { + let buffer = ByteBuffer(string: Fixtures.jsonContent) + let content = JSONDecoder().content(from: buffer, contentType: .json) + for test in allTests { + try test(content, true) + } + try _testNestedArray(content: content) + try _testNestedDecode(content: content) + } + + func testUrl() throws { + let buffer = ByteBuffer(string: Fixtures.urlContent) + let content = URLEncodedFormDecoder().content(from: buffer, contentType: .urlForm) + for test in allTests { + try test(content, false) + } + try _testNestedDecode(content: content) + } + + func _testAccess(content: Content, allowsNull: Bool) throws { + AssertTrue(content["foo"] == nil) + AssertEqual(try content["string"].stringThrowing, "string") + AssertEqual(try content["string"].decode(String.self), "string") + AssertEqual(try content["int"].intThrowing, 0) + AssertEqual(try content["bool"].boolThrowing, true) + AssertEqual(try content["double"].doubleThrowing, 1.23) + } + + func _testNestedAccess(content: Content, allowsNull: Bool) throws { + AssertTrue(content.object.four.isNull) + XCTAssertThrowsError(try content["array"].stringThrowing) + AssertEqual(try content["array"].arrayThrowing.count, 3) + XCTAssertThrowsError(try content["array"][0].arrayThrowing) + AssertEqual(try content["array"][0].intThrowing, 1) + AssertEqual(try content["array"][1].intThrowing, 2) + AssertEqual(try content["array"][2].intThrowing, 3) + AssertEqual(try content["object"]["one"].stringThrowing, "one") + AssertEqual(try content["object"]["two"].stringThrowing, "two") + AssertEqual(try content["object"]["three"].stringThrowing, "three") + } + + func _testEnumAccess(content: Content, allowsNull: Bool) throws { + enum Test: String, Decodable { + case one, two, three + } + + var expectedDict: [String: Test?] = ["one": .one, "two": .two, "three": .three] + if allowsNull { expectedDict = ["one": .one, "two": .two, "three": .three, "four": nil] } + + AssertEqual(try content.object.one.decode(Test?.self), .one) + AssertEqual(try content.object.decode([String: Test?].self), expectedDict) + } + + func _testMultipart(content: Content) throws { + let file = try content["file"].fileThrowing + AssertEqual(file.name, "a.txt") + AssertEqual(file.content.buffer.string, "Content of a.txt.\n") + } + + func _testFlatten(content: Content, allowsNull: Bool) throws { + var expectedArray: [String?] = ["one", "three", "two"] + if allowsNull { expectedArray.append(nil) } + AssertEqual(try content["object"][*].decode(Optional.self).sorted(), expectedArray) + } + + func _testDecode(content: Content, allowsNull: Bool) throws { + struct TopLevelType: Codable, Equatable { + var string: String = "string" + var int: Int = 0 + var bool: Bool = true + var double: Double = 1.23 + } + + AssertEqual(try content.decode(TopLevelType.self), TopLevelType()) + } + + func _testNestedDecode(content: Content) throws { + struct NestedType: Codable, Equatable { + let one: String + let two: String + let three: String + } + + let expectedStruct = NestedType(one: "one", two: "two", three: "three") + AssertEqual(try content["object"].decode(NestedType.self), expectedStruct) + AssertEqual(try content["array"].decode([Int].self), [1, 2, 3]) + AssertEqual(try content["array"].decode([Int8].self), [1, 2, 3]) + } + + func _test(content: Content, allowsNull: Bool) throws { + struct DecodableType: Codable, Equatable { + let one: String + let two: String + let three: String + } + + struct TopLevelType: Codable, Equatable { + var string: String = "string" + var int: Int = 0 + var bool: Bool = false + var double: Double = 1.23 + } + + let expectedStruct = DecodableType(one: "one", two: "two", three: "three") + AssertEqual(try content.decode(TopLevelType.self), TopLevelType()) + AssertEqual(try content["object"].decode(DecodableType.self), expectedStruct) + AssertEqual(try content["array"].decode([Int].self), [1, 2, 3]) + AssertEqual(try content["array"].decode([Int8].self), [1, 2, 3]) + } + + func _testNestedArray(content: Content) throws { + struct ArrayType: Codable, Equatable { + let foo: String + } + + AssertEqual(try content["objectArray"][*]["foo"].stringThrowing, ["bar", "baz", "tiz"]) + let expectedArray = [ArrayType(foo: "bar"), ArrayType(foo: "baz"), ArrayType(foo: "tiz")] + AssertEqual(try content.objectArray.decode([ArrayType].self), expectedArray) + } +} + +private struct Fixtures { + static let dictContent: [String: Any] = [ + "string": "string", + "int": 0, + "bool": true, + "double": 1.23, + "array": [ + 1, + 2, + 3 + ], + "object": [ + "one": "one", + "two": "two", + "three": "three", + "four": nil + ], + "objectArray": [ + [ + "foo": "bar" + ], + [ + "foo": "baz" + ], + [ + "foo": "tiz" + ] + ] + ] + + static let multipartBoundary = "---------------------------9051914041544843365972754266" + static let multipartContent = """ + + -----------------------------9051914041544843365972754266\r + Content-Disposition: form-data; name="string"\r + \r + string\r + -----------------------------9051914041544843365972754266\r + Content-Disposition: form-data; name="int"\r + \r + 0\r + -----------------------------9051914041544843365972754266\r + Content-Disposition: form-data; name="bool"\r + \r + true\r + -----------------------------9051914041544843365972754266\r + Content-Disposition: form-data; name="double"\r + \r + 1.23\r + -----------------------------9051914041544843365972754266\r + Content-Disposition: form-data; name="file"; filename="a.txt"\r + Content-Type: text/plain\r + \r + Content of a.txt. + \r + -----------------------------9051914041544843365972754266--\r + + """ + + static let jsonContent = """ + { + "string": "string", + "int": 0, + "bool": true, + "double": 1.23, + "array": [ + 1, + 2, + 3 + ], + "object": { + "one": "one", + "two": "two", + "three": "three", + "four": null + }, + "objectArray": [ + { + "foo": "bar" + }, + { + "foo": "baz" + }, + { + "foo": "tiz" + } + ] + } + """ + + static let urlContent = """ + string=string&int=0&bool=true&double=1.23&array[]=1&array[]=2&array[]=3&object[one]=one&object[two]=two&object[three]=three + """ +} + +extension Optional: Comparable where Wrapped == String { + public static func < (lhs: Self, rhs: Self) -> Bool { + if let lhs = lhs, let rhs = rhs { + return lhs < rhs + } else if rhs == nil { + return true + } else { + return false + } + } +} diff --git a/Tests/Alchemy/HTTP/Content/ContentTypeTests.swift b/Tests/Alchemy/HTTP/Content/ContentTypeTests.swift new file mode 100644 index 00000000..3796ba5d --- /dev/null +++ b/Tests/Alchemy/HTTP/Content/ContentTypeTests.swift @@ -0,0 +1,23 @@ +import AlchemyTest + +final class ContentTypeTests: XCTestCase { + func testFileExtension() { + XCTAssertEqual(ContentType(fileExtension: ".html"), .html) + } + + func testInvalidFileExtension() { + XCTAssertEqual(ContentType(fileExtension: ".sc2save"), nil) + } + + func testParameters() { + let type = ContentType.multipart(boundary: "foo") + XCTAssertEqual(type.value, "multipart/form-data") + XCTAssertEqual(type.string, "multipart/form-data; boundary=foo") + } + + func testEquality() { + let first = ContentType.multipart(boundary: "foo") + let second = ContentType.multipart(boundary: "bar") + XCTAssertEqual(first, second) + } +} diff --git a/Tests/Alchemy/HTTP/Content/StreamTests.swift b/Tests/Alchemy/HTTP/Content/StreamTests.swift new file mode 100644 index 00000000..8aaa3ffe --- /dev/null +++ b/Tests/Alchemy/HTTP/Content/StreamTests.swift @@ -0,0 +1,9 @@ +@testable +import Alchemy +import AlchemyTest + +final class StreamTests: TestCase { + func testUnusedDoesntCrash() throws { + _ = ByteStream(eventLoop: Loop.current) + } +} diff --git a/Tests/Alchemy/HTTP/HTTPErrorTests.swift b/Tests/Alchemy/HTTP/HTTPErrorTests.swift new file mode 100644 index 00000000..b92fc8bd --- /dev/null +++ b/Tests/Alchemy/HTTP/HTTPErrorTests.swift @@ -0,0 +1,10 @@ +import AlchemyTest + +final class HTTPErrorTests: XCTestCase { + func testConvertResponse() throws { + try HTTPError(.badGateway, message: "foo") + .response() + .assertStatus(.badGateway) + .assertJson(["message": "foo"]) + } +} diff --git a/Tests/Alchemy/HTTP/Request/ParameterTests.swift b/Tests/Alchemy/HTTP/Request/ParameterTests.swift new file mode 100644 index 00000000..ee562ceb --- /dev/null +++ b/Tests/Alchemy/HTTP/Request/ParameterTests.swift @@ -0,0 +1,20 @@ +@testable +import Alchemy +import AlchemyTest + +final class ParameterTests: XCTestCase { + func testStringConversion() { + XCTAssertEqual(Parameter(key: "foo", value: "bar").string(), "bar") + } + + func testIntConversion() throws { + XCTAssertEqual(try Parameter(key: "foo", value: "1").int(), 1) + XCTAssertThrowsError(try Parameter(key: "foo", value: "foo").int()) + } + + func testUuidConversion() throws { + let uuid = UUID() + XCTAssertEqual(try Parameter(key: "foo", value: uuid.uuidString).uuid(), uuid) + XCTAssertThrowsError(try Parameter(key: "foo", value: "foo").uuid()) + } +} diff --git a/Tests/Alchemy/HTTP/Request/RequestAssociatedValueTests.swift b/Tests/Alchemy/HTTP/Request/RequestAssociatedValueTests.swift new file mode 100644 index 00000000..0ae09a50 --- /dev/null +++ b/Tests/Alchemy/HTTP/Request/RequestAssociatedValueTests.swift @@ -0,0 +1,24 @@ +@testable +import Alchemy +import XCTest + +final class RequestAssociatedValueTests: XCTestCase { + func testValue() { + let request = Request.fixture() + request.set("foo") + XCTAssertEqual(try request.get(), "foo") + } + + func testOverwite() { + let request = Request.fixture() + request.set("foo") + request.set("bar") + XCTAssertEqual(try request.get(), "bar") + } + + func testNoValue() { + let request = Request.fixture() + request.set(1) + XCTAssertThrowsError(try request.get(String.self)) + } +} diff --git a/Tests/Alchemy/HTTP/Request/RequestAuthTests.swift b/Tests/Alchemy/HTTP/Request/RequestAuthTests.swift new file mode 100644 index 00000000..b6deb1e5 --- /dev/null +++ b/Tests/Alchemy/HTTP/Request/RequestAuthTests.swift @@ -0,0 +1,41 @@ +@testable +import Alchemy +import NIOHTTP1 +import XCTest + +final class RequestAuthTests: XCTestCase { + private let sampleBase64Credentials = Data("username:password".utf8).base64EncodedString() + private let sampleToken = UUID().uuidString + + func testNoAuth() { + XCTAssertNil(Request.fixture().basicAuth()) + XCTAssertNil(Request.fixture().bearerAuth()) + XCTAssertNil(Request.fixture().getAuth()) + } + + func testUnknownAuth() { + let request = Request.fixture(headers: ["Authorization": "Foo \(sampleToken)"]) + XCTAssertNil(request.getAuth()) + } + + func testBearerAuth() { + let request = Request.fixture(headers: ["Authorization": "Bearer \(sampleToken)"]) + XCTAssertNil(request.basicAuth()) + XCTAssertNotNil(request.bearerAuth()) + XCTAssertEqual(request.bearerAuth()?.token, sampleToken) + } + + func testBasicAuth() { + let request = Request.fixture(headers: ["Authorization": "Basic \(sampleBase64Credentials)"]) + XCTAssertNil(request.bearerAuth()) + XCTAssertNotNil(request.basicAuth()) + XCTAssertEqual(request.basicAuth(), HTTPAuth.Basic(username: "username", password: "password")) + } + + func testMalformedBasicAuth() { + let notBase64Encoded = Request.fixture(headers: ["Authorization": "Basic user:pass"]) + XCTAssertNil(notBase64Encoded.basicAuth()) + let empty = Request.fixture(headers: ["Authorization": "Basic "]) + XCTAssertNil(empty.basicAuth()) + } +} diff --git a/Tests/Alchemy/HTTP/Request/RequestUtilitiesTests.swift b/Tests/Alchemy/HTTP/Request/RequestUtilitiesTests.swift new file mode 100644 index 00000000..253f362b --- /dev/null +++ b/Tests/Alchemy/HTTP/Request/RequestUtilitiesTests.swift @@ -0,0 +1,63 @@ +@testable +import Alchemy +import XCTest + +final class RequestUtilitiesTests: XCTestCase { + func testPath() { + XCTAssertEqual(Request.fixture(uri: "/foo/bar").path, "/foo/bar") + } + + func testInvalidPath() { + XCTAssertEqual(Request.fixture(uri: "%").path, "") + } + + func testQueryItems() { + XCTAssertEqual(Request.fixture(uri: "/path").queryItems, nil) + XCTAssertEqual(Request.fixture(uri: "/path?foo=1&bar=2").queryItems, [ + URLQueryItem(name: "foo", value: "1"), + URLQueryItem(name: "bar", value: "2") + ]) + } + + func testParameter() { + let request = Request.fixture() + request.parameters = [ + Parameter(key: "foo", value: "one"), + Parameter(key: "bar", value: "two"), + Parameter(key: "baz", value: "three"), + ] + XCTAssertEqual(try request.parameter("foo"), "one") + XCTAssertEqual(try request.parameter("bar"), "two") + XCTAssertEqual(try request.parameter("baz"), "three") + XCTAssertThrowsError(try request.parameter("fake", as: String.self)) + XCTAssertThrowsError(try request.parameter("foo", as: Int.self)) + XCTAssertTrue(request.parameters.contains(Parameter(key: "foo", value: "one"))) + } + + func testBody() { + XCTAssertNil(Request.fixture(body: nil).body) + XCTAssertNotNil(Request.fixture(body: .empty).body) + } + + func testDecodeBodyJSON() { + struct ExpectedJSON: Codable, Equatable { + var foo = "bar" + } + + XCTAssertThrowsError(try Request.fixture(body: nil).decode(ExpectedJSON.self)) + XCTAssertThrowsError(try Request.fixture(body: .empty).decode(ExpectedJSON.self)) + XCTAssertEqual(try Request.fixture(body: .json).decode(), ExpectedJSON()) + } +} + +extension ByteContent { + fileprivate static var empty: ByteContent { + .buffer(ByteBuffer()) + } + + fileprivate static var json: ByteContent { + .string(""" + {"foo":"bar"} + """) + } +} diff --git a/Tests/Alchemy/HTTP/Response/ResponseTests.swift b/Tests/Alchemy/HTTP/Response/ResponseTests.swift new file mode 100644 index 00000000..fe0e1cd3 --- /dev/null +++ b/Tests/Alchemy/HTTP/Response/ResponseTests.swift @@ -0,0 +1,96 @@ +@testable +import Alchemy +import AlchemyTest +import MultipartKit + +final class ResponseTests: XCTestCase { + override class func setUp() { + super.setUp() + FormDataEncoder.boundary = { Fixtures.multipartBoundary } + } + + func testInit() throws { + Response(status: .created, headers: ["foo": "1", "bar": "2"]) + .assertHeader("foo", value: "1") + .assertHeader("bar", value: "2") + .assertHeader("Content-Length", value: "0") + .assertCreated() + } + + func testInitContentLength() { + Response(status: .ok) + .withString("foo") + .assertHeader("Content-Length", value: "3") + .assertBody("foo") + .assertOk() + } + + func testJSONEncode() throws { + let res = try Response().withValue(Fixtures.object, encoder: .json) + XCTAssertEqual(res.headers.contentType, .json) + // Linux doesn't guarantee json coding order. + XCTAssertTrue(res.body?.string() == Fixtures.jsonString || res.body?.string() == Fixtures.altJsonString) + } + + func testJSONDecode() throws { + let res = Response().withString(Fixtures.jsonString, type: .json) + XCTAssertEqual(try res.decode(), Fixtures.object) + } + + func testURLEncode() throws { + let res = try Response().withValue(Fixtures.object, encoder: .urlForm) + XCTAssertEqual(res.headers.contentType, .urlForm) + XCTAssertTrue(res.body?.string() == Fixtures.urlString || res.body?.string() == Fixtures.urlStringAlternate) + } + + func testURLDecode() throws { + let res = Response().withString(Fixtures.urlString, type: .urlForm) + XCTAssertEqual(try res.decode(), Fixtures.object) + } + + func testMultipartEncode() throws { + let res = try Response().withValue(Fixtures.object, encoder: .multipart) + XCTAssertEqual(res.headers.contentType, .multipart(boundary: Fixtures.multipartBoundary)) + XCTAssertEqual(res.body?.string(), Fixtures.multipartString) + } + + func testMultipartDecode() throws { + let res = Response().withString(Fixtures.multipartString, type: .multipart(boundary: Fixtures.multipartBoundary)) + XCTAssertEqual(try res.decode(), Fixtures.object) + } +} + +private struct Fixtures { + struct Test: Codable, Equatable { + var foo = "foo" + var bar = "bar" + } + + static let jsonString = """ + {"foo":"foo","bar":"bar"} + """ + + static let altJsonString = """ + {"bar":"bar","foo":"foo"} + """ + + static let urlString = "foo=foo&bar=bar" + static let urlStringAlternate = "bar=bar&foo=foo" + + static let multipartBoundary = "foo123" + + static let multipartString = """ + --foo123\r + Content-Disposition: form-data; name=\"foo\"\r + \r + foo\r + --foo123\r + Content-Disposition: form-data; name=\"bar\"\r + \r + bar\r + --foo123--\r + + """ + + static let object = Test() +} diff --git a/Tests/Alchemy/HTTP/StreamingTests.swift b/Tests/Alchemy/HTTP/StreamingTests.swift new file mode 100644 index 00000000..cc7ac314 --- /dev/null +++ b/Tests/Alchemy/HTTP/StreamingTests.swift @@ -0,0 +1,74 @@ +@testable +import Alchemy +import AlchemyTest +import NIOCore + +final class StreamingTests: TestCase { + + // MARK: - Client + + func testClientResponseStream() async throws { + let streamResponse: Client.Response = .stub(body: .stream { + try await $0.write("foo") + try await $0.write("bar") + try await $0.write("baz") + }) + + Http.stub(["example.com/*": streamResponse]) + + var res = try await Http.get("https://example.com/foo") + try await res.collect() + .assertOk() + .assertBody("foobarbaz") + } + + func testServerResponseStream() async throws { + app.get("/stream") { _ in + Response { + try await $0.write("foo") + try await $0.write("bar") + try await $0.write("baz") + } + } + + try await Test.get("/stream") + .collect() + .assertOk() + .assertBody("foobarbaz") + } + + func testEndToEndStream() async throws { + app.get("/stream", options: .stream) { _ in + Response { + try await $0.write("foo") + try await $0.write("bar") + try await $0.write("baz") + } + } + + try app.start() + var expected = ["foo", "bar", "baz"] + try await Http + .withStream() + .get("http://localhost:3000/stream") + .assertStream { + guard expected.first != nil else { + XCTFail("There were too many stream elements.") + return + } + + XCTAssertEqual($0.string, expected.removeFirst()) + } + .assertOk() + } + + func testFileRequest() { + app.get("/stream") { _ in + Response { + try await $0.write("foo") + try await $0.write("bar") + try await $0.write("baz") + } + } + } +} diff --git a/Tests/Alchemy/HTTP/ValidationErrorTests.swift b/Tests/Alchemy/HTTP/ValidationErrorTests.swift new file mode 100644 index 00000000..fe2e8b62 --- /dev/null +++ b/Tests/Alchemy/HTTP/ValidationErrorTests.swift @@ -0,0 +1,10 @@ +import AlchemyTest + +final class ValidationErrorTests: XCTestCase { + func testConvertResponse() throws { + try ValidationError("bar") + .response() + .assertStatus(.badRequest) + .assertJson(["validation_error": "bar"]) + } +} diff --git a/Tests/Alchemy/Middleware/Concrete/CORSMiddlewareTests.swift b/Tests/Alchemy/Middleware/Concrete/CORSMiddlewareTests.swift new file mode 100644 index 00000000..5995b2b6 --- /dev/null +++ b/Tests/Alchemy/Middleware/Concrete/CORSMiddlewareTests.swift @@ -0,0 +1,75 @@ +@testable +import Alchemy +import AlchemyTest + +final class CORSMiddlewareTests: TestCase { + func testDefault() async throws { + let cors = CORSMiddleware() + app.useAll(cors) + + try await Test.get("/hello") + .assertHeaderMissing("Access-Control-Allow-Origin") + + try await Test.withHeader("Origin", value: "https://foo.example") + .get("/hello") + .assertHeader("Access-Control-Allow-Origin", value: "https://foo.example") + .assertHeader("Access-Control-Allow-Headers", value: "Accept, Authorization, Content-Type, Origin, X-Requested-With") + .assertHeader("Access-Control-Allow-Methods", value: "GET, POST, PUT, OPTIONS, DELETE, PATCH") + .assertHeader("Access-Control-Max-Age", value: "600") + .assertHeaderMissing("Access-Control-Expose-Headers") + .assertHeaderMissing("Access-Control-Allow-Credentials") + } + + func testCustom() async throws { + let cors = CORSMiddleware(configuration: .init( + allowedOrigin: .originBased, + allowedMethods: [.GET, .POST], + allowedHeaders: ["foo", "bar"], + allowCredentials: true, + cacheExpiration: 123, + exposedHeaders: ["baz"] + )) + app.useAll(cors) + + try await Test.get("/hello") + .assertHeaderMissing("Access-Control-Allow-Origin") + + try await Test.withHeader("Origin", value: "https://foo.example") + .get("/hello") + .assertHeader("Access-Control-Allow-Origin", value: "https://foo.example") + .assertHeader("Access-Control-Allow-Headers", value: "foo, bar") + .assertHeader("Access-Control-Allow-Methods", value: "GET, POST") + .assertHeader("Access-Control-Expose-Headers", value: "baz") + .assertHeader("Access-Control-Max-Age", value: "123") + .assertHeader("Access-Control-Allow-Credentials", value: "true") + } + + func testPreflight() async throws { + let cors = CORSMiddleware() + app.useAll(cors) + + try await Test.options("/hello") + .assertHeaderMissing("Access-Control-Allow-Origin") + + try await Test.withHeader("Origin", value: "https://foo.example") + .withHeader("Access-Control-Request-Method", value: "PUT") + .options("/hello") + .assertOk() + .assertHeader("Access-Control-Allow-Origin", value: "https://foo.example") + .assertHeader("Access-Control-Allow-Headers", value: "Accept, Authorization, Content-Type, Origin, X-Requested-With") + .assertHeader("Access-Control-Allow-Methods", value: "GET, POST, PUT, OPTIONS, DELETE, PATCH") + .assertHeader("Access-Control-Max-Age", value: "600") + .assertHeaderMissing("Access-Control-Expose-Headers") + .assertHeaderMissing("Access-Control-Allow-Credentials") + } + + func testOriginSettings() { + let origin = "https://foo.example" + XCTAssertEqual(CORSMiddleware.AllowOriginSetting.none.header(forOrigin: origin), "") + XCTAssertEqual(CORSMiddleware.AllowOriginSetting.originBased.header(forOrigin: origin), origin) + XCTAssertEqual(CORSMiddleware.AllowOriginSetting.all.header(forOrigin: origin), "*") + XCTAssertEqual(CORSMiddleware.AllowOriginSetting.any([origin]).header(forOrigin: origin), origin) + XCTAssertEqual(CORSMiddleware.AllowOriginSetting.any(["foo"]).header(forOrigin: origin), "") + XCTAssertEqual(CORSMiddleware.AllowOriginSetting.custom(origin).header(forOrigin: origin), origin) + } +} diff --git a/Tests/Alchemy/Middleware/Concrete/FileMiddlewareTests.swift b/Tests/Alchemy/Middleware/Concrete/FileMiddlewareTests.swift new file mode 100644 index 00000000..a64067d4 --- /dev/null +++ b/Tests/Alchemy/Middleware/Concrete/FileMiddlewareTests.swift @@ -0,0 +1,89 @@ +@testable +import Alchemy +import AlchemyTest + +final class FileMiddlewareTests: TestCase { + var middleware: FileMiddleware! + var fileName = UUID().uuidString + + override func setUp() { + super.setUp() + middleware = FileMiddleware(from: FileCreator.shared.rootPath + "Public", extensions: ["html"]) + fileName = UUID().uuidString + } + + func testDirectorySanitize() async throws { + middleware = FileMiddleware(from: FileCreator.shared.rootPath + "Public/", extensions: ["html"]) + try FileCreator.shared.create(fileName: fileName, extension: "html", contents: "foo;bar;baz", in: "Public") + + try await middleware + .intercept(.get(fileName), next: { _ in .default }) + .collect() + .assertBody("foo;bar;baz") + + try await middleware + .intercept(.get("//////\(fileName)"), next: { _ in .default }) + .collect() + .assertBody("foo;bar;baz") + + do { + _ = try await middleware.intercept(.get("../foo"), next: { _ in .default }) + XCTFail("An error should be thrown") + } catch {} + } + + func testGetOnly() async throws { + try await middleware + .intercept(.post(fileName), next: { _ in .default }) + .assertBody("bar") + } + + func testRedirectIndex() async throws { + try FileCreator.shared.create(fileName: "index", extension: "html", contents: "foo;bar;baz", in: "Public") + try await middleware + .intercept(.get(""), next: { _ in .default }) + .collect() + .assertBody("foo;bar;baz") + } + + func testLoadingFile() async throws { + try FileCreator.shared.create(fileName: fileName, extension: "txt", contents: "foo;bar;baz", in: "Public") + + try await middleware + .intercept(.get("\(fileName).txt"), next: { _ in .default }) + .collect() + .assertBody("foo;bar;baz") + + try await middleware + .intercept(.get(fileName), next: { _ in .default }) + .assertBody("bar") + } + + func testLoadingAlternateExtension() async throws { + try FileCreator.shared.create(fileName: fileName, extension: "html", contents: "foo;bar;baz", in: "Public") + + try await middleware + .intercept(.get(fileName), next: { _ in .default }) + .collect() + .assertBody("foo;bar;baz") + + try await middleware + .intercept(.get("\(fileName).html"), next: { _ in .default }) + .collect() + .assertBody("foo;bar;baz") + } +} + +extension Request { + fileprivate static func get(_ uri: String) -> Request { + .fixture(method: .GET, uri: uri) + } + + fileprivate static func post(_ uri: String) -> Request { + .fixture(method: .POST, uri: uri) + } +} + +extension Response { + static let `default` = Response(status: .ok).withString("bar") +} diff --git a/Tests/Alchemy/Middleware/MiddlewareTests.swift b/Tests/Alchemy/Middleware/MiddlewareTests.swift new file mode 100644 index 00000000..ee0c4c83 --- /dev/null +++ b/Tests/Alchemy/Middleware/MiddlewareTests.swift @@ -0,0 +1,132 @@ +import AlchemyTest + +final class MiddlewareTests: TestCase { + func testMiddlewareCalling() async throws { + let expect = Expect() + let mw1 = TestMiddleware(req: { _ in await expect.signalOne() }) + let mw2 = TestMiddleware(req: { _ in await expect.signalTwo() }) + + app.use(mw1) + .get("/foo") { _ in } + .use(mw2) + .post("/foo") { _ in } + + _ = try await Test.get("/foo") + + AssertTrue(await expect.one) + AssertFalse(await expect.two) + } + + func testMiddlewareCalledWhenError() async throws { + let expect = Expect() + let global = TestMiddleware(res: { _ in await expect.signalOne() }) + let mw1 = TestMiddleware(res: { _ in await expect.signalTwo() }) + let mw2 = TestMiddleware(req: { _ in + struct SomeError: Error {} + await expect.signalThree() + throw SomeError() + }) + + app.useAll(global) + .use(mw1) + .use(mw2) + .get("/foo") { _ in } + + _ = try await Test.get("/foo") + + AssertTrue(await expect.one) + AssertTrue(await expect.two) + AssertTrue(await expect.three) + } + + func testGroupMiddleware() async throws { + let expect = Expect() + let mw = TestMiddleware(req: { request in + XCTAssertEqual(request.path, "/foo") + XCTAssertEqual(request.method, .POST) + await expect.signalOne() + }) + + app.group(mw) { + $0.post("/foo") { _ in 1 } + } + .get("/foo") { _ in 2 } + + try await Test.get("/foo").assertOk().assertBody("2") + try await Test.post("/foo").assertOk().assertBody("1") + AssertTrue(await expect.one) + } + + func testGroupMiddlewareRemoved() async throws { + let exp = Expect() + let mw = ActionMiddleware { await exp.signalOne() } + + app.group(mw) { + $0.get("/foo") { _ in 1 } + } + .get("/bar") { _ async -> Int in + await exp.signalTwo() + return 2 + } + + try await Test.get("/bar").assertOk() + AssertFalse(await exp.one) + AssertTrue(await exp.two) + } + + func testMiddlewareOrder() async throws { + var stack = [Int]() + let expect = Expect() + let mw1 = TestMiddleware { _ in + XCTAssertEqual(stack, []) + await expect.signalOne() + stack.append(0) + } res: { _ in + XCTAssertEqual(stack, [0,1,2,3,4]) + await expect.signalTwo() + } + + let mw2 = TestMiddleware { _ in + XCTAssertEqual(stack, [0]) + await expect.signalThree() + stack.append(1) + } res: { _ in + XCTAssertEqual(stack, [0,1,2,3]) + await expect.signalFour() + stack.append(4) + } + + let mw3 = TestMiddleware { _ in + XCTAssertEqual(stack, [0,1]) + await expect.signalFive() + stack.append(2) + } res: { _ in + XCTAssertEqual(stack, [0,1,2]) + await expect.signalSix() + stack.append(3) + } + + app.use(mw1, mw2, mw3).get("/foo") { _ in } + _ = try await Test.get("/foo") + AssertTrue(await expect.one) + AssertTrue(await expect.two) + AssertTrue(await expect.three) + AssertTrue(await expect.four) + AssertTrue(await expect.five) + AssertTrue(await expect.six) + } +} + +/// Runs the specified callback on a request / response. +struct TestMiddleware: Middleware { + var req: ((Request) async throws -> Void)? + var res: ((Response) async throws -> Void)? + + func intercept(_ request: Request, next: Next) async throws -> Response { + try await req?(request) + let response = try await next(request) + try await res?(response) + return response + } +} + diff --git a/Tests/Alchemy/Queue/QueueTests.swift b/Tests/Alchemy/Queue/QueueTests.swift new file mode 100644 index 00000000..dce1436e --- /dev/null +++ b/Tests/Alchemy/Queue/QueueTests.swift @@ -0,0 +1,181 @@ +@testable +import Alchemy +import AlchemyTest + +final class QueueTests: TestCase { + private lazy var allTests = [ + _testEnqueue, + _testWorker, + _testFailure, + _testRetry, + ] + + override func tearDownWithError() throws { + // Redis seems to throw on shutdown if it could never connect in the + // first place. While this shouldn't be necessary, it is a stopgap + // for throwing an error when shutting down unconnected redis. + try? app.stop() + JobDecoding.reset() + } + + func testConfig() { + let config = Queue.Config(queues: [.default: .memory, 1: .memory, 2: .memory], jobs: [.job(TestJob.self)]) + Queue.configure(with: config) + XCTAssertNotNil(Container.resolve(Queue.self, identifier: Queue.Identifier.default)) + XCTAssertNotNil(Container.resolve(Queue.self, identifier: 1)) + XCTAssertNotNil(Container.resolve(Queue.self, identifier: 2)) + XCTAssertTrue(app.registeredJobs.contains(where: { ObjectIdentifier($0) == ObjectIdentifier(TestJob.self) })) + } + + func testJobDecoding() { + let fakeData = JobData(id: UUID().uuidString, json: "", jobName: "foo", channel: "bar", recoveryStrategy: .none, retryBackoff: .zero, attempts: 0, backoffUntil: nil) + XCTAssertThrowsError(try JobDecoding.decode(fakeData)) + + struct TestJob: Job { + let foo: String + func run() async throws {} + } + + JobDecoding.register(TestJob.self) + let invalidData = JobData(id: "foo", json: "bar", jobName: "TestJob", channel: "foo", recoveryStrategy: .none, retryBackoff: .zero, attempts: 0, backoffUntil: nil) + XCTAssertThrowsError(try JobDecoding.decode(invalidData)) + } + + func testDatabaseQueue() async throws { + for test in allTests { + Database.fake(migrations: [Queue.AddJobsMigration()]) + Queue.bind(.database) + try await test(#filePath, #line) + } + } + + func testMemoryQueue() async throws { + for test in allTests { + Queue.fake() + try await test(#filePath, #line) + } + } + + func testRedisQueue() async throws { + for test in allTests { + RedisClient.bind(.testing) + Queue.bind(.redis) + + guard await Redis.checkAvailable() else { + throw XCTSkip() + } + + try await test(#filePath, #line) + _ = try await Redis.send(command: "FLUSHDB").get() + } + } + + private func _testEnqueue(file: StaticString = #filePath, line: UInt = #line) async throws { + try await TestJob(foo: "bar").dispatch() + guard let jobData = try await Q.dequeue(from: ["default"]) else { + XCTFail("Failed to dequeue a job.", file: file, line: line) + return + } + + XCTAssertEqual(jobData.jobName, "TestJob", file: file, line: line) + XCTAssertEqual(jobData.recoveryStrategy, .retry(3), file: file, line: line) + XCTAssertEqual(jobData.backoff, .seconds(0), file: file, line: line) + + let decodedJob = try JobDecoding.decode(jobData) + guard let testJob = decodedJob as? TestJob else { + XCTFail("Failed to decode TestJob \(jobData.jobName) \(type(of: decodedJob))", file: file, line: line) + return + } + + XCTAssertEqual(testJob.foo, "bar", file: file, line: line) + } + + private func _testWorker(file: StaticString = #filePath, line: UInt = #line) async throws { + try await ConfirmableJob().dispatch() + + let sema = DispatchSemaphore(value: 0) + ConfirmableJob.didRun = { + sema.signal() + } + + let loop = EmbeddedEventLoop() + Q.startWorker(on: loop) + loop.advanceTime(by: .seconds(5)) + sema.wait() + } + + private func _testFailure(file: StaticString = #filePath, line: UInt = #line) async throws { + try await FailureJob().dispatch() + + let sema = DispatchSemaphore(value: 0) + FailureJob.didFinish = { + sema.signal() + } + + let loop = EmbeddedEventLoop() + Q.startWorker(on: loop) + loop.advanceTime(by: .seconds(5)) + + sema.wait() + AssertNil(try await Q.dequeue(from: ["default"])) + } + + private func _testRetry(file: StaticString = #filePath, line: UInt = #line) async throws { + try await TestJob(foo: "bar").dispatch() + + let sema = DispatchSemaphore(value: 0) + TestJob.didFail = { + sema.signal() + } + + let loop = EmbeddedEventLoop() + Q.startWorker(untilEmpty: false, on: loop) + loop.advanceTime(by: .seconds(5)) + + sema.wait() + + guard let jobData = try await Q.dequeue(from: ["default"]) else { + XCTFail("Failed to dequeue a job.", file: file, line: line) + return + } + + XCTAssertEqual(jobData.jobName, "TestJob", file: file, line: line) + XCTAssertEqual(jobData.attempts, 1, file: file, line: line) + } +} + +private struct FailureJob: Job { + static var didFinish: (() -> Void)? = nil + + func run() async throws { + throw JobError("foo") + } + + func finished(result: Result) { + FailureJob.didFinish?() + } +} + +private struct ConfirmableJob: Job { + static var didRun: (() -> Void)? = nil + + func run() async throws { + ConfirmableJob.didRun?() + } +} + +private struct TestJob: Job { + static var didFail: (() -> Void)? = nil + + let foo: String + var recoveryStrategy: RecoveryStrategy = .retry(3) + var retryBackoff: TimeAmount = .seconds(0) + + func run() async throws { + throw JobError("foo") + } + + func failed(error: Error) { + TestJob.didFail?() + } +} diff --git a/Tests/Alchemy/Redis/Redis+Testing.swift b/Tests/Alchemy/Redis/Redis+Testing.swift new file mode 100644 index 00000000..dd47fb04 --- /dev/null +++ b/Tests/Alchemy/Redis/Redis+Testing.swift @@ -0,0 +1,24 @@ +import Alchemy +import RediStack + +extension Alchemy.RedisClient { + static var testing: Alchemy.RedisClient { + .configuration(RedisConnectionPool.Configuration( + initialServerConnectionAddresses: [ + try! .makeAddressResolvingHost("localhost", port: 6379) + ], + maximumConnectionCount: .maximumActiveConnections(1), + connectionFactoryConfiguration: RedisConnectionPool.ConnectionFactoryConfiguration(connectionDefaultLogger: Log.logger), + connectionRetryTimeout: .milliseconds(100) + )) + } + + func checkAvailable() async -> Bool { + do { + _ = try await ping().get() + return true + } catch { + return false + } + } +} diff --git a/Tests/Alchemy/Routing/ResponseConvertibleTests.swift b/Tests/Alchemy/Routing/ResponseConvertibleTests.swift new file mode 100644 index 00000000..9ffa7297 --- /dev/null +++ b/Tests/Alchemy/Routing/ResponseConvertibleTests.swift @@ -0,0 +1,8 @@ +import AlchemyTest + +final class ResponseConvertibleTests: XCTestCase { + func testConvertArray() throws { + let array = ["one", "two"] + try array.response().assertOk().assertJson(array) + } +} diff --git a/Tests/Alchemy/Routing/RouterTests.swift b/Tests/Alchemy/Routing/RouterTests.swift new file mode 100644 index 00000000..effd7a22 --- /dev/null +++ b/Tests/Alchemy/Routing/RouterTests.swift @@ -0,0 +1,167 @@ +@testable +import Alchemy +import AlchemyTest + +let kMinTimeout: TimeInterval = 0.01 + +final class RouterTests: TestCase { + func testResponseConvertibleHandlers() async throws { + app.get("/string") { _ in "one" } + app.post("/string") { _ in "two" } + app.put("/string") { _ in "three" } + app.patch("/string") { _ in "four" } + app.delete("/string") { _ in "five" } + app.options("/string") { _ in "six" } + app.head("/string") { _ in "seven" } + + try await Test.get("/string").assertBody("one").assertOk() + try await Test.post("/string").assertBody("two").assertOk() + try await Test.put("/string").assertBody("three").assertOk() + try await Test.patch("/string").assertBody("four").assertOk() + try await Test.delete("/string").assertBody("five").assertOk() + try await Test.options("/string").assertBody("six").assertOk() + try await Test.head("/string").assertBody("seven").assertOk() + } + + func testVoidHandlers() async throws { + app.get("/void") { _ in } + app.post("/void") { _ in } + app.put("/void") { _ in } + app.patch("/void") { _ in } + app.delete("/void") { _ in } + app.options("/void") { _ in } + app.head("/void") { _ in } + + try await Test.get("/void").assertEmpty().assertOk() + try await Test.post("/void").assertEmpty().assertOk() + try await Test.put("/void").assertEmpty().assertOk() + try await Test.patch("/void").assertEmpty().assertOk() + try await Test.delete("/void").assertEmpty().assertOk() + try await Test.options("/void").assertEmpty().assertOk() + try await Test.head("/void").assertEmpty().assertOk() + } + + func testEncodableHandlers() async throws { + app.get("/encodable") { _ in 1 } + app.post("/encodable") { _ in 2 } + app.put("/encodable") { _ in 3 } + app.patch("/encodable") { _ in 4 } + app.delete("/encodable") { _ in 5 } + app.options("/encodable") { _ in 6 } + app.head("/encodable") { _ in 7 } + + try await Test.get("/encodable").assertBody("1").assertOk() + try await Test.post("/encodable").assertBody("2").assertOk() + try await Test.put("/encodable").assertBody("3").assertOk() + try await Test.patch("/encodable").assertBody("4").assertOk() + try await Test.delete("/encodable").assertBody("5").assertOk() + try await Test.options("/encodable").assertBody("6").assertOk() + try await Test.head("/encodable").assertBody("7").assertOk() + } + + func testMissing() async throws { + app.get("/foo") { _ in } + app.post("/bar") { _ in } + try await Test.post("/foo").assertNotFound() + } + + func testQueriesIgnored() async throws { + app.get("/foo") { _ in } + try await Test.get("/foo?query=1").assertEmpty().assertOk() + } + + func testPathParametersMatch() async throws { + let expect = Expect() + let uuidString = UUID().uuidString + app.get("/v1/some_path/:uuid/:user_id") { request async -> ResponseConvertible in + XCTAssertEqual(request.parameters, [ + Parameter(key: "uuid", value: uuidString), + Parameter(key: "user_id", value: "123"), + ]) + await expect.signalOne() + return "foo" + } + + try await Test.get("/v1/some_path/\(uuidString)/123").assertBody("foo").assertOk() + AssertTrue(await expect.one) + } + + func testMultipleRequests() async throws { + app.get("/foo") { _ in 1 } + app.get("/foo") { _ in 2 } + try await Test.get("/foo").assertOk().assertBody("2") + } + + func testInvalidPath() throws { + throw XCTSkip() + } + + func testForwardSlashIssues() async throws { + app.get("noslash") { _ in 1 } + app.get("wrongslash/") { _ in 2 } + app.get("//////////manyslash//////////////") { _ in 3 } + app.get("split/path") { _ in 4 } + try await Test.get("/noslash").assertOk().assertBody("1") + try await Test.get("/wrongslash").assertOk().assertBody("2") + try await Test.get("/manyslash").assertOk().assertBody("3") + try await Test.get("/splitpath").assertNotFound() + try await Test.get("/split/path").assertOk().assertBody("4") + } + + func testGroupedPathPrefix() async throws { + app + .grouped("group") { app in + app + .get("/foo") { _ in 1 } + .get("/bar") { _ in 2 } + .grouped("/nested") { app in + app.post("/baz") { _ in 3 } + } + .post("/bar") { _ in 4 } + } + .put("/foo") { _ in 5 } + + try await Test.get("/group/foo").assertOk().assertBody("1") + try await Test.get("/group/bar").assertOk().assertBody("2") + try await Test.post("/group/nested/baz").assertOk().assertBody("3") + try await Test.post("/group/bar").assertOk().assertBody("4") + + // defined outside group -> still available without group prefix + try await Test.put("/foo").assertOk().assertBody("5") + + // only available under group prefix + try await Test.get("/bar").assertNotFound() + try await Test.post("/baz").assertNotFound() + try await Test.post("/bar").assertNotFound() + try await Test.get("/foo").assertNotFound() + } + + func testError() async throws { + app.get("/error") { _ -> Void in throw TestError() } + let status = HTTPResponseStatus.internalServerError + try await Test.get("/error").assertStatus(status).assertBody(status.reasonPhrase) + } + + func testErrorHandling() async throws { + app.get("/error_convert") { _ -> Void in throw TestConvertibleError() } + app.get("/error_convert_error") { _ -> Void in throw TestThrowingConvertibleError() } + + let errorStatus = HTTPResponseStatus.internalServerError + try await Test.get("/error_convert").assertStatus(.badGateway).assertEmpty() + try await Test.get("/error_convert_error").assertStatus(errorStatus).assertBody(errorStatus.reasonPhrase) + } +} + +private struct TestError: Error {} + +private struct TestConvertibleError: Error, ResponseConvertible { + func response() async throws -> Response { + Response(status: .badGateway, body: nil) + } +} + +private struct TestThrowingConvertibleError: Error, ResponseConvertible { + func response() async throws -> Response { + throw TestError() + } +} diff --git a/Tests/Alchemy/Routing/TrieTests.swift b/Tests/Alchemy/Routing/TrieTests.swift new file mode 100644 index 00000000..4ef1d880 --- /dev/null +++ b/Tests/Alchemy/Routing/TrieTests.swift @@ -0,0 +1,50 @@ +@testable import Alchemy +import XCTest + +final class TrieTests: XCTestCase { + func testTrie() { + let trie = Trie() + + trie.insert(path: ["one"], value: "foo") + trie.insert(path: ["one", "two"], value: "bar") + trie.insert(path: ["one", "two", "three"], value: "baz") + trie.insert(path: ["one", ":id"], value: "doo") + trie.insert(path: ["one", ":id", "two"], value: "dar") + trie.insert(path: [], value: "daz") + trie.insert(path: [":id0", ":id1", ":id2", ":id3"], value: "hmm") + + let result1 = trie.search(path: ["one"]) + let result2 = trie.search(path: ["one", "two"]) + let result3 = trie.search(path: ["one", "two", "three"]) + let result4 = trie.search(path: ["one", "zonk"]) + let result5 = trie.search(path: ["one", "fail", "two"]) + let result6 = trie.search(path: ["one", "aaa", "two"]) + let result7 = trie.search(path: ["one", "bbb", "two"]) + let result8 = trie.search(path: ["1", "2", "3", "4"]) + let result9 = trie.search(path: ["1", "2", "3", "5", "6"]) + + XCTAssertEqual(result1?.value, "foo") + XCTAssertEqual(result1?.parameters, []) + XCTAssertEqual(result2?.value, "bar") + XCTAssertEqual(result2?.parameters, []) + XCTAssertEqual(result3?.value, "baz") + XCTAssertEqual(result3?.parameters, []) + XCTAssertEqual(result4?.value, "doo") + XCTAssertEqual(result4?.parameters, [Parameter(key: "id", value: "zonk")]) + XCTAssertEqual(result5?.value, "dar") + XCTAssertEqual(result5?.parameters, [Parameter(key: "id", value: "fail")]) + XCTAssertEqual(result6?.value, "dar") + XCTAssertEqual(result6?.parameters, [Parameter(key: "id", value: "aaa")]) + XCTAssertEqual(result7?.value, "dar") + XCTAssertEqual(result7?.parameters, [Parameter(key: "id", value: "bbb")]) + XCTAssertEqual(result8?.value, "hmm") + XCTAssertEqual(result8?.parameters, [ + Parameter(key: "id0", value: "1"), + Parameter(key: "id1", value: "2"), + Parameter(key: "id2", value: "3"), + Parameter(key: "id3", value: "4"), + ]) + XCTAssertEqual(result9?.0, nil) + XCTAssertEqual(result9?.1, nil) + } +} diff --git a/Tests/Alchemy/SQL/Database/Core/DatabaseConfigTests.swift b/Tests/Alchemy/SQL/Database/Core/DatabaseConfigTests.swift new file mode 100644 index 00000000..b542f873 --- /dev/null +++ b/Tests/Alchemy/SQL/Database/Core/DatabaseConfigTests.swift @@ -0,0 +1,32 @@ +import AlchemyTest + +final class DatabaseConfigTests: TestCase { + func testConfig() { + let config = Database.Config( + databases: [ + .default: .memory, + 1: .memory, + 2: .memory + ], + migrations: [Migration1()], + seeders: [TestSeeder()], + redis: [ + .default: .testing, + 1: .testing, + 2: .testing + ]) + Database.configure(with: config) + XCTAssertNotNil(Container.resolve(Database.self, identifier: Database.Identifier.default)) + XCTAssertNotNil(Container.resolve(Database.self, identifier: 1)) + XCTAssertNotNil(Container.resolve(Database.self, identifier: 2)) + XCTAssertNotNil(Container.resolve(RedisClient.self, identifier: RedisClient.Identifier.default)) + XCTAssertNotNil(Container.resolve(RedisClient.self, identifier: 1)) + XCTAssertNotNil(Container.resolve(RedisClient.self, identifier: 2)) + XCTAssertEqual(DB.migrations.count, 1) + XCTAssertEqual(DB.seeders.count, 1) + } +} + +private struct TestSeeder: Seeder { + func run() async throws {} +} diff --git a/Tests/Alchemy/SQL/Database/Core/DatabaseKeyMappingTests.swift b/Tests/Alchemy/SQL/Database/Core/DatabaseKeyMappingTests.swift new file mode 100644 index 00000000..136e9e68 --- /dev/null +++ b/Tests/Alchemy/SQL/Database/Core/DatabaseKeyMappingTests.swift @@ -0,0 +1,19 @@ +import Alchemy +import XCTest + +final class DatabaseKeyMappingTests: XCTestCase { + func testCustom() { + let custom = DatabaseKeyMapping.custom { "\($0)_1" } + XCTAssertEqual(custom.map(input: "foo"), "foo_1") + } + + func testSnakeCase() { + let snakeCase = DatabaseKeyMapping.convertToSnakeCase + XCTAssertEqual(snakeCase.map(input: ""), "") + XCTAssertEqual(snakeCase.map(input: "foo"), "foo") + XCTAssertEqual(snakeCase.map(input: "fooBar"), "foo_bar") + XCTAssertEqual(snakeCase.map(input: "AI"), "a_i") + XCTAssertEqual(snakeCase.map(input: "testJSON"), "test_json") + XCTAssertEqual(snakeCase.map(input: "testNumbers123"), "test_numbers123") + } +} diff --git a/Tests/Alchemy/SQL/Database/Core/SQLRowTests.swift b/Tests/Alchemy/SQL/Database/Core/SQLRowTests.swift new file mode 100644 index 00000000..f80006dd --- /dev/null +++ b/Tests/Alchemy/SQL/Database/Core/SQLRowTests.swift @@ -0,0 +1,104 @@ +@testable +import Alchemy +import AlchemyTest + +final class SQLRowTests: XCTestCase { + func testDecode() { + struct Test: Decodable, Equatable { + let foo: Int + let bar: String + } + + let row: SQLRow = StubDatabaseRow(data: [ + "foo": 1, + "bar": "two" + ]) + XCTAssertEqual(try row.decode(Test.self), Test(foo: 1, bar: "two")) + } + + func testModel() { + let date = Date() + let uuid = UUID() + let row: SQLRow = StubDatabaseRow(data: [ + "id": SQLValue.null, + "bool": false, + "string": "foo", + "double": 0.0, + "float": 0.0, + "int": 0, + "int8": 0, + "int16": 0, + "int32": 0, + "int64": 0, + "uint": 0, + "uint8": 0, + "uint16": 0, + "uint32": 0, + "uint64": 0, + "string_enum": "one", + "int_enum": 2, + "double_enum": 3.0, + "nested": SQLValue.json(""" + {"string":"foo","int":1} + """.data(using: .utf8) ?? Data()), + "date": SQLValue.date(date), + "uuid": SQLValue.uuid(uuid), + "belongs_to_id": 1 + ]) + XCTAssertEqual(try row.decode(EverythingModel.self), EverythingModel(date: date, uuid: uuid, belongsTo: .pk(1))) + } + + func testSubscript() { + let row: SQLRow = StubDatabaseRow(data: ["foo": 1]) + XCTAssertEqual(row["foo"], .int(1)) + XCTAssertEqual(row["bar"], nil) + } +} + +struct EverythingModel: Model, Equatable { + struct Nested: Codable, Equatable { + let string: String + let int: Int + } + enum StringEnum: String, ModelEnum { case one } + enum IntEnum: Int, ModelEnum { case two = 2 } + enum DoubleEnum: Double, ModelEnum { case three = 3.0 } + + var id: Int? + + // Enum + var stringEnum: StringEnum = .one + var intEnum: IntEnum = .two + var doubleEnum: DoubleEnum = .three + + // Keyed + var bool: Bool = false + var string: String = "foo" + var double: Double = 0 + var float: Float = 0 + var int: Int = 0 + var int8: Int8 = 0 + var int16: Int16 = 0 + var int32: Int32 = 0 + var int64: Int64 = 0 + var uint: UInt = 0 + var uint8: UInt8 = 0 + var uint16: UInt16 = 0 + var uint32: UInt32 = 0 + var uint64: UInt64 = 0 + var nested: Nested = Nested(string: "foo", int: 1) + var date: Date = Date() + var uuid: UUID = UUID() + + @HasMany var hasMany: [EverythingModel] + @HasOne var hasOne: EverythingModel + @HasOne var hasOneOptional: EverythingModel? + @BelongsTo var belongsTo: EverythingModel + @BelongsTo var belongsToOptional: EverythingModel? + + static var jsonEncoder: JSONEncoder = { + let encoder = JSONEncoder() + encoder.outputFormatting = [.sortedKeys] + return encoder + }() +} diff --git a/Tests/Alchemy/SQL/Database/Core/SQLTests.swift b/Tests/Alchemy/SQL/Database/Core/SQLTests.swift new file mode 100644 index 00000000..53b8135a --- /dev/null +++ b/Tests/Alchemy/SQL/Database/Core/SQLTests.swift @@ -0,0 +1,9 @@ +import Alchemy +import XCTest + +final class SQLTests: XCTestCase { + func testValueConvertible() { + let sql: SQL = "NOW()" + XCTAssertEqual(sql.value, .string("NOW()")) + } +} diff --git a/Tests/Alchemy/SQL/Database/Core/SQLValueConvertibleTests.swift b/Tests/Alchemy/SQL/Database/Core/SQLValueConvertibleTests.swift new file mode 100644 index 00000000..e71b73be --- /dev/null +++ b/Tests/Alchemy/SQL/Database/Core/SQLValueConvertibleTests.swift @@ -0,0 +1,18 @@ +import Alchemy +import XCTest + +final class SQLValueConvertibleTests: XCTestCase { + func testValueLiteral() { + let jsonString = """ + {"foo":"bar"} + """ + let jsonData = jsonString.data(using: .utf8) ?? Data() + XCTAssertEqual(SQLValue.json(jsonData).sqlLiteral, "'\(jsonString)'") + XCTAssertEqual(SQLValue.null.sqlLiteral, "NULL") + } + + func testSQL() { + XCTAssertEqual(SQLValue.string("foo").sql, SQL("'foo'")) + XCTAssertEqual(SQL("foo", bindings: [.string("bar")]).sql, SQL("foo", bindings: [.string("bar")])) + } +} diff --git a/Tests/Alchemy/SQL/Database/Core/SQLValueTests.swift b/Tests/Alchemy/SQL/Database/Core/SQLValueTests.swift new file mode 100644 index 00000000..fcc5c300 --- /dev/null +++ b/Tests/Alchemy/SQL/Database/Core/SQLValueTests.swift @@ -0,0 +1,83 @@ +import AlchemyTest + +final class SQLValueTests: XCTestCase { + func testNull() { + XCTAssertThrowsError(try SQLValue.null.int()) + XCTAssertThrowsError(try SQLValue.null.double()) + XCTAssertThrowsError(try SQLValue.null.bool()) + XCTAssertThrowsError(try SQLValue.null.string()) + XCTAssertThrowsError(try SQLValue.null.json()) + XCTAssertThrowsError(try SQLValue.null.date()) + XCTAssertThrowsError(try SQLValue.null.uuid("foo")) + } + + func testInt() { + XCTAssertEqual(try SQLValue.int(1).int(), 1) + XCTAssertThrowsError(try SQLValue.string("foo").int()) + } + + func testDouble() { + XCTAssertEqual(try SQLValue.double(1.0).double(), 1.0) + XCTAssertThrowsError(try SQLValue.string("foo").double()) + } + + func testBool() { + XCTAssertEqual(try SQLValue.bool(false).bool(), false) + XCTAssertEqual(try SQLValue.int(1).bool(), true) + XCTAssertThrowsError(try SQLValue.string("foo").bool()) + } + + func testString() { + XCTAssertEqual(try SQLValue.string("foo").string(), "foo") + XCTAssertThrowsError(try SQLValue.int(1).string()) + } + + func testDate() { + let date = Date() + XCTAssertEqual(try SQLValue.date(date).date(), date) + XCTAssertThrowsError(try SQLValue.int(1).date()) + } + + func testDateIso8601() { + let date = Date() + let formatter = ISO8601DateFormatter() + let dateString = formatter.string(from: date) + let roundedDate = formatter.date(from: dateString) ?? Date() + XCTAssertEqual(try SQLValue.string(formatter.string(from: date)).date(), roundedDate) + XCTAssertThrowsError(try SQLValue.string("").date()) + } + + func testJson() { + let jsonString = """ + {"foo":1} + """ + XCTAssertEqual(try SQLValue.json(Data()).json(), Data()) + XCTAssertEqual(try SQLValue.string(jsonString).json(), jsonString.data(using: .utf8)) + XCTAssertThrowsError(try SQLValue.int(1).json()) + } + + func testUuid() { + let uuid = UUID() + XCTAssertEqual(try SQLValue.uuid(uuid).uuid(), uuid) + XCTAssertEqual(try SQLValue.string(uuid.uuidString).uuid(), uuid) + XCTAssertThrowsError(try SQLValue.string("").uuid()) + XCTAssertThrowsError(try SQLValue.int(1).uuid("foo")) + } + + func testDescription() { + XCTAssertEqual(SQLValue.int(0).description, "SQLValue.int(0)") + XCTAssertEqual(SQLValue.double(1.23).description, "SQLValue.double(1.23)") + XCTAssertEqual(SQLValue.bool(true).description, "SQLValue.bool(true)") + XCTAssertEqual(SQLValue.string("foo").description, "SQLValue.string(`foo`)") + let date = Date() + XCTAssertEqual(SQLValue.date(date).description, "SQLValue.date(\(date))") + let jsonString = """ + {"foo":"bar"} + """ + let jsonData = jsonString.data(using: .utf8) ?? Data() + XCTAssertEqual(SQLValue.json(jsonData).description, "SQLValue.json(\(jsonString))") + let uuid = UUID() + XCTAssertEqual(SQLValue.uuid(uuid).description, "SQLValue.uuid(\(uuid.uuidString))") + XCTAssertEqual(SQLValue.null.description, "SQLValue.null") + } +} diff --git a/Tests/Alchemy/SQL/Database/Drivers/MySQL/MySQLDatabaseRowTests.swift b/Tests/Alchemy/SQL/Database/Drivers/MySQL/MySQLDatabaseRowTests.swift new file mode 100644 index 00000000..47395150 --- /dev/null +++ b/Tests/Alchemy/SQL/Database/Drivers/MySQL/MySQLDatabaseRowTests.swift @@ -0,0 +1,95 @@ +@testable import MySQLNIO +@testable import Alchemy +import AlchemyTest + +final class MySQLDatabaseRowTests: TestCase { + func testGet() { + let row = MySQLDatabaseRow(.fooOneBar2) + XCTAssertEqual(try row.get("foo"), .string("one")) + XCTAssertEqual(try row.get("bar"), .int(2)) + XCTAssertThrowsError(try row.get("baz")) + } + + func testNil() { + XCTAssertEqual(try MySQLData(.null).toSQLValue(), .null) + } + + func testString() { + XCTAssertEqual(try MySQLData(.string("foo")).toSQLValue(), .string("foo")) + XCTAssertEqual(try MySQLData(type: .string, buffer: nil).toSQLValue(), .null) + } + + func testInt() { + XCTAssertEqual(try MySQLData(.int(1)).toSQLValue(), .int(1)) + XCTAssertEqual(try MySQLData(type: .long, buffer: nil).toSQLValue(), .null) + } + + func testDouble() { + XCTAssertEqual(try MySQLData(.double(2.0)).toSQLValue(), .double(2.0)) + XCTAssertEqual(try MySQLData(type: .float, buffer: nil).toSQLValue(), .null) + } + + func testBool() { + XCTAssertEqual(try MySQLData(.bool(false)).toSQLValue(), .bool(false)) + XCTAssertEqual(try MySQLData(type: .tiny, buffer: nil).toSQLValue(), .null) + } + + func testDate() throws { + let date = Date() + // MySQLNIO occasionally loses some millisecond precision; round off. + let roundedDate = Date(timeIntervalSince1970: TimeInterval((Int(date.timeIntervalSince1970) / 1000) * 1000)) + XCTAssertEqual(try MySQLData(.date(roundedDate)).toSQLValue(), .date(roundedDate)) + XCTAssertEqual(try MySQLData(type: .date, buffer: nil).toSQLValue(), .null) + } + + func testJson() { + XCTAssertEqual(try MySQLData(.json(Data())).toSQLValue(), .json(Data())) + XCTAssertEqual(try MySQLData(type: .json, buffer: nil).toSQLValue(), .null) + } + + func testUuid() { + let uuid = UUID() + // Store as a string in MySQL + XCTAssertEqual(try MySQLData(.uuid(uuid)).toSQLValue(), .string(uuid.uuidString)) + } + + func testUnsupportedTypeThrows() { + XCTAssertThrowsError(try MySQLData(type: .time, buffer: nil).toSQLValue()) + XCTAssertThrowsError(try MySQLData(type: .time, buffer: nil).toSQLValue("fake")) + } +} + +extension MySQLRow { + static let fooOneBar2 = MySQLRow( + format: .text, + columnDefinitions: [ + .init( + catalog: "", + schema: "", + table: "", + orgTable: "", + name: "foo", + orgName: "", + characterSet: .utf8, + columnLength: 3, + columnType: .varchar, + flags: [], + decimals: 0), + .init( + catalog: "", + schema: "", + table: "", + orgTable: "", + name: "bar", + orgName: "", + characterSet: .utf8, + columnLength: 8, + columnType: .long, + flags: [], + decimals: 0) + ], + values: [ + .init(string: "one"), + .init(string: "2"), + ]) +} diff --git a/Tests/Alchemy/SQL/Database/Drivers/MySQL/MySQLDatabaseTests.swift b/Tests/Alchemy/SQL/Database/Drivers/MySQL/MySQLDatabaseTests.swift new file mode 100644 index 00000000..19930218 --- /dev/null +++ b/Tests/Alchemy/SQL/Database/Drivers/MySQL/MySQLDatabaseTests.swift @@ -0,0 +1,59 @@ +@testable +import Alchemy +import AlchemyTest +import NIOSSL + +final class MySQLDatabaseTests: TestCase { + func testDatabase() throws { + let db = Database.mysql(host: "127.0.0.1", database: "foo", username: "bar", password: "baz") + guard let provider = db.provider as? Alchemy.MySQLDatabase else { + XCTFail("The database provider should be MySQL.") + return + } + + XCTAssertEqual(try provider.pool.source.configuration.address().ipAddress, "127.0.0.1") + XCTAssertEqual(try provider.pool.source.configuration.address().port, 3306) + XCTAssertEqual(provider.pool.source.configuration.database, "foo") + XCTAssertEqual(provider.pool.source.configuration.username, "bar") + XCTAssertEqual(provider.pool.source.configuration.password, "baz") + XCTAssertTrue(provider.pool.source.configuration.tlsConfiguration == nil) + try db.shutdown() + } + + func testConfigIp() throws { + let socket: Socket = .ip(host: "127.0.0.1", port: 1234) + let provider = MySQLDatabase(socket: socket, database: "foo", username: "bar", password: "baz") + XCTAssertEqual(try provider.pool.source.configuration.address().ipAddress, "127.0.0.1") + XCTAssertEqual(try provider.pool.source.configuration.address().port, 1234) + XCTAssertEqual(provider.pool.source.configuration.database, "foo") + XCTAssertEqual(provider.pool.source.configuration.username, "bar") + XCTAssertEqual(provider.pool.source.configuration.password, "baz") + XCTAssertTrue(provider.pool.source.configuration.tlsConfiguration == nil) + try provider.shutdown() + } + + func testConfigSSL() throws { + let socket: Socket = .ip(host: "127.0.0.1", port: 1234) + let tlsConfig = TLSConfiguration.makeClientConfiguration() + let provider = MySQLDatabase(socket: socket, database: "foo", username: "bar", password: "baz", tlsConfiguration: tlsConfig) + XCTAssertEqual(try provider.pool.source.configuration.address().ipAddress, "127.0.0.1") + XCTAssertEqual(try provider.pool.source.configuration.address().port, 1234) + XCTAssertEqual(provider.pool.source.configuration.database, "foo") + XCTAssertEqual(provider.pool.source.configuration.username, "bar") + XCTAssertEqual(provider.pool.source.configuration.password, "baz") + XCTAssertTrue(provider.pool.source.configuration.tlsConfiguration != nil) + try provider.shutdown() + } + + func testConfigPath() throws { + let socket: Socket = .unix(path: "/test") + let provider = MySQLDatabase(socket: socket, database: "foo", username: "bar", password: "baz") + XCTAssertEqual(try provider.pool.source.configuration.address().pathname, "/test") + XCTAssertEqual(try provider.pool.source.configuration.address().port, nil) + XCTAssertEqual(provider.pool.source.configuration.database, "foo") + XCTAssertEqual(provider.pool.source.configuration.username, "bar") + XCTAssertEqual(provider.pool.source.configuration.password, "baz") + XCTAssertTrue(provider.pool.source.configuration.tlsConfiguration == nil) + try provider.shutdown() + } +} diff --git a/Tests/Alchemy/SQL/Database/Drivers/Postgres/PostgresDatabaseRowTests.swift b/Tests/Alchemy/SQL/Database/Drivers/Postgres/PostgresDatabaseRowTests.swift new file mode 100644 index 00000000..5b66015e --- /dev/null +++ b/Tests/Alchemy/SQL/Database/Drivers/Postgres/PostgresDatabaseRowTests.swift @@ -0,0 +1,89 @@ +@testable import PostgresNIO +@testable import Alchemy +import AlchemyTest + +final class PostgresDatabaseRowTests: TestCase { + func testGet() { + let row = PostgresDatabaseRow(.fooOneBar2) + XCTAssertEqual(try row.get("foo"), .string("one")) + XCTAssertEqual(try row.get("bar"), .int(2)) + XCTAssertThrowsError(try row.get("baz")) + } + + func testNull() { + XCTAssertEqual(try PostgresData(.null).toSQLValue(), .null) + } + + func testString() { + XCTAssertEqual(try PostgresData(.string("foo")).toSQLValue(), .string("foo")) + XCTAssertEqual(try PostgresData(type: .varchar).toSQLValue(), .null) + } + + func testInt() { + XCTAssertEqual(try PostgresData(.int(1)).toSQLValue(), .int(1)) + XCTAssertEqual(try PostgresData(type: .int8).toSQLValue(), .null) + } + + func testDouble() { + XCTAssertEqual(try PostgresData(.double(2.0)).toSQLValue(), .double(2.0)) + XCTAssertEqual(try PostgresData(type: .float8).toSQLValue(), .null) + } + + func testBool() { + XCTAssertEqual(try PostgresData(.bool(false)).toSQLValue(), .bool(false)) + XCTAssertEqual(try PostgresData(type: .bool).toSQLValue(), .null) + } + + func testDate() { + let date = Date() + XCTAssertEqual(try PostgresData(.date(date)).toSQLValue(), .date(date)) + XCTAssertEqual(try PostgresData(type: .date).toSQLValue(), .null) + } + + func testJson() { + XCTAssertEqual(try PostgresData(.json(Data())).toSQLValue(), .json(Data())) + XCTAssertEqual(try PostgresData(type: .json).toSQLValue(), .null) + } + + func testUuid() { + let uuid = UUID() + XCTAssertEqual(try PostgresData(.uuid(uuid)).toSQLValue(), .uuid(uuid)) + XCTAssertEqual(try PostgresData(type: .uuid).toSQLValue(), .null) + } + + func testUnsupportedTypeThrows() { + XCTAssertThrowsError(try PostgresData(type: .time).toSQLValue()) + XCTAssertThrowsError(try PostgresData(type: .point).toSQLValue("column")) + } +} + +extension PostgresRow { + static let fooOneBar2 = PostgresRow( + dataRow: .init(columns: [ + .init(value: ByteBuffer(string: "one")), + .init(value: ByteBuffer(integer: 2)) + ]), + lookupTable: .init( + rowDescription: .init( + fields: [ + .init( + name: "foo", + tableOID: 0, + columnAttributeNumber: 0, + dataType: .varchar, + dataTypeSize: 3, + dataTypeModifier: 0, + formatCode: .text + ), + .init( + name: "bar", + tableOID: 0, + columnAttributeNumber: 0, + dataType: .int8, + dataTypeSize: 8, + dataTypeModifier: 0, + formatCode: .binary + ), + ]), + resultFormat: [.binary])) +} diff --git a/Tests/Alchemy/SQL/Database/Drivers/Postgres/PostgresDatabaseTests.swift b/Tests/Alchemy/SQL/Database/Drivers/Postgres/PostgresDatabaseTests.swift new file mode 100644 index 00000000..1455fd8b --- /dev/null +++ b/Tests/Alchemy/SQL/Database/Drivers/Postgres/PostgresDatabaseTests.swift @@ -0,0 +1,64 @@ +@testable +import Alchemy +import AlchemyTest +import NIOSSL + +final class PostgresDatabaseTests: TestCase { + func testDatabase() throws { + let db = Database.postgres(host: "127.0.0.1", database: "foo", username: "bar", password: "baz") + guard let provider = db.provider as? Alchemy.PostgresDatabase else { + XCTFail("The database provider should be PostgreSQL.") + return + } + + XCTAssertEqual(try provider.pool.source.configuration.address().ipAddress, "127.0.0.1") + XCTAssertEqual(try provider.pool.source.configuration.address().port, 5432) + XCTAssertEqual(provider.pool.source.configuration.database, "foo") + XCTAssertEqual(provider.pool.source.configuration.username, "bar") + XCTAssertEqual(provider.pool.source.configuration.password, "baz") + XCTAssertTrue(provider.pool.source.configuration.tlsConfiguration == nil) + try db.shutdown() + } + + func testConfigIp() throws { + let socket: Socket = .ip(host: "127.0.0.1", port: 1234) + let provider = PostgresDatabase(socket: socket, database: "foo", username: "bar", password: "baz") + XCTAssertEqual(try provider.pool.source.configuration.address().ipAddress, "127.0.0.1") + XCTAssertEqual(try provider.pool.source.configuration.address().port, 1234) + XCTAssertEqual(provider.pool.source.configuration.database, "foo") + XCTAssertEqual(provider.pool.source.configuration.username, "bar") + XCTAssertEqual(provider.pool.source.configuration.password, "baz") + XCTAssertTrue(provider.pool.source.configuration.tlsConfiguration == nil) + try provider.shutdown() + } + + func testConfigSSL() throws { + let socket: Socket = .ip(host: "127.0.0.1", port: 1234) + let tlsConfig = TLSConfiguration.makeClientConfiguration() + let provider = PostgresDatabase(socket: socket, database: "foo", username: "bar", password: "baz", tlsConfiguration: tlsConfig) + XCTAssertEqual(try provider.pool.source.configuration.address().ipAddress, "127.0.0.1") + XCTAssertEqual(try provider.pool.source.configuration.address().port, 1234) + XCTAssertEqual(provider.pool.source.configuration.database, "foo") + XCTAssertEqual(provider.pool.source.configuration.username, "bar") + XCTAssertEqual(provider.pool.source.configuration.password, "baz") + XCTAssertTrue(provider.pool.source.configuration.tlsConfiguration != nil) + try provider.shutdown() + } + + func testConfigPath() throws { + let socket: Socket = .unix(path: "/test") + let provider = PostgresDatabase(socket: socket, database: "foo", username: "bar", password: "baz") + XCTAssertEqual(try provider.pool.source.configuration.address().pathname, "/test") + XCTAssertEqual(try provider.pool.source.configuration.address().port, nil) + XCTAssertEqual(provider.pool.source.configuration.database, "foo") + XCTAssertEqual(provider.pool.source.configuration.username, "bar") + XCTAssertEqual(provider.pool.source.configuration.password, "baz") + XCTAssertTrue(provider.pool.source.configuration.tlsConfiguration == nil) + try provider.shutdown() + } + + func testPositionBindings() { + let query = "select * from cats where name = ? and age > ?" + XCTAssertEqual(query.positionPostgresBindings(), "select * from cats where name = $1 and age > $2") + } +} diff --git a/Tests/Alchemy/SQL/Database/Drivers/SQLite/SQLiteDatabaseTests.swift b/Tests/Alchemy/SQL/Database/Drivers/SQLite/SQLiteDatabaseTests.swift new file mode 100644 index 00000000..b3eaa6f3 --- /dev/null +++ b/Tests/Alchemy/SQL/Database/Drivers/SQLite/SQLiteDatabaseTests.swift @@ -0,0 +1,35 @@ +@testable +import Alchemy +import AlchemyTest + +final class SQLiteDatabaseTests: TestCase { + func testDatabase() throws { + let memory = Database.memory + guard memory.provider as? Alchemy.SQLiteDatabase != nil else { + XCTFail("The database provider should be SQLite.") + return + } + + let path = Database.sqlite(path: "foo") + guard path.provider as? Alchemy.SQLiteDatabase != nil else { + XCTFail("The database provider should be SQLite.") + return + } + + try memory.shutdown() + try path.shutdown() + } + + func testConfigPath() throws { + let provider = SQLiteDatabase(config: .file("foo")) + XCTAssertEqual(provider.config, .file("foo")) + try provider.shutdown() + } + + func testConfigMemory() throws { + let id = UUID().uuidString + let provider = SQLiteDatabase(config: .memory(identifier: id)) + XCTAssertEqual(provider.config, .memory(identifier: id)) + try provider.shutdown() + } +} diff --git a/Tests/Alchemy/SQL/Database/Drivers/SQLite/SQLiteRowTests.swift b/Tests/Alchemy/SQL/Database/Drivers/SQLite/SQLiteRowTests.swift new file mode 100644 index 00000000..6b8b7466 --- /dev/null +++ b/Tests/Alchemy/SQL/Database/Drivers/SQLite/SQLiteRowTests.swift @@ -0,0 +1,70 @@ +@testable import SQLiteNIO +@testable import Alchemy +import AlchemyTest + +final class SQLiteRowTests: TestCase { + func testGet() { + let row = SQLiteDatabaseRow(.fooOneBar2) + XCTAssertEqual(try row.get("foo"), .string("one")) + XCTAssertEqual(try row.get("bar"), .int(2)) + XCTAssertThrowsError(try row.get("baz")) + } + + func testNull() { + XCTAssertEqual(try SQLiteData(.null).toSQLValue(), .null) + } + + func testString() { + XCTAssertEqual(try SQLiteData(.string("foo")).toSQLValue(), .string("foo")) + } + + func testInt() { + XCTAssertEqual(try SQLiteData(.int(1)).toSQLValue(), .int(1)) + } + + func testDouble() { + XCTAssertEqual(try SQLiteData(.double(2.0)).toSQLValue(), .double(2.0)) + } + + func testBool() { + XCTAssertEqual(try SQLiteData(.bool(false)).toSQLValue(), .int(0)) + XCTAssertEqual(try SQLiteData(.bool(true)).toSQLValue(), .int(1)) + } + + func testDate() { + let date = Date() + let dateString = SQLValue.iso8601DateFormatter.string(from: date) + XCTAssertEqual(try SQLiteData(.date(date)).toSQLValue(), .string(dateString)) + } + + func testJson() { + let jsonString = """ + {"foo":"one","bar":2} + """ + let jsonData = jsonString.data(using: .utf8) ?? Data() + XCTAssertEqual(try SQLiteData(.json(jsonData)).toSQLValue(), .string(jsonString)) + let invalidBytes: [UInt8] = [0xFF, 0xD9] + XCTAssertEqual(try SQLiteData(.json(Data(bytes: invalidBytes, count: 2))).toSQLValue(), .null) + } + + func testUuid() { + let uuid = UUID() + XCTAssertEqual(try SQLiteData(.uuid(uuid)).toSQLValue(), .string(uuid.uuidString)) + } + + func testUnsupportedTypeThrows() { + XCTAssertThrowsError(try SQLiteData.blob(ByteBuffer()).toSQLValue()) + } +} + +extension SQLiteRow { + static let fooOneBar2 = SQLiteRow( + columnOffsets: .init(offsets: [ + ("foo", 0), + ("bar", 1), + ]), + data: [ + .text("one"), + .integer(2) + ]) +} diff --git a/Tests/Alchemy/SQL/Database/Fixtures/Models.swift b/Tests/Alchemy/SQL/Database/Fixtures/Models.swift new file mode 100644 index 00000000..f776e07a --- /dev/null +++ b/Tests/Alchemy/SQL/Database/Fixtures/Models.swift @@ -0,0 +1,49 @@ +import Alchemy + +struct SeedModel: Model, Seedable { + struct Migrate: Migration { + func up(schema: Schema) { + schema.create(table: "seed_models") { + $0.increments("id").primary() + $0.string("name").notNull() + $0.string("email").notNull().unique() + } + } + + func down(schema: Schema) { + schema.drop(table: "seed_models") + } + } + + var id: Int? + let name: String + let email: String + + static func generate() -> SeedModel { + SeedModel(name: faker.name.name(), email: faker.internet.email()) + } +} + +struct OtherSeedModel: Model, Seedable { + struct Migrate: Migration { + func up(schema: Schema) { + schema.create(table: "other_seed_models") { + $0.uuid("id").primary() + $0.int("foo").notNull() + $0.bool("bar").notNull() + } + } + + func down(schema: Schema) { + schema.drop(table: "seed_models") + } + } + + var id: UUID? = UUID() + let foo: Int + let bar: Bool + + static func generate() -> OtherSeedModel { + OtherSeedModel(foo: .random(), bar: .random()) + } +} diff --git a/Tests/Alchemy/SQL/Database/Seeding/DatabaseSeederTests.swift b/Tests/Alchemy/SQL/Database/Seeding/DatabaseSeederTests.swift new file mode 100644 index 00000000..246a679f --- /dev/null +++ b/Tests/Alchemy/SQL/Database/Seeding/DatabaseSeederTests.swift @@ -0,0 +1,52 @@ +@testable +import Alchemy +import AlchemyTest + +final class DatabaseSeederTests: TestCase { + func testSeeder() async throws { + Database.fake( + migrations: [ + SeedModel.Migrate(), + OtherSeedModel.Migrate()], + seeders: [TestSeeder()]) + + AssertEqual(try await SeedModel.all().count, 10) + AssertEqual(try await OtherSeedModel.all().count, 0) + + try await DB.seed(with: OtherSeeder()) + AssertEqual(try await OtherSeedModel.all().count, 999) + } + + func testSeedWithNames() async throws { + Database.fake( + migrations: [ + SeedModel.Migrate(), + OtherSeedModel.Migrate()]) + + DB.seeders = [ + TestSeeder(), + OtherSeeder() + ] + + try await DB.seed(names: ["otherseeder"]) + AssertEqual(try await SeedModel.all().count, 0) + AssertEqual(try await OtherSeedModel.all().count, 999) + + do { + try await DB.seed(names: ["foo"]) + XCTFail("Unknown seeder name should throw") + } catch {} + } +} + +private struct TestSeeder: Seeder { + func run() async throws { + try await SeedModel.seed(10) + } +} + +private struct OtherSeeder: Seeder { + func run() async throws { + try await OtherSeedModel.seed(999) + } +} diff --git a/Tests/Alchemy/SQL/Database/Seeding/SeederTests.swift b/Tests/Alchemy/SQL/Database/Seeding/SeederTests.swift new file mode 100644 index 00000000..c0e29ab7 --- /dev/null +++ b/Tests/Alchemy/SQL/Database/Seeding/SeederTests.swift @@ -0,0 +1,13 @@ +import AlchemyTest + +final class SeederTests: TestCase { + func testSeeder() async throws { + Database.fake(migrations: [SeedModel.Migrate()]) + + try await SeedModel.seed() + AssertEqual(try await SeedModel.all().count, 1) + + try await SeedModel.seed(10) + AssertEqual(try await SeedModel.all().count, 11) + } +} diff --git a/Tests/Alchemy/SQL/Migrations/DatabaseMigrationTests.swift b/Tests/Alchemy/SQL/Migrations/DatabaseMigrationTests.swift new file mode 100644 index 00000000..5a46ab42 --- /dev/null +++ b/Tests/Alchemy/SQL/Migrations/DatabaseMigrationTests.swift @@ -0,0 +1,28 @@ +@testable +import Alchemy +import AlchemyTest + +final class DatabaseMigrationTests: TestCase { + func testMigration() async throws { + let db = Database.fake() + try await db.rollbackMigrations() + db.migrations = [MigrationA()] + try await db.migrate() + AssertEqual(try await AlchemyMigration.all().count, 1) + db.migrations.append(MigrationB()) + try await db.migrate() + AssertEqual(try await AlchemyMigration.all().count, 2) + try await db.rollbackMigrations() + AssertEqual(try await AlchemyMigration.all().count, 1) + } +} + +private struct MigrationA: Migration { + func up(schema: Schema) {} + func down(schema: Schema) {} +} + +private struct MigrationB: Migration { + func up(schema: Schema) {} + func down(schema: Schema) {} +} diff --git a/Tests/AlchemyTests/SQL/Migrations/MigrationTests.swift b/Tests/Alchemy/SQL/Migrations/MigrationTests.swift similarity index 100% rename from Tests/AlchemyTests/SQL/Migrations/MigrationTests.swift rename to Tests/Alchemy/SQL/Migrations/MigrationTests.swift diff --git a/Tests/AlchemyTests/SQL/Migrations/SampleMigrations.swift b/Tests/Alchemy/SQL/Migrations/SampleMigrations.swift similarity index 98% rename from Tests/AlchemyTests/SQL/Migrations/SampleMigrations.swift rename to Tests/Alchemy/SQL/Migrations/SampleMigrations.swift index ea06ca86..d9ddf7d3 100644 --- a/Tests/AlchemyTests/SQL/Migrations/SampleMigrations.swift +++ b/Tests/Alchemy/SQL/Migrations/SampleMigrations.swift @@ -57,7 +57,7 @@ struct Migration1: TestMigration { "counter" serial, "is_pro" bool DEFAULT false, "created_at" timestamptz, - "date_default" timestamptz DEFAULT '1970-01-01T00:00:00', + "date_default" timestamptz DEFAULT '1970-01-01 00:00:00 +0000', "uuid_default" uuid DEFAULT '\(kFixedUUID.uuidString)', "some_json" json DEFAULT '{"age":27,"name":"Josh"}'::jsonb, "other_json" json DEFAULT '{}'::jsonb, @@ -92,7 +92,7 @@ struct Migration1: TestMigration { "counter" serial, "is_pro" boolean DEFAULT false, "created_at" datetime, - "date_default" datetime DEFAULT '1970-01-01T00:00:00', + "date_default" datetime DEFAULT '1970-01-01 00:00:00 +0000', "uuid_default" varchar(36) DEFAULT '\(kFixedUUID.uuidString)', "some_json" json DEFAULT ('{"age":27,"name":"Josh"}'), "other_json" json DEFAULT ('{}'), diff --git a/Tests/Alchemy/SQL/Query/Builder/QueryCrudTests.swift b/Tests/Alchemy/SQL/Query/Builder/QueryCrudTests.swift new file mode 100644 index 00000000..494ddd74 --- /dev/null +++ b/Tests/Alchemy/SQL/Query/Builder/QueryCrudTests.swift @@ -0,0 +1,46 @@ +import AlchemyTest + +final class QueryCrudTests: TestCase { + var db: Database! + + override func setUp() { + super.setUp() + db = Database.fake(migrations: [TestModelMigration()]) + } + + func testFind() async throws { + AssertTrue(try await db.table("test_models").findRow("foo", equals: .string("bar")) == nil) + try await TestModel(foo: "bar", bar: false).insert() + AssertTrue(try await db.table("test_models").findRow("foo", equals: .string("bar")) != nil) + } + + func testCount() async throws { + AssertEqual(try await db.table("test_models").count(), 0) + try await TestModel(foo: "bar", bar: false).insert() + AssertEqual(try await db.table("test_models").count(), 1) + } +} + +private struct TestModel: Model, Seedable, Equatable { + var id: Int? + var foo: String + var bar: Bool + + static func generate() async throws -> TestModel { + TestModel(foo: faker.lorem.word(), bar: faker.number.randomBool()) + } +} + +private struct TestModelMigration: Migration { + func up(schema: Schema) { + schema.create(table: "test_models") { + $0.increments("id").primary() + $0.string("foo").notNull() + $0.bool("bar").notNull() + } + } + + func down(schema: Schema) { + schema.drop(table: "test_models") + } +} diff --git a/Tests/Alchemy/SQL/Query/Builder/QueryGroupingTests.swift b/Tests/Alchemy/SQL/Query/Builder/QueryGroupingTests.swift new file mode 100644 index 00000000..59aeaa13 --- /dev/null +++ b/Tests/Alchemy/SQL/Query/Builder/QueryGroupingTests.swift @@ -0,0 +1,32 @@ +@testable +import Alchemy +import AlchemyTest + +final class QueryGroupingTests: TestCase { + private let sampleWhere = Query.Where( + type: .value(key: "id", op: .equals, value: .int(1)), + boolean: .and) + + override func setUp() { + super.setUp() + Database.stub() + } + + func testGroupBy() { + XCTAssertEqual(DB.table("foo").groupBy("bar").groups, ["bar"]) + XCTAssertEqual(DB.table("foo").groupBy("bar").groupBy("baz").groups, ["bar", "baz"]) + } + + func testHaving() { + let orWhere = Query.Where(type: sampleWhere.type, boolean: .or) + let query = DB.table("foo") + .having(sampleWhere) + .orHaving(orWhere) + .having(key: "bar", op: .like, value: "baz", boolean: .or) + XCTAssertEqual(query.havings, [ + sampleWhere, + orWhere, + Query.Where(type: .value(key: "bar", op: .like, value: .string("baz")), boolean: .or) + ]) + } +} diff --git a/Tests/Alchemy/SQL/Query/Builder/QueryJoinTests.swift b/Tests/Alchemy/SQL/Query/Builder/QueryJoinTests.swift new file mode 100644 index 00000000..2b7f6219 --- /dev/null +++ b/Tests/Alchemy/SQL/Query/Builder/QueryJoinTests.swift @@ -0,0 +1,60 @@ +@testable +import Alchemy +import AlchemyTest + +final class QueryJoinTests: TestCase { + override func setUp() { + super.setUp() + Database.stub() + } + + func testJoin() { + let query = DB.table("foo").join(table: "bar", first: "id1", second: "id2") + XCTAssertEqual(query.joins, [sampleJoin(of: .inner)]) + XCTAssertEqual(query.wheres, []) + } + + func testLeftJoin() { + let query = DB.table("foo").leftJoin(table: "bar", first: "id1", second: "id2") + XCTAssertEqual(query.joins, [sampleJoin(of: .left)]) + XCTAssertEqual(query.wheres, []) + } + + func testRightJoin() { + let query = DB.table("foo").rightJoin(table: "bar", first: "id1", second: "id2") + XCTAssertEqual(query.joins, [sampleJoin(of: .right)]) + XCTAssertEqual(query.wheres, []) + } + + func testCrossJoin() { + let query = DB.table("foo").crossJoin(table: "bar", first: "id1", second: "id2") + XCTAssertEqual(query.joins, [sampleJoin(of: .cross)]) + XCTAssertEqual(query.wheres, []) + } + + func testOn() { + let query = DB.table("foo").join(table: "bar") { + $0.on(first: "id1", op: .equals, second: "id2") + .orOn(first: "id3", op: .greaterThan, second: "id4") + } + + let expectedJoin = Query.Join(database: DB.provider, table: "foo", type: .inner, joinTable: "bar") + expectedJoin.joinWheres = [ + Query.Where(type: .column(first: "id1", op: .equals, second: "id2"), boolean: .and), + Query.Where(type: .column(first: "id3", op: .greaterThan, second: "id4"), boolean: .or) + ] + XCTAssertEqual(query.joins, [expectedJoin]) + XCTAssertEqual(query.wheres, []) + } + + func testEquality() { + XCTAssertEqual(sampleJoin(of: .inner), sampleJoin(of: .inner)) + XCTAssertNotEqual(sampleJoin(of: .inner), sampleJoin(of: .cross)) + XCTAssertNotEqual(sampleJoin(of: .inner), DB.table("foo")) + } + + private func sampleJoin(of type: Query.JoinType) -> Query.Join { + return Query.Join(database: DB.provider, table: "foo", type: type, joinTable: "bar") + .on(first: "id1", op: .equals, second: "id2") + } +} diff --git a/Tests/Alchemy/SQL/Query/Builder/QueryLockTests.swift b/Tests/Alchemy/SQL/Query/Builder/QueryLockTests.swift new file mode 100644 index 00000000..8281f32d --- /dev/null +++ b/Tests/Alchemy/SQL/Query/Builder/QueryLockTests.swift @@ -0,0 +1,18 @@ +@testable +import Alchemy +import AlchemyTest + +final class QueryLockTests: TestCase { + override func setUp() { + super.setUp() + Database.stub() + } + + func testLock() { + XCTAssertNil(DB.table("foo").lock) + XCTAssertEqual(DB.table("foo").lock(for: .update).lock, Query.Lock(strength: .update, option: nil)) + XCTAssertEqual(DB.table("foo").lock(for: .share).lock, Query.Lock(strength: .share, option: nil)) + XCTAssertEqual(DB.table("foo").lock(for: .update, option: .noWait).lock, Query.Lock(strength: .update, option: .noWait)) + XCTAssertEqual(DB.table("foo").lock(for: .update, option: .skipLocked).lock, Query.Lock(strength: .update, option: .skipLocked)) + } +} diff --git a/Tests/Alchemy/SQL/Query/Builder/QueryOperatorTests.swift b/Tests/Alchemy/SQL/Query/Builder/QueryOperatorTests.swift new file mode 100644 index 00000000..8926e394 --- /dev/null +++ b/Tests/Alchemy/SQL/Query/Builder/QueryOperatorTests.swift @@ -0,0 +1,22 @@ +@testable +import Alchemy +import AlchemyTest + +final class QueryOperatorTests: TestCase { + override func setUp() { + super.setUp() + Database.stub() + } + + func testOperatorDescriptions() { + XCTAssertEqual(Query.Operator.equals.description, "=") + XCTAssertEqual(Query.Operator.lessThan.description, "<") + XCTAssertEqual(Query.Operator.greaterThan.description, ">") + XCTAssertEqual(Query.Operator.lessThanOrEqualTo.description, "<=") + XCTAssertEqual(Query.Operator.greaterThanOrEqualTo.description, ">=") + XCTAssertEqual(Query.Operator.notEqualTo.description, "!=") + XCTAssertEqual(Query.Operator.like.description, "LIKE") + XCTAssertEqual(Query.Operator.notLike.description, "NOT LIKE") + XCTAssertEqual(Query.Operator.raw("foo").description, "foo") + } +} diff --git a/Tests/Alchemy/SQL/Query/Builder/QueryOrderTests.swift b/Tests/Alchemy/SQL/Query/Builder/QueryOrderTests.swift new file mode 100644 index 00000000..93ddff9e --- /dev/null +++ b/Tests/Alchemy/SQL/Query/Builder/QueryOrderTests.swift @@ -0,0 +1,20 @@ +@testable +import Alchemy +import AlchemyTest + +final class QueryOrderTests: TestCase { + override func setUp() { + super.setUp() + Database.stub() + } + + func testOrderBy() { + let query = DB.table("foo") + .orderBy("bar") + .orderBy("baz", direction: .desc) + XCTAssertEqual(query.orders, [ + Query.Order(column: "bar", direction: .asc), + Query.Order(column: "baz", direction: .desc), + ]) + } +} diff --git a/Tests/Alchemy/SQL/Query/Builder/QueryPagingTests.swift b/Tests/Alchemy/SQL/Query/Builder/QueryPagingTests.swift new file mode 100644 index 00000000..65154788 --- /dev/null +++ b/Tests/Alchemy/SQL/Query/Builder/QueryPagingTests.swift @@ -0,0 +1,28 @@ +@testable +import Alchemy +import AlchemyTest + +final class QueryPagingTests: TestCase { + override func setUp() { + super.setUp() + Database.stub() + } + + func testLimit() { + XCTAssertEqual(DB.table("foo").distinct().isDistinct, true) + } + + func testOffset() { + XCTAssertEqual(DB.table("foo").distinct().isDistinct, true) + } + + func testPaging() { + let standardPage = DB.table("foo").forPage(4) + XCTAssertEqual(standardPage.limit, 25) + XCTAssertEqual(standardPage.offset, 75) + + let customPage = DB.table("foo").forPage(2, perPage: 10) + XCTAssertEqual(customPage.limit, 10) + XCTAssertEqual(customPage.offset, 10) + } +} diff --git a/Tests/Alchemy/SQL/Query/Builder/QuerySelectTests.swift b/Tests/Alchemy/SQL/Query/Builder/QuerySelectTests.swift new file mode 100644 index 00000000..b9c36fe7 --- /dev/null +++ b/Tests/Alchemy/SQL/Query/Builder/QuerySelectTests.swift @@ -0,0 +1,36 @@ +@testable +import Alchemy +import AlchemyTest + +final class QuerySelectTests: TestCase { + override func setUp() { + super.setUp() + Database.stub() + } + + func testStartsEmpty() { + let query = DB.table("foo") + XCTAssertEqual(query.table, "foo") + XCTAssertEqual(query.columns, ["*"]) + XCTAssertEqual(query.isDistinct, false) + XCTAssertNil(query.limit) + XCTAssertNil(query.offset) + XCTAssertNil(query.lock) + XCTAssertEqual(query.joins, []) + XCTAssertEqual(query.wheres, []) + XCTAssertEqual(query.groups, []) + XCTAssertEqual(query.havings, []) + XCTAssertEqual(query.orders, []) + } + + func testSelect() { + let specific = DB.table("foo").select(["bar", "baz"]) + XCTAssertEqual(specific.columns, ["bar", "baz"]) + let all = DB.table("foo").select() + XCTAssertEqual(all.columns, ["*"]) + } + + func testDistinct() { + XCTAssertEqual(DB.table("foo").distinct().isDistinct, true) + } +} diff --git a/Tests/Alchemy/SQL/Query/Builder/QueryWhereTests.swift b/Tests/Alchemy/SQL/Query/Builder/QueryWhereTests.swift new file mode 100644 index 00000000..ba6eab0c --- /dev/null +++ b/Tests/Alchemy/SQL/Query/Builder/QueryWhereTests.swift @@ -0,0 +1,119 @@ +@testable +import Alchemy +import AlchemyTest + +final class QueryWhereTests: TestCase { + override func setUp() { + super.setUp() + Database.stub() + } + + func testWhere() { + let query = DB.table("foo") + .where("foo" == 1) + .orWhere("bar" == 2) + XCTAssertEqual(query.wheres, [_andWhere(), _orWhere(key: "bar", value: 2)]) + } + + func testNestedWhere() { + let query = DB.table("foo") + .where { $0.where("foo" == 1).orWhere("bar" == 2) } + .orWhere { $0.where("baz" == 3).orWhere("fiz" == 4) } + XCTAssertEqual(query.wheres, [ + _andWhere(.nested(wheres: [ + _andWhere(), + _orWhere(key: "bar", value: 2) + ])), + _orWhere(.nested(wheres: [ + _andWhere(key: "baz", value: 3), + _orWhere(key: "fiz", value: 4) + ])) + ]) + } + + func testWhereIn() { + let query = DB.table("foo") + .where(key: "foo", in: [1]) + .orWhere(key: "bar", in: [2]) + XCTAssertEqual(query.wheres, [ + _andWhere(.in(key: "foo", values: [.int(1)], type: .in)), + _orWhere(.in(key: "bar", values: [.int(2)], type: .in)), + ]) + } + + func testWhereNotIn() { + let query = DB.table("foo") + .whereNot(key: "foo", in: [1]) + .orWhereNot(key: "bar", in: [2]) + XCTAssertEqual(query.wheres, [ + _andWhere(.in(key: "foo", values: [.int(1)], type: .notIn)), + _orWhere(.in(key: "bar", values: [.int(2)], type: .notIn)), + ]) + } + + func testWhereRaw() { + let query = DB.table("foo") + .whereRaw(sql: "foo", bindings: [1]) + .orWhereRaw(sql: "bar", bindings: [2]) + XCTAssertEqual(query.wheres, [ + _andWhere(.raw(SQL("foo", bindings: [.int(1)]))), + _orWhere(.raw(SQL("bar", bindings: [.int(2)]))), + ]) + } + + func testWhereColumn() { + let query = DB.table("foo") + .whereColumn(first: "foo", op: .equals, second: "bar") + .orWhereColumn(first: "baz", op: .like, second: "fiz") + XCTAssertEqual(query.wheres, [ + _andWhere(.column(first: "foo", op: .equals, second: "bar")), + _orWhere(.column(first: "baz", op: .like, second: "fiz")), + ]) + } + + func testWhereNull() { + let query = DB.table("foo") + .whereNull(key: "foo") + .orWhereNull(key: "bar") + XCTAssertEqual(query.wheres, [ + _andWhere(.raw(SQL("foo IS NULL"))), + _orWhere(.raw(SQL("bar IS NULL"))), + ]) + } + + func testWhereNotNull() { + let query = DB.table("foo") + .whereNotNull(key: "foo") + .orWhereNotNull(key: "bar") + XCTAssertEqual(query.wheres, [ + _andWhere(.raw(SQL("foo IS NOT NULL"))), + _orWhere(.raw(SQL("bar IS NOT NULL"))), + ]) + } + + func testCustomOperators() { + XCTAssertEqual("foo" == 1, _andWhere(op: .equals)) + XCTAssertEqual("foo" != 1, _andWhere(op: .notEqualTo)) + XCTAssertEqual("foo" < 1, _andWhere(op: .lessThan)) + XCTAssertEqual("foo" > 1, _andWhere(op: .greaterThan)) + XCTAssertEqual("foo" <= 1, _andWhere(op: .lessThanOrEqualTo)) + XCTAssertEqual("foo" >= 1, _andWhere(op: .greaterThanOrEqualTo)) + XCTAssertEqual("foo" ~= 1, _andWhere(op: .like)) + } + + private func _andWhere(key: String = "foo", op: Query.Operator = .equals, value: SQLValueConvertible = 1) -> Query.Where { + _andWhere(.value(key: key, op: op, value: value.value)) + } + + private func _orWhere(key: String = "foo", op: Query.Operator = .equals, value: SQLValueConvertible = 1) -> Query.Where { + _orWhere(.value(key: key, op: op, value: value.value)) + } + + private func _andWhere(_ type: Query.WhereType) -> Query.Where { + Query.Where(type: type, boolean: .and) + } + + private func _orWhere(_ type: Query.WhereType) -> Query.Where { + Query.Where(type: type, boolean: .or) + } +} diff --git a/Tests/Alchemy/SQL/Query/DatabaseQueryTests.swift b/Tests/Alchemy/SQL/Query/DatabaseQueryTests.swift new file mode 100644 index 00000000..e1c3f542 --- /dev/null +++ b/Tests/Alchemy/SQL/Query/DatabaseQueryTests.swift @@ -0,0 +1,18 @@ +@testable +import Alchemy +import AlchemyTest + +final class DatabaseQueryTests: TestCase { + override func setUp() { + super.setUp() + Database.stub() + } + + func testTable() { + XCTAssertEqual(DB.from("foo").table, "foo") + } + + func testAlias() { + XCTAssertEqual(DB.from("foo", as: "bar").table, "foo as bar") + } +} diff --git a/Tests/Alchemy/SQL/Query/Grammar/GrammarTests.swift b/Tests/Alchemy/SQL/Query/Grammar/GrammarTests.swift new file mode 100644 index 00000000..64bf930a --- /dev/null +++ b/Tests/Alchemy/SQL/Query/Grammar/GrammarTests.swift @@ -0,0 +1,125 @@ +@testable +import Alchemy +import AlchemyTest + +final class GrammarTests: XCTestCase { + private let grammar = Grammar() + + func testCompileSelect() { + + } + + func testCompileJoins() { + + } + + func testCompileWheres() { + + } + + func testCompileGroups() { + XCTAssertEqual(grammar.compileGroups(["foo, bar, baz"]), "group by foo, bar, baz") + XCTAssertEqual(grammar.compileGroups([]), nil) + } + + func testCompileHavings() { + + } + + func testCompileOrders() { + XCTAssertEqual(grammar.compileOrders([ + Query.Order(column: "foo", direction: .asc), + Query.Order(column: "bar", direction: .desc) + ]), "order by foo asc, bar desc") + XCTAssertEqual(grammar.compileOrders([]), nil) + } + + func testCompileLimit() { + XCTAssertEqual(grammar.compileLimit(1), "limit 1") + XCTAssertEqual(grammar.compileLimit(nil), nil) + } + + func testCompileOffset() { + XCTAssertEqual(grammar.compileOffset(1), "offset 1") + XCTAssertEqual(grammar.compileOffset(nil), nil) + } + + func testCompileInsert() { + + } + + func testCompileInsertAndReturn() { + + } + + func testCompileUpdate() { + + } + + func testCompileDelete() { + + } + + func testCompileLock() { + XCTAssertEqual(grammar.compileLock(nil), nil) + XCTAssertEqual(grammar.compileLock(Query.Lock(strength: .update, option: nil)), "FOR UPDATE") + XCTAssertEqual(grammar.compileLock(Query.Lock(strength: .share, option: nil)), "FOR SHARE") + XCTAssertEqual(grammar.compileLock(Query.Lock(strength: .update, option: .skipLocked)), "FOR UPDATE SKIP LOCKED") + XCTAssertEqual(grammar.compileLock(Query.Lock(strength: .update, option: .noWait)), "FOR UPDATE NO WAIT") + } + + func testCompileCreateTable() { + + } + + func testCompileRenameTable() { + XCTAssertEqual(grammar.compileRenameTable("foo", to: "bar"), """ + ALTER TABLE foo RENAME TO bar + """) + } + + func testCompileDropTable() { + XCTAssertEqual(grammar.compileDropTable("foo"), """ + DROP TABLE foo + """) + } + + func testCompileAlterTable() { + + } + + func testCompileRenameColumn() { + XCTAssertEqual(grammar.compileRenameColumn(on: "foo", column: "bar", to: "baz"), """ + ALTER TABLE foo RENAME COLUMN "bar" TO "baz" + """) + } + + func testCompileCreateIndexes() { + + } + + func testCompileDropIndex() { + XCTAssertEqual(grammar.compileDropIndex(on: "foo", indexName: "bar"), "DROP INDEX bar") + } + + func testColumnTypeString() { + XCTAssertEqual(grammar.columnTypeString(for: .increments), "serial") + XCTAssertEqual(grammar.columnTypeString(for: .int), "int") + XCTAssertEqual(grammar.columnTypeString(for: .bigInt), "bigint") + XCTAssertEqual(grammar.columnTypeString(for: .double), "float8") + XCTAssertEqual(grammar.columnTypeString(for: .string(.limit(10))), "varchar(10)") + XCTAssertEqual(grammar.columnTypeString(for: .string(.unlimited)), "text") + XCTAssertEqual(grammar.columnTypeString(for: .uuid), "uuid") + XCTAssertEqual(grammar.columnTypeString(for: .bool), "bool") + XCTAssertEqual(grammar.columnTypeString(for: .date), "timestamptz") + XCTAssertEqual(grammar.columnTypeString(for: .json), "json") + } + + func testCreateColumnString() { + + } + + func testJsonLiteral() { + XCTAssertEqual(grammar.jsonLiteral(for: "foo"), "'foo'::jsonb") + } +} diff --git a/Tests/Alchemy/SQL/Query/QueryTests.swift b/Tests/Alchemy/SQL/Query/QueryTests.swift new file mode 100644 index 00000000..c0f5537f --- /dev/null +++ b/Tests/Alchemy/SQL/Query/QueryTests.swift @@ -0,0 +1,30 @@ +@testable +import Alchemy +import AlchemyTest + +final class QueryTests: TestCase { + override func setUp() { + super.setUp() + Database.stub() + } + + func testStartsEmpty() { + let query = DB.table("foo") + XCTAssertEqual(query.table, "foo") + XCTAssertEqual(query.columns, ["*"]) + XCTAssertEqual(query.isDistinct, false) + XCTAssertNil(query.limit) + XCTAssertNil(query.offset) + XCTAssertNil(query.lock) + XCTAssertEqual(query.joins, []) + XCTAssertEqual(query.wheres, []) + XCTAssertEqual(query.groups, []) + XCTAssertEqual(query.havings, []) + XCTAssertEqual(query.orders, []) + } + + func testEquality() { + XCTAssertEqual(DB.table("foo"), DB.table("foo")) + XCTAssertNotEqual(DB.table("foo"), DB.table("bar")) + } +} diff --git a/Tests/Alchemy/SQL/Query/SQLUtilitiesTests.swift b/Tests/Alchemy/SQL/Query/SQLUtilitiesTests.swift new file mode 100644 index 00000000..367c656d --- /dev/null +++ b/Tests/Alchemy/SQL/Query/SQLUtilitiesTests.swift @@ -0,0 +1,19 @@ +@testable +import Alchemy +import XCTest + +final class SQLUtilitiesTests: XCTestCase { + func testJoined() { + XCTAssertEqual([ + SQL("where foo = ?", bindings: [.int(1)]), + SQL("bar"), + SQL("where baz = ?", bindings: [.string("two")]) + ].joinedSQL(), SQL("where foo = ? bar where baz = ?", bindings: [.int(1), .string("two")])) + } + + func testDropLeadingBoolean() { + XCTAssertEqual(SQL("foo").droppingLeadingBoolean().statement, "foo") + XCTAssertEqual(SQL("and bar").droppingLeadingBoolean().statement, "bar") + XCTAssertEqual(SQL("or baz").droppingLeadingBoolean().statement, "baz") + } +} diff --git a/Tests/Alchemy/SQL/Rune/Model/Decoding/SQLRowDecoderTests.swift b/Tests/Alchemy/SQL/Rune/Model/Decoding/SQLRowDecoderTests.swift new file mode 100644 index 00000000..77b2a8d7 --- /dev/null +++ b/Tests/Alchemy/SQL/Rune/Model/Decoding/SQLRowDecoderTests.swift @@ -0,0 +1,22 @@ +@testable +import Alchemy +import AlchemyTest + +final class SQLRowDecoderTests: XCTestCase { + func testDecodeThrowing() throws { + let row = StubDatabaseRow() + let decoder = SQLRowDecoder(row: row, keyMapping: .useDefaultKeys, jsonDecoder: JSONDecoder()) + XCTAssertThrowsError(try decoder.singleValueContainer()) + XCTAssertThrowsError(try decoder.unkeyedContainer()) + + let keyed = try decoder.container(keyedBy: DummyKeys.self) + XCTAssertThrowsError(try keyed.nestedUnkeyedContainer(forKey: .foo)) + XCTAssertThrowsError(try keyed.nestedContainer(keyedBy: DummyKeys.self, forKey: .foo)) + XCTAssertThrowsError(try keyed.superDecoder()) + XCTAssertThrowsError(try keyed.superDecoder(forKey: .foo)) + } +} + +private enum DummyKeys: String, CodingKey { + case foo +} diff --git a/Tests/Alchemy/SQL/Rune/Model/Fields/ModelFieldsTests.swift b/Tests/Alchemy/SQL/Rune/Model/Fields/ModelFieldsTests.swift new file mode 100644 index 00000000..22224dce --- /dev/null +++ b/Tests/Alchemy/SQL/Rune/Model/Fields/ModelFieldsTests.swift @@ -0,0 +1,113 @@ +@testable import Alchemy +import XCTest + +final class ModelFieldsTests: XCTestCase { + func testEncoding() throws { + let uuid = UUID() + let date = Date() + let json = EverythingModel.Nested(string: "foo", int: 1) + let model = EverythingModel( + stringEnum: .one, + intEnum: .two, + doubleEnum: .three, + bool: true, + string: "foo", + double: 1.23, + float: 2.0, + int: 1, + int8: 2, + int16: 3, + int32: 4, + int64: 5, + uint: 6, + uint8: 7, + uint16: 8, + uint32: 9, + uint64: 10, + nested: EverythingModel.Nested(string: "foo", int: 1), + date: date, + uuid: uuid, + belongsTo: .pk(1) + ) + + let jsonData = try EverythingModel.jsonEncoder.encode(json) + let expectedFields: [String: SQLValueConvertible] = [ + "string_enum": "one", + "int_enum": 2, + "double_enum": 3.0, + "bool": true, + "string": "foo", + "double": 1.23, + "float": 2.0, + "int": 1, + "int8": 2, + "int16": 3, + "int32": 4, + "int64": 5, + "uint": 6, + "uint8": 7, + "uint16": 8, + "uint32": 9, + "uint64": 10, + "nested": SQLValue.json(jsonData), + "date": SQLValue.date(date), + "uuid": SQLValue.uuid(uuid), + "belongs_to_id": 1, + "belongs_to_optional_id": SQLValue.null, + ] + + XCTAssertEqual("everything_models", EverythingModel.tableName) + XCTAssertEqual(expectedFields.mapValues(\.value), try model.fields()) + } + + func testKeyMapping() throws { + let model = CustomKeyedModel.pk(0) + let fields = try model.fields() + XCTAssertEqual("CustomKeyedModels", CustomKeyedModel.tableName) + XCTAssertEqual([ + "id", + "val1", + "valueTwo", + "valueThreeInt", + "snake_case" + ].sorted(), fields.map { $0.key }.sorted()) + } + + func testCustomJSONEncoder() throws { + let json = DatabaseJSON(val1: "one", val2: Date()) + let jsonData = try CustomDecoderModel.jsonEncoder.encode(json) + let model = CustomDecoderModel(json: json) + + XCTAssertEqual("custom_decoder_models", CustomDecoderModel.tableName) + XCTAssertEqual(try model.fields(), [ + "json": .json(jsonData) + ]) + } +} + +private struct DatabaseJSON: Codable { + var val1: String + var val2: Date +} + +private struct CustomKeyedModel: Model { + static var keyMapping: DatabaseKeyMapping = .useDefaultKeys + + var id: Int? + var val1: String = "foo" + var valueTwo: Int = 0 + var valueThreeInt: Int = 1 + var snake_case: String = "bar" +} + +private struct CustomDecoderModel: Model { + static var jsonEncoder: JSONEncoder = { + let encoder = JSONEncoder() + encoder.dateEncodingStrategy = .iso8601 + encoder.outputFormatting = .sortedKeys + return encoder + }() + + var id: Int? + var json: DatabaseJSON +} diff --git a/Tests/Alchemy/SQL/Rune/Model/ModelCrudTests.swift b/Tests/Alchemy/SQL/Rune/Model/ModelCrudTests.swift new file mode 100644 index 00000000..3100b0e8 --- /dev/null +++ b/Tests/Alchemy/SQL/Rune/Model/ModelCrudTests.swift @@ -0,0 +1,186 @@ +import AlchemyTest + +final class ModelCrudTests: TestCase { + override func setUp() { + super.setUp() + Database.fake(migrations: [TestModelMigration(), TestModelCustomIdMigration()]) + } + + func testAll() async throws { + let all = try await TestModel.all() + XCTAssertEqual(all, []) + + try await TestModel.seed(5) + + let newAll = try await TestModel.all() + XCTAssertEqual(newAll.count, 5) + } + + func testSearch() async throws { + let first = try await TestModel.first() + XCTAssertEqual(first, nil) + + let model = try await TestModel(foo: "baz", bar: false).insertReturn() + + let findById = try await TestModel.find(model.getID()) + XCTAssertEqual(findById, model) + + do { + _ = try await TestModel.find(999, or: TestError()) + XCTFail("`find(_:or:)` should throw on a missing element.") + } catch { + // do nothing + } + + let missingId = try await TestModel.find(999) + XCTAssertEqual(missingId, nil) + + let findByWhere = try await TestModel.find("foo" == "baz") + XCTAssertEqual(findByWhere, model) + + let newFirst = try await TestModel.first() + XCTAssertEqual(newFirst, model) + + let unwrappedFirst = try await TestModel.unwrapFirstWhere("bar" == false, or: TestError()) + XCTAssertEqual(unwrappedFirst, model) + + let allWhere = try await TestModel.allWhere("bar" == false) + XCTAssertEqual(allWhere, [model]) + + do { + _ = try await TestModel.ensureNotExists("id" == model.id, else: TestError()) + XCTFail("`ensureNotExists` should throw on a matching element.") + } catch { + // do nothing + } + } + + func testRandom() async throws { + let random = try await TestModel.random() + XCTAssertEqual(random, nil) + + try await TestModel.seed() + + let newRandom = try await TestModel.random() + XCTAssertNotNil(newRandom) + } + + func testDelete() async throws { + let models = try await TestModel.seed(5) + guard let first = models.first else { + XCTFail("There should be 5 models in the database.") + return + } + + try await TestModel.delete(first.getID()) + + let count = try await TestModel.all().count + XCTAssertEqual(count, 4) + + try await TestModel.deleteAll() + let newCount = try await TestModel.all().count + XCTAssertEqual(newCount, 0) + + let model = try await TestModel.seed() + try await TestModel.delete("foo" == model.foo) + AssertEqual(try await TestModel.all().count, 0) + + let modelNew = try await TestModel.seed() + try await TestModel.deleteAll(where: "foo" == modelNew.foo) + AssertEqual(try await TestModel.all().count, 0) + } + + func testDeleteAll() async throws { + let models = try await TestModel.seed(5) + try await models.deleteAll() + AssertEqual(try await TestModel.all().count, 0) + } + + func testInsertReturn() async throws { + let model = try await TestModel(foo: "bar", bar: false).insertReturn() + XCTAssertEqual(model.foo, "bar") + XCTAssertEqual(model.bar, false) + + let customId = try await TestModelCustomId(foo: "bar").insertReturn() + XCTAssertEqual(customId.foo, "bar") + } + + func testUpdate() async throws { + var model = try await TestModel.seed() + let id = try model.getID() + model.foo = "baz" + AssertNotEqual(try await TestModel.find(id), model) + + _ = try await model.save() + AssertEqual(try await TestModel.find(id), model) + + _ = try await model.update(with: ["foo": "foo"]) + AssertEqual(try await TestModel.find(id)?.foo, "foo") + + _ = try await TestModel.update(id, with: ["foo": "qux"]) + AssertEqual(try await TestModel.find(id)?.foo, "qux") + } + + func testSync() async throws { + let model = try await TestModel.seed() + _ = try await model.update { $0.foo = "bar" } + AssertNotEqual(model.foo, "bar") + AssertEqual(try await model.sync().foo, "bar") + + do { + let unsavedModel = TestModel(id: 12345, foo: "one", bar: false) + _ = try await unsavedModel.sync() + XCTFail("Syncing an unsaved model should throw") + } catch {} + + do { + let unsavedModel = TestModel(foo: "two", bar: true) + _ = try await unsavedModel.sync() + XCTFail("Syncing an unsaved model should throw") + } catch {} + } +} + +private struct TestError: Error {} + +private struct TestModelCustomId: Model { + var id: UUID? = UUID() + var foo: String +} + +private struct TestModel: Model, Seedable, Equatable { + var id: Int? + var foo: String + var bar: Bool + + static func generate() async throws -> TestModel { + TestModel(foo: faker.lorem.word(), bar: faker.number.randomBool()) + } +} + +private struct TestModelMigration: Migration { + func up(schema: Schema) { + schema.create(table: "test_models") { + $0.increments("id").primary() + $0.string("foo").notNull() + $0.bool("bar").notNull() + } + } + + func down(schema: Schema) { + schema.drop(table: "test_models") + } +} + +private struct TestModelCustomIdMigration: Migration { + func up(schema: Schema) { + schema.create(table: "test_model_custom_ids") { + $0.uuid("id").primary() + $0.string("foo").notNull() + } + } + + func down(schema: Schema) { + schema.drop(table: "test_model_custom_ids") + } +} diff --git a/Tests/Alchemy/SQL/Rune/Model/ModelPrimaryKeyTests.swift b/Tests/Alchemy/SQL/Rune/Model/ModelPrimaryKeyTests.swift new file mode 100644 index 00000000..1ea95630 --- /dev/null +++ b/Tests/Alchemy/SQL/Rune/Model/ModelPrimaryKeyTests.swift @@ -0,0 +1,84 @@ +@testable +import Alchemy +import AlchemyTest + +final class ModelPrimaryKeyTests: XCTestCase { + func testPrimaryKeyFromSqlValue() { + let uuid = UUID() + XCTAssertEqual(try UUID(value: .string(uuid.uuidString)), uuid) + XCTAssertThrowsError(try UUID(value: .int(1))) + XCTAssertEqual(try Int(value: .int(1)), 1) + XCTAssertThrowsError(try Int(value: .string("foo"))) + XCTAssertEqual(try String(value: .string("foo")), "foo") + XCTAssertThrowsError(try String(value: .bool(false))) + } + + func testPk() { + XCTAssertEqual(TestModel.pk(123).id, 123) + } + + func testDummyDecoderThrowing() throws { + let decoder = DummyDecoder() + XCTAssertThrowsError(try decoder.singleValueContainer()) + XCTAssertThrowsError(try decoder.unkeyedContainer()) + + let keyed = try decoder.container(keyedBy: DummyKeys.self) + XCTAssertThrowsError(try keyed.nestedUnkeyedContainer(forKey: .one)) + XCTAssertThrowsError(try keyed.nestedContainer(keyedBy: DummyKeys.self, forKey: .one)) + XCTAssertThrowsError(try keyed.superDecoder()) + XCTAssertThrowsError(try keyed.superDecoder(forKey: .one)) + } +} + +private enum DummyKeys: String, CodingKey { + case one +} + +private struct TestModel: Model { + struct Nested: Codable { + let string: String + } + + enum Enum: String, ModelEnum { + case one, two, three + } + + var id: Int? + + // Enum + let `enum`: Enum + + // Keyed + let bool: Bool + let string: String + let double: Double + let float: Float + let int: Int + let int8: Int8 + let int16: Int16 + let int32: Int32 + let int64: Int64 + let uint: UInt + let uint8: UInt8 + let uint16: UInt16 + let uint32: UInt32 + let uint64: UInt64 + let nested: Nested + + // Arrays + let boolArray: [Bool] + let stringArray: [String] + let doubleArray: [Double] + let floatArray: [Float] + let intArray: [Int] + let int8Array: [Int8] + let int16Array: [Int16] + let int32Array: [Int32] + let int64Array: [Int64] + let uintArray: [UInt] + let uint8Array: [UInt8] + let uint16Array: [UInt16] + let uint32Array: [UInt32] + let uint64Array: [UInt64] + let nestedArray: [Nested] +} diff --git a/Tests/Alchemy/SQL/Rune/Model/ModelQueryTests.swift b/Tests/Alchemy/SQL/Rune/Model/ModelQueryTests.swift new file mode 100644 index 00000000..0e4a80dc --- /dev/null +++ b/Tests/Alchemy/SQL/Rune/Model/ModelQueryTests.swift @@ -0,0 +1,83 @@ +import AlchemyTest + +final class ModelQueryTests: TestCase { + override func setUp() { + super.setUp() + Database.fake(migrations: [ + TestModelMigration(), + TestParentMigration() + ]) + } + + func testWith() async throws { + try await TestParent.seed() + let child = try await TestModel.seed() + let fetchedChild = try await TestModel.query().with(\.$testParent).first() + XCTAssertEqual(fetchedChild, child) + } +} + +private struct TestError: Error {} + +private struct TestParent: Model, Seedable, Equatable { + var id: Int? + var baz: String + + static func generate() async throws -> TestParent { + TestParent(baz: faker.lorem.word()) + } +} + +private struct TestModel: Model, Seedable, Equatable { + var id: Int? + var foo: String + var bar: Bool + + @BelongsTo var testParent: TestParent + + static func generate() async throws -> TestModel { + let parent: TestParent + if let random = try await TestParent.random() { + parent = random + } else { + parent = try await .seed() + } + + return TestModel(foo: faker.lorem.word(), bar: faker.number.randomBool(), testParent: parent) + } + + static func == (lhs: TestModel, rhs: TestModel) -> Bool { + lhs.id == rhs.id && + lhs.foo == rhs.foo && + lhs.bar == rhs.bar && + lhs.$testParent.id == rhs.$testParent.id + } +} + +private struct TestParentMigration: Migration { + func up(schema: Schema) { + schema.create(table: "test_parents") { + $0.increments("id").primary() + $0.string("baz").notNull() + } + } + + func down(schema: Schema) { + schema.drop(table: "test_parents") + } +} + +private struct TestModelMigration: Migration { + func up(schema: Schema) { + schema.create(table: "test_models") { + $0.increments("id").primary() + $0.string("foo").notNull() + $0.bool("bar").notNull() + $0.bigInt("test_parent_id").references("id", on: "test_parents").notNull() + } + } + + func down(schema: Schema) { + schema.drop(table: "test_models") + } +} diff --git a/Tests/Alchemy/SQL/Rune/Relationships/RelationshipMapperTests.swift b/Tests/Alchemy/SQL/Rune/Relationships/RelationshipMapperTests.swift new file mode 100644 index 00000000..9938f0fd --- /dev/null +++ b/Tests/Alchemy/SQL/Rune/Relationships/RelationshipMapperTests.swift @@ -0,0 +1,86 @@ +@testable +import Alchemy +import XCTest + +final class RelationshipMapperTests: XCTestCase { + func testGetSet() { + let mapper = RelationshipMapper() + XCTAssertEqual(mapper.getConfig(for: \.$belongsTo), .defaultBelongsTo()) + XCTAssertEqual(mapper.getConfig(for: \.$hasMany), .defaultHas()) + XCTAssertEqual(mapper.getConfig(for: \.$hasOne), .defaultHas()) + let defaultHas = mapper.getConfig(for: \.$hasOne) + XCTAssertEqual(defaultHas.fromKey, "id") + XCTAssertEqual(defaultHas.toKey, "mapper_model_id") + let val = mapper.config(\.$hasOne) + .from("foo") + .to("bar") + XCTAssertNotEqual(mapper.getConfig(for: \.$hasOne), .defaultHas()) + XCTAssertEqual(mapper.getConfig(for: \.$hasOne), val) + XCTAssertEqual(val.fromKey, "foo") + XCTAssertEqual(val.toKey, "bar") + } + + func testHasThrough() { + let mapper = RelationshipMapper() + let mapping = mapper.config(\.$hasMany).through("foo", from: "bar", to: "baz") + let expected = RelationshipMapping( + .has, + fromTable: "mapper_models", + fromKey: "id", + toTable: "mapper_models", + toKey: "foo_id", + through: .init( + table: "foo", + fromKey: "bar", + toKey: "baz")) + XCTAssertEqual(mapping, expected) + let mappingDefault = mapper.config(\.$hasMany).through("foo") + XCTAssertEqual(mappingDefault.through?.fromKey, "mapper_model_id") + XCTAssertEqual(mappingDefault.through?.toKey, "id") + } + + func testBelongsThrough() { + let mapper = RelationshipMapper() + let mapping = mapper.config(\.$belongsTo).through("foo", from: "bar", to: "baz") + let expected = RelationshipMapping( + .belongs, + fromTable: "mapper_models", + fromKey: "foo_id", + toTable: "mapper_models", + toKey: "id", + through: .init( + table: "foo", + fromKey: "bar", + toKey: "baz")) + XCTAssertEqual(mapping, expected) + let mappingDefault = mapper.config(\.$belongsTo).through("foo") + XCTAssertEqual(mappingDefault.through?.fromKey, "id") + XCTAssertEqual(mappingDefault.through?.toKey, "mapper_model_id") + } + + func testThroughPivot() { + let mapper = RelationshipMapper() + let mapping = mapper.config(\.$hasMany).throughPivot("foo", from: "bar", to: "baz") + let expected = RelationshipMapping( + .has, + fromTable: "mapper_models", + fromKey: "id", + toTable: "mapper_models", + toKey: "id", + through: .init( + table: "foo", + fromKey: "bar", + toKey: "baz")) + XCTAssertEqual(mapping, expected) + } +} + +struct MapperModel: Model { + var id: Int? + + @BelongsTo var belongsTo: MapperModel + @BelongsTo var belongsToOptional: MapperModel? + @HasOne var hasOne: MapperModel + @HasOne var hasOneOptional: MapperModel? + @HasMany var hasMany: [MapperModel] +} diff --git a/Tests/Alchemy/SQL/Rune/Relationships/RelationshipTests.swift b/Tests/Alchemy/SQL/Rune/Relationships/RelationshipTests.swift new file mode 100644 index 00000000..341912bb --- /dev/null +++ b/Tests/Alchemy/SQL/Rune/Relationships/RelationshipTests.swift @@ -0,0 +1,28 @@ +@testable +import Alchemy +import XCTest + +final class RelationshipTests: XCTestCase { + func testModelMaybeOptional() throws { + let nilModel: TestModel? = nil + let doubleOptionalNilModel: TestModel?? = nil + XCTAssertEqual(nilModel.id, nil) + XCTAssertEqual(try Optional.from(nilModel), nil) + XCTAssertEqual(try Optional.from(doubleOptionalNilModel), nil) + + let optionalModel: TestModel? = TestModel(id: 1) + let doubleOptionalModel: TestModel?? = TestModel(id: 1) + XCTAssertEqual(optionalModel.id, 1) + XCTAssertEqual(try Optional.from(optionalModel), optionalModel) + XCTAssertEqual(try Optional.from(doubleOptionalModel), optionalModel) + + let model: TestModel = TestModel(id: 1) + XCTAssertEqual(model.id, 1) + XCTAssertEqual(try TestModel.from(model), model) + XCTAssertThrowsError(try TestModel.from(nil)) + } +} + +private struct TestModel: Model, Equatable { + var id: Int? +} diff --git a/Tests/Alchemy/Scheduler/ScheduleTests.swift b/Tests/Alchemy/Scheduler/ScheduleTests.swift new file mode 100644 index 00000000..e9ce8fe1 --- /dev/null +++ b/Tests/Alchemy/Scheduler/ScheduleTests.swift @@ -0,0 +1,81 @@ +@testable import Alchemy +import XCTest + +final class ScheduleTests: XCTestCase { + func testDayOfWeek() { + XCTAssertEqual([DayOfWeek.sun, .mon, .tue, .wed, .thu, .fri, .sat, .sun], [0, 1, 2, 3, 4, 5, 6, 7]) + } + + func testMonth() { + XCTAssertEqual( + [Month.jan, .feb, .mar, .apr, .may, .jun, .jul, .aug, .sep, .oct, .nov, .dec, .jan], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + ) + } + + func testScheduleSecondly() { + Schedule("* * * * * * *").secondly() + } + + func testScheduleMinutely() { + Schedule("0 * * * * * *").minutely() + Schedule("1 * * * * * *").minutely(sec: 1) + } + + func testScheduleHourly() { + Schedule("0 0 * * * * *").hourly() + Schedule("1 2 * * * * *").hourly(min: 2, sec: 1) + } + + func testScheduleDaily() { + Schedule("0 0 0 * * * *").daily() + Schedule("1 2 3 * * * *").daily(hr: 3, min: 2, sec: 1) + } + + func testScheduleWeekly() { + Schedule("0 0 0 * * 0 *").weekly() + Schedule("1 2 3 * * 4 *").weekly(day: .thu, hr: 3, min: 2, sec: 1) + } + + func testScheduleMonthly() { + Schedule("0 0 0 1 * * *").monthly() + Schedule("1 2 3 4 * * *").monthly(day: 4, hr: 3, min: 2, sec: 1) + } + + func testScheduleYearly() { + Schedule("0 0 0 1 1 * *").yearly() + Schedule("1 2 3 4 5 * *").yearly(month: .may, day: 4, hr: 3, min: 2, sec: 1) + } + + func testCustomSchedule() { + Schedule("0 0 22 * * 1-5 *").expression("0 0 22 * * 1-5 *") + } + + func testNext() { + Schedule { schedule in + let next = schedule.next() + XCTAssertNotNil(next) + if let next = next { + XCTAssertLessThanOrEqual(next, .seconds(1)) + } + }.secondly() + + Schedule { schedule in + let next = schedule.next() + XCTAssertNotNil(next) + if let next = next { + XCTAssertGreaterThan(next, .hours(24 * 365 * 10)) + } + }.expression("0 0 0 1 * * 2060") + } + + func testNoNext() { + Schedule { XCTAssertNil($0.next()) }.expression("0 0 0 11 9 * 1993") + } +} + +extension Schedule { + fileprivate convenience init(_ expectedExpression: String) { + self.init { XCTAssertEqual($0.cronExpression, expectedExpression) } + } +} diff --git a/Tests/Alchemy/Scheduler/SchedulerTests.swift b/Tests/Alchemy/Scheduler/SchedulerTests.swift new file mode 100644 index 00000000..046e8df4 --- /dev/null +++ b/Tests/Alchemy/Scheduler/SchedulerTests.swift @@ -0,0 +1,84 @@ +@testable +import Alchemy +import AlchemyTest + +final class SchedulerTests: TestCase { + private var scheduler = Scheduler(isTesting: true) + private var loop = EmbeddedEventLoop() + + override func setUp() { + super.setUp() + self.scheduler = Scheduler(isTesting: true) + self.loop = EmbeddedEventLoop() + } + + func testScheduleTask() { + let exp = expectation(description: "") + scheduler.run { exp.fulfill() }.daily() + + let loop = EmbeddedEventLoop() + scheduler.start(on: loop) + loop.advanceTime(by: .hours(24)) + + waitForExpectations(timeout: 0.1) + } + + func testScheduleJob() { + struct ScheduledJob: Job, Equatable { + func run() async throws {} + } + + let queue = Queue.fake() + let loop = EmbeddedEventLoop() + + scheduler.job(ScheduledJob()).daily() + scheduler.start(on: loop) + loop.advanceTime(by: .hours(24)) + + let exp = expectation(description: "") + DispatchQueue.global().asyncAfter(deadline: .now() + 0.05) { + queue.assertPushed(ScheduledJob.self) + exp.fulfill() + } + + waitForExpectations(timeout: 0.1) + } + + func testNoRunWithoutStart() { + makeSchedule(invertExpect: true).daily() + waitForExpectations(timeout: kMinTimeout) + } + + func testStart() { + makeSchedule().daily() + scheduler.start(on: loop) + loop.advanceTime(by: .hours(24)) + waitForExpectations(timeout: kMinTimeout) + } + + func testStartTwiceRunsOnce() { + makeSchedule().daily() + scheduler.start(on: loop) + scheduler.start(on: loop) + loop.advanceTime(by: .hours(24)) + waitForExpectations(timeout: kMinTimeout) + } + + func testDoesntRunNoNext() { + makeSchedule(invertExpect: true).expression("0 0 0 11 9 * 1993") + scheduler.start(on: loop) + loop.advanceTime(by: .hours(24)) + + waitForExpectations(timeout: kMinTimeout) + } + + private func makeSchedule(invertExpect: Bool = false) -> Schedule { + let exp = expectation(description: "") + exp.isInverted = invertExpect + return Schedule { + self.scheduler.addWork(schedule: $0) { + exp.fulfill() + } + } + } +} diff --git a/Tests/Alchemy/Utilities/BCryptTests.swift b/Tests/Alchemy/Utilities/BCryptTests.swift new file mode 100644 index 00000000..662571ec --- /dev/null +++ b/Tests/Alchemy/Utilities/BCryptTests.swift @@ -0,0 +1,13 @@ +import AlchemyTest + +final class BcryptTests: TestCase { + func testBcrypt() async throws { + let hashed = try await Bcrypt.hash("foo") + let verify = try await Bcrypt.verify(plaintext: "foo", hashed: hashed) + XCTAssertTrue(verify) + } + + func testCostTooLow() { + XCTAssertThrowsError(try Bcrypt.hashSync("foo", cost: 1)) + } +} diff --git a/Tests/Alchemy/Utilities/UUIDLosslessStringConvertibleTests.swift b/Tests/Alchemy/Utilities/UUIDLosslessStringConvertibleTests.swift new file mode 100644 index 00000000..5c7bb6b6 --- /dev/null +++ b/Tests/Alchemy/Utilities/UUIDLosslessStringConvertibleTests.swift @@ -0,0 +1,12 @@ +import AlchemyTest + +final class UUIDLosslessStringConvertibleTests: XCTestCase { + func testValidUUID() { + let uuid = UUID() + XCTAssertEqual(UUID(uuid.uuidString), uuid) + } + + func testInvalidUUID() { + XCTAssertEqual(UUID("foo"), nil) + } +} diff --git a/Tests/AlchemyTest/Assertions/ClientAssertionTests.swift b/Tests/AlchemyTest/Assertions/ClientAssertionTests.swift new file mode 100644 index 00000000..77f783cd --- /dev/null +++ b/Tests/AlchemyTest/Assertions/ClientAssertionTests.swift @@ -0,0 +1,33 @@ +import AlchemyTest + +final class ClientAssertionTests: TestCase { + func testAssertNothingSent() { + Http.assertNothingSent() + } + + func testAssertSent() async throws { + Http.stub() + _ = try await Http.get("https://localhost:3000/foo?bar=baz") + Http.assertSent(1) { + $0.hasPath("/foo") && + $0.hasQuery("bar", value: "baz") + } + + struct User: Codable { + let name: String + let age: Int + } + + let user = User(name: "Cyanea", age: 35) + _ = try await Http + .withJSON(user) + .post("https://localhost:3000/bar") + + Http.assertSent(2) { + $0.hasMethod(.POST) && + $0.hasPath("/bar") && + $0["name"].string == "Cyanea" && + $0["age"].int == 35 + } + } +} diff --git a/Tests/AlchemyTests/Routing/RouterTests.swift b/Tests/AlchemyTests/Routing/RouterTests.swift deleted file mode 100644 index 074f675f..00000000 --- a/Tests/AlchemyTests/Routing/RouterTests.swift +++ /dev/null @@ -1,300 +0,0 @@ -import NIO -import NIOHTTP1 -import XCTest -@testable import Alchemy - -let kMinTimeout: TimeInterval = 0.01 - -final class RouterTests: XCTestCase { - private var app = TestApp() - - override func setUp() { - super.setUp() - app = TestApp() - app.mockServices() - } - - func testMatch() throws { - self.app.get { _ in "Hello, world!" } - self.app.post { _ in 1 } - self.app.register(.get1) - self.app.register(.post1) - XCTAssertEqual(try self.app.request(TestRequest(method: .GET, path: "", response: "")), "Hello, world!") - XCTAssertEqual(try self.app.request(TestRequest(method: .POST, path: "", response: "")), "1") - XCTAssertEqual(try self.app.request(.get1), TestRequest.get1.response) - XCTAssertEqual(try self.app.request(.post1), TestRequest.post1.response) - } - - func testMissing() throws { - self.app.register(.getEmpty) - self.app.register(.get1) - self.app.register(.post1) - XCTAssertEqual(try self.app.request(.get2), "Not Found") - XCTAssertEqual(try self.app.request(.postEmpty), "Not Found") - } - - func testMiddlewareCalling() throws { - let shouldFulfull = expectation(description: "The middleware should be called.") - - let mw1 = TestMiddleware(req: { request in - shouldFulfull.fulfill() - }) - - let mw2 = TestMiddleware(req: { request in - XCTFail("This middleware should not be called.") - }) - - self.app - .use(mw1) - .register(.get1) - .use(mw2) - .register(.post1) - - _ = try self.app.request(.get1) - - wait(for: [shouldFulfull], timeout: kMinTimeout) - } - - func testMiddlewareCalledWhenError() throws { - let globalFulfill = expectation(description: "") - let global = TestMiddleware(res: { _ in globalFulfill.fulfill() }) - - let mw1Fulfill = expectation(description: "") - let mw1 = TestMiddleware(res: { _ in mw1Fulfill.fulfill() }) - - let mw2Fulfill = expectation(description: "") - let mw2 = TestMiddleware(req: { _ in - struct SomeError: Error {} - mw2Fulfill.fulfill() - throw SomeError() - }) - - app.useAll(global) - .use(mw1) - .use(mw2) - .register(.get1) - - _ = try app.request(.get1) - - wait(for: [globalFulfill, mw1Fulfill, mw2Fulfill], timeout: kMinTimeout) - } - - func testGroupMiddleware() { - let expect = expectation(description: "The middleware should be called once.") - let mw = TestMiddleware(req: { request in - XCTAssertEqual(request.head.uri, TestRequest.post1.path) - XCTAssertEqual(request.head.method, TestRequest.post1.method) - expect.fulfill() - }) - - self.app - .group(middleware: mw) { newRouter in - newRouter.register(.post1) - } - .register(.get1) - - XCTAssertEqual(try self.app.request(.get1), TestRequest.get1.response) - XCTAssertEqual(try self.app.request(.post1), TestRequest.post1.response) - waitForExpectations(timeout: kMinTimeout) - } - - func testMiddlewareOrder() throws { - var stack = [Int]() - let mw1Req = expectation(description: "") - let mw1Res = expectation(description: "") - let mw1 = TestMiddleware { _ in - XCTAssertEqual(stack, []) - mw1Req.fulfill() - stack.append(0) - } res: { _ in - XCTAssertEqual(stack, [0,1,2,3,4]) - mw1Res.fulfill() - } - - let mw2Req = expectation(description: "") - let mw2Res = expectation(description: "") - let mw2 = TestMiddleware { _ in - XCTAssertEqual(stack, [0]) - mw2Req.fulfill() - stack.append(1) - } res: { _ in - XCTAssertEqual(stack, [0,1,2,3]) - mw2Res.fulfill() - stack.append(4) - } - - let mw3Req = expectation(description: "") - let mw3Res = expectation(description: "") - let mw3 = TestMiddleware { _ in - XCTAssertEqual(stack, [0,1]) - mw3Req.fulfill() - stack.append(2) - } res: { _ in - XCTAssertEqual(stack, [0,1,2]) - mw3Res.fulfill() - stack.append(3) - } - - self.app - .use(mw1) - .use(mw2) - .use(mw3) - .register(.getEmpty) - - _ = try self.app.request(.getEmpty) - - waitForExpectations(timeout: kMinTimeout) - } - - func testQueriesIgnored() { - self.app.register(.get1) - XCTAssertEqual(try self.app.request(.get1Queries), TestRequest.get1.response) - } - - func testPathParametersMatch() throws { - let expect = expectation(description: "The handler should be called.") - - let uuidString = UUID().uuidString - let orderedExpectedParameters = [ - PathParameter(parameter: "uuid", stringValue: uuidString), - PathParameter(parameter: "user_id", stringValue: "123"), - ] - - let routeMethod = HTTPMethod.GET - let routeToRegister = "/v1/some_path/:uuid/:user_id" - let routeToCall = "/v1/some_path/\(uuidString)/123" - let routeResponse = "some response" - - self.app.on(routeMethod, at: routeToRegister) { request -> ResponseConvertible in - XCTAssertEqual(request.pathParameters, orderedExpectedParameters) - expect.fulfill() - - return routeResponse - } - - let res = try self.app.request(TestRequest(method: routeMethod, path: routeToCall, response: "")) - print(res ?? "N/A") - - XCTAssertEqual(res, routeResponse) - waitForExpectations(timeout: kMinTimeout) - } - - func testMultipleRequests() { - // What happens if a user registers the same route twice? - } - - func testInvalidPath() { - // What happens if a user registers an invalid path string? - } - - func testForwardSlashIssues() { - // Could update the router to automatically add "/" if URI strings are missing them, - // automatically add/remove trailing "/", etc. - } - - func testGroupedPathPrefix() throws { - self.app - .grouped("group") { app in - app - .register(.get1) - .register(.get2) - .grouped("nested") { app in - app.register(.post1) - } - .register(.post2) - } - .register(.get3) - - XCTAssertEqual(try self.app.request(TestRequest( - method: .GET, - path: "/group\(TestRequest.get1.path)", - response: TestRequest.get1.path - )), TestRequest.get1.response) - - XCTAssertEqual(try self.app.request(TestRequest( - method: .GET, - path: "/group\(TestRequest.get2.path)", - response: TestRequest.get2.path - )), TestRequest.get2.response) - - XCTAssertEqual(try self.app.request(TestRequest( - method: .POST, - path: "/group/nested\(TestRequest.post1.path)", - response: TestRequest.post1.path - )), TestRequest.post1.response) - - XCTAssertEqual(try self.app.request(TestRequest( - method: .POST, - path: "/group\(TestRequest.post2.path)", - response: TestRequest.post2.path - )), TestRequest.post2.response) - - // only available under group prefix - XCTAssertEqual(try self.app.request(TestRequest.get1), "Not Found") - XCTAssertEqual(try self.app.request(TestRequest.get2), "Not Found") - XCTAssertEqual(try self.app.request(TestRequest.post1), "Not Found") - XCTAssertEqual(try self.app.request(TestRequest.post2), "Not Found") - - // defined outside group --> still available without group prefix - XCTAssertEqual(try self.app.request(TestRequest.get3), TestRequest.get3.response) - } -} - -/// Runs the specified callback on a request / response. -struct TestMiddleware: Middleware { - var req: ((Request) throws -> Void)? - var res: ((Response) throws -> Void)? - - func intercept(_ request: Request, next: @escaping Next) throws -> EventLoopFuture { - try req?(request) - return next(request) - .flatMapThrowing { response in - try res?(response) - return response - } - } -} - -extension Application { - @discardableResult - func register(_ test: TestRequest) -> Self { - self.on(test.method, at: test.path, handler: { _ in test.response }) - } - - func request(_ test: TestRequest) throws -> String? { - return try Router.default.handle( - request: Request( - head: .init( - version: .init( - major: 1, - minor: 1 - ), - method: test.method, - uri: test.path, - headers: .init()), - bodyBuffer: nil - ) - ).wait().body?.decodeString() - } -} - -struct TestApp: Application { - func boot() {} -} - -struct TestRequest { - let method: HTTPMethod - let path: String - let response: String - - static let postEmpty = TestRequest(method: .POST, path: "", response: "post empty") - static let post1 = TestRequest(method: .POST, path: "/something", response: "post 1") - static let post2 = TestRequest(method: .POST, path: "/something/else", response: "post 2") - static let post3 = TestRequest(method: .POST, path: "/something_else", response: "post 3") - - static let getEmpty = TestRequest(method: .GET, path: "", response: "get empty") - static let get1 = TestRequest(method: .GET, path: "/something", response: "get 1") - static let get1Queries = TestRequest(method: .GET, path: "/something?some=value&other=2", response: "get 1") - static let get2 = TestRequest(method: .GET, path: "/something/else", response: "get 2") - static let get3 = TestRequest(method: .GET, path: "/something_else", response: "get 3") -} diff --git a/Tests/AlchemyTests/Routing/TrieTests.swift b/Tests/AlchemyTests/Routing/TrieTests.swift deleted file mode 100644 index 4a708dc9..00000000 --- a/Tests/AlchemyTests/Routing/TrieTests.swift +++ /dev/null @@ -1,53 +0,0 @@ -@testable import Alchemy -import XCTest - -final class TrieTests: XCTestCase { - func testTrie() { - let trie = RouterTrieNode() - - trie.insert(path: ["one"], storageKey: 0, value: "foo") - trie.insert(path: ["one", "two"], storageKey: 1, value: "bar") - trie.insert(path: ["one", "two", "three"], storageKey: 1, value: "baz") - trie.insert(path: ["one", ":id"], storageKey: 1, value: "doo") - trie.insert(path: ["one", ":id", "two"], storageKey: 2, value: "dar") - trie.insert(path: [], storageKey: 2, value: "daz") - trie.insert(path: ["one", ":id", "two"], storageKey: 3, value: "zoo") - trie.insert(path: ["one", ":id", "two"], storageKey: 4, value: "zar") - trie.insert(path: ["one", ":id", "two"], storageKey: 3, value: "zaz") - trie.insert(path: [":id0", ":id1", ":id2", ":id3"], storageKey: 0, value: "hmm") - - let result1 = trie.search(path: ["one"], storageKey: 0) - let result2 = trie.search(path: ["one", "two"], storageKey: 1) - let result3 = trie.search(path: ["one", "two", "three"], storageKey: 1) - let result4 = trie.search(path: ["one", "zonk"], storageKey: 1) - let result5 = trie.search(path: ["one", "fail", "two"], storageKey: 2) - let result6 = trie.search(path: ["one", "aaa", "two"], storageKey: 3) - let result7 = trie.search(path: ["one", "bbb", "two"], storageKey: 4) - let result8 = trie.search(path: ["1", "2", "3", "4"], storageKey: 0) - let result9 = trie.search(path: ["1", "2", "3", "5", "6"], storageKey: 0) - - XCTAssertEqual(result1?.0, "foo") - XCTAssertEqual(result1?.1, []) - XCTAssertEqual(result2?.0, "bar") - XCTAssertEqual(result2?.1, []) - XCTAssertEqual(result3?.0, "baz") - XCTAssertEqual(result3?.1, []) - XCTAssertEqual(result4?.0, "doo") - XCTAssertEqual(result4?.1, [PathParameter(parameter: "id", stringValue: "zonk")]) - XCTAssertEqual(result5?.0, "dar") - XCTAssertEqual(result5?.1, [PathParameter(parameter: "id", stringValue: "fail")]) - XCTAssertEqual(result6?.0, "zaz") - XCTAssertEqual(result6?.1, [PathParameter(parameter: "id", stringValue: "aaa")]) - XCTAssertEqual(result7?.0, "zar") - XCTAssertEqual(result7?.1, [PathParameter(parameter: "id", stringValue: "bbb")]) - XCTAssertEqual(result8?.0, "hmm") - XCTAssertEqual(result8?.1, [ - PathParameter(parameter: "id0", stringValue: "1"), - PathParameter(parameter: "id1", stringValue: "2"), - PathParameter(parameter: "id2", stringValue: "3"), - PathParameter(parameter: "id3", stringValue: "4"), - ]) - XCTAssertEqual(result9?.0, nil) - XCTAssertEqual(result9?.1, nil) - } -} diff --git a/Tests/AlchemyTests/SQL/Abstract/DatabaseEncodingTests.swift b/Tests/AlchemyTests/SQL/Abstract/DatabaseEncodingTests.swift deleted file mode 100644 index 58d74072..00000000 --- a/Tests/AlchemyTests/SQL/Abstract/DatabaseEncodingTests.swift +++ /dev/null @@ -1,120 +0,0 @@ -@testable import Alchemy -import XCTest - -final class DatabaseEncodingTests: XCTestCase { - func testEncoding() throws { - let uuid = UUID() - let date = Date() - let json = DatabaseJSON(val1: "sample", val2: Date()) - let model = TestModel( - string: "one", - int: 2, - uuid: uuid, - date: date, - bool: true, - double: 3.14159, - json: json, - stringEnum: .third, - intEnum: .two - ) - - let jsonData = try TestModel.jsonEncoder.encode(json) - let expectedFields: [DatabaseField] = [ - DatabaseField(column: "string", value: .string("one")), - DatabaseField(column: "int", value: .int(2)), - DatabaseField(column: "uuid", value: .uuid(uuid)), - DatabaseField(column: "date", value: .date(date)), - DatabaseField(column: "bool", value: .bool(true)), - DatabaseField(column: "double", value: .double(3.14159)), - DatabaseField(column: "json", value: .json(jsonData)), - DatabaseField(column: "string_enum", value: .string("third")), - DatabaseField(column: "int_enum", value: .int(1)), - DatabaseField(column: "test_conversion_caps_test", value: .string("")), - DatabaseField(column: "test_conversion123", value: .string("")), - ] - - XCTAssertEqual("test_models", TestModel.tableName) - XCTAssertEqual(expectedFields, try model.fields()) - } - - func testKeyMapping() throws { - let model = CustomKeyedModel.pk(0) - let fields = try model.fields() - XCTAssertEqual("CustomKeyedModels", CustomKeyedModel.tableName) - XCTAssertEqual([ - "id", - "val1", - "valueTwo", - "valueThreeInt", - "snake_case" - ], fields.map(\.column)) - } - - func testCustomJSONEncoder() throws { - let json = DatabaseJSON(val1: "one", val2: Date()) - let jsonData = try CustomDecoderModel.jsonEncoder.encode(json) - let model = CustomDecoderModel(json: json) - let expectedFields: [DatabaseField] = [ - DatabaseField(column: "json", value: .json(jsonData)) - ] - - XCTAssertEqual("custom_decoder_models", CustomDecoderModel.tableName) - XCTAssertEqual(expectedFields, try model.fields()) - } -} - -private struct DatabaseJSON: Codable { - var val1: String - var val2: Date -} - -private enum IntEnum: Int, ModelEnum { - case one, two, three -} - -private enum StringEnum: String, ModelEnum { - case first, second, third -} - -private struct TestModel: Model { - var id: Int? - var string: String - var int: Int - var uuid: UUID - var date: Date - var bool: Bool - var double: Double - var json: DatabaseJSON - var stringEnum: StringEnum - var intEnum: IntEnum - var testConversionCAPSTest: String = "" - var testConversion123: String = "" - - static var jsonEncoder: JSONEncoder = { - let encoder = JSONEncoder() - encoder.outputFormatting = [.sortedKeys] - return encoder - }() -} - -private struct CustomKeyedModel: Model { - static var keyMapping: DatabaseKeyMapping = .useDefaultKeys - - var id: Int? - var val1: String = "foo" - var valueTwo: Int = 0 - var valueThreeInt: Int = 1 - var snake_case: String = "bar" -} - -private struct CustomDecoderModel: Model { - static var jsonEncoder: JSONEncoder = { - let encoder = JSONEncoder() - encoder.dateEncodingStrategy = .iso8601 - encoder.outputFormatting = .sortedKeys - return encoder - }() - - var id: Int? - var json: DatabaseJSON -}