Skip to content

Commit

Permalink
Merge pull request #2070 from Shopify/uk-improve-where-chain
Browse files Browse the repository at this point in the history
Use method overloading to improve WhereChain definition
  • Loading branch information
egiurleo authored Nov 18, 2024
2 parents dd9dc0e + 519d65a commit 063f134
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 39 deletions.
70 changes: 43 additions & 27 deletions lib/tapioca/dsl/compilers/active_record_relations.rb
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,6 @@ def gather_constants
query_methods |= ActiveRecord::SpawnMethods.instance_methods(false)
# Remove the ones we know are private API
query_methods -= [:all, :arel, :build_subquery, :construct_join_dependency, :extensions, :spawn]
# Remove "group" which needs a custom return type for GroupChains
query_methods -= [:group]
# Remove "where" which needs a custom return type for WhereChains
query_methods -= [:where]
# Remove the methods that ...
query_methods
.grep_v(/_clause$/) # end with "_clause"
Expand Down Expand Up @@ -419,18 +415,15 @@ def create_group_chain_methods(klass)

sig { void }
def create_relation_where_chain_class
model.create_class(RelationWhereChainClassName, superclass_name: RelationClassName) do |klass|
model.create_class(RelationWhereChainClassName) do |klass|
create_where_chain_methods(klass, RelationClassName)
klass.create_type_variable("Elem", type: "type_member", fixed: constant_name)
end
end

sig { void }
def create_association_relation_where_chain_class
model.create_class(
AssociationRelationWhereChainClassName,
superclass_name: AssociationRelationClassName,
) do |klass|
model.create_class(AssociationRelationWhereChainClassName) do |klass|
create_where_chain_methods(klass, AssociationRelationClassName)
klass.create_type_variable("Elem", type: "type_member", fixed: constant_name)
end
Expand Down Expand Up @@ -560,27 +553,21 @@ def create_collection_proxy_methods(klass)
sig { void }
def create_relation_methods
create_relation_method("all")
create_relation_method(
"group",
parameters: [
create_rest_param("args", type: "T.untyped"),
create_block_param("blk", type: "T.untyped"),
],
relation_return_type: RelationGroupChainClassName,
association_return_type: AssociationRelationGroupChainClassName,
)
create_relation_method(
"where",
parameters: [
create_rest_param("args", type: "T.untyped"),
create_block_param("blk", type: "T.untyped"),
],
relation_return_type: RelationWhereChainClassName,
association_return_type: AssociationRelationWhereChainClassName,
)

QUERY_METHODS.each do |method_name|
case method_name
when :where
create_where_relation_method
when :group
create_relation_method(
"group",
parameters: [
create_rest_param("args", type: "T.untyped"),
create_block_param("blk", type: "T.untyped"),
],
relation_return_type: RelationGroupChainClassName,
association_return_type: AssociationRelationGroupChainClassName,
)
when :distinct
create_relation_method(
method_name.to_s,
Expand Down Expand Up @@ -1056,6 +1043,35 @@ def create_common_method(name, parameters: [], return_type: nil)
)
end

sig { void }
def create_where_relation_method
relation_methods_module.create_method("where") do |method|
method.add_rest_param("args")

method.add_sig do |sig|
sig.return_type = RelationWhereChainClassName
end

method.add_sig do |sig|
sig.add_param("args", "T.untyped")
sig.return_type = RelationClassName
end
end

association_relation_methods_module.create_method("where") do |method|
method.add_rest_param("args")

method.add_sig do |sig|
sig.return_type = AssociationRelationWhereChainClassName
end

method.add_sig do |sig|
sig.add_param("args", "T.untyped")
sig.return_type = AssociationRelationClassName
end
end
end

sig do
params(
name: T.any(Symbol, String),
Expand Down
28 changes: 16 additions & 12 deletions spec/tapioca/dsl/compilers/active_record_relations_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,9 @@ def upsert(attributes, returning: nil, unique_by: nil); end
sig { params(attributes: T::Array[Hash], returning: T.nilable(T.any(T::Array[Symbol], FalseClass)), unique_by: T.nilable(T.any(T::Array[Symbol], Symbol))).returns(ActiveRecord::Result) }
def upsert_all(attributes, returning: nil, unique_by: nil); end
sig { params(args: T.untyped, blk: T.untyped).returns(PrivateAssociationRelationWhereChain) }
def where(*args, &blk); end
sig { returns(PrivateAssociationRelationWhereChain) }
sig { params(args: T.untyped).returns(PrivateAssociationRelation) }
def where(*args); end
<% if rails_version(">= 7.1") %>
sig { params(args: T.untyped, blk: T.untyped).returns(PrivateAssociationRelation) }
Expand Down Expand Up @@ -608,8 +609,9 @@ def uniq!(*args, &blk); end
sig { params(args: T.untyped, blk: T.untyped).returns(PrivateRelation) }
def unscope(*args, &blk); end
sig { params(args: T.untyped, blk: T.untyped).returns(PrivateRelationWhereChain) }
def where(*args, &blk); end
sig { returns(PrivateRelationWhereChain) }
sig { params(args: T.untyped).returns(PrivateRelation) }
def where(*args); end
<% if rails_version(">= 7.1") %>
sig { params(args: T.untyped, blk: T.untyped).returns(PrivateRelation) }
Expand Down Expand Up @@ -660,7 +662,7 @@ def minimum(column_name); end
def sum(column_name = nil, &block); end
end
class PrivateAssociationRelationWhereChain < PrivateAssociationRelation
class PrivateAssociationRelationWhereChain
Elem = type_member { { fixed: ::Post } }
<% if rails_version(">= 7.0") %>
Expand Down Expand Up @@ -762,7 +764,7 @@ def minimum(column_name); end
def sum(column_name = nil, &block); end
end
class PrivateRelationWhereChain < PrivateRelation
class PrivateRelationWhereChain
Elem = type_member { { fixed: ::Post } }
<% if rails_version(">= 7.0") %>
Expand Down Expand Up @@ -1171,8 +1173,9 @@ def upsert(attributes, returning: nil, unique_by: nil); end
sig { params(attributes: T::Array[Hash], returning: T.nilable(T.any(T::Array[Symbol], FalseClass)), unique_by: T.nilable(T.any(T::Array[Symbol], Symbol))).returns(ActiveRecord::Result) }
def upsert_all(attributes, returning: nil, unique_by: nil); end
sig { params(args: T.untyped, blk: T.untyped).returns(PrivateAssociationRelationWhereChain) }
def where(*args, &blk); end
sig { returns(PrivateAssociationRelationWhereChain) }
sig { params(args: T.untyped).returns(PrivateAssociationRelation) }
def where(*args); end
<% if rails_version(">= 7.1") %>
sig { params(args: T.untyped, blk: T.untyped).returns(PrivateAssociationRelation) }
Expand Down Expand Up @@ -1324,8 +1327,9 @@ def uniq!(*args, &blk); end
sig { params(args: T.untyped, blk: T.untyped).returns(PrivateRelation) }
def unscope(*args, &blk); end
sig { params(args: T.untyped, blk: T.untyped).returns(PrivateRelationWhereChain) }
def where(*args, &blk); end
sig { returns(PrivateRelationWhereChain) }
sig { params(args: T.untyped).returns(PrivateRelation) }
def where(*args); end
<% if rails_version(">= 7.1") %>
sig { params(args: T.untyped, blk: T.untyped).returns(PrivateRelation) }
Expand Down Expand Up @@ -1376,7 +1380,7 @@ def minimum(column_name); end
def sum(column_name = nil, &block); end
end
class PrivateAssociationRelationWhereChain < PrivateAssociationRelation
class PrivateAssociationRelationWhereChain
Elem = type_member { { fixed: ::Post } }
<% if rails_version(">= 7.0") %>
Expand Down Expand Up @@ -1478,7 +1482,7 @@ def minimum(column_name); end
def sum(column_name = nil, &block); end
end
class PrivateRelationWhereChain < PrivateRelation
class PrivateRelationWhereChain
Elem = type_member { { fixed: ::Post } }
<% if rails_version(">= 7.0") %>
Expand Down

0 comments on commit 063f134

Please sign in to comment.