Skip to content

Commit 2ef88d4

Browse files
authored
Improve Record type inference for hash literals (#360)
* Improve Record type field access semantics Record types have statically defined fields, unlike Hash which allows dynamic keys. This change makes the type inference more precise: - Non-existent field access (e.g., `record[:unknown]`) now returns `nil` instead of `untyped`, since Record fields are fixed at definition time - Symbol variable access (e.g., `record[key]`) now returns `nil | union of all field types` to account for the possibility that the key may not exist in the Record Hash type behavior remains unchanged (returns `untyped` for unknown keys). * Use Type::Record for hash literals with symbol keys This improves RBS output precision for hash literals. For example: - Before: `def foo: -> Hash[:a | :b, Integer | String]` - After: `def foo: -> { a: Integer, b: String }` Record type preserves individual field types instead of merging all keys and values into a single union type.
1 parent 38f9bdb commit 2ef88d4

14 files changed

Lines changed: 88 additions & 20 deletions

File tree

lib/typeprof/core/ast/value.rb

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,13 +293,18 @@ def install0(genv)
293293
unified_key = Vertex.new(self)
294294
unified_val = Vertex.new(self)
295295
literal_pairs = {}
296+
all_symbol_keys = true
296297
@keys.zip(@vals) do |key, val|
297298
if key
298299
k = key.install(genv).new_vertex(genv, self)
299300
v = val.install(genv).new_vertex(genv, self)
300301
@changes.add_edge(genv, k, unified_key)
301302
@changes.add_edge(genv, v, unified_val)
302-
literal_pairs[key.lit] = v if key.is_a?(SymbolNode)
303+
if key.is_a?(SymbolNode)
304+
literal_pairs[key.lit] = v
305+
else
306+
all_symbol_keys = false
307+
end
303308
else
304309
if val.is_a?(DummyNilNode)
305310
h = @lenv.get_var(:"**anonymous_keyword")
@@ -310,10 +315,13 @@ def install0(genv)
310315
@changes.add_hash_splat_box(genv, h, unified_key, unified_val)
311316
end
312317
end
318+
base_hash_type = genv.gen_hash_type(unified_key, unified_val)
313319
if @splat
314-
Source.new(genv.gen_hash_type(unified_key, unified_val))
320+
Source.new(base_hash_type)
321+
elsif all_symbol_keys
322+
Source.new(Type::Record.new(genv, literal_pairs, base_hash_type))
315323
else
316-
Source.new(Type::Hash.new(genv, literal_pairs, genv.gen_hash_type(unified_key, unified_val)))
324+
Source.new(Type::Hash.new(genv, literal_pairs, base_hash_type))
317325
end
318326
end
319327
end

lib/typeprof/core/builtin.rb

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def array_push(changes, node, ty, a_args, ret)
9090
def hash_aref(changes, node, ty, a_args, ret)
9191
if a_args.positionals.size == 1
9292
case ty
93-
when Type::Hash, Type::Record
93+
when Type::Hash
9494
idx = node.positional_args[0]
9595
idx = idx.is_a?(AST::SymbolNode) ? idx.lit : nil
9696
value = ty.get_value(idx)
@@ -101,6 +101,20 @@ def hash_aref(changes, node, ty, a_args, ret)
101101
changes.add_edge(@genv, Source.new(), ret)
102102
end
103103
true
104+
when Type::Record
105+
idx = node.positional_args[0]
106+
idx = idx.is_a?(AST::SymbolNode) ? idx.lit : nil
107+
value = ty.get_value(idx)
108+
if value
109+
changes.add_edge(@genv, value, ret)
110+
else
111+
changes.add_edge(@genv, Source.new(@genv.nil_type), ret)
112+
end
113+
# Symbol variable access - add nil possibility
114+
if idx.nil?
115+
changes.add_edge(@genv, Source.new(@genv.nil_type), ret)
116+
end
117+
true
104118
else
105119
false
106120
end
@@ -125,6 +139,15 @@ def hash_aset(changes, node, ty, a_args, ret)
125139
end
126140
changes.add_edge(@genv, val, ret)
127141
true
142+
when Type::Record
143+
val = a_args.positionals[1]
144+
idx = node.positional_args[0]
145+
if idx.is_a?(AST::SymbolNode)
146+
field_vtx = ty.get_value(idx.lit)
147+
changes.add_edge(@genv, val, field_vtx) if field_vtx
148+
end
149+
changes.add_edge(@genv, val, ret)
150+
true
128151
else
129152
false
130153
end

lib/typeprof/core/env/method.rb

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def get_keyword_arg(genv, changes, name)
6969
case ty
7070
when Type::Hash
7171
changes.add_edge(genv, ty.get_value(name), vtx)
72+
when Type::Record
73+
field_vtx = ty.get_value(name)
74+
changes.add_edge(genv, field_vtx, vtx) if field_vtx
7275
when Type::Instance
7376
if ty.mod == genv.mod_hash
7477
changes.add_edge(genv, ty.args[1], vtx)

lib/typeprof/core/type.rb

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,33 @@ def self.strip_array(s)
1616
end
1717

1818
def self.extract_hash_value_type(s)
19-
if s.start_with?("Hash[") && s.end_with?("]")
19+
begin
2020
type = RBS::Parser.parse_type(s)
21+
collect_hash_value_types(type).uniq.join(" | ")
22+
rescue
23+
s
24+
end
25+
end
2126

22-
if type.is_a?(RBS::Types::Union)
23-
type.types.map {|t| t.args[1].to_s }.join(" | ")
27+
# Returns an array of value type strings from hash-like types
28+
def self.collect_hash_value_types(type)
29+
case type
30+
when RBS::Types::Record
31+
# Extract value types from record fields
32+
# all_fields returns { key => [type, required] }
33+
type.all_fields.values.map { |t, _required| t.to_s }
34+
when RBS::Types::ClassInstance
35+
# Handle Hash[K, V]
36+
if type.name.name == :Hash && type.args.size == 2
37+
[type.args[1].to_s]
2438
else
25-
type.args[1].to_s
39+
[type.to_s]
2640
end
41+
when RBS::Types::Union
42+
# Handle union of types - extract value types from all hash-like types
43+
type.types.flat_map { |t| collect_hash_value_types(t) }
2744
else
28-
s
45+
[type.to_s]
2946
end
3047
end
3148

scenario/args/anonymous_keyword.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ def bar(**)
1212
## assert
1313
class Object
1414
def foo: (**untyped) -> nil
15-
def bar: (**Integer | String | untyped) -> nil
15+
def bar: (**untyped | Integer | String) -> nil
1616
end

scenario/block/rbs_block6.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ def foo(h)
1111

1212
## assert
1313
class Object
14-
def foo: (Hash[:a, Integer] | Hash[:b, String]) -> (Integer | String)?
14+
def foo: ({ a: Integer } | { b: String }) -> (Integer | String)?
1515
end

scenario/hash/basic1.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def baz
1717

1818
## assert
1919
class Object
20-
def foo: -> Hash[:a | :b, Float | Integer | String]
20+
def foo: -> { a: Integer, b: String }
2121
def bar: -> Integer
22-
def baz: -> (Float | Integer | String)
22+
def baz: -> nil
2323
end

scenario/hash/hash-splat.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ def bar
1010
## assert
1111
class Object
1212
def foo: -> Hash[:a | :b, Integer]
13-
def bar: -> Hash[:a, Integer]
13+
def bar: -> { a: Integer }
1414
end

scenario/hash/implicit.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ def create
77

88
## assert
99
class Object
10-
def create: -> Hash[:x | :y, Integer | String]
10+
def create: -> { x: Integer, y: String }
1111
end

scenario/method/keywords.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,5 @@ def foo(a:, b: 1, **c)
7171

7272
## assert
7373
class Object
74-
def foo: (a: String, ?b: Integer, **Integer | String | true) -> Hash[:a | :b | :c, Integer | String | true]
74+
def foo: (a: String, ?b: Integer, **String | Integer | true) -> { a: String, b: Integer, c: true }
7575
end

0 commit comments

Comments
 (0)