Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions tools/clang/unittests/HLSLExec/LinAlgTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,18 @@ class DxilConf_SM610_LinAlg {
TEST_CLASS_SETUP(setupClass);
TEST_METHOD_SETUP(setupMethod);

// Load/Store
TEST_METHOD(LoadStoreRoundtrip_Wave_F32);
TEST_METHOD(LoadStoreRoundtrip_Wave_I32);

// Splat Store
TEST_METHOD(SplatStore_Wave_F32);
TEST_METHOD(SplatStore_Wave_I32);

// Element access
TEST_METHOD(ElementAccess_Wave_F32);
TEST_METHOD(ElementAccess_Wave_I32);

private:
bool createDevice();

Expand Down Expand Up @@ -389,6 +396,7 @@ static void runSplatStore(ID3D12Device *Device,
}
#endif

// Build expected data.
std::vector<float> ExpectedFloats;
std::vector<int32_t> ExpectedInts;
switch (Params.CompType) {
Expand Down Expand Up @@ -456,4 +464,189 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_I32() {
runSplatStore(D3DDevice, DxcSupport, Params, 7.0f, VerboseLogging);
}

static const char ElementAccessShader[] = R"(
RWByteAddressBuffer Input : register(u0);
RWByteAddressBuffer Output : register(u1);

// flatten the 2D index into a 1d index then scale by element size
uint cordToByteOffset(uint2 coord) {
return (coord.x * MAJOR_DIM + coord.y) * ELEM_SIZE;
}

[numthreads(NUMTHREADS, 1, 1)]
void main(uint threadIndex : SV_GroupIndex) {
__builtin_LinAlgMatrix
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
Mat;
__builtin_LinAlg_MatrixLoadFromDescriptor(
Mat, Input, 0, STRIDE, LAYOUT, 0);

// Copy Matrix values from input to output without assuming order
for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) {
uint2 Coord = __builtin_LinAlg_MatrixGetCoordinate(Mat, I);
uint Offset = cordToByteOffset(Coord);
#if COMP_TYPE == 9
float Elem;
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
Output.Store(Offset, asuint(Elem));
#else
uint Elem;
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
Output.Store(Offset, Elem);
#endif
}

// Store each threads Length in the output after the copied matrix
uint finalIdx = (M_DIM * N_DIM + threadIndex) * ELEM_SIZE;
uint Len = __builtin_LinAlg_MatrixLength(Mat);
Output.Store(finalIdx, Len);
}
)";

static void runElementAccess(ID3D12Device *Device,
dxc::SpecificDllLoader &DxcSupport,
const MatrixParams &Params, int MajorDim, bool Verbose) {
const size_t NumElements = Params.totalElements();
const size_t NumThreads = Params.NumThreads;
const size_t InputBufSize = Params.totalBytes();
const size_t ElementSize = elemSize(Params.CompType);
// Output: ElementSize bytes per element
// 1 element for each mat idx
// 1 element for each thread's length
const size_t OutputBufSize = (NumElements + NumThreads) * ElementSize;

std::stringstream ExtraDefs;
ExtraDefs << " -DMAJOR_DIM=" << MajorDim;
ExtraDefs << " -DELEM_SIZE=" << ElementSize;
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());

compileShader(DxcSupport, ElementAccessShader, "cs_6_10", Args, Verbose);

#ifndef _HLK_CONF
// Skip GPU execution if no device.
if (!Device) {
hlsl_test::LogCommentFmt(
L"Shader compiled OK; skipping execution (no SM 6.10 device)");
WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped);
return;
}
#endif

// Build expected data.
std::vector<float> ExpectedFloats(NumElements);
std::vector<int32_t> ExpectedInts(NumElements);
for (size_t I = 0; I < NumElements; I++) {
ExpectedFloats[I] = static_cast<float>(I + 1);
ExpectedInts[I] = static_cast<int32_t>(I + 1);
}

auto Op = createComputeOp(ElementAccessShader, "cs_6_10", "UAV(u0), UAV(u1)",
Args.c_str());
addUAVBuffer(Op.get(), "Input", InputBufSize, false, "byname");
addUAVBuffer(Op.get(), "Output", OutputBufSize, true);
addRootUAV(Op.get(), 0, "Input");
addRootUAV(Op.get(), 1, "Output");

auto Result =
runShaderOp(Device, DxcSupport, std::move(Op),
[&](LPCSTR Name, std::vector<BYTE> &Data, st::ShaderOp *) {
if (_stricmp(Name, "Input") != 0)
return;

switch (Params.CompType) {
case ComponentType::F32: {
float *Ptr = reinterpret_cast<float *>(Data.data());
for (size_t I = 0; I < NumElements; I++)
Ptr[I] = static_cast<float>(I + 1);
break;
}
case ComponentType::I32: {
int32_t *Ptr = reinterpret_cast<int32_t *>(Data.data());
for (size_t I = 0; I < NumElements; I++)
Ptr[I] = static_cast<int32_t>(I + 1);
break;
}
default:
VERIFY_IS_TRUE(false, "Saw unsupported component type");
break;
}
});

MappedData OutData;
Result->Test->GetReadBackData("Output", &OutData);
const uint32_t *Out = static_cast<const uint32_t *>(OutData.data());

// Build actual data.
std::vector<float> ActualFloats(NumElements);
std::vector<int32_t> ActualInts(NumElements);
for (size_t I = 0; I < NumElements * 4; I = I + 4) {
switch (Params.CompType) {
case ComponentType::F32: {
float Actual;
memcpy(&Actual, &Out[I], sizeof(float));
ActualFloats[I / 4] = Actual;
break;
}
case ComponentType::I32: {
ActualInts[I / 4] = Out[I];
break;
}
default:
VERIFY_IS_TRUE(false, "Saw unsupported component type");
break;
}
}

// Verify element values match input data.
switch (Params.CompType) {
case ComponentType::F32:
VERIFY_IS_TRUE(verifyFloatBuffer(ActualFloats.data(), ExpectedFloats.data(),
NumElements, Verbose));
break;
case ComponentType::I32:
VERIFY_IS_TRUE(verifyIntBuffer(ActualInts.data(), ExpectedInts.data(),
NumElements, Verbose));
break;
default:
VERIFY_IS_TRUE(false, "Saw unsupported component type");
break;
}

// The sum of the values returned by Length across all threads must be
// greater than or equal to the total number of matrix elements
size_t TotalLength = 0;
for (size_t I = NumElements * 4; I < (NumElements + NumThreads) * 4;
I = I + 4) {
TotalLength += Out[I];
}
VERIFY_IS_TRUE(TotalLength >= NumElements,
"Sum of all lengths must be gte num elements");
}

void DxilConf_SM610_LinAlg::ElementAccess_Wave_F32() {
MatrixParams Params = {};
Params.CompType = ComponentType::F32;
Params.M = 4;
Params.N = 4;
Params.Use = MatrixUse::Accumulator;
Params.Scope = MatrixScope::Wave;
Params.Layout = LinalgMatrixLayout::RowMajor;
Params.NumThreads = 4;
Params.Enable16Bit = false;
runElementAccess(D3DDevice, DxcSupport, Params, Params.M, VerboseLogging);
}

void DxilConf_SM610_LinAlg::ElementAccess_Wave_I32() {
MatrixParams Params = {};
Params.CompType = ComponentType::I32;
Params.M = 4;
Params.N = 4;
Params.Use = MatrixUse::Accumulator;
Params.Scope = MatrixScope::Wave;
Params.Layout = LinalgMatrixLayout::RowMajor;
Params.NumThreads = 4;
Params.Enable16Bit = false;
runElementAccess(D3DDevice, DxcSupport, Params, Params.M, VerboseLogging);
}

} // namespace LinAlg
Loading