Skip to content

Commit 0a3cc27

Browse files
committed
Add rest keywords (**) support for overload resolution and fix FIXME
This allows overload resolution for `(**Integer) -> A | (**String) -> B`.
1 parent d7e471f commit 0a3cc27

4 files changed

Lines changed: 81 additions & 3 deletions

File tree

lib/typeprof/core/graph/box.rb

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,9 @@ def match_arguments?(genv, changes, param_map, a_args, method_type)
190190
method_type.opt_keyword_keys.zip(method_type.opt_keyword_values) do |key, ty|
191191
return false unless keyword_arg_typecheck?(genv, changes, a_args.keywords, key, ty, param_map)
192192
end
193+
if method_type.rest_keywords
194+
return false unless rest_keyword_args_typecheck?(genv, changes, a_args.keywords, method_type, param_map)
195+
end
193196
end
194197

195198
return true
@@ -336,6 +339,30 @@ def keyword_arg_typecheck?(genv, changes, keywords_vtx, key, expected_ty, param_
336339
true
337340
end
338341

342+
# Typecheck rest keyword argument values (those not consumed by named
343+
# keywords) against the method type's rest_keywords type.
344+
def rest_keyword_args_typecheck?(genv, changes, keywords_vtx, method_type, param_map)
345+
named_keys = method_type.req_keyword_keys + method_type.opt_keyword_keys
346+
rest_ty = method_type.rest_keywords
347+
keywords_vtx.each_type do |kw_ty|
348+
case kw_ty
349+
when Type::Record
350+
kw_ty.fields.each do |key, val_vtx|
351+
next if named_keys.include?(key)
352+
return false unless rest_ty.typecheck(genv, changes, val_vtx, param_map)
353+
end
354+
when Type::Hash
355+
val_vtx = kw_ty.base_type(genv).args[1]
356+
return false if val_vtx && !rest_ty.typecheck(genv, changes, val_vtx, param_map)
357+
when Type::Instance
358+
if kw_ty.mod == genv.mod_hash && kw_ty.args[1]
359+
return false unless rest_ty.typecheck(genv, changes, kw_ty.args[1], param_map)
360+
end
361+
end
362+
end
363+
true
364+
end
365+
339366
# Check if two method types have structurally identical positional
340367
# parameter types (req, opt, rest).
341368
def positionals_match?(mt1, mt2)
@@ -726,8 +753,18 @@ def pass_arguments(changes, genv, a_args)
726753
end
727754

728755
if @node.rest_keywords
729-
# FIXME: Extract the rest keywords excluding req_keywords and opt_keywords.
730-
changes.add_edge(genv, a_args.keywords, @f_args.rest_keywords)
756+
named_keys = @node.req_keywords + @node.opt_keywords
757+
a_args.keywords.each_type do |kw_ty|
758+
case kw_ty
759+
when Type::Record
760+
rest_fields = kw_ty.fields.reject {|key, _| named_keys.include?(key) }
761+
base = kw_ty.base_type(genv)
762+
rest_record = Type::Record.new(genv, rest_fields, base)
763+
changes.add_edge(genv, Source.new(rest_record), @f_args.rest_keywords)
764+
when Type::Hash, Type::Instance
765+
changes.add_edge(genv, Source.new(kw_ty), @f_args.rest_keywords)
766+
end
767+
end
731768
end
732769
end
733770

scenario/method/keywords.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,5 @@ def foo(a:, b: 1, **c)
7171

7272
## assert
7373
class Object
74-
def foo: (a: String, ?b: Integer, **String | Integer | true) -> { a: String, b: Integer, c: true }
74+
def foo: (a: String, ?b: Integer, **true) -> { c: true }
7575
end
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
## update: test.rbs
2+
class C
3+
def foo: (mode: :read, **Integer) -> Array[Integer]
4+
| (mode: :write, **String) -> Array[String]
5+
end
6+
7+
## update: test.rb
8+
class C
9+
def bar
10+
foo(mode: :read, x: 1)
11+
end
12+
def baz
13+
foo(mode: :write, x: "a")
14+
end
15+
end
16+
17+
## assert: test.rb
18+
class C
19+
def bar: -> Array[Integer]
20+
def baz: -> Array[String]
21+
end
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
## update: test.rbs
2+
class C
3+
def foo: (**Integer) -> Integer | (**String) -> String
4+
end
5+
6+
## update: test.rb
7+
class C
8+
def bar
9+
foo(x: 1, y: 2)
10+
end
11+
def baz
12+
foo(x: "a", y: "b")
13+
end
14+
end
15+
16+
## assert: test.rb
17+
class C
18+
def bar: -> Integer
19+
def baz: -> String
20+
end

0 commit comments

Comments
 (0)