Skip to content

Fix AMX support using MultiRamp#9122

Draft
abadams wants to merge 66 commits intomainfrom
abadams/fix_amx
Draft

Fix AMX support using MultiRamp#9122
abadams wants to merge 66 commits intomainfrom
abadams/fix_amx

Conversation

@abadams
Copy link
Copy Markdown
Member

@abadams abadams commented May 5, 2026

Rewrites the AMX support to use MultiRamp. Ignore this PR until after MultiRamp has gone in, though this is a good demonstration that the MultiRamp concept is useful. This, I believe, fixes the outstanding bugs in AMX support identified by #8350

Validated by running the AMX tests under SDE.

Future work is generalizing the AMX support to be willing to ingest larger vectors, and automatically slice it up into multiple tile-level operations.

abadams and others added 30 commits January 26, 2026 15:52
The previous comment reported a time that seemed to have regressed. It
was not 8.2ms on main - more like 11
Before:

Computing best tile sizes for each type
.................................................
bytes, tile width, tile height, bandwidth (GB/s):
1 8 8 20.9997
1 16 8 20.8329
1 8 16 18.5702
1 8 32 17.2463
1 8 64 14.312

2 8 16 19.2047
2 8 8 18.8368
2 16 8 17.0593
2 8 32 17.0591
2 4 8 15.7681

4 8 8 24.9364
4 4 16 22.9699
4 8 16 22.5743
4 4 32 22.255
4 4 8 20.4468

8 8 8 38.4094
8 16 4 28.4167
8 16 8 27.6184
8 8 4 27.6062
8 8 16 26.8693

After:

Computing best tile sizes for each type
.................................................
bytes, tile width, tile height, bandwidth (GB/s):
1 16 32 34.1921
1 16 16 31.8399
1 8 16 25.575
1 16 64 25.1665
1 32 16 25.0061

2 8 32 28.2635
2 8 16 27.7648
2 16 16 27.2126
2 16 32 23.9034
2 8 8 23.6345

4 8 16 34.5303
4 8 8 28.3653
4 16 8 26.8521
4 8 32 26.084
4 16 16 24.4519

8 8 8 33.7163
8 8 4 29.1339
8 4 16 26.418
8 16 4 25.4663
8 2 8 24.3949
Also better algorithm for innermost containing stmt
abadams and others added 28 commits March 11, 2026 13:09
Co-authored-by: Claude Code <[email protected]>
Adds a MultiRamp IR helper that generalizes the old InterleavedRamp: a
nested ramp with a scalar base, a vector of strides (innermost first),
and a vector of per-dim lane counts. Supports in-place mul/add/div/mod
with symbolic strides where possible, reorder, slice, flatten into 1D
ramps, shuffle index construction for permutations and slices, and an
alias-free predicate.

Replaces InterleavedRamp recognition and handling in VectorizeLoops with
MultiRamp. The reduction-store path peels stride-zero and non-alias-free
dims (turning the latter into unrolled containing loops), computes the
per-iteration shuffle mask from the pre-peel shape via
shuffle_from_slice, and gracefully falls back when alias-freedom can't
be proven.

Wires MultiRamp into the simplifier: the Load and Store rules that
recognize a ramp-of-ramp index now use MultiRamp to rotate the stride-1
dim outermost via a single Shuffle::make_transpose, fixing a latent
correctness bug in the old rule for triply-nested ramps.

In FlattenNestedRamps, teaches the Load and Store visitors to recognize
multiramp indices and emit a concat of per-outer-multi-index 1D ramp
loads/stores, rather than a single flat-indexed load that downstream
passes struggle to combine. The existing bounded-span-to-dense-load
path runs first so strided-gather patterns (e.g. HVX vdelta) are
preserved.

Adds correctness tests for the MultiRamp API (test/correctness/multiramp.cpp)
and a nested-vectorize reduction test (transposed_vector_reduce.cpp).

Co-authored-by: Claude <[email protected]>
API and doc cleanup on MultiRamp:
 - Add a real invariants block to the class header; clarify that 0-dim
   (scalar) multiramps are legal and all methods handle them.
 - Switch all class member comments to doxygen style.
 - Reword alias_free, alias_free_slice, rotate_stride_one_innermost, and
   the shuffle_from_* overloads so each leads with what it does.
 - Fix is_multiramp's Mul branch to use fresh local MultiRamps rather
   than letting a failed first attempt leak partial state into the
   second.
 - Drop the single-dim shuffle_from_slice overload in favour of the
   multi-dim version.
 - Relax add() so it handles 0-dim inputs trivially (base+base); this
   also makes operator== work for 0-dim via its existing add path.
 - Add accept/mutate methods (Function-idiom) so callers don't reach
   into base/strides to walk scalar subexpressions.
 - Add alias_free_slice (replaces ad-hoc in-caller peeling) and
   rotate_stride_one_innermost (replaces near-duplicate dance in the
   simplifier rules).

Use those APIs from VectorizeLoops, FlattenNestedRamps, Simplify_Exprs,
and Simplify_Stmts. In particular the atomic-store reduction block in
VectorizeLoops is restructured: one alias_free_slice call discovers
both stride-zero peels (handled via VectorReduce or a tree reduction
over a reordered b) and symbolic/overlapping aliasing peels (handled
via an unrolled cartesian-product loop block). The b's current lane
layout is tracked as a MultiRamp (b_shape_mr) so the per-iteration
slice of the reduced vector can be computed via shuffle_from_slice.

Extend the downsampling/atomic-vectorize test (downsampling_reduce.cpp)
to exercise MultiRamp::div through the vectorize path, and expand the
multiramp API tests.

Co-authored-by: Claude <[email protected]>
Three places in Halide iterate the cartesian product of a box of
integer sizes — MultiRamp::shuffle_from_permuted, MultiRamp::flatten,
MultiRamp::shuffle_from_slice, and the unroll block in VectorizeLoops's
atomic-store reduction path. Replace each manual decompose-via-
rem%/rem/ loop with a call to a new for_each_coordinate helper in
Util.h that invokes a callback on each coordinate in lex order.

Co-authored-by: Claude <[email protected]>
is_multiramp's recursive impl may leave its output in a partial state
on failure (e.g. after a successful Ramp::base recursion followed by a
failed stride check). The contract previously told callers "don't read
*result on failure," but honoring that required every recursive call
site to use a fresh local MultiRamp. Split is_multiramp into an
internal impl that keeps the old partial-state behavior plus a thin
public wrapper that commits to *result only on success. The public
contract is now the cleaner "untouched on failure." The Mul branch's
fresh-local workaround falls out.

VectorizeLoops's atomic-store reduction path needs to shuffle `b` from
its original lane order into a permuted order (so a subsequent
reduction tree can slice contiguous sub-vectors per stride-zero peel).
It was calling shuffle_from_permuted directly, but that method returns
indices for the opposite direction — "permuted → original," which is
what the Simplify_Exprs caller wants. Invert the result as a
permutation. Without this, any case with multiple stride-zero peels
that needed a real reorder summed lanes that should have gone to
different output addresses.

Simplify_Shuffle's "slice of concat" rule drops concat vectors that
don't overlap with the slice's range. The overlap check tested whether
the concat vector's start OR its last lane was inside the slice, but
missed the case where the slice is entirely contained within one
concat vector (neither endpoint inside). When every vector contained
the slice, new_concat_vectors came out empty and Shuffle::make_concat
tripped its empty-vector assert. Replaced with a standard interval-
overlap check. This is a pre-existing bug that the randomized test
exercised through multiramp-derived slice patterns.

Co-authored-by: Claude <[email protected]>
Add hand-picked tests for MultiRamp API properties that weren't
previously covered: mul, operator==, alias_free_slice (unique lanes /
zero-stride peeling / degenerate scalar), rotate_stride_one_innermost
(rotation + transpose round-trip), and is_multiramp round-trips for a
handful of shapes.

Add test_random to transposed_vector_reduce.cpp: 1000 random
quasi-affine store/load index pairs over a 3-dim RDom, each compiled
scalarly and with .atomic().vectorize() across all three RVars and
compared. This test found all three bugs fixed in the preceding
commit.

Co-authored-by: Claude <[email protected]>
Ran a weak subagent (Haiku) over the MultiRamp PR as an adversarial
comprehension test — asking it to explain the code in detail, then
fixing whatever it got wrong. The theory: if a weaker model
misreads something, the comment is probably unclear, not the model.

Fixes prompted by the review:

- Simplify_Exprs.cpp / Simplify_Stmts.cpp: stale "outermost" wording
  from before rotate_stride_one_outermost was renamed to
  rotate_stride_one_innermost. The comments contradicted the function
  name and Haiku echoed the contradiction.
- MultiRamp.h alias_free: state explicitly that the returned Expr is
  a sufficient (not necessary) condition for lane uniqueness.
- MultiRamp.h alias_free_slice: clarify that kept dims are a subset
  preserving order, not necessarily a prefix.
- VectorizeLoops.cpp: rename ContainingLoop -> UnrolledLoop and note
  that the peeled dims are fully unrolled into a flat Block, not a
  runtime loop nest (despite the old name).
- MultiRamp.h alias_free_slice: note that stride-zero and purely
  symbolic strides always peel (added by Andrew directly).

A second Haiku pass after these edits answered every question
correctly, including the ones it got wrong the first time.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
Replaces the four hand-rolled AST pattern matchers in
ExtractTileOperations with a uniform multiramp-based approach. The
previous matchers were brittle (Add operand ordering, tile_x/tile_y
swap, missing stride checks, etc. — see #8350) and had separate
branches for the broadcast / collapsed / general RHS cases.

The new flow lifts each load index to a MultiRamp, then coerces it
into the canonical AMX shape via MultiRamp::strides_for_shape (a new
one-sided gcd-walk that returns per-target-dim strides), and reads
off the extracted strides from known slot positions. The inner-K /
broadcast / row-stride bits all fall out of the same path.

Also adds:
- A normalizing MultiRamp constructor that drops extent-1 dims.
- MultiRamp::strides_for_shape to map an MR's lane sequence onto a
  caller-specified dim shape.
- Asymmetric tile-size cases to the correctness test (tile_x !=
  tile_y) — these would have caught the swap bug listed in #8350.
- Constraints on the perf-test ImageParams' VNNI inner stride so the
  contiguity check now folds away cleanly.

Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
@abadams abadams marked this pull request as draft May 5, 2026 22:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants