@@ -9,6 +9,7 @@ using MatrixBTy = Matrix<ComponentType::F32, 4, 4, MatrixUse::B, MatrixScope::Wa
99using MatrixBTyInt = Matrix<ComponentType::I32, 4 , 4 , MatrixUse::B, MatrixScope::Wave>;
1010using MatrixAccumTy = Matrix<ComponentType::F32, 4 , 4 , MatrixUse::Accumulator, MatrixScope::Wave>;
1111using TSMatrixATy = Matrix<ComponentType::F32, 4 , 4 , MatrixUse::A, MatrixScope::Thread>;
12+ using TSMatrixAccumTy = Matrix<ComponentType::F32, 4 , 4 , MatrixUse::Accumulator, MatrixScope::Thread>;
1213
1314ByteAddressBuffer BAB : register (t0);
1415RWByteAddressBuffer RWBAB : register (u0);
@@ -128,7 +129,7 @@ void main(uint ID : SV_GroupID)
128129// Matrix::InterlockedAccumulate to groupshared memory
129130//
130131// CHECK: call void @dx.op.linAlgMatrixAccumulateToMemory.mC9M4N4U2S1.f32(i32 -2147483620,
131- // CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U2S1 %18 ,
132+ // CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U2S1 %[[ACCUM0]] ,
132133// CHECK-SAME: float addrspace(3)* getelementptr inbounds ([256 x float],
133134// CHECK-SAME: [256 x float] addrspace(3)* @"\01?SharedArr@@3PAMA", i32 0, i32 0), i32 0, i32 16, i32 1)
134135// CHECK-SAME: ; LinAlgMatrixAccumulateToMemory(matrix,memory,offset,stride,layout)
@@ -166,16 +167,20 @@ void main(uint ID : SV_GroupID)
166167// Matrix::Load for thread-scope matrix
167168//
168169// CHECK: %[[TSMATA:.*]] = call %dx.types.LinAlgMatrixC9M4N4U0S0 @dx.op.linAlgMatrixLoadFromDescriptor.mC9M4N4U0S0(
169- // CHECK-SAME: i32 -2147483634, %dx.types.Handle %24 , i32 0, i32 16, i32 1, i32 4)
170+ // CHECK-SAME: i32 -2147483634, %dx.types.Handle %{{[0-9]+}} , i32 0, i32 16, i32 1, i32 4)
170171// CHECK-SAME: ; LinAlgMatrixLoadFromDescriptor(handle,offset,stride,layout,align)
171172 TSMatrixATy TSMatA = TSMatrixATy::Load<MatrixLayoutEnum::ColMajor>(BAB, 0 , 16 );
172173
173174// Matrix::InterlockedAccumulate for thread-scope matrix
174175//
175- // CHECK: call void @dx.op.linAlgMatrixAccumulateToDescriptor.mC9M4N4U0S0(i32 -2147483621,
176- // CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U0S0 %25, %dx.types.Handle %26, i32 0, i32 16, i32 1, i32 4)
176+ // CHECK: %[[TSACCUM:.*]] = call %dx.types.LinAlgMatrixC9M4N4U2S0 @dx.op.linAlgMatrixOuterProduct.mC9M4N4U2S0.v4f32.v4f32
177+ // CHECK: call void @dx.op.linAlgMatrixAccumulateToDescriptor.mC9M4N4U2S0(i32 -2147483621,
178+ // CHECK-SAME: %dx.types.LinAlgMatrixC9M4N4U2S0 %[[TSACCUM]], %dx.types.Handle %{{[0-9]+}}, i32 0, i32 16, i32 1, i32 4)
177179// CHECK-SAME: ; LinAlgMatrixAccumulateToDescriptor(matrix,handle,offset,stride,layout,align)
178- TSMatA.InterlockedAccumulate (RWBAB, 0 , 16 , MatrixLayoutEnum::ColMajor);
180+ vector <float , 4 > vec1 = 1.0f ;
181+ vector <float , 4 > vec2 = 2.0f ;
182+ TSMatrixAccumTy TSMatAccum = OuterProduct<ComponentType::F32>(vec1, vec2);
183+ TSMatAccum.InterlockedAccumulate (RWBAB, 0 , 16 , MatrixLayoutEnum::ColMajor);
179184
180185// CHECK: call i32 @dx.op.linAlgMatrixQueryAccumulatorLayout(i32 -2147483626) ; LinAlgMatrixQueryAccumulatorLayout()
181186 MatrixUseEnum layout = AccumulatorLayout ();
0 commit comments