Skip to content

Commit a1805a8

Browse files
committed
Accum to memory
1 parent 36faa95 commit a1805a8

1 file changed

Lines changed: 80 additions & 1 deletion

File tree

tools/clang/unittests/HLSLExec/LinAlgTests.cpp

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ class DxilConf_SM610_LinAlg {
325325
// Load/Store/Accumulate Memory
326326
TEST_METHOD(LoadMemory_Wave_16x16_F16);
327327
TEST_METHOD(StoreMemory_Wave_16x16_F16);
328+
TEST_METHOD(AccumulateMemory_Wave_16x16_F16);
328329

329330
// Element access
330331
TEST_METHOD(ElementAccess_Wave_16x16_F16);
@@ -1444,7 +1445,6 @@ void DxilConf_SM610_LinAlg::QueryAccumLayout() {
14441445
static const char LoadMemoryShader[] = R"(
14451446
RWByteAddressBuffer Input : register(u0);
14461447
RWByteAddressBuffer Output : register(u1);
1447-
14481448
groupshared ELEM_TYPE GsData[M_DIM * N_DIM];
14491449
14501450
#define ELEM_PER_THREAD (M_DIM * N_DIM / NUMTHREADS)
@@ -1593,4 +1593,83 @@ void DxilConf_SM610_LinAlg::StoreMemory_Wave_16x16_F16() {
15931593
runStoreMemory(D3DDevice, DxcSupport, Params, VerboseLogging, /*FillValue=*/7.0f);
15941594
}
15951595

1596+
static const char AccumulateMemoryShader[] = R"(
1597+
RWByteAddressBuffer Output : register(u0);
1598+
groupshared ELEM_TYPE GsData[M_DIM * N_DIM];
1599+
1600+
#define ELEM_PER_THREAD (M_DIM * N_DIM / NUMTHREADS)
1601+
1602+
[WaveSize(4, 64)]
1603+
[numthreads(NUMTHREADS, 1, 1)]
1604+
void main(uint threadID : SV_GroupIndex) {
1605+
ELEM_TYPE fill = FILL_VALUE;
1606+
for (uint I = 0; I < ELEM_PER_THREAD; ++I) {
1607+
uint Index = threadID * ELEM_PER_THREAD + I;
1608+
GsData[Index] = fill;
1609+
}
1610+
1611+
GroupMemoryBarrierWithGroupSync();
1612+
1613+
if (WaveReadLaneFirst(threadID) != 0)
1614+
return;
1615+
1616+
__builtin_LinAlgMatrix
1617+
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
1618+
Mat;
1619+
__builtin_LinAlg_FillMatrix(Mat, FILL_VALUE);
1620+
1621+
__builtin_LinAlg_MatrixAccumulateToMemory(
1622+
Mat, GsData, OFFSET, STRIDE, LAYOUT);
1623+
1624+
for (uint I = 0; I < M_DIM*N_DIM; ++I) {
1625+
Output.Store<ELEM_TYPE>(I*ELEM_SIZE, GsData[I]);
1626+
}
1627+
}
1628+
)";
1629+
1630+
static void runAccumulateMemory(ID3D12Device *Device,
1631+
dxc::SpecificDllLoader &DxcSupport,
1632+
const MatrixParams &Params, bool Verbose,
1633+
float FillValue) {
1634+
const size_t NumElements = Params.totalElements();
1635+
const size_t BufferSize = Params.totalBytes();
1636+
1637+
std::stringstream ExtraDefs;
1638+
ExtraDefs << " -DOFFSET=" << 0;
1639+
ExtraDefs << " -DFILL_VALUE=" << FillValue;
1640+
1641+
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());
1642+
1643+
compileShader(DxcSupport, AccumulateMemoryShader, "cs_6_10", Args,
1644+
Verbose);
1645+
1646+
auto Expected = makeExpectedMat(Params.CompType, Params.M, Params.N, FillValue * 2, /*Increment=*/false);
1647+
1648+
auto Op = createComputeOp(AccumulateMemoryShader, "cs_6_10", "UAV(u0)",
1649+
Args.c_str());
1650+
addUAVBuffer(Op.get(), "Output", BufferSize, true);
1651+
addRootUAV(Op.get(), 0, "Output");
1652+
1653+
auto Result = runShaderOp(Device, DxcSupport, std::move(Op));
1654+
1655+
MappedData OutData;
1656+
Result->Test->GetReadBackData("Output", &OutData);
1657+
1658+
VERIFY_IS_TRUE(verifyComponentBuffer(Params.CompType, OutData.data(),
1659+
Expected, NumElements, Verbose));
1660+
}
1661+
1662+
void DxilConf_SM610_LinAlg::AccumulateMemory_Wave_16x16_F16() {
1663+
MatrixParams Params = {};
1664+
Params.CompType = ComponentType::F16;
1665+
Params.M = 16;
1666+
Params.N = 16;
1667+
Params.Use = MatrixUse::Accumulator;
1668+
Params.Scope = MatrixScope::Wave;
1669+
Params.Layout = LinalgMatrixLayout::RowMajor;
1670+
Params.NumThreads = 64;
1671+
Params.Enable16Bit = true;
1672+
runAccumulateMemory(D3DDevice, DxcSupport, Params, VerboseLogging, /*FillValue=*/7.0f);
1673+
}
1674+
15961675
} // namespace LinAlg

0 commit comments

Comments
 (0)