Skip to content

Commit f147b42

Browse files
authored
Merge pull request #2315 from Shopify/seb-batch-methods-cursor
Add support for cursor parameter on ActiveRecord's batch methods in Rails 8.0
2 parents ac1d7a1 + c0b0267 commit f147b42

File tree

2 files changed

+87
-90
lines changed

2 files changed

+87
-90
lines changed

lib/tapioca/dsl/compilers/active_record_relations.rb

Lines changed: 53 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,17 @@ def gather_constants
217217
[]
218218
end #: Array[Symbol]
219219
BATCHES_METHODS = ActiveRecord::Batches.instance_methods(false) #: Array[Symbol]
220+
BATCHES_METHODS_PARAMETERS = {
221+
start: ["T.untyped", "nil"],
222+
finish: ["T.untyped", "nil"],
223+
load: ["T.untyped", "false"],
224+
batch_size: ["Integer", "1000"],
225+
of: ["Integer", "1000"],
226+
error_on_ignore: ["T.untyped", "nil"],
227+
order: ["Symbol", ":asc"],
228+
cursor: ["T.untyped", "primary_key"],
229+
use_ranges: ["T.untyped", "nil"],
230+
} #: Hash[Symbol, [String, String]]
220231
CALCULATION_METHODS = ActiveRecord::Calculations.instance_methods(false) #: Array[Symbol]
221232
ENUMERABLE_QUERY_METHODS = [:any?, :many?, :none?, :one?] #: Array[Symbol]
222233
FIND_OR_CREATE_METHODS = [
@@ -855,102 +866,31 @@ def create_common_methods
855866
end
856867

857868
BATCHES_METHODS.each do |method_name|
858-
case method_name
859-
when :find_each
860-
order = ActiveRecord::Batches.instance_method(:find_each).parameters.include?([:key, :order])
861-
862-
common_relation_methods_module.create_method("find_each") do |method|
863-
method.add_kw_opt_param("start", "nil")
864-
method.add_kw_opt_param("finish", "nil")
865-
method.add_kw_opt_param("batch_size", "1000")
866-
method.add_kw_opt_param("error_on_ignore", "nil")
867-
method.add_kw_opt_param("order", ":asc") if order
868-
method.add_block_param("block")
869+
block_param, return_type, parameters = batch_method_configs(method_name)
870+
next if block_param.nil? || return_type.nil? || parameters.nil?
869871

870-
method.add_sig do |sig|
871-
sig.add_param("start", "T.untyped")
872-
sig.add_param("finish", "T.untyped")
873-
sig.add_param("batch_size", "Integer")
874-
sig.add_param("error_on_ignore", "T.untyped")
875-
sig.add_param("order", "Symbol") if order
876-
sig.add_param("block", "T.proc.params(object: #{constant_name}).void")
877-
sig.return_type = "void"
878-
end
872+
common_relation_methods_module.create_method(method_name.to_s) do |method|
873+
parameters.each do |name, (style, _type, default)|
874+
# The style is always "key", but this is a safeguard to prevent confusing errors in the future.
875+
raise "Unexpected style #{style} for #{name}" unless style == :key
879876

880-
method.add_sig do |sig|
881-
sig.add_param("start", "T.untyped")
882-
sig.add_param("finish", "T.untyped")
883-
sig.add_param("batch_size", "Integer")
884-
sig.add_param("error_on_ignore", "T.untyped")
885-
sig.add_param("order", "Symbol") if order
886-
sig.return_type = "T::Enumerator[#{constant_name}]"
887-
end
877+
method.add_kw_opt_param(name, T.must(default))
888878
end
889-
when :find_in_batches
890-
order = ActiveRecord::Batches.instance_method(:find_in_batches).parameters.include?([:key, :order])
891-
common_relation_methods_module.create_method("find_in_batches") do |method|
892-
method.add_kw_opt_param("start", "nil")
893-
method.add_kw_opt_param("finish", "nil")
894-
method.add_kw_opt_param("batch_size", "1000")
895-
method.add_kw_opt_param("error_on_ignore", "nil")
896-
method.add_kw_opt_param("order", ":asc") if order
897-
method.add_block_param("block")
898-
899-
method.add_sig do |sig|
900-
sig.add_param("start", "T.untyped")
901-
sig.add_param("finish", "T.untyped")
902-
sig.add_param("batch_size", "Integer")
903-
sig.add_param("error_on_ignore", "T.untyped")
904-
sig.add_param("order", "Symbol") if order
905-
sig.add_param("block", "T.proc.params(object: T::Array[#{constant_name}]).void")
906-
sig.return_type = "void"
907-
end
879+
method.add_block_param("block")
908880

909-
method.add_sig do |sig|
910-
sig.add_param("start", "T.untyped")
911-
sig.add_param("finish", "T.untyped")
912-
sig.add_param("batch_size", "Integer")
913-
sig.add_param("error_on_ignore", "T.untyped")
914-
sig.add_param("order", "Symbol") if order
915-
sig.return_type = "T::Enumerator[T::Enumerator[#{constant_name}]]"
881+
method.add_sig do |sig|
882+
parameters.each do |name, (_style, type, _default)|
883+
sig.add_param(name, type)
916884
end
885+
sig.add_param("block", "T.proc.params(object: #{block_param}).void")
886+
sig.return_type = "void"
917887
end
918-
when :in_batches
919-
order = ActiveRecord::Batches.instance_method(:in_batches).parameters.include?([:key, :order])
920-
use_ranges = ActiveRecord::Batches.instance_method(:in_batches).parameters.include?([:key, :use_ranges])
921-
922-
common_relation_methods_module.create_method("in_batches") do |method|
923-
method.add_kw_opt_param("of", "1000")
924-
method.add_kw_opt_param("start", "nil")
925-
method.add_kw_opt_param("finish", "nil")
926-
method.add_kw_opt_param("load", "false")
927-
method.add_kw_opt_param("error_on_ignore", "nil")
928-
method.add_kw_opt_param("order", ":asc") if order
929-
method.add_kw_opt_param("use_ranges", "nil") if use_ranges
930-
method.add_block_param("block")
931-
932-
method.add_sig do |sig|
933-
sig.add_param("of", "Integer")
934-
sig.add_param("start", "T.untyped")
935-
sig.add_param("finish", "T.untyped")
936-
sig.add_param("load", "T.untyped")
937-
sig.add_param("error_on_ignore", "T.untyped")
938-
sig.add_param("order", "Symbol") if order
939-
sig.add_param("use_ranges", "T.untyped") if use_ranges
940-
sig.add_param("block", "T.proc.params(object: #{RelationClassName}).void")
941-
sig.return_type = "void"
942-
end
943888

944-
method.add_sig do |sig|
945-
sig.add_param("of", "Integer")
946-
sig.add_param("start", "T.untyped")
947-
sig.add_param("finish", "T.untyped")
948-
sig.add_param("load", "T.untyped")
949-
sig.add_param("error_on_ignore", "T.untyped")
950-
sig.add_param("order", "Symbol") if order
951-
sig.add_param("use_ranges", "T.untyped") if use_ranges
952-
sig.return_type = "::ActiveRecord::Batches::BatchEnumerator"
889+
method.add_sig do |sig|
890+
parameters.each do |name, (_style, type, _default)|
891+
sig.add_param(name, type)
953892
end
893+
sig.return_type = return_type
954894
end
955895
end
956896
end
@@ -1029,6 +969,31 @@ def create_common_methods
1029969
end
1030970
end
1031971

972+
#: (Symbol) -> [String, String, Hash[String, [Symbol, String, String?]]]?
973+
def batch_method_configs(method_name)
974+
block_param, return_type = case method_name
975+
when :find_each
976+
[constant_name, "T::Enumerator[#{constant_name}]"]
977+
when :find_in_batches
978+
["T::Array[#{constant_name}]", "T::Enumerator[T::Enumerator[#{constant_name}]]"]
979+
when :in_batches
980+
[RelationClassName, "::ActiveRecord::Batches::BatchEnumerator"]
981+
else
982+
return
983+
end
984+
985+
parameters = {}
986+
987+
ActiveRecord::Batches.instance_method(method_name).parameters.each do |style, name|
988+
type, default = BATCHES_METHODS_PARAMETERS[name]
989+
next if type.nil?
990+
991+
parameters[name.to_s] = [style, type, default]
992+
end
993+
994+
[block_param, return_type, parameters]
995+
end
996+
1032997
#: ((Symbol | String) name, ?parameters: Array[RBI::TypedParam], ?return_type: String?) -> void
1033998
def create_common_method(name, parameters: [], return_type: nil)
1034999
common_relation_methods_module.create_method(

spec/tapioca/dsl/compilers/active_record_relations_spec.rb

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,25 @@ def find_by(*args); end
155155
sig { params(args: T.untyped).returns(::Post) }
156156
def find_by!(*args); end
157157
158+
<% if rails_version(">= 8.0") %>
159+
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, cursor: T.untyped, order: Symbol, block: T.proc.params(object: ::Post).void).void }
160+
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, cursor: T.untyped, order: Symbol).returns(T::Enumerator[::Post]) }
161+
def find_each(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, cursor: primary_key, order: :asc, &block); end
162+
<% else %>
158163
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, order: Symbol, block: T.proc.params(object: ::Post).void).void }
159164
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, order: Symbol).returns(T::Enumerator[::Post]) }
160165
def find_each(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, order: :asc, &block); end
166+
<% end %>
161167
168+
<% if rails_version(">= 8.0") %>
169+
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, cursor: T.untyped, order: Symbol, block: T.proc.params(object: T::Array[::Post]).void).void }
170+
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, cursor: T.untyped, order: Symbol).returns(T::Enumerator[T::Enumerator[::Post]]) }
171+
def find_in_batches(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, cursor: primary_key, order: :asc, &block); end
172+
<% else %>
162173
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, order: Symbol, block: T.proc.params(object: T::Array[::Post]).void).void }
163174
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, order: Symbol).returns(T::Enumerator[T::Enumerator[::Post]]) }
164175
def find_in_batches(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, order: :asc, &block); end
176+
<% end %>
165177
166178
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
167179
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
@@ -208,7 +220,11 @@ def fourth!; end
208220
sig { returns(Array) }
209221
def ids; end
210222
211-
<% if rails_version(">= 7.1") %>
223+
<% if rails_version(">= 8.0") %>
224+
sig { params(of: Integer, start: T.untyped, finish: T.untyped, load: T.untyped, error_on_ignore: T.untyped, cursor: T.untyped, order: Symbol, use_ranges: T.untyped, block: T.proc.params(object: PrivateRelation).void).void }
225+
sig { params(of: Integer, start: T.untyped, finish: T.untyped, load: T.untyped, error_on_ignore: T.untyped, cursor: T.untyped, order: Symbol, use_ranges: T.untyped).returns(::ActiveRecord::Batches::BatchEnumerator) }
226+
def in_batches(of: 1000, start: nil, finish: nil, load: false, error_on_ignore: nil, cursor: primary_key, order: :asc, use_ranges: nil, &block); end
227+
<% elsif rails_version(">= 7.1") %>
212228
sig { params(of: Integer, start: T.untyped, finish: T.untyped, load: T.untyped, error_on_ignore: T.untyped, order: Symbol, use_ranges: T.untyped, block: T.proc.params(object: PrivateRelation).void).void }
213229
sig { params(of: Integer, start: T.untyped, finish: T.untyped, load: T.untyped, error_on_ignore: T.untyped, order: Symbol, use_ranges: T.untyped).returns(::ActiveRecord::Batches::BatchEnumerator) }
214230
def in_batches(of: 1000, start: nil, finish: nil, load: false, error_on_ignore: nil, order: :asc, use_ranges: nil, &block); end
@@ -881,13 +897,25 @@ def find_by(*args); end
881897
sig { params(args: T.untyped).returns(::Post) }
882898
def find_by!(*args); end
883899
900+
<% if rails_version(">= 8.0") %>
901+
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, cursor: T.untyped, order: Symbol, block: T.proc.params(object: ::Post).void).void }
902+
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, cursor: T.untyped, order: Symbol).returns(T::Enumerator[::Post]) }
903+
def find_each(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, cursor: primary_key, order: :asc, &block); end
904+
<% else %>
884905
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, order: Symbol, block: T.proc.params(object: ::Post).void).void }
885906
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, order: Symbol).returns(T::Enumerator[::Post]) }
886907
def find_each(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, order: :asc, &block); end
908+
<% end %>
887909
910+
<% if rails_version(">= 8.0") %>
911+
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, cursor: T.untyped, order: Symbol, block: T.proc.params(object: T::Array[::Post]).void).void }
912+
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, cursor: T.untyped, order: Symbol).returns(T::Enumerator[T::Enumerator[::Post]]) }
913+
def find_in_batches(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, cursor: primary_key, order: :asc, &block); end
914+
<% else %>
888915
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, order: Symbol, block: T.proc.params(object: T::Array[::Post]).void).void }
889916
sig { params(start: T.untyped, finish: T.untyped, batch_size: Integer, error_on_ignore: T.untyped, order: Symbol).returns(T::Enumerator[T::Enumerator[::Post]]) }
890917
def find_in_batches(start: nil, finish: nil, batch_size: 1000, error_on_ignore: nil, order: :asc, &block); end
918+
<% end %>
891919
892920
sig { params(attributes: T::Array[T.untyped], block: T.nilable(T.proc.params(object: ::Post).void)).returns(T::Array[::Post]) }
893921
sig { params(attributes: T.untyped, block: T.nilable(T.proc.params(object: ::Post).void)).returns(::Post) }
@@ -934,7 +962,11 @@ def fourth!; end
934962
sig { returns(Array) }
935963
def ids; end
936964
937-
<% if rails_version(">= 7.1") %>
965+
<% if rails_version(">= 8.0") %>
966+
sig { params(of: Integer, start: T.untyped, finish: T.untyped, load: T.untyped, error_on_ignore: T.untyped, cursor: T.untyped, order: Symbol, use_ranges: T.untyped, block: T.proc.params(object: PrivateRelation).void).void }
967+
sig { params(of: Integer, start: T.untyped, finish: T.untyped, load: T.untyped, error_on_ignore: T.untyped, cursor: T.untyped, order: Symbol, use_ranges: T.untyped).returns(::ActiveRecord::Batches::BatchEnumerator) }
968+
def in_batches(of: 1000, start: nil, finish: nil, load: false, error_on_ignore: nil, cursor: primary_key, order: :asc, use_ranges: nil, &block); end
969+
<% elsif rails_version(">= 7.1") %>
938970
sig { params(of: Integer, start: T.untyped, finish: T.untyped, load: T.untyped, error_on_ignore: T.untyped, order: Symbol, use_ranges: T.untyped, block: T.proc.params(object: PrivateRelation).void).void }
939971
sig { params(of: Integer, start: T.untyped, finish: T.untyped, load: T.untyped, error_on_ignore: T.untyped, order: Symbol, use_ranges: T.untyped).returns(::ActiveRecord::Batches::BatchEnumerator) }
940972
def in_batches(of: 1000, start: nil, finish: nil, load: false, error_on_ignore: nil, order: :asc, use_ranges: nil, &block); end

0 commit comments

Comments
 (0)