diff --git a/db/migrations/20240224173636_create_transactions.cr b/db/migrations/20240224173636_create_transactions.cr new file mode 100644 index 000000000..419ac3d67 --- /dev/null +++ b/db/migrations/20240224173636_create_transactions.cr @@ -0,0 +1,15 @@ +class CreateTransactions::V20240224173636 < Avram::Migrator::Migration::V1 + def migrate + create table_for(Transaction) do + primary_key id : Int64 + add_timestamps + add type : Int32 + add soft_deleted_at : Time? + add_belongs_to user : User, on_delete: :cascade + end + end + + def rollback + drop table_for(Transaction) + end +end diff --git a/db/migrations/20240225160631_create_follows.cr b/db/migrations/20240225160631_create_follows.cr new file mode 100644 index 000000000..62f72b6f6 --- /dev/null +++ b/db/migrations/20240225160631_create_follows.cr @@ -0,0 +1,15 @@ +class CreateFollows::V20240225160631 < Avram::Migrator::Migration::V1 + def migrate + create table_for(Follow) do + primary_key id : Int64 + add_timestamps + add soft_deleted_at : Time? + add_belongs_to follower : User, on_delete: :cascade + add_belongs_to followee : User, on_delete: :cascade + end + end + + def rollback + drop table_for(Follow) + end +end diff --git a/spec/avram/preloading/preloading_has_many_spec.cr b/spec/avram/preloading/preloading_has_many_spec.cr index 739c8d5c2..7a7b88e70 100644 --- a/spec/avram/preloading/preloading_has_many_spec.cr +++ b/spec/avram/preloading/preloading_has_many_spec.cr @@ -270,4 +270,49 @@ describe "Preloading has_many associations" do end end end + + describe "override base_query_class" do + it "uses the custom query class to ignore soft_deleted records" do + user = UserFactory.create + good_txn = TransactionFactory.create(&.user(user)) + deleted_txn = TransactionFactory.create(&.user(user).soft_deleted_at(1.day.ago)) + + u = UserQuery.new.preload_transactions.first + u.transactions_count.should eq(1) + u.transactions.size.should eq(1) + ids = u.transactions.map(&.id) + ids.should contain(good_txn.id) + ids.should_not contain(deleted_txn.id) + end + + it "has an escape hatch" do + user = UserFactory.create + good_txn = TransactionFactory.create(&.user(user)) + deleted_txn = TransactionFactory.create(&.user(user).soft_deleted_at(1.day.ago)) + + u = UserQuery.new.preload_transactions(Transaction::BaseQuery.new).first + # NOTE: This is an edge case. It uses the base_query_class defined on the association + # not what was preloaded. + u.transactions_count.should eq(1) + u.transactions.size.should eq(2) + ids = u.transactions.map(&.id) + ids.should contain(good_txn.id) + ids.should contain(deleted_txn.id) + end + + it "extends the custom base query class" do + user = UserFactory.create + non_special_txn = TransactionFactory.create(&.user(user)) + deleted_txn = TransactionFactory.create(&.user(user).soft_deleted_at(1.day.ago)) + special_txn = TransactionFactory.create(&.user(user).type(Transaction::Type::Special)) + + u = UserQuery.new.preload_transactions(&.special).first + u.transactions_count.should eq(2) + u.transactions.size.should eq(1) + ids = u.transactions.map(&.id) + ids.should_not contain(non_special_txn.id) + ids.should_not contain(deleted_txn.id) + ids.should contain(special_txn.id) + end + end end diff --git a/spec/avram/preloading/preloading_has_many_through_spec.cr b/spec/avram/preloading/preloading_has_many_through_spec.cr index 7bba77dae..645eec173 100644 --- a/spec/avram/preloading/preloading_has_many_through_spec.cr +++ b/spec/avram/preloading/preloading_has_many_through_spec.cr @@ -210,4 +210,44 @@ describe "Preloading has_many through associations" do end end end + + describe "override base_query_class" do + it "uses the custom query class to ignore soft_deleted records" do + user = UserFactory.create + new_friend = UserFactory.create + not_friend = UserFactory.create + FollowFactory.create(&.followee(user).follower(new_friend)) + FollowFactory.create(&.followee(user).follower(not_friend).soft_deleted_at(1.day.ago)) + + u = UserQuery.new.preload_followers.first + ids = u.followers.map(&.id) + ids.should contain(new_friend.id) + ids.should_not contain(not_friend.id) + end + + it "has an escape hatch" do + user = UserFactory.create + new_friend = UserFactory.create + not_friend = UserFactory.create + FollowFactory.create(&.followee(user).follower(new_friend)) + FollowFactory.create(&.followee(user).follower(not_friend).soft_deleted_at(1.day.ago)) + + # Preloads work in a lot of different ways, so we need to account for + # all of the different method options + u = UserQuery.new.preload_followers(through: Follow::BaseQuery.new).id(user.id).first + ids = u.followers.map(&.id) + ids.should contain(new_friend.id) + ids.should contain(not_friend.id) + + u = UserQuery.new.preload_followers(User::BaseQuery.new, through: Follow::BaseQuery.new).id(user.id).first + ids = u.followers.map(&.id) + ids.should contain(new_friend.id) + ids.should contain(not_friend.id) + + u = UserQuery.new.preload_followers(through: Follow::BaseQuery.new, &.preload_follows).id(user.id).first + ids = u.followers.map(&.id) + ids.should contain(new_friend.id) + ids.should contain(not_friend.id) + end + end end diff --git a/spec/avram/queryable_spec.cr b/spec/avram/queryable_spec.cr index 3c690042c..4cd363c14 100644 --- a/spec/avram/queryable_spec.cr +++ b/spec/avram/queryable_spec.cr @@ -1247,7 +1247,7 @@ describe Avram::Queryable do UserQuery.new.select_count.should eq 10 # NOTE: we don't test rows_affected here because this isn't # available with a truncate statement - UserQuery.truncate + UserQuery.truncate(cascade: true) UserQuery.new.select_count.should eq 0 end diff --git a/spec/support/factories/follow_factory.cr b/spec/support/factories/follow_factory.cr new file mode 100644 index 000000000..4dfcf7e6f --- /dev/null +++ b/spec/support/factories/follow_factory.cr @@ -0,0 +1,20 @@ +class FollowFactory < BaseFactory + def initialize + before_save do + if operation.follower_id.value.nil? + follower(UserFactory.create) + end + if operation.followee_id.value.nil? + followee(UserFactory.create) + end + end + end + + def follower(u : User) + follower_id(u.id) + end + + def followee(u : User) + followee_id(u.id) + end +end diff --git a/spec/support/factories/transaction_factory.cr b/spec/support/factories/transaction_factory.cr new file mode 100644 index 000000000..88ff4a7d0 --- /dev/null +++ b/spec/support/factories/transaction_factory.cr @@ -0,0 +1,13 @@ +class TransactionFactory < BaseFactory + def initialize + before_save do + if operation.user_id.value.nil? + user(UserFactory.create) + end + end + end + + def user(u : User) + user_id(u.id) + end +end diff --git a/spec/support/models/follow.cr b/spec/support/models/follow.cr new file mode 100644 index 000000000..34c466c30 --- /dev/null +++ b/spec/support/models/follow.cr @@ -0,0 +1,18 @@ +class Follow < BaseModel + include Avram::SoftDelete::Model + + table do + column soft_deleted_at : Time? + + belongs_to follower : User + belongs_to followee : User + end +end + +class FollowQuery < Follow::BaseQuery + include Avram::SoftDelete::Query + + def initialize + defaults &.only_kept + end +end diff --git a/spec/support/models/transaction.cr b/spec/support/models/transaction.cr new file mode 100644 index 000000000..583b6f6e0 --- /dev/null +++ b/spec/support/models/transaction.cr @@ -0,0 +1,26 @@ +class Transaction < BaseModel + include Avram::SoftDelete::Model + + enum Type + Unknown + Special + end + + table do + column type : Transaction::Type = Transaction::Type::Unknown + column soft_deleted_at : Time? + belongs_to user : User + end +end + +class TransactionQuery < Transaction::BaseQuery + include Avram::SoftDelete::Query + + def initialize + defaults &.only_kept + end + + def special + type(Transaction::Type::Special) + end +end diff --git a/spec/support/models/user.cr b/spec/support/models/user.cr index 0561006d4..a2ce6f1c9 100644 --- a/spec/support/models/user.cr +++ b/spec/support/models/user.cr @@ -11,6 +11,9 @@ class User < BaseModel column average_score : Float64? column available_for_hire : Bool? has_one sign_in_credential : SignInCredential? + has_many transactions : Transaction, base_query_class: TransactionQuery + has_many follows : Follow, foreign_key: :followee_id, base_query_class: FollowQuery + has_many followers : User, through: [:follows, :follower] end end diff --git a/src/avram/associations/has_many.cr b/src/avram/associations/has_many.cr index b4a951441..a70251861 100644 --- a/src/avram/associations/has_many.cr +++ b/src/avram/associations/has_many.cr @@ -1,5 +1,5 @@ module Avram::Associations::HasMany - macro has_many(type_declaration, through = nil, foreign_key = nil) + macro has_many(type_declaration, through = nil, foreign_key = nil, base_query_class = nil) {% if !through.is_a?(NilLiteral) && (!through.is_a?(ArrayLiteral) || through.any? { |item| !item.is_a?(SymbolLiteral) }) %} {% through.raise <<-ERROR 'through' on #{@type.name} must be given an Array(Symbol). Instead, got: #{through} @@ -31,28 +31,29 @@ module Avram::Associations::HasMany {% end %} {% foreign_key = foreign_key.id %} + {% model = type_declaration.type %} + {% query_class = base_query_class || "#{model}::BaseQuery".id %} association \ assoc_name: :{{ assoc_name }}, type: {{ type_declaration.type }}, foreign_key: :{{ foreign_key }}, through: {{ through }}, - relationship_type: :has_many - - {% model = type_declaration.type %} + relationship_type: :has_many, + base_query_class: {{ query_class }} define_has_many_lazy_loading({{ assoc_name }}, {{ model }}, {{ foreign_key }}, {{ through }}) - define_has_many_base_query({{ @type }}, {{ assoc_name }}, {{ model }}, {{ foreign_key }}, {{ through }}) + define_has_many_base_query({{ @type }}, {{ assoc_name }}, {{ model }}, {{ foreign_key }}, {{ through }}, {{ query_class }}) end - private macro define_has_many_base_query(class_type, assoc_name, model, foreign_key, through) + private macro define_has_many_base_query(class_type, assoc_name, model, foreign_key, through, query_class) class BaseQuery def self.preload_{{ assoc_name }}(record : {{ class_type }}, force : Bool = false) : {{ class_type }} - preload_{{ assoc_name }}(record: record, preload_query: {{ model }}::BaseQuery.new, force: force) + preload_{{ assoc_name }}(record: record, preload_query: {{ query_class }}.new, force: force) end def self.preload_{{ assoc_name }}(record : {{ class_type }}, force : Bool = false) : {{ class_type }} - modified_query = yield {{ model }}::BaseQuery.new + modified_query = yield {{ query_class }}.new preload_{{ assoc_name }}(record: record, preload_query: modified_query, force: force) end @@ -73,11 +74,11 @@ module Avram::Associations::HasMany {% end %} def self.preload_{{ assoc_name }}(records : Enumerable({{ class_type }}), force : Bool = false) : Array({{ class_type }}) - preload_{{ assoc_name }}(records: records, preload_query: {{ model }}::BaseQuery.new, force: force) + preload_{{ assoc_name }}(records: records, preload_query: {{ query_class }}.new, force: force) end def self.preload_{{ assoc_name }}(records : Enumerable({{ class_type }}), force : Bool = false) : Array({{ class_type }}) - modified_query = yield {{ model }}::BaseQuery.new + modified_query = yield {{ query_class }}.new preload_{{ assoc_name }}(records: records, preload_query: modified_query, force: force) end @@ -122,19 +123,36 @@ module Avram::Associations::HasMany end {% end %} + {% if through %} + def preload_{{ assoc_name }}(*, through : Avram::Queryable? = nil) : self + preload_{{ assoc_name }}({{ query_class }}.new, through: through) + end + {% else %} def preload_{{ assoc_name }} : self - preload_{{ assoc_name }}({{ model }}::BaseQuery.new) + preload_{{ assoc_name }}({{ query_class }}.new) end + {% end %} - def preload_{{ assoc_name }} : self - modified_query = yield {{ model }}::BaseQuery.new + {% if through %} + def preload_{{ assoc_name }}(*, through : Avram::Queryable? = nil, &) : self + modified_query = yield {{ query_class }}.new + preload_{{ assoc_name }}(modified_query, through: through) + end + {% else %} + def preload_{{ assoc_name }}(&) : self + modified_query = yield {{ query_class }}.new preload_{{ assoc_name }}(modified_query) end + {% end %} {% if through %} - def preload_{{ assoc_name }}(preload_query : {{ model }}::BaseQuery) : self + def preload_{{ assoc_name }}(preload_query : {{ model }}::BaseQuery, *, through : Avram::Queryable? = nil) : self preload_{{ through.first.id }} do |through_query| - through_query.preload_{{ through[1].id }}(preload_query) + if base_q = through + base_q.preload_{{ through[1].id }}(preload_query) + else + through_query.preload_{{ through[1].id }}(preload_query) + end end add_preload do |records| records.each do |record| @@ -196,10 +214,7 @@ module Avram::Associations::HasMany .map(&.{{ through[1].id }}_count) .sum {% else %} - {{ model }}::BaseQuery - .new - .{{ foreign_key }}(id) - .select_count + {{ assoc_name.id }}_query.select_count {% end %} end @@ -218,10 +233,7 @@ module Avram::Associations::HasMany assoc_results.is_a?(Array) ? assoc_results : [assoc_results] end.compact {% else %} - {{ model }}::BaseQuery - .new - .{{ foreign_key }}(id) - .results + {{ assoc_name.id }}_query.results {% end %} end end diff --git a/src/avram/model.cr b/src/avram/model.cr index c9c2fbb90..c69ffd41e 100644 --- a/src/avram/model.cr +++ b/src/avram/model.cr @@ -178,7 +178,7 @@ abstract class Avram::Model {% for assoc in associations %} def {{ assoc[:assoc_name] }}_query {% if assoc[:relationship_type] == :has_many %} - {{ assoc[:type] }}::BaseQuery.new.{{ assoc[:foreign_key].id }}(id) + {{ assoc[:base_query_class] }}.new.{{ assoc[:foreign_key].id }}(id) {% elsif assoc[:relationship_type] == :belongs_to %} {{ assoc[:type] }}::BaseQuery.new.id({{ assoc[:foreign_key].id }}) {% else %} @@ -261,7 +261,7 @@ abstract class Avram::Model end end - macro association(assoc_name, type, relationship_type, foreign_key = nil, through = nil) - {% ASSOCIATIONS << {type: type, assoc_name: assoc_name.id, foreign_key: foreign_key, relationship_type: relationship_type, through: through} %} + macro association(assoc_name, type, relationship_type, foreign_key = nil, through = nil, base_query_class = nil) + {% ASSOCIATIONS << {type: type, assoc_name: assoc_name.id, foreign_key: foreign_key, relationship_type: relationship_type, through: through, base_query_class: base_query_class} %} end end