Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 204 additions & 4 deletions tools/clang/unittests/HLSLExec/LinAlgTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ using hlsl::DXIL::MatrixScope;
using hlsl::DXIL::MatrixUse;

/// Return the byte size of a single element for the given component type.
static int elemSize(ComponentType CT) {
static int elementSize(ComponentType CT) {
switch (CT) {
case ComponentType::F16:
case ComponentType::I16:
Expand All @@ -64,22 +64,23 @@ struct MatrixParams {
bool Enable16Bit;

int strideBytes() const {
int ES = elemSize(CompType);
int ES = elementSize(CompType);
if (Layout == LinalgMatrixLayout::RowMajor)
return N * ES;
return M * ES;
}

size_t totalElements() const { return static_cast<size_t>(M) * N; }

size_t totalBytes() const { return totalElements() * elemSize(CompType); }
size_t totalBytes() const { return totalElements() * elementSize(CompType); }
};

static std::string buildCompilerArgs(const MatrixParams &Params,
const char *ExtraDefines = nullptr) {
std::stringstream SS;
SS << "-HV 202x";
SS << " -DCOMP_TYPE=" << static_cast<int>(Params.CompType);
SS << " -DCOMP_TYPE_F32=" << 9;
SS << " -DM_DIM=" << Params.M;
SS << " -DN_DIM=" << Params.N;
SS << " -DUSE=" << static_cast<int>(Params.Use);
Expand Down Expand Up @@ -152,11 +153,18 @@ class DxilConf_SM610_LinAlg {
TEST_CLASS_SETUP(setupClass);
TEST_METHOD_SETUP(setupMethod);

// Load/Store
TEST_METHOD(LoadStoreRoundtrip_Wave_F32);
TEST_METHOD(LoadStoreRoundtrip_Wave_I32);

// Splat Store
TEST_METHOD(SplatStore_Wave_F32);
TEST_METHOD(SplatStore_Wave_I32);

// Element access
TEST_METHOD(ElementAccess_Wave_F32);
TEST_METHOD(ElementAccess_Wave_I32);

private:
bool createDevice();

Expand Down Expand Up @@ -266,7 +274,6 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device,
}
#endif

// Build expected data.
std::vector<float> ExpectedFloats(NumElements);
std::vector<int32_t> ExpectedInts(NumElements);
for (size_t I = 0; I < NumElements; I++) {
Expand Down Expand Up @@ -456,4 +463,197 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_I32() {
runSplatStore(D3DDevice, DxcSupport, Params, 7.0f, VerboseLogging);
}

static const char ElementAccessShader[] = R"(
RWByteAddressBuffer Input : register(u0);
RWByteAddressBuffer Output : register(u1);

// flatten the 2D index into a 1D index then scale by element size
uint coordToByteOffset(uint2 coord) {
return (coord.x * MAJOR_DIM + coord.y) * ELEM_SIZE;
}

[numthreads(NUMTHREADS, 1, 1)]
void main(uint threadIndex : SV_GroupIndex) {
__builtin_LinAlgMatrix
[[__LinAlgMatrix_Attributes(COMP_TYPE, M_DIM, N_DIM, USE, SCOPE)]]
Mat;
__builtin_LinAlg_MatrixLoadFromDescriptor(
Comment thread
bob80905 marked this conversation as resolved.
Mat, Input, 0, STRIDE, LAYOUT, 0);

// Copy Matrix values from input to output without assuming order
for (uint I = 0; I < __builtin_LinAlg_MatrixLength(Mat); ++I) {
uint2 Coord = __builtin_LinAlg_MatrixGetCoordinate(Mat, I);
uint Offset = coordToByteOffset(Coord);
#if COMP_TYPE == COMP_TYPE_F32
float Elem;
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
Output.Store(Offset, asuint(Elem));
#else
uint Elem;
__builtin_LinAlg_MatrixGetElement(Elem, Mat, I);
Output.Store(Offset, Elem);
#endif
}

// Save the matrix length that this thread saw. The length is written
// to the output right after the matrix, offset by the thread index
uint LenIdx = (M_DIM * N_DIM * ELEM_SIZE) + (threadIndex * sizeof(uint));
uint Len = __builtin_LinAlg_MatrixLength(Mat);
Output.Store(LenIdx, Len);
}
)";

static void runElementAccess(ID3D12Device *Device,
dxc::SpecificDllLoader &DxcSupport,
const MatrixParams &Params, bool Verbose) {
const size_t NumElements = Params.totalElements();
const size_t NumThreads = Params.NumThreads;
const size_t InputBufSize = Params.totalBytes();
const size_t ElementSize = elementSize(Params.CompType);
const size_t MajorDim =
Params.Layout == LinalgMatrixLayout::RowMajor ? Params.M : Params.N;
// Output: ElementSize bytes per element
// 1 element for each mat idx
// 1 uint for each thread's length
const size_t OutputBufSize =
NumElements * ElementSize + NumThreads * sizeof(uint32_t);

std::stringstream ExtraDefs;
ExtraDefs << " -DMAJOR_DIM=" << MajorDim;
ExtraDefs << " -DELEM_SIZE=" << ElementSize;
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());

compileShader(DxcSupport, ElementAccessShader, "cs_6_10", Args, Verbose);

#ifndef _HLK_CONF
Comment thread
alsepkow marked this conversation as resolved.
// Skip GPU execution if no device.
if (!Device) {
hlsl_test::LogCommentFmt(
L"Shader compiled OK; skipping execution (no SM 6.10 device)");
WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped);
return;
}
#endif

std::vector<float> ExpectedFloats(NumElements);
std::vector<int32_t> ExpectedInts(NumElements);
for (size_t I = 0; I < NumElements; I++) {
ExpectedFloats[I] = static_cast<float>(I + 1);
ExpectedInts[I] = static_cast<int32_t>(I + 1);
}

auto Op = createComputeOp(ElementAccessShader, "cs_6_10", "UAV(u0), UAV(u1)",
Args.c_str());
addUAVBuffer(Op.get(), "Input", InputBufSize, false, "byname");
addUAVBuffer(Op.get(), "Output", OutputBufSize, true);
addRootUAV(Op.get(), 0, "Input");
addRootUAV(Op.get(), 1, "Output");

auto Result =
runShaderOp(Device, DxcSupport, std::move(Op),
[&](LPCSTR Name, std::vector<BYTE> &Data, st::ShaderOp *) {
if (_stricmp(Name, "Input") != 0)
return;

switch (Params.CompType) {
case ComponentType::F32: {
float *Ptr = reinterpret_cast<float *>(Data.data());
for (size_t I = 0; I < NumElements; I++)
Ptr[I] = static_cast<float>(I + 1);
break;
}
case ComponentType::I32: {
int32_t *Ptr = reinterpret_cast<int32_t *>(Data.data());
for (size_t I = 0; I < NumElements; I++)
Ptr[I] = static_cast<int32_t>(I + 1);
break;
}
default:
VERIFY_IS_TRUE(false, "Saw unsupported component type");
break;
}
});

MappedData OutData;
Result->Test->GetReadBackData("Output", &OutData);
const BYTE *Out = static_cast<const BYTE *>(OutData.data());

std::vector<float> ActualFloats(NumElements);
std::vector<int32_t> ActualInts(NumElements);
for (size_t I = 0; I < NumElements; ++I) {
switch (Params.CompType) {
case ComponentType::F32: {
float Actual;
memcpy(&Actual, &Out[I * ElementSize], ElementSize);
ActualFloats[I] = Actual;
break;
}
case ComponentType::I32: {
int32_t Actual;
memcpy(&Actual, &Out[I * ElementSize], ElementSize);
ActualInts[I] = Actual;
break;
}
default:
VERIFY_IS_TRUE(false, "Saw unsupported component type");
break;
}
}

// Verify element values match input data.
switch (Params.CompType) {
case ComponentType::F32:
VERIFY_IS_TRUE(verifyFloatBuffer(ActualFloats.data(), ExpectedFloats.data(),
NumElements, Verbose));
break;
case ComponentType::I32:
VERIFY_IS_TRUE(verifyIntBuffer(ActualInts.data(), ExpectedInts.data(),
NumElements, Verbose));
break;
default:
VERIFY_IS_TRUE(false, "Saw unsupported component type");
break;
}

// The sum of the values returned by Length across all threads must be
// greater than or equal to the total number of matrix elements
size_t MatrixEndOffset = NumElements * ElementSize;
size_t LengthValuesEnd = MatrixEndOffset + (NumThreads * sizeof(uint32_t));
size_t TotalLength = 0;
for (size_t I = MatrixEndOffset; I < LengthValuesEnd;
I = I + sizeof(uint32_t)) {
uint32_t Length;
memcpy(&Length, &Out[I], sizeof(uint32_t));
TotalLength += Length;
}
VERIFY_IS_TRUE(TotalLength >= NumElements,
"Sum of all lengths must be gte num elements");
}

void DxilConf_SM610_LinAlg::ElementAccess_Wave_F32() {
MatrixParams Params = {};
Params.CompType = ComponentType::F32;
Params.M = 16;
Params.N = 16;
Params.Use = MatrixUse::Accumulator;
Params.Scope = MatrixScope::Wave;
Params.Layout = LinalgMatrixLayout::RowMajor;
Params.NumThreads = 4;
Params.Enable16Bit = false;
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging);
}

void DxilConf_SM610_LinAlg::ElementAccess_Wave_I32() {
MatrixParams Params = {};
Params.CompType = ComponentType::I32;
Params.M = 16;
Params.N = 16;
Params.Use = MatrixUse::Accumulator;
Params.Scope = MatrixScope::Wave;
Params.Layout = LinalgMatrixLayout::RowMajor;
Params.NumThreads = 4;
Params.Enable16Bit = false;
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging);
}

} // namespace LinAlg
Loading