Skip to content

Commit 79fc821

Browse files
committed
Extract OverloadSet class to cache overload comparison results
1 parent cdb6174 commit 79fc821

2 files changed

Lines changed: 188 additions & 69 deletions

File tree

lib/typeprof/core/ast/sig_decl.rb

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,16 +193,16 @@ def initialize(raw_decl, lenv)
193193
@mid_code_range = lenv.code_range_from_node(raw_decl.location[:name])
194194
@singleton = raw_decl.singleton?
195195
@instance = raw_decl.instance?
196-
@method_types = raw_decl.overloads.map do |overload|
196+
@method_types = OverloadSet.new(raw_decl.overloads.map do |overload|
197197
method_type = overload.method_type
198198
AST.create_rbs_func_type(method_type, method_type.type_params, method_type.block, lenv)
199-
end
199+
end)
200200
@overloading = raw_decl.overloading
201201
end
202202

203203
attr_reader :mid, :singleton, :instance, :method_types, :overloading, :mid_code_range
204204

205-
def subnodes = { method_types: }
205+
def subnodes = { method_types: @method_types.to_a }
206206
def attrs = { mid:, mid_code_range:, singleton:, instance:, overloading: }
207207

208208
def mname_code_range(_name) = @mid_code_range
@@ -346,6 +346,7 @@ def initialize(raw_decl, lenv)
346346
location: raw_decl.type.location
347347
)
348348
@method_type = AST.create_rbs_func_type(rbs_method_type, [], nil, lenv)
349+
@method_types = OverloadSet.new([@method_type])
349350
end
350351

351352
attr_reader :mid, :method_type
@@ -354,7 +355,7 @@ def subnodes = { method_type: }
354355
def attrs = { mid: }
355356

356357
def install0(genv)
357-
@changes.add_method_decl_box(genv, @lenv.cref.cpath, false, @mid, [@method_type], false)
358+
@changes.add_method_decl_box(genv, @lenv.cref.cpath, false, @mid, @method_types, false)
358359
Source.new
359360
end
360361
end
@@ -381,6 +382,7 @@ def initialize(raw_decl, lenv)
381382
location: raw_decl.type.location
382383
)
383384
@method_type = AST.create_rbs_func_type(rbs_method_type, [], nil, lenv)
385+
@method_types = OverloadSet.new([@method_type])
384386
end
385387

386388
attr_reader :mid, :method_type
@@ -389,7 +391,7 @@ def subnodes = { method_type: }
389391
def attrs = { mid: }
390392

391393
def install0(genv)
392-
@changes.add_method_decl_box(genv, @lenv.cref.cpath, false, @mid, [@method_type], false)
394+
@changes.add_method_decl_box(genv, @lenv.cref.cpath, false, @mid, @method_types, false)
393395
Source.new
394396
end
395397
end

lib/typeprof/core/graph/box.rb

Lines changed: 181 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,141 @@ def run0(genv, changes)
9999
end
100100
end
101101

102+
class OverloadSet
103+
include Enumerable
104+
105+
def initialize(method_types)
106+
@method_types = method_types
107+
end
108+
109+
def each(&blk) = @method_types.each(&blk)
110+
def map(&blk) = @method_types.map(&blk)
111+
def first = @method_types.first
112+
def size = @method_types.size
113+
def to_a = @method_types
114+
115+
# lazy cache: combination(2) for all-pair comparison
116+
def overloads_differ_in_args?
117+
return @overloads_differ_in_args if defined?(@overloads_differ_in_args)
118+
@overloads_differ_in_args = !@method_types.combination(2).all? { |a, b|
119+
positionals_match?(a, b) && keywords_match?(a, b)
120+
}
121+
end
122+
123+
def overloads_differ_at_top_level?
124+
return @overloads_differ_at_top_level if defined?(@overloads_differ_at_top_level)
125+
@overloads_differ_at_top_level = !@method_types.combination(2).all? { |a, b|
126+
positionals_match_shallow?(a, b) && keywords_match_shallow?(a, b)
127+
}
128+
end
129+
130+
private
131+
132+
# Check if two method types have structurally identical positional
133+
# parameter types (req, opt, rest).
134+
def positionals_match?(mt1, mt2)
135+
return false unless mt1.req_positionals.size == mt2.req_positionals.size
136+
return false unless mt1.opt_positionals.size == mt2.opt_positionals.size
137+
return false unless mt1.rest_positionals.nil? == mt2.rest_positionals.nil?
138+
mt1.req_positionals.zip(mt2.req_positionals).all? {|a, b| sig_types_match?(a, b) } &&
139+
mt1.opt_positionals.zip(mt2.opt_positionals).all? {|a, b| sig_types_match?(a, b) } &&
140+
(mt1.rest_positionals.nil? || sig_types_match?(mt1.rest_positionals, mt2.rest_positionals))
141+
end
142+
143+
# Check if two method types have identical positional parameter
144+
# types at the top level (ignoring type parameter contents).
145+
def positionals_match_shallow?(mt1, mt2)
146+
return false unless mt1.req_positionals.size == mt2.req_positionals.size
147+
return false unless mt1.opt_positionals.size == mt2.opt_positionals.size
148+
return false unless mt1.rest_positionals.nil? == mt2.rest_positionals.nil?
149+
mt1.req_positionals.zip(mt2.req_positionals).all? {|a, b| sig_types_match_shallow?(a, b) } &&
150+
mt1.opt_positionals.zip(mt2.opt_positionals).all? {|a, b| sig_types_match_shallow?(a, b) } &&
151+
(mt1.rest_positionals.nil? || sig_types_match_shallow?(mt1.rest_positionals, mt2.rest_positionals))
152+
end
153+
154+
# Check if two method types have structurally identical keyword
155+
# parameter types (req, opt, rest).
156+
def keywords_match?(mt1, mt2)
157+
return false unless mt1.req_keyword_keys == mt2.req_keyword_keys
158+
return false unless mt1.opt_keyword_keys == mt2.opt_keyword_keys
159+
return false unless mt1.rest_keywords.nil? == mt2.rest_keywords.nil?
160+
mt1.req_keyword_values.zip(mt2.req_keyword_values).all? {|a, b| sig_types_match?(a, b) } &&
161+
mt1.opt_keyword_values.zip(mt2.opt_keyword_values).all? {|a, b| sig_types_match?(a, b) } &&
162+
(mt1.rest_keywords.nil? || sig_types_match?(mt1.rest_keywords, mt2.rest_keywords))
163+
end
164+
165+
# Shallow version: compare keyword keys and structure, but use
166+
# shallow type comparison for values.
167+
def keywords_match_shallow?(mt1, mt2)
168+
return false unless mt1.req_keyword_keys == mt2.req_keyword_keys
169+
return false unless mt1.opt_keyword_keys == mt2.opt_keyword_keys
170+
return false unless mt1.rest_keywords.nil? == mt2.rest_keywords.nil?
171+
mt1.req_keyword_values.zip(mt2.req_keyword_values).all? {|a, b| sig_types_match_shallow?(a, b) } &&
172+
mt1.opt_keyword_values.zip(mt2.opt_keyword_values).all? {|a, b| sig_types_match_shallow?(a, b) } &&
173+
(mt1.rest_keywords.nil? || sig_types_match_shallow?(mt1.rest_keywords, mt2.rest_keywords))
174+
end
175+
176+
# Structural equality check for two SigTyNode objects.
177+
def sig_types_match?(a, b)
178+
return false unless a.class == b.class
179+
case a
180+
when AST::SigTyInstanceNode, AST::SigTyInterfaceNode
181+
a.cpath == b.cpath &&
182+
a.args.size == b.args.size &&
183+
a.args.zip(b.args).all? {|x, y| sig_types_match?(x, y) }
184+
when AST::SigTySingletonNode
185+
a.cpath == b.cpath
186+
when AST::SigTyTupleNode, AST::SigTyUnionNode, AST::SigTyIntersectionNode
187+
a.types.size == b.types.size &&
188+
a.types.zip(b.types).all? {|x, y| sig_types_match?(x, y) }
189+
when AST::SigTyRecordNode
190+
a.fields.size == b.fields.size &&
191+
a.fields.all? {|k, v| b.fields[k] && sig_types_match?(v, b.fields[k]) }
192+
when AST::SigTyOptionalNode, AST::SigTyProcNode
193+
sig_types_match?(a.type, b.type)
194+
when AST::SigTyVarNode
195+
a.var == b.var
196+
when AST::SigTyLiteralNode
197+
a.lit == b.lit
198+
when AST::SigTyAliasNode
199+
a.cpath == b.cpath && a.name == b.name &&
200+
a.args.size == b.args.size &&
201+
a.args.zip(b.args).all? {|x, y| sig_types_match?(x, y) }
202+
else
203+
true # Leaf types (bool, nil, self, void, untyped, etc.)
204+
end
205+
end
206+
207+
# Shallow structural equality: compare only the top-level type
208+
# identity without recursing into type parameters.
209+
def sig_types_match_shallow?(a, b)
210+
return false unless a.class == b.class
211+
case a
212+
when AST::SigTyInstanceNode, AST::SigTyInterfaceNode
213+
a.cpath == b.cpath
214+
when AST::SigTySingletonNode
215+
a.cpath == b.cpath
216+
when AST::SigTyTupleNode
217+
a.types.size == b.types.size
218+
when AST::SigTyUnionNode, AST::SigTyIntersectionNode
219+
a.types.size == b.types.size &&
220+
a.types.zip(b.types).all? {|x, y| sig_types_match_shallow?(x, y) }
221+
when AST::SigTyRecordNode
222+
a.fields.keys.sort == b.fields.keys.sort
223+
when AST::SigTyOptionalNode, AST::SigTyProcNode
224+
true
225+
when AST::SigTyVarNode
226+
a.var == b.var
227+
when AST::SigTyLiteralNode
228+
a.lit == b.lit
229+
when AST::SigTyAliasNode
230+
a.cpath == b.cpath && a.name == b.name
231+
else
232+
true
233+
end
234+
end
235+
end
236+
102237
class MethodDeclBox < Box
103238
def initialize(node, genv, cpath, singleton, mid, method_types, overloading)
104239
super(node)
@@ -274,17 +409,35 @@ def resolve_overloads(changes, genv, node, param_map, a_args, ret, &blk)
274409
# cyclic cases and avoids false "failed to resolve overloads"
275410
# diagnostics for untyped arguments.
276411
#
277-
# Top-level empty vertices are always uninformative. For type
278-
# parameter vertices (e.g., Array[T], Hash[K,V], tuples), we
279-
# only recurse when overloads differ in their positional or keyword
280-
# parameter types -- otherwise empty type params (like those of
281-
# `{}`) cannot cause oscillation and should not trigger bail-out.
282-
overloads_differ = !@method_types.each_cons(2).all? {|mt1, mt2|
283-
positionals_match?(mt1, mt2) && keywords_match?(mt1, mt2)
284-
}
285-
has_uninformative_args = if overloads_differ
286-
a_args.positionals.any? {|vtx| vertex_uninformative?(genv, vtx) } ||
287-
(a_args.keywords && vertex_uninformative?(genv, a_args.keywords))
412+
# We check at two levels:
413+
# 1. Top-level empty vertices are always uninformative.
414+
# 2. Empty type parameter vertices (e.g., Array[T] where T is
415+
# empty) are only uninformative when overloads differ solely
416+
# in their type parameters (e.g., Array[Integer] vs
417+
# Array[String]). When overloads differ at the top level
418+
# (e.g., Integer vs Float), the type parameter contents are
419+
# irrelevant for overload selection and should not trigger
420+
# bail-out.
421+
has_uninformative_args = if @method_types.overloads_differ_in_args?
422+
# Check whether overloads also differ at the top level (e.g.,
423+
# Integer vs Float) or only in their type parameters (e.g.,
424+
# Array[Integer] vs Array[String]).
425+
if @method_types.overloads_differ_at_top_level?
426+
# Overloads are distinguished by top-level types.
427+
# Only top-level empty vertices matter; empty type parameters
428+
# are irrelevant for overload selection.
429+
# However, splatted arguments have their elements extracted
430+
# during matching, so also check splat element vertices.
431+
a_args.positionals.any? {|vtx| vtx.types.empty? } ||
432+
splat_elements_uninformative?(genv, a_args) ||
433+
(a_args.keywords && a_args.keywords.types.empty?)
434+
else
435+
# Overloads differ only in type parameters (e.g.,
436+
# Array[Integer] vs Array[String]). Empty type parameter
437+
# vertices can cause oscillation, so check recursively.
438+
a_args.positionals.any? {|vtx| vertex_uninformative?(genv, vtx) } ||
439+
(a_args.keywords && vertex_uninformative?(genv, a_args.keywords))
440+
end
288441
else
289442
a_args.positionals.any? {|vtx| vtx.types.empty? } ||
290443
(a_args.keywords && a_args.keywords.types.empty?)
@@ -310,6 +463,23 @@ def resolve_overloads(changes, genv, node, param_map, a_args, ret, &blk)
310463
end
311464
end
312465

466+
# Check if any splatted argument has an Array element vertex
467+
# that is empty. Splat expansion extracts elements during
468+
# overload matching, so empty element types can cause oscillation
469+
# even when the top-level Array type is present.
470+
def splat_elements_uninformative?(genv, a_args)
471+
a_args.positionals.each_with_index do |vtx, i|
472+
next unless a_args.splat_flags[i]
473+
vtx.each_type do |ty|
474+
base = ty.base_type(genv)
475+
if base.is_a?(Type::Instance) && base.mod == genv.mod_ary && base.args[0]
476+
return true if base.args[0].types.empty?
477+
end
478+
end
479+
end
480+
false
481+
end
482+
313483
def vertex_uninformative?(genv, vtx, depth = 0)
314484
return true if vtx.types.empty?
315485
return false if depth > 3
@@ -363,59 +533,6 @@ def rest_keyword_args_typecheck?(genv, changes, keywords_vtx, method_type, param
363533
true
364534
end
365535

366-
# Check if two method types have structurally identical positional
367-
# parameter types (req, opt, rest).
368-
def positionals_match?(mt1, mt2)
369-
return false unless mt1.req_positionals.size == mt2.req_positionals.size
370-
return false unless mt1.opt_positionals.size == mt2.opt_positionals.size
371-
return false unless mt1.rest_positionals.nil? == mt2.rest_positionals.nil?
372-
mt1.req_positionals.zip(mt2.req_positionals).all? {|a, b| sig_types_match?(a, b) } &&
373-
mt1.opt_positionals.zip(mt2.opt_positionals).all? {|a, b| sig_types_match?(a, b) } &&
374-
(mt1.rest_positionals.nil? || sig_types_match?(mt1.rest_positionals, mt2.rest_positionals))
375-
end
376-
377-
# Check if two method types have structurally identical keyword
378-
# parameter types (req, opt, rest).
379-
def keywords_match?(mt1, mt2)
380-
return false unless mt1.req_keyword_keys == mt2.req_keyword_keys
381-
return false unless mt1.opt_keyword_keys == mt2.opt_keyword_keys
382-
return false unless mt1.rest_keywords.nil? == mt2.rest_keywords.nil?
383-
mt1.req_keyword_values.zip(mt2.req_keyword_values).all? {|a, b| sig_types_match?(a, b) } &&
384-
mt1.opt_keyword_values.zip(mt2.opt_keyword_values).all? {|a, b| sig_types_match?(a, b) } &&
385-
(mt1.rest_keywords.nil? || sig_types_match?(mt1.rest_keywords, mt2.rest_keywords))
386-
end
387-
388-
# Structural equality check for two SigTyNode objects.
389-
def sig_types_match?(a, b)
390-
return false unless a.class == b.class
391-
case a
392-
when AST::SigTyInstanceNode, AST::SigTyInterfaceNode
393-
a.cpath == b.cpath &&
394-
a.args.size == b.args.size &&
395-
a.args.zip(b.args).all? {|x, y| sig_types_match?(x, y) }
396-
when AST::SigTySingletonNode
397-
a.cpath == b.cpath
398-
when AST::SigTyTupleNode, AST::SigTyUnionNode, AST::SigTyIntersectionNode
399-
a.types.size == b.types.size &&
400-
a.types.zip(b.types).all? {|x, y| sig_types_match?(x, y) }
401-
when AST::SigTyRecordNode
402-
a.fields.size == b.fields.size &&
403-
a.fields.all? {|k, v| b.fields[k] && sig_types_match?(v, b.fields[k]) }
404-
when AST::SigTyOptionalNode, AST::SigTyProcNode
405-
sig_types_match?(a.type, b.type)
406-
when AST::SigTyVarNode
407-
a.var == b.var
408-
when AST::SigTyLiteralNode
409-
a.lit == b.lit
410-
when AST::SigTyAliasNode
411-
a.cpath == b.cpath && a.name == b.name &&
412-
a.args.size == b.args.size &&
413-
a.args.zip(b.args).all? {|x, y| sig_types_match?(x, y) }
414-
else
415-
true # Leaf types (bool, nil, self, void, untyped, etc.)
416-
end
417-
end
418-
419536
def show
420537
@method_types.map do |method_type|
421538
args = []

0 commit comments

Comments
 (0)