From 519d65aa0719c5205598ddf2bf5ef3e1830b11d2 Mon Sep 17 00:00:00 2001 From: Ufuk Kayserilioglu Date: Fri, 8 Nov 2024 19:29:48 +0300 Subject: [PATCH] Use method overloading to improve WhereChain definition `ActiveRecord::QueryMethods#where` returns a `WhereChain` instance only when it is called with no arguments. When called with arguments, it returns the same `ActiveRecord::Relation` instance. Previously, we didn't have proper overloading support in RBI files, so the best thing we could do was act like `WhereChain` was a subclass of the related relation class. This was not ideal because it would allow users to call methods on `WhereChain` that were not actually defined on it. For example, one could do `User.all.where(name: "John").not` which is not valid. Now that we have support for overloading, we can define `where` to return a `WhereChain` instance only when called with no arguments, and in all other cases, we can say that it returns the same relation instance. --- .../dsl/compilers/active_record_relations.rb | 70 ++++++++++++------- .../compilers/active_record_relations_spec.rb | 28 ++++---- 2 files changed, 59 insertions(+), 39 deletions(-) diff --git a/lib/tapioca/dsl/compilers/active_record_relations.rb b/lib/tapioca/dsl/compilers/active_record_relations.rb index 0e009186d..f42345297 100644 --- a/lib/tapioca/dsl/compilers/active_record_relations.rb +++ b/lib/tapioca/dsl/compilers/active_record_relations.rb @@ -210,10 +210,6 @@ def gather_constants query_methods |= ActiveRecord::SpawnMethods.instance_methods(false) # Remove the ones we know are private API query_methods -= [:all, :arel, :build_subquery, :construct_join_dependency, :extensions, :spawn] - # Remove "group" which needs a custom return type for GroupChains - query_methods -= [:group] - # Remove "where" which needs a custom return type for WhereChains - query_methods -= [:where] # Remove the methods that ... query_methods .grep_v(/_clause$/) # end with "_clause" @@ -419,7 +415,7 @@ def create_group_chain_methods(klass) sig { void } def create_relation_where_chain_class - model.create_class(RelationWhereChainClassName, superclass_name: RelationClassName) do |klass| + model.create_class(RelationWhereChainClassName) do |klass| create_where_chain_methods(klass, RelationClassName) klass.create_type_variable("Elem", type: "type_member", fixed: constant_name) end @@ -427,10 +423,7 @@ def create_relation_where_chain_class sig { void } def create_association_relation_where_chain_class - model.create_class( - AssociationRelationWhereChainClassName, - superclass_name: AssociationRelationClassName, - ) do |klass| + model.create_class(AssociationRelationWhereChainClassName) do |klass| create_where_chain_methods(klass, AssociationRelationClassName) klass.create_type_variable("Elem", type: "type_member", fixed: constant_name) end @@ -560,27 +553,21 @@ def create_collection_proxy_methods(klass) sig { void } def create_relation_methods create_relation_method("all") - create_relation_method( - "group", - parameters: [ - create_rest_param("args", type: "T.untyped"), - create_block_param("blk", type: "T.untyped"), - ], - relation_return_type: RelationGroupChainClassName, - association_return_type: AssociationRelationGroupChainClassName, - ) - create_relation_method( - "where", - parameters: [ - create_rest_param("args", type: "T.untyped"), - create_block_param("blk", type: "T.untyped"), - ], - relation_return_type: RelationWhereChainClassName, - association_return_type: AssociationRelationWhereChainClassName, - ) QUERY_METHODS.each do |method_name| case method_name + when :where + create_where_relation_method + when :group + create_relation_method( + "group", + parameters: [ + create_rest_param("args", type: "T.untyped"), + create_block_param("blk", type: "T.untyped"), + ], + relation_return_type: RelationGroupChainClassName, + association_return_type: AssociationRelationGroupChainClassName, + ) when :distinct create_relation_method( method_name.to_s, @@ -1056,6 +1043,35 @@ def create_common_method(name, parameters: [], return_type: nil) ) end + sig { void } + def create_where_relation_method + relation_methods_module.create_method("where") do |method| + method.add_rest_param("args") + + method.add_sig do |sig| + sig.return_type = RelationWhereChainClassName + end + + method.add_sig do |sig| + sig.add_param("args", "T.untyped") + sig.return_type = RelationClassName + end + end + + association_relation_methods_module.create_method("where") do |method| + method.add_rest_param("args") + + method.add_sig do |sig| + sig.return_type = AssociationRelationWhereChainClassName + end + + method.add_sig do |sig| + sig.add_param("args", "T.untyped") + sig.return_type = AssociationRelationClassName + end + end + end + sig do params( name: T.any(Symbol, String), diff --git a/spec/tapioca/dsl/compilers/active_record_relations_spec.rb b/spec/tapioca/dsl/compilers/active_record_relations_spec.rb index 9d9b9462c..e5defa1cb 100644 --- a/spec/tapioca/dsl/compilers/active_record_relations_spec.rb +++ b/spec/tapioca/dsl/compilers/active_record_relations_spec.rb @@ -455,8 +455,9 @@ def upsert(attributes, returning: nil, unique_by: nil); end sig { params(attributes: T::Array[Hash], returning: T.nilable(T.any(T::Array[Symbol], FalseClass)), unique_by: T.nilable(T.any(T::Array[Symbol], Symbol))).returns(ActiveRecord::Result) } def upsert_all(attributes, returning: nil, unique_by: nil); end - sig { params(args: T.untyped, blk: T.untyped).returns(PrivateAssociationRelationWhereChain) } - def where(*args, &blk); end + sig { returns(PrivateAssociationRelationWhereChain) } + sig { params(args: T.untyped).returns(PrivateAssociationRelation) } + def where(*args); end <% if rails_version(">= 7.1") %> sig { params(args: T.untyped, blk: T.untyped).returns(PrivateAssociationRelation) } @@ -608,8 +609,9 @@ def uniq!(*args, &blk); end sig { params(args: T.untyped, blk: T.untyped).returns(PrivateRelation) } def unscope(*args, &blk); end - sig { params(args: T.untyped, blk: T.untyped).returns(PrivateRelationWhereChain) } - def where(*args, &blk); end + sig { returns(PrivateRelationWhereChain) } + sig { params(args: T.untyped).returns(PrivateRelation) } + def where(*args); end <% if rails_version(">= 7.1") %> sig { params(args: T.untyped, blk: T.untyped).returns(PrivateRelation) } @@ -660,7 +662,7 @@ def minimum(column_name); end def sum(column_name = nil, &block); end end - class PrivateAssociationRelationWhereChain < PrivateAssociationRelation + class PrivateAssociationRelationWhereChain Elem = type_member { { fixed: ::Post } } <% if rails_version(">= 7.0") %> @@ -762,7 +764,7 @@ def minimum(column_name); end def sum(column_name = nil, &block); end end - class PrivateRelationWhereChain < PrivateRelation + class PrivateRelationWhereChain Elem = type_member { { fixed: ::Post } } <% if rails_version(">= 7.0") %> @@ -1171,8 +1173,9 @@ def upsert(attributes, returning: nil, unique_by: nil); end sig { params(attributes: T::Array[Hash], returning: T.nilable(T.any(T::Array[Symbol], FalseClass)), unique_by: T.nilable(T.any(T::Array[Symbol], Symbol))).returns(ActiveRecord::Result) } def upsert_all(attributes, returning: nil, unique_by: nil); end - sig { params(args: T.untyped, blk: T.untyped).returns(PrivateAssociationRelationWhereChain) } - def where(*args, &blk); end + sig { returns(PrivateAssociationRelationWhereChain) } + sig { params(args: T.untyped).returns(PrivateAssociationRelation) } + def where(*args); end <% if rails_version(">= 7.1") %> sig { params(args: T.untyped, blk: T.untyped).returns(PrivateAssociationRelation) } @@ -1324,8 +1327,9 @@ def uniq!(*args, &blk); end sig { params(args: T.untyped, blk: T.untyped).returns(PrivateRelation) } def unscope(*args, &blk); end - sig { params(args: T.untyped, blk: T.untyped).returns(PrivateRelationWhereChain) } - def where(*args, &blk); end + sig { returns(PrivateRelationWhereChain) } + sig { params(args: T.untyped).returns(PrivateRelation) } + def where(*args); end <% if rails_version(">= 7.1") %> sig { params(args: T.untyped, blk: T.untyped).returns(PrivateRelation) } @@ -1376,7 +1380,7 @@ def minimum(column_name); end def sum(column_name = nil, &block); end end - class PrivateAssociationRelationWhereChain < PrivateAssociationRelation + class PrivateAssociationRelationWhereChain Elem = type_member { { fixed: ::Post } } <% if rails_version(">= 7.0") %> @@ -1478,7 +1482,7 @@ def minimum(column_name); end def sum(column_name = nil, &block); end end - class PrivateRelationWhereChain < PrivateRelation + class PrivateRelationWhereChain Elem = type_member { { fixed: ::Post } } <% if rails_version(">= 7.0") %>