Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 60 additions & 36 deletions tools/clang/unittests/HLSLExec/LinAlgTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,37 +228,58 @@ static bool fillInputBuffer(LPCSTR Name, std::vector<BYTE> &Data,
return false;
}

static VariantCompType makeExpected(ComponentType CompType, size_t NumElements,
float StartingVal, bool Increment) {
static VariantCompType makeExpected(ComponentType CompType, int32_t M,
int32_t N, float StartingVal,
bool Increment = true,
bool Transpose = false) {
const size_t NumElements = M * N;
std::vector<float> Floats(NumElements);
std::vector<int32_t> Ints(NumElements);
std::vector<HLSLHalf_t> Halfs(NumElements);

for (size_t I = 0; I < M; ++I) {
for (size_t J = 0; J < N; ++J) {
size_t Value = I * M + J;
size_t Idx = Transpose ? J * N + I : Value;
switch (CompType) {
case ComponentType::F32:
Floats[Idx] = StartingVal + static_cast<float>(Increment ? Value : 0);
break;
case ComponentType::I32:
VERIFY_IS_TRUE(StartingVal < static_cast<float>(
std::numeric_limits<int32_t>::max()),
"Value too large to cast to int32_t");
VERIFY_IS_TRUE(StartingVal > static_cast<float>(
std::numeric_limits<int32_t>::min()),
"Value too small to cast to int32_t");
Ints[Idx] = static_cast<int32_t>(StartingVal) +
static_cast<int32_t>(Increment ? Value : 0);
break;
case ComponentType::F16: {
// Downcasting is safe here since HLSLHalf_t will clamp if F is too
// large.
float F = StartingVal + static_cast<float>(Increment ? Value : 0);
Halfs[Idx] = HLSLHalf_t(F);
break;
}
Comment thread
V-FEXrt marked this conversation as resolved.
default:
VERIFY_IS_TRUE(false, "Unable to fill unexpected ComponentType");
break;
}
}
}

switch (CompType) {
case ComponentType::F32: {
std::vector<float> Floats(NumElements);
for (size_t I = 0; I < NumElements; I++)
Floats[I] = StartingVal + static_cast<float>(Increment ? I : 0);
case ComponentType::F32:
return Floats;
}
case ComponentType::I32: {
DXASSERT(StartingVal < static_cast<float>(INT_MAX),
"Value too large to cast to int32_t");
std::vector<int32_t> Ints(NumElements);
for (size_t I = 0; I < NumElements; I++)
Ints[I] = static_cast<int32_t>(StartingVal) +
static_cast<int32_t>(Increment ? I : 0);
case ComponentType::I32:
return Ints;
}
case ComponentType::F16: {
std::vector<HLSLHalf_t> Halfs(NumElements);
for (size_t I = 0; I < NumElements; I++) {
// Downcasting is safe here since HLSLHalf_t will clamp if F is too large.
float F = StartingVal + static_cast<float>(Increment ? I : 0);
Halfs[I] = HLSLHalf_t(F);
}
case ComponentType::F16:
return Halfs;
default:
VERIFY_IS_TRUE(false, "Unable to fill unexpected ComponentType");
return Floats;
}
}

DXASSERT(false, "Unable to fill unexpected ComponentType");
return std::vector<float>();
}

static void logCompiledButSkipping() {
Expand Down Expand Up @@ -384,6 +405,7 @@ static const char LoadStoreShader[] = R"(
RWByteAddressBuffer Output : register(u1);

#ifndef EMULATE_TEST
[WaveSize(4, 64)]
[numthreads(NUMTHREADS, 1, 1)]
void main() {
__builtin_LinAlgMatrix
Expand Down Expand Up @@ -429,7 +451,7 @@ static void runLoadStoreRoundtrip(ID3D12Device *Device,
return;
}

auto Expected = makeExpected(Params.CompType, NumElements, 1, true);
auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1);
Comment thread
bob80905 marked this conversation as resolved.

// Construct the ShaderOp: two UAV buffers, load from one, store to other.
auto Op = createComputeOp(LoadStoreShader, Target.c_str(), "UAV(u0), UAV(u1)",
Expand Down Expand Up @@ -463,7 +485,7 @@ void DxilConf_SM610_LinAlg::LoadStoreRoundtrip_Wave_16x16_F16() {
Params.Use = MatrixUse::A;
Params.Scope = MatrixScope::Wave;
Params.Layout = LinalgMatrixLayout::RowMajor;
Params.NumThreads = 4;
Params.NumThreads = 64;
Params.Enable16Bit = true;
Params.EmulateTest = EmulateTest;
runLoadStoreRoundtrip(D3DDevice, DxcSupport, Params, VerboseLogging,
Expand All @@ -474,6 +496,7 @@ static const char SplatStoreShader[] = R"(
RWByteAddressBuffer Output : register(u0);

#ifndef EMULATE_TEST
[WaveSize(4, 64)]
[numthreads(NUMTHREADS, 1, 1)]
void main() {
__builtin_LinAlgMatrix
Expand Down Expand Up @@ -517,7 +540,8 @@ static void runSplatStore(ID3D12Device *Device,
return;
}

auto Expected = makeExpected(Params.CompType, NumElements, FillValue, false);
auto Expected =
makeExpected(Params.CompType, Params.M, Params.N, FillValue, false);

auto Op = createComputeOp(SplatStoreShader, Target.c_str(), "UAV(u0)",
Args.c_str());
Expand All @@ -541,7 +565,7 @@ void DxilConf_SM610_LinAlg::SplatStore_Wave_16x16_F16() {
Params.Use = MatrixUse::Accumulator;
Params.Scope = MatrixScope::Wave;
Params.Layout = LinalgMatrixLayout::RowMajor;
Params.NumThreads = 4;
Params.NumThreads = 64;
Params.Enable16Bit = true;
Params.EmulateTest = EmulateTest;
runSplatStore(D3DDevice, DxcSupport, Params, 42.0f, VerboseLogging,
Expand All @@ -553,11 +577,13 @@ static const char ElementAccessShader[] = R"(
RWByteAddressBuffer Output : register(u1);

// flatten the 2D index into a 1D index then scale by element size
// Always store row-major and work it out in the test runner
uint coordToByteOffset(uint2 coord) {
return (coord.x * MAJOR_DIM + coord.y) * ELEM_SIZE;
return (coord.y * N_DIM + coord.x) * ELEM_SIZE;
}

#ifndef EMULATE_TEST
[WaveSize(4, 64)]
Comment thread
V-FEXrt marked this conversation as resolved.
Comment thread
bob80905 marked this conversation as resolved.
[numthreads(NUMTHREADS, 1, 1)]
void main(uint threadIndex : SV_GroupIndex) {
__builtin_LinAlgMatrix
Expand Down Expand Up @@ -605,8 +631,7 @@ static void runElementAccess(ID3D12Device *Device,
const size_t NumThreads = Params.NumThreads;
const size_t InputBufSize = Params.totalBytes();
const size_t ElementSize = elementSize(Params.CompType);
const size_t MajorDim =
Params.Layout == LinalgMatrixLayout::RowMajor ? Params.M : Params.N;

// Output: ElementSize bytes per element
// 1 element for each mat idx
// 1 uint for each thread's length
Expand All @@ -618,7 +643,6 @@ static void runElementAccess(ID3D12Device *Device,
Target = "cs_6_8";

std::stringstream ExtraDefs;
ExtraDefs << " -DMAJOR_DIM=" << MajorDim;
std::string Args = buildCompilerArgs(Params, ExtraDefs.str().c_str());

compileShader(DxcSupport, ElementAccessShader, Target.c_str(), Args, Verbose);
Expand All @@ -628,7 +652,7 @@ static void runElementAccess(ID3D12Device *Device,
return;
}

auto Expected = makeExpected(Params.CompType, NumElements, 1, true);
auto Expected = makeExpected(Params.CompType, Params.M, Params.N, 1);

auto Op = createComputeOp(ElementAccessShader, Target.c_str(),
"UAV(u0), UAV(u1)", Args.c_str());
Expand Down Expand Up @@ -674,7 +698,7 @@ void DxilConf_SM610_LinAlg::ElementAccess_Wave_16x16_F16() {
Params.Use = MatrixUse::Accumulator;
Params.Scope = MatrixScope::Wave;
Params.Layout = LinalgMatrixLayout::RowMajor;
Params.NumThreads = 4;
Params.NumThreads = 64;
Comment thread
V-FEXrt marked this conversation as resolved.
Params.Enable16Bit = true;
Params.EmulateTest = EmulateTest;
runElementAccess(D3DDevice, DxcSupport, Params, VerboseLogging, CompileOnly);
Expand Down