From b88d8b875f27b812e3511c7b4de045ba28c4669a Mon Sep 17 00:00:00 2001 From: pvcresin Date: Tue, 23 Jun 2026 19:30:08 +0900 Subject: [PATCH] Support type inference for pattern-matching capture variables Capture variables in `case/in` patterns were always inferred as `untyped` because the matched value was never threaded into the patterns: case a in Integer => n n + 1 # n was untyped in [String => s, Symbol => t] [s, t] # s, t were untyped end Add an `install_pattern`/`install_pattern0` pair (mirroring `install`/`install0`, so @changes is reconciled the same way during incremental analysis) that flows the matched value into each pattern. Local targets bind to it, `Const => var` narrows it by the class the same way `when Const` does, and array/find patterns decompose it into element vertices. `CaseMatchNode` and the one-line `=>`/`in` forms pass the subject in. Hash patterns still capture as untyped (extracting a value type per key needs a dedicated box); sub-patterns inside them are still matched. Promote scenario/known-issues/pattern-capture-narrowing.rb into scenario/patterns/capture.rb and update scenario/patterns/var_pat.rb. Co-authored-by: Claude Opus 4.8 (1M context) --- lib/typeprof/core/ast/base.rb | 13 ++++ lib/typeprof/core/ast/control.rb | 4 +- lib/typeprof/core/ast/misc.rb | 6 +- lib/typeprof/core/ast/pattern.rb | 61 +++++++++++++------ lib/typeprof/core/ast/variable.rb | 6 ++ .../known-issues/pattern-capture-narrowing.rb | 32 ---------- scenario/patterns/capture.rb | 23 +++++++ scenario/patterns/var_pat.rb | 8 +-- 8 files changed, 91 insertions(+), 62 deletions(-) delete mode 100644 scenario/known-issues/pattern-capture-narrowing.rb diff --git a/lib/typeprof/core/ast/base.rb b/lib/typeprof/core/ast/base.rb index b53d80035..6f4bf480f 100644 --- a/lib/typeprof/core/ast/base.rb +++ b/lib/typeprof/core/ast/base.rb @@ -118,6 +118,19 @@ def install0(_) raise "should override" end + # Counterpart of install/install0 for pattern position; reuses the install + # machinery so @changes is reconciled the same way during incremental analysis. + def install_pattern(genv, subject) + @ret = install_pattern0(genv, subject) + @changes.reinstall(genv) + @ret + end + + # By default a pattern behaves as a plain expression, ignoring the matched value. + def install_pattern0(genv, subject) + install0(genv) + end + def uninstall(genv) @changes.reinstall(genv) each_subnode do |subnode| diff --git a/lib/typeprof/core/ast/control.rb b/lib/typeprof/core/ast/control.rb index 2e64bdec6..6c315699d 100644 --- a/lib/typeprof/core/ast/control.rb +++ b/lib/typeprof/core/ast/control.rb @@ -414,9 +414,9 @@ def subnodes = { pivot:, patterns:, clauses:, else_clause: } def install0(genv) ret = Vertex.new(self) - @pivot&.install(genv) + subject = @pivot ? @pivot.install(genv) : Source.new(genv.nil_type) @patterns.zip(@clauses) do |pattern, clause| - pattern.install(genv) + pattern.install_pattern(genv, subject) @changes.add_edge(genv, clause.install(genv), ret) end @changes.add_edge(genv, @else_clause.install(genv), ret) if @else_clause diff --git a/lib/typeprof/core/ast/misc.rb b/lib/typeprof/core/ast/misc.rb index 52db028d6..1224084bf 100644 --- a/lib/typeprof/core/ast/misc.rb +++ b/lib/typeprof/core/ast/misc.rb @@ -332,8 +332,7 @@ def initialize(raw_node, lenv) def subnodes = { value:, pat: } def install0(genv) - @value.install(genv) - @pat.install(genv) + @pat.install_pattern(genv, @value.install(genv)) Source.new(genv.nil_type) end end @@ -350,8 +349,7 @@ def initialize(raw_node, lenv) def subnodes = { value:, pat: } def install0(genv) - @value.install(genv) - @pat.install(genv) + @pat.install_pattern(genv, @value.install(genv)) Source.new(genv.true_type, genv.false_type) end end diff --git a/lib/typeprof/core/ast/pattern.rb b/lib/typeprof/core/ast/pattern.rb index d418a3611..e6592cf54 100644 --- a/lib/typeprof/core/ast/pattern.rb +++ b/lib/typeprof/core/ast/pattern.rb @@ -22,14 +22,19 @@ def initialize(raw_node, lenv) def attrs = { rest: } def subnodes = { requireds:, rest_pattern:, posts: } - def install0(genv) - @requireds.each do |pat| - pat.install(genv) + def install_pattern0(genv, subject) + @requireds.each_with_index do |pat, i| + pat.install_pattern(genv, @changes.add_splat_box(genv, subject, i).ret) + end + if @rest_pattern + elem = @changes.add_splat_box(genv, subject).ret + @rest_pattern.install_pattern(genv, Source.new(genv.gen_ary_type(elem))) end - @rest_pattern.install(genv) if @rest_pattern @posts.each do |pat| - pat.install(genv) + # TODO: precise indices for post elements (those after `*rest`) + pat.install_pattern(genv, @changes.add_splat_box(genv, subject).ret) end + subject end end @@ -47,11 +52,13 @@ def initialize(raw_node, lenv) def attrs = { keys:, rest: } def subnodes = { values:, rest_pattern: } - def install0(genv) + def install_pattern0(genv, subject) + # TODO: extract each key's value type from `subject` (captures stay untyped for now) @values.each do |pat| - pat.install(genv) + pat.install_pattern(genv, Vertex.new(self)) end - @rest_pattern.install(genv) if @rest_pattern + @rest_pattern.install_pattern(genv, Vertex.new(self)) if @rest_pattern + subject end end @@ -67,12 +74,15 @@ def initialize(raw_node, lenv) def subnodes = { left:, requireds:, right: } - def install0(genv) - @left.install(genv) if @left + def install_pattern0(genv, subject) + elem = @changes.add_splat_box(genv, subject).ret + rest_ary = Source.new(genv.gen_ary_type(elem)) + @left.install_pattern(genv, rest_ary) if @left @requireds.each do |pat| - pat.install(genv) + pat.install_pattern(genv, elem) end - @right.install(genv) if @right + @right.install_pattern(genv, rest_ary) if @right + subject end end @@ -87,9 +97,10 @@ def initialize(raw_node, lenv) def subnodes = { left:, right: } - def install0(genv) - @left.install(genv) - @right.install(genv) + def install_pattern0(genv, subject) + @left.install_pattern(genv, subject) + @right.install_pattern(genv, subject) + subject end end @@ -104,9 +115,18 @@ def initialize(raw_node, lenv) def subnodes = { value:, target: } - def install0(genv) - @value.install(genv) - @target.install(genv) + def install_pattern0(genv, subject) + @value.install_pattern(genv, subject) + # For `Const => var`, narrow the capture by the class, as `when Const` does + narrowed = + if @value.is_a?(ConstantReadNode) && @value.static_ret + filtered = subject.new_vertex(genv, self) + IsAFilter.new(genv, self, filtered, false, @value.static_ret).next_vtx + else + subject + end + @target.install_pattern(genv, narrowed) + subject end end @@ -124,9 +144,10 @@ def initialize(raw_node, lenv) def subnodes = { cond:, body: } - def install0(genv) + def install_pattern0(genv, subject) @cond.install(genv) - @body.install(genv) + @body.install_pattern(genv, subject) + subject end end diff --git a/lib/typeprof/core/ast/variable.rb b/lib/typeprof/core/ast/variable.rb index fc34adc0c..f42898b7b 100644 --- a/lib/typeprof/core/ast/variable.rb +++ b/lib/typeprof/core/ast/variable.rb @@ -63,6 +63,12 @@ def install0(genv) val end + def install_pattern0(genv, subject) + install0(genv) + @changes.add_edge(genv, subject, @rhs.ret) + subject + end + def retrieve_at(pos, &blk) yield self if @var_code_range && @var_code_range.include?(pos) super(pos, &blk) diff --git a/scenario/known-issues/pattern-capture-narrowing.rb b/scenario/known-issues/pattern-capture-narrowing.rb deleted file mode 100644 index 7812cb25e..000000000 --- a/scenario/known-issues/pattern-capture-narrowing.rb +++ /dev/null @@ -1,32 +0,0 @@ -## update -def check(a) - case a - in Integer => n - n + 1 - in String => s - s.upcase - end -end - -check(1) -check("foo") - -## assert -class Object - def check: (Integer | String) -> (Integer | String)? -end - -## update -def check(a) - case a - in [Integer => n, String => s] - [n, s] - end -end - -check([42, "foo"]) - -## assert -class Object - def check: ([Integer, String]) -> [Integer, String]? -end diff --git a/scenario/patterns/capture.rb b/scenario/patterns/capture.rb index 08ae48a40..59a3d18bb 100644 --- a/scenario/patterns/capture.rb +++ b/scenario/patterns/capture.rb @@ -1,3 +1,21 @@ +## update: test.rb +def check(a) + case a + in Integer => n + n + 1 + in String => s + s.upcase + end +end + +check(1) +check("foo") + +## assert +class Object + def check: (Integer | String) -> (Integer | String) +end + ## update: test.rb def check(a) case a @@ -7,3 +25,8 @@ def check(a) end check([42, "foo"]) + +## assert +class Object + def check: ([Integer, String]) -> [Integer, String] +end diff --git a/scenario/patterns/var_pat.rb b/scenario/patterns/var_pat.rb index cb7b86e0d..4cb961ff8 100644 --- a/scenario/patterns/var_pat.rb +++ b/scenario/patterns/var_pat.rb @@ -2,7 +2,7 @@ def check(x) case x in y - y # TODO! + y end end @@ -11,14 +11,14 @@ def check(x) ## assert class Object - def check: (Integer | String) -> untyped + def check: (Integer | String) -> (Integer | String) end ## update: test.rb def check(x) case x in a, b, c, *rest - [a, b, c, rest] # TODO! + [a, b, c, rest] # TODO: a, b, c stay untyped because x is not an array end end @@ -27,7 +27,7 @@ def check(x) ## assert class Object - def check: (Integer | String) -> [untyped, untyped, untyped, untyped] + def check: (Integer | String) -> [untyped, untyped, untyped, Array[untyped]] end