Skip to content

Commit 721087a

Browse files
Support odd matrix/vector sizes
1 parent fb89081 commit 721087a

1 file changed

Lines changed: 58 additions & 26 deletions

File tree

tools/clang/unittests/HLSLExec/ExecutionTest.cpp

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12149,6 +12149,14 @@ void ExecutionTest::runCoopVecMulTestConfig(
1214912149
{32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true},
1215012150
{32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false},
1215112151
{32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true},
12152+
{17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false},
12153+
{17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true},
12154+
{17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false},
12155+
{17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true},
12156+
{1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false},
12157+
{1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true},
12158+
{1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false},
12159+
{1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true},
1215212160
{16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false},
1215312161
{16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true},
1215412162
{16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false},
@@ -12157,6 +12165,14 @@ void ExecutionTest::runCoopVecMulTestConfig(
1215712165
{32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true},
1215812166
{32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false},
1215912167
{32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true},
12168+
{17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false},
12169+
{17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true},
12170+
{17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false},
12171+
{17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true},
12172+
{1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false},
12173+
{1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true},
12174+
{1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false},
12175+
{1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true},
1216012176
{16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false},
1216112177
{16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true},
1216212178
{16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false},
@@ -12165,6 +12181,14 @@ void ExecutionTest::runCoopVecMulTestConfig(
1216512181
{32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true},
1216612182
{32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false},
1216712183
{32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true},
12184+
{17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false},
12185+
{17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true},
12186+
{17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false},
12187+
{17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true},
12188+
{1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false},
12189+
{1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true},
12190+
{1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false},
12191+
{1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true},
1216812192
{16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
1216912193
false},
1217012194
{16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
@@ -12181,6 +12205,22 @@ void ExecutionTest::runCoopVecMulTestConfig(
1218112205
false},
1218212206
{32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
1218312207
true},
12208+
{17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
12209+
false},
12210+
{17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
12211+
true},
12212+
{17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
12213+
false},
12214+
{17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
12215+
true},
12216+
{1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
12217+
false},
12218+
{1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
12219+
true},
12220+
{1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
12221+
false},
12222+
{1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
12223+
true},
1218412224
};
1218512225

1218612226
for (auto Config : TestConfigs) {
@@ -12280,18 +12320,15 @@ void ExecutionTest::runCoopVecMulSubtest(
1228012320
// FIXME: This does not capture all cases, but is sufficient for the preview
1228112321
// feature set
1228212322
if (MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8) {
12283-
int32_t *InputBiasI32 = (int32_t *)InputBias.getBuffer();
12284-
float *InputVectorF32 = (float *)InputVector.getBuffer();
12285-
1228612323
for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) {
12324+
int32_t *InputBiasI32 = InputBias.getVector<int32_t>(0);
1228712325
for (int OutputIdx = 0; OutputIdx < Config.OutputPerThread; ++OutputIdx) {
1228812326
int Acc = 0;
1228912327

1229012328
for (int InputIdx = 0; InputIdx < Config.InputPerThread; ++InputIdx) {
1229112329
int InputElem;
1229212330
if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) {
12293-
InputElem = (int)
12294-
InputVectorF32[ThreadIdx * Config.InputPerThread + InputIdx];
12331+
InputElem = (int)InputVector.getVector<float>(ThreadIdx)[InputIdx];
1229512332
} else {
1229612333
InputElem = InputVector.getVector<int8_t>(ThreadIdx)[InputIdx];
1229712334
}
@@ -12315,22 +12352,21 @@ void ExecutionTest::runCoopVecMulSubtest(
1231512352
D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 ||
1231612353
MulProps.MatrixInterpretation ==
1231712354
D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) {
12318-
DirectX::PackedVector::HALF *InputVectorFP16 =
12319-
(DirectX::PackedVector::HALF *)InputVector.getBuffer();
12320-
DirectX::PackedVector::HALF *InputBiasFP16 =
12321-
(DirectX::PackedVector::HALF *)InputBias.getBuffer();
12322-
1232312355
// The CPU reference matrix is float
1232412356
std::vector<float> InputMatrixFP32(InputMatrix.size() / sizeof(float));
1232512357
std::memcpy(InputMatrixFP32.data(), InputMatrix.data(), InputMatrix.size());
1232612358

1232712359
for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) {
12360+
DirectX::PackedVector::HALF *InputVectorFP16 =
12361+
InputVector.getVector<DirectX::PackedVector::HALF>(ThreadIdx);
12362+
DirectX::PackedVector::HALF *InputBiasFP16 =
12363+
InputBias.getVector<DirectX::PackedVector::HALF>(0);
1232812364
for (int OutputIdx = 0; OutputIdx < Config.OutputPerThread; ++OutputIdx) {
1232912365
float Acc = 0;
1233012366

1233112367
for (int InputIdx = 0; InputIdx < Config.InputPerThread; ++InputIdx) {
12332-
float const InputElem = ConvertFloat16ToFloat32(
12333-
InputVectorFP16[ThreadIdx * Config.InputPerThread + InputIdx]);
12368+
float const InputElem =
12369+
ConvertFloat16ToFloat32(InputVectorFP16[InputIdx]);
1233412370
float const MatrixElem =
1233512371
InputMatrixFP32[OutputIdx * Config.InputPerThread + InputIdx];
1233612372
Acc += InputElem * MatrixElem;
@@ -12365,7 +12401,7 @@ void main(uint threadIdx : SV_GroupThreadID)
1236512401
using namespace dx::linalg;
1236612402

1236712403
uint inputOffset = (threadIdx * INPUT_VECTOR_STRIDE);
12368-
vector<INPUT_DATA_TYPE, INPUT_PER_THREAD / INPUT_DIVISOR> input = InputVector.Load<vector<INPUT_DATA_TYPE, INPUT_PER_THREAD / INPUT_DIVISOR> >(inputOffset);
12404+
vector<INPUT_DATA_TYPE, INPUT_VECTOR_NUM_ELEMENTS> input = InputVector.Load<vector<INPUT_DATA_TYPE, INPUT_VECTOR_NUM_ELEMENTS> >(inputOffset);
1236912405

1237012406
MatrixRef<MATRIX_DATA_TYPE_ENUM, OUTPUT_PER_THREAD, INPUT_PER_THREAD, HLSL_MATRIX_LAYOUT, /*transpose*/false> mat = { InputMatrix, 0, STRIDE };
1237112407

@@ -12439,8 +12475,9 @@ void main(uint threadIdx : SV_GroupThreadID)
1243912475
auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride);
1244012476
auto InputDataTypeDefine =
1244112477
CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType);
12442-
auto InputDivisorDefine =
12443-
CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor);
12478+
auto InputDivisorDefine = CreateDefineFromInt(
12479+
L"INPUT_VECTOR_NUM_ELEMENTS",
12480+
(Config.InputPerThread + InputDivisor - 1) / InputDivisor);
1244412481
auto AccumDataTypeDefine =
1244512482
CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType);
1244612483
auto InputInterpretationEnumDefine = CreateDefineFromString(
@@ -12596,11 +12633,12 @@ void main(uint threadIdx : SV_GroupThreadID)
1259612633
&ConvertInfo.DestInfo);
1259712634
}
1259812635

12636+
int SRVSize = (ConvertInfo.DestInfo.DestSize + 15) / 16 * 16;
12637+
1259912638
// Create resource to hold matrix copy
12600-
CreateTestResources(
12601-
D3DDevice, CommandList, nullptr, 0,
12602-
CD3DX12_RESOURCE_DESC::Buffer(ConvertInfo.DestInfo.DestSize),
12603-
&ConvertedMatrixResource, nullptr);
12639+
CreateTestResources(D3DDevice, CommandList, nullptr, SRVSize,
12640+
CD3DX12_RESOURCE_DESC::Buffer(SRVSize),
12641+
&ConvertedMatrixResource, nullptr);
1260412642

1260512643
// Set up data descriptors
1260612644
ConvertInfo.DataDesc.DestVA =
@@ -12613,13 +12651,7 @@ void main(uint threadIdx : SV_GroupThreadID)
1261312651
__uuidof(ID3D12GraphicsCommandList11), (void **)&CommandList11));
1261412652
CommandList11->ConvertLinearAlgebraMatrix(&ConvertInfo, 1);
1261512653

12616-
// This increments baseHandle
12617-
if ((ConvertInfo.DestInfo.DestSize % 4) != 0) {
12618-
WEX::Logging::Log::Error(L"DestSize is not aligned to 4 bytes");
12619-
return;
12620-
}
12621-
CreateRawSRV(D3DDevice, BaseHandle,
12622-
ConvertInfo.DestInfo.DestSize / sizeof(int32_t),
12654+
CreateRawSRV(D3DDevice, BaseHandle, SRVSize / sizeof(int32_t),
1262312655
ConvertedMatrixResource);
1262412656
}
1262512657

0 commit comments

Comments
 (0)