Skip to content
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions tools/clang/unittests/HLSLExec/LinAlgTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,18 @@ class DxilConf_SM610_LinAlg {
TEST_CLASS_SETUP(setupClass);
TEST_METHOD_SETUP(setupMethod);

// Load/Store
TEST_METHOD(LoadStoreRoundtrip_Wave_F32);
TEST_METHOD(LoadStoreRoundtrip_Wave_I32);

// Splat Store
TEST_METHOD(SplatStore_Wave_F32);
TEST_METHOD(SplatStore_Wave_I32);

// Element access
TEST_METHOD(ElementAccess_Wave_F32);
TEST_METHOD(ElementAccess_Wave_I32);

private:
bool createDevice();

Expand Down Expand Up @@ -389,6 +396,7 @@ static void runSplatStore(ID3D12Device *Device,
}
#endif

// Build expected data.
Comment thread
alsepkow marked this conversation as resolved.
Outdated
std::vector<float> ExpectedFloats;
std::vector<int32_t> ExpectedInts;
switch (Params.CompType) {
Expand Down Expand Up @@ -456,4 +464,192 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_I32() {
runSplatStore(D3DDevice, DxcSupport, Params, 7.0f, VerboseLogging);
}

static const char ElementAccessShader[] = R"(
RWByteAddressBuffer Input : register(u0);
RWByteAddressBuffer Output : register(u1);

// flatten the 2D index into a 1d index then scale by element size
Comment thread
V-FEXrt marked this conversation as resolved.
Outdated
uint cordToByteOffset(uint2 coord) {
Comment thread
V-FEXrt marked this conversation as resolved.
Outdated
return (coord.x * MAJOR_DIM + coord.y) * ELEM_SIZE;
}

[numthreads(NUMTHREADS, 1, 1)]
void main(uint threadIndex : SV_GroupIndex) {
__builtin_LinAlgMatrix
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
Mat;
__builtin_LinAlg_MatrixLoadFromDescriptor(
Comment thread
bob80905 marked this conversation as resolved.
Mat, Input, 0, STRIDE, LAYOUT, 0);

// Copy Matrix values from input to output without assuming order
for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) {
uint2 Coord = __builtin_LinAlg_MatrixGetCoordinate(Mat, I);
uint Offset = cordToByteOffset(Coord);
#if COMP_TYPE == 9
Comment thread
V-FEXrt marked this conversation as resolved.
Outdated
Comment thread
V-FEXrt marked this conversation as resolved.
Outdated
float Elem;
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
Output.Store(Offset, asuint(Elem));
#else
uint Elem;
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
Output.Store(Offset, Elem);
#endif
}

// Store each threads Length in the output after the copied matrix
Comment thread
V-FEXrt marked this conversation as resolved.
Outdated
uint finalIdx = (M_DIM * N_DIM + threadIndex) * ELEM_SIZE;
Comment thread
V-FEXrt marked this conversation as resolved.
Outdated
uint Len = __builtin_LinAlg_MatrixLength(Mat);
Output.Store(finalIdx, Len);
}
)";

static void runElementAccess(ID3D12Device *Device,
dxc::SpecificDllLoader &DxcSupport,
const MatrixParams &Params, int MajorDim,
Comment thread
V-FEXrt marked this conversation as resolved.
Outdated
bool Verbose) {
const size_t NumElements = Params.totalElements();
const size_t NumThreads = Params.NumThreads;
const size_t InputBufSize = Params.totalBytes();
const size_t ElementSize = elemSize(Params.CompType);
Comment thread
V-FEXrt marked this conversation as resolved.
Outdated
// Output: ElementSize bytes per element
// 1 element for each mat idx
// 1 element for each thread's length
const size_t OutputBufSize = (NumElements + NumThreads) * ElementSize;

std::stringstream ExtraDefs;
ExtraDefs << " -DMAJOR_DIM=" << MajorDim;
ExtraDefs << " -DELEM_SIZE=" << ElementSize;
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());

compileShader(DxcSupport, ElementAccessShader, "cs_6_10", Args, Verbose);

#ifndef _HLK_CONF
Comment thread
alsepkow marked this conversation as resolved.
// Skip GPU execution if no device.
if (!Device) {
hlsl_test::LogCommentFmt(
L"Shader compiled OK; skipping execution (no SM 6.10 device)");
WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped);
return;
}
#endif

// Build expected data.
std::vector<float> ExpectedFloats(NumElements);
std::vector<int32_t> ExpectedInts(NumElements);
for (size_t I = 0; I < NumElements; I++) {
ExpectedFloats[I] = static_cast<float>(I + 1);
ExpectedInts[I] = static_cast<int32_t>(I + 1);
}

auto Op = createComputeOp(ElementAccessShader, "cs_6_10", "UAV(u0), UAV(u1)",
Args.c_str());
addUAVBuffer(Op.get(), "Input", InputBufSize, false, "byname");
addUAVBuffer(Op.get(), "Output", OutputBufSize, true);
addRootUAV(Op.get(), 0, "Input");
addRootUAV(Op.get(), 1, "Output");

auto Result =
runShaderOp(Device, DxcSupport, std::move(Op),
[&](LPCSTR Name, std::vector<BYTE> &Data, st::ShaderOp *) {
if (_stricmp(Name, "Input") != 0)
return;

switch (Params.CompType) {
case ComponentType::F32: {
float *Ptr = reinterpret_cast<float *>(Data.data());
for (size_t I = 0; I < NumElements; I++)
Ptr[I] = static_cast<float>(I + 1);
break;
}
case ComponentType::I32: {
int32_t *Ptr = reinterpret_cast<int32_t *>(Data.data());
for (size_t I = 0; I < NumElements; I++)
Ptr[I] = static_cast<int32_t>(I + 1);
break;
}
default:
VERIFY_IS_TRUE(false, "Saw unsupported component type");
break;
}
});

MappedData OutData;
Result->Test->GetReadBackData("Output", &OutData);
const uint32_t *Out = static_cast<const uint32_t *>(OutData.data());

// Build actual data.
Comment thread
V-FEXrt marked this conversation as resolved.
Outdated
std::vector<float> ActualFloats(NumElements);
std::vector<int32_t> ActualInts(NumElements);
for (size_t I = 0; I < NumElements * ElementSize; I = I + ElementSize) {
Comment thread
bob80905 marked this conversation as resolved.
Outdated
switch (Params.CompType) {
case ComponentType::F32: {
float Actual;
memcpy(&Actual, &Out[I], sizeof(float));
ActualFloats[I / 4] = Actual;
break;
}
case ComponentType::I32: {
ActualInts[I / 4] = Out[I];
break;
}
default:
VERIFY_IS_TRUE(false, "Saw unsupported component type");
break;
}
}

// Verify element values match input data.
switch (Params.CompType) {
case ComponentType::F32:
VERIFY_IS_TRUE(verifyFloatBuffer(ActualFloats.data(), ExpectedFloats.data(),
NumElements, Verbose));
break;
case ComponentType::I32:
VERIFY_IS_TRUE(verifyIntBuffer(ActualInts.data(), ExpectedInts.data(),
NumElements, Verbose));
break;
default:
VERIFY_IS_TRUE(false, "Saw unsupported component type");
break;
}

// The sum of the values returned by Length across all threads must be
// greater than or equal to the total number of matrix elements
size_t MatrixEndOffset = NumElements * ElementSize;
size_t LengthValuesEnd = MatrixEndOffset + (NumThreads * sizeof(uint32_t));
size_t TotalLength = 0;
for (size_t I = MatrixEndOffset; I < LengthValuesEnd;
I = I + sizeof(uint32_t)) {
TotalLength += Out[I];
}
VERIFY_IS_TRUE(TotalLength >= NumElements,
"Sum of all lengths must be gte num elements");
}

void DxilConf_SM610_LinAlg::ElementAccess_Wave_F32() {
MatrixParams Params = {};
Params.CompType = ComponentType::F32;
Params.M = 4;
Params.N = 4;
Params.Use = MatrixUse::Accumulator;
Params.Scope = MatrixScope::Wave;
Params.Layout = LinalgMatrixLayout::RowMajor;
Params.NumThreads = 4;
Params.Enable16Bit = false;
runElementAccess(D3DDevice, DxcSupport, Params, Params.M, VerboseLogging);
}

void DxilConf_SM610_LinAlg::ElementAccess_Wave_I32() {
MatrixParams Params = {};
Params.CompType = ComponentType::I32;
Params.M = 4;
Params.N = 4;
Params.Use = MatrixUse::Accumulator;
Params.Scope = MatrixScope::Wave;
Params.Layout = LinalgMatrixLayout::RowMajor;
Params.NumThreads = 4;
Params.Enable16Bit = false;
runElementAccess(D3DDevice, DxcSupport, Params, Params.M, VerboseLogging);
}

} // namespace LinAlg
Loading