-
Notifications
You must be signed in to change notification settings - Fork 40
/
update_and_check.rs
257 lines (234 loc) · 8 KB
/
update_and_check.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
//! CTE implementation for "UPDATE with extended return status".
use super::pool::DbConnection;
use async_bb8_diesel::{AsyncRunQueryDsl, ConnectionManager, PoolError};
use diesel::associations::HasTable;
use diesel::pg::Pg;
use diesel::prelude::*;
use diesel::query_builder::*;
use diesel::query_dsl::methods::LoadQuery;
use diesel::query_source::Table;
use diesel::sql_types::Nullable;
use diesel::QuerySource;
use std::marker::PhantomData;
/// A simple wrapper type for Diesel's [`UpdateStatement`], which
/// allows referencing generics with names (and extending usage
/// without re-stating those generic parameters everywhere).
pub trait UpdateStatementExt {
type Table: QuerySource;
type WhereClause;
type Changeset;
fn statement(
self,
) -> UpdateStatement<Self::Table, Self::WhereClause, Self::Changeset>;
}
impl<T, U, V> UpdateStatementExt for UpdateStatement<T, U, V>
where
T: QuerySource,
{
type Table = T;
type WhereClause = U;
type Changeset = V;
fn statement(self) -> UpdateStatement<T, U, V> {
self
}
}
/// Wrapper around [`diesel::update`] for a Table, which allows
/// callers to distinguish between "not found", "found but not updated", and
/// "updated".
///
/// US: [`UpdateStatement`] which we are extending.
/// K: Primary Key type.
pub trait UpdateAndCheck<US, K>
where
US: UpdateStatementExt,
{
/// Nests the existing update statement in a CTE which
/// identifies if the row exists (by ID), even if the row
/// cannot be successfully updated.
fn check_if_exists<Q>(self, key: K) -> UpdateAndQueryStatement<US, K, Q>;
}
// UpdateStatement has four generic parameters:
// - T: Table which is being updated
// - U: Where clause
// - V: Changeset to be applied (default = SetNotCalled)
// - Ret: Returning clause (default = NoReturningClause)
//
// As currently implemented, we only define "UpdateAndCheck" for
// UpdateStatements using the default "Ret" value. This means
// the UpdateAndCheck methods can only be invoked for update statements
// to which a "returning" clause has not yet been added.
//
// This allows our implementation of the CTE to overwrite
// the return behavior of the SQL statement.
impl<US, K> UpdateAndCheck<US, K> for US
where
US: UpdateStatementExt,
US::Table: HasTable<Table = US::Table>
+ Table
+ diesel::query_dsl::methods::FindDsl<K>,
<US::Table as diesel::query_dsl::methods::FindDsl<K>>::Output:
QueryFragment<Pg> + Send + 'static,
K: 'static + Copy + Send,
{
fn check_if_exists<Q>(self, key: K) -> UpdateAndQueryStatement<US, K, Q> {
let find_subquery = Box::new(US::Table::table().find(key));
UpdateAndQueryStatement {
update_statement: self.statement(),
find_subquery,
key_type: PhantomData,
query_type: PhantomData,
}
}
}
/// An UPDATE statement which can be combined (via a CTE)
/// with other statements to also SELECT a row.
#[must_use = "Queries must be executed"]
pub struct UpdateAndQueryStatement<US, K, Q>
where
US: UpdateStatementExt,
{
update_statement:
UpdateStatement<US::Table, US::WhereClause, US::Changeset>,
find_subquery: Box<dyn QueryFragment<Pg> + Send>,
key_type: PhantomData<K>,
query_type: PhantomData<Q>,
}
impl<US, K, Q> QueryId for UpdateAndQueryStatement<US, K, Q>
where
US: UpdateStatementExt,
{
type QueryId = ();
const HAS_STATIC_QUERY_ID: bool = false;
}
/// Result of [`UpdateAndQueryStatement`].
#[derive(Copy, Clone, PartialEq, Debug)]
pub struct UpdateAndQueryResult<Q> {
pub status: UpdateStatus,
pub found: Q,
}
/// Status of [`UpdateAndQueryResult`].
#[derive(Copy, Clone, PartialEq, Debug)]
pub enum UpdateStatus {
/// The row exists and was updated.
Updated,
/// The row exists, but was not updated.
NotUpdatedButExists,
}
// Representation of an UpdateStatement's table.
type UpdateTable<US> = <US as UpdateStatementExt>::Table;
// Representation of Primary Key in Rust.
type PrimaryKey<US> = <UpdateTable<US> as diesel::Table>::PrimaryKey;
// Representation of Primary Key in SQL.
type SerializedPrimaryKey<US> = <PrimaryKey<US> as diesel::Expression>::SqlType;
impl<US, K, Q> UpdateAndQueryStatement<US, K, Q>
where
Self: Send,
US: 'static + UpdateStatementExt,
K: 'static + Copy + PartialEq + Send,
US::Table: 'static + Table + Send,
US::WhereClause: 'static + Send,
US::Changeset: 'static + Send,
Q: std::fmt::Debug + Send + 'static,
{
/// Issues the CTE and parses the result.
///
/// The three outcomes are:
/// - Ok(Row exists and was updated)
/// - Ok(Row exists, but was not updated)
/// - Error (row doesn't exist, or other diesel error)
pub async fn execute_and_check(
self,
pool: &bb8::Pool<ConnectionManager<DbConnection>>,
) -> Result<UpdateAndQueryResult<Q>, PoolError>
where
// We require this bound to ensure that "Self" is runnable as query.
Self: LoadQuery<'static, DbConnection, (Option<K>, Option<K>, Q)>,
{
let (id0, id1, found) =
self.get_result_async::<(Option<K>, Option<K>, Q)>(pool).await?;
let status = if id0 == id1 {
UpdateStatus::Updated
} else {
UpdateStatus::NotUpdatedButExists
};
Ok(UpdateAndQueryResult { status, found })
}
}
type SelectableSqlType<Q> =
<<Q as diesel::Selectable<Pg>>::SelectExpression as Expression>::SqlType;
impl<US, K, Q> Query for UpdateAndQueryStatement<US, K, Q>
where
US: UpdateStatementExt,
US::Table: Table,
Q: Selectable<Pg>,
{
type SqlType = (
Nullable<SerializedPrimaryKey<US>>,
Nullable<SerializedPrimaryKey<US>>,
SelectableSqlType<Q>,
);
}
impl<US, K, Q> RunQueryDsl<DbConnection> for UpdateAndQueryStatement<US, K, Q>
where
US: UpdateStatementExt,
US::Table: Table,
{
}
/// This implementation uses the following CTE:
///
/// ```text
/// // WITH found AS (SELECT <primary key> FROM T WHERE <primary key = value>)
/// // updated AS (UPDATE T SET <constraints> RETURNING *)
/// // SELECT
/// // found.<primary key>
/// // updated.<primary key>
/// // found.*
/// // FROM
/// // found
/// // LEFT JOIN
/// // updated
/// // ON
/// // found.<primary_key> = updated.<primary_key>;
/// ```
impl<US, K, Q> QueryFragment<Pg> for UpdateAndQueryStatement<US, K, Q>
where
US: UpdateStatementExt,
US::Table: HasTable<Table = US::Table> + Table,
PrimaryKey<US>: diesel::Column,
UpdateStatement<US::Table, US::WhereClause, US::Changeset>:
QueryFragment<Pg>,
{
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
out.push_sql("WITH found AS (");
self.find_subquery.walk_ast(out.reborrow())?;
out.push_sql("), updated AS (");
self.update_statement.walk_ast(out.reborrow())?;
// TODO: Only need primary? Or would we actually want
// to pass the returned rows back through the result?
out.push_sql(" RETURNING *) ");
out.push_sql("SELECT");
let name = <PrimaryKey<US> as Column>::NAME;
out.push_sql(" found.");
out.push_identifier(name)?;
out.push_sql(", updated.");
out.push_identifier(name)?;
// TODO: I'd prefer to list all columns explicitly. But how?
// The types exist within Table::AllColumns, and each one
// has a name as "<C as Column>::Name".
// But Table::AllColumns is a tuple, which makes iteration
// a pain.
//
// TODO: Technically, we're repeating the PK here.
out.push_sql(", found.*");
out.push_sql(" FROM found LEFT JOIN updated ON");
out.push_sql(" found.");
out.push_identifier(name)?;
out.push_sql(" = ");
out.push_sql("updated.");
out.push_identifier(name)?;
Ok(())
}
}