Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change comparisons in cursor-based pagination #4287

Closed
wants to merge 14 commits into from
90 changes: 51 additions & 39 deletions crates/db_views/src/post_view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ use lemmy_db_schema::{
ListFn,
Queries,
ReadFn,
FETCH_LIMIT_MAX,
},
ListingType,
SortType,
Expand All @@ -63,9 +64,9 @@ enum Ord {
struct PaginationCursorField<Q, QS> {
then_order_by_desc: fn(Q) -> Q,
then_order_by_asc: fn(Q) -> Q,
le: fn(&PostAggregates) -> Box<dyn BoxableExpression<QS, Pg, SqlType = sql_types::Bool>>,
ge: fn(&PostAggregates) -> Box<dyn BoxableExpression<QS, Pg, SqlType = sql_types::Bool>>,
ne: fn(&PostAggregates) -> Box<dyn BoxableExpression<QS, Pg, SqlType = sql_types::Bool>>,
lt: fn(&PostAggregates) -> Box<dyn BoxableExpression<QS, Pg, SqlType = sql_types::Bool>>,
gt: fn(&PostAggregates) -> Box<dyn BoxableExpression<QS, Pg, SqlType = sql_types::Bool>>,
eq: fn(&PostAggregates) -> Box<dyn BoxableExpression<QS, Pg, SqlType = sql_types::Bool>>,
}

/// Returns `PaginationCursorField<_, _>` for the given name
Expand All @@ -75,9 +76,9 @@ macro_rules! field {
PaginationCursorField {
then_order_by_desc: |query| QueryDsl::then_order_by(query, post_aggregates::$name.desc()),
then_order_by_asc: |query| QueryDsl::then_order_by(query, post_aggregates::$name.asc()),
le: |e| Box::new(post_aggregates::$name.le(e.$name)),
ge: |e| Box::new(post_aggregates::$name.ge(e.$name)),
ne: |e| Box::new(post_aggregates::$name.ne(e.$name)),
lt: |e| Box::new(post_aggregates::$name.lt(e.$name)),
gt: |e| Box::new(post_aggregates::$name.gt(e.$name)),
eq: |e| Box::new(post_aggregates::$name.eq(e.$name)),
}
};
}
Expand Down Expand Up @@ -498,7 +499,14 @@ fn queries<'a>() -> Queries<
];
let sorts_iter = sorts.iter().flatten();

// This loop does almost the same thing as sorting by and comparing tuples. If the rows were
for (order, field) in sorts_iter.clone() {
query = match order {
Ord::Desc => (field.then_order_by_desc)(query),
Ord::Asc => (field.then_order_by_asc)(query),
};
}

// This loop does almost the same thing as comparing tuples. If the rows were
// only sorted by 1 field called `foo` in descending order, then it would be like this:
//
// ```
Expand All @@ -515,45 +523,43 @@ fn queries<'a>() -> Queries<
// grouped together, and the rows in that group are sorted by the next fields.
// When checking if a row is within the range determined by the cursors, a field
// that's sorted after other fields is only compared if the row and the cursor
// are in the same group created by the previous sort, which is checked by using
// `or` to skip the comparison if any previously sorted field is not equal.
for (i, (order, field)) in sorts_iter.clone().enumerate() {
// Both cursors are treated as inclusive here. `page_after` is made exclusive
// by adding `1` to the offset.
let (then_order_by_field, compare_first, compare_last) = match order {
Ord::Desc => (field.then_order_by_desc, field.le, field.ge),
Ord::Asc => (field.then_order_by_asc, field.ge, field.le),
// are in the same group created by the previous sort. This is checked with a
// condition like `(a > 0) OR (a = 0 AND b > 1) OR (a = 0 AND b = 1 AND c > 2)`.
for (cursor_data, reverse_direction) in
[(&options.page_after, false), (&options.page_before, true)]
{
let Some(cursor_data) = cursor_data else {
continue;
};

query = then_order_by_field(query);
// Combines each `subcondition` using `or`
let mut condition: Box<dyn BoxableExpression<_, Pg, SqlType = sql_types::Bool>> =
Box::new(false.into_sql::<sql_types::Bool>());

for (cursor_data, compare) in [
(&options.page_after, compare_first),
(&options.page_before_or_equal, compare_last),
] {
let Some(cursor_data) = cursor_data else {
continue;
for (i, (order, field)) in sorts_iter.clone().enumerate() {
let compare = if (*order == Ord::Desc) ^ reverse_direction {
field.lt
} else {
field.gt
};
let mut condition: Box<dyn BoxableExpression<_, Pg, SqlType = sql_types::Bool>> =

// Combines comparisons using `and`
let mut subcondition: Box<dyn BoxableExpression<_, Pg, SqlType = sql_types::Bool>> =
Box::new(compare(&cursor_data.0));

// For each field that was sorted before the current one, skip the filter by changing
// `condition` to `true` if the row's value doesn't equal the cursor's value.
// For each field that was sorted before the current one, require it to equal the cursor's
// corresponding value for `subcondition` to be true.
for (_, other_field) in sorts_iter.clone().take(i) {
condition = Box::new(condition.or((other_field.ne)(&cursor_data.0)));
subcondition = Box::new(subcondition.and((other_field.eq)(&cursor_data.0)));
}

query = query.filter(condition);
condition = Box::new(condition.or(subcondition));
}
}

let (limit, mut offset) = limit_and_offset(options.page, options.limit)?;
if options.page_after.is_some() {
// always skip exactly one post because that's the last post of the previous page
// fixing the where clause is more difficult because we'd have to change only the last order-by-where clause
// e.g. WHERE (featured_local<=, hot_rank<=, published<=) to WHERE (<=, <=, <)
offset = 1;
query = query.filter(condition)
}

let (limit, offset) = limit_and_offset(options.page, options.limit)?;
query = query.limit(limit).offset(offset);

debug!("Post View Query: {:?}", debug_query::<Pg, _>(&query));
Expand Down Expand Up @@ -624,7 +630,7 @@ pub struct PostQuery<'a> {
pub page: Option<i64>,
pub limit: Option<i64>,
pub page_after: Option<PaginationCursorData>,
pub page_before_or_equal: Option<PaginationCursorData>,
pub page_before: Option<PaginationCursorData>,
}

impl<'a> PostQuery<'a> {
Expand All @@ -649,12 +655,15 @@ impl<'a> PostQuery<'a> {
person_id,
},
};
let (limit, offset) = limit_and_offset(self.page, self.limit)?;
let (mut limit, offset) = limit_and_offset(self.page, self.limit)?;
if offset != 0 && self.page_after.is_some() {
return Err(Error::QueryBuilderError(
"legacy pagination cannot be combined with v2 pagination".into(),
));
}
// Include the first post after the current page so it can be used as an exclusive bound.
// An inclusive bound would be more complicated because only the last comparison would be inclusive.
limit += 1;
let self_person_id = self
.local_user
.expect("part of the above if")
Expand Down Expand Up @@ -683,6 +692,7 @@ impl<'a> PostQuery<'a> {
PostQuery {
community_id: Some(largest_subscribed),
community_id_just_for_prefetch: true,
limit: Some(limit),
..self.clone()
},
)
Expand All @@ -692,9 +702,9 @@ impl<'a> PostQuery<'a> {
if (v.len() as i64) < limit {
Ok(Some(self.clone()))
} else {
let page_before_or_equal = Some(PaginationCursorData(v.pop().expect("else case").counts));
let page_before = Some(PaginationCursorData(v.pop().expect("else case").counts));
Ok(Some(PostQuery {
page_before_or_equal,
page_before,
..self.clone()
}))
}
Expand All @@ -704,7 +714,9 @@ impl<'a> PostQuery<'a> {
if self.listing_type == Some(ListingType::Subscribed)
&& self.community_id.is_none()
&& self.local_user.is_some()
&& self.page_before_or_equal.is_none()
&& self.page_before.is_none()
// prevent `limit + 1` from exceeding max
&& self.limit != Some(FETCH_LIMIT_MAX)
{
if let Some(query) = self.prefetch_upper_bound_for_page_before(pool).await? {
queries().list(pool, query).await
Expand Down