@@ -12149,6 +12149,14 @@ void ExecutionTest::runCoopVecMulTestConfig(
1214912149 {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true},
1215012150 {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false},
1215112151 {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true},
12152+ {17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false},
12153+ {17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true},
12154+ {17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false},
12155+ {17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true},
12156+ {1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false},
12157+ {1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true},
12158+ {1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, false},
12159+ {1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR, true},
1215212160 {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false},
1215312161 {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true},
1215412162 {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false},
@@ -12157,6 +12165,14 @@ void ExecutionTest::runCoopVecMulTestConfig(
1215712165 {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true},
1215812166 {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false},
1215912167 {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true},
12168+ {17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false},
12169+ {17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true},
12170+ {17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false},
12171+ {17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true},
12172+ {1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false},
12173+ {1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true},
12174+ {1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, false},
12175+ {1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR, true},
1216012176 {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false},
1216112177 {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true},
1216212178 {16, 16, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false},
@@ -12165,6 +12181,14 @@ void ExecutionTest::runCoopVecMulTestConfig(
1216512181 {32, 8, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true},
1216612182 {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false},
1216712183 {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true},
12184+ {17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false},
12185+ {17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true},
12186+ {17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false},
12187+ {17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true},
12188+ {1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false},
12189+ {1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true},
12190+ {1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, false},
12191+ {1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_MUL_OPTIMAL, true},
1216812192 {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
1216912193 false},
1217012194 {16, 16, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
@@ -12181,6 +12205,22 @@ void ExecutionTest::runCoopVecMulTestConfig(
1218112205 false},
1218212206 {32, 8, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
1218312207 true},
12208+ {17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
12209+ false},
12210+ {17, 63, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
12211+ true},
12212+ {17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
12213+ false},
12214+ {17, 63, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
12215+ true},
12216+ {1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
12217+ false},
12218+ {1, 1, 16, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
12219+ true},
12220+ {1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
12221+ false},
12222+ {1, 1, 32, 1, D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_OUTER_PRODUCT_OPTIMAL,
12223+ true},
1218412224 };
1218512225
1218612226 for (auto Config : TestConfigs) {
@@ -12280,18 +12320,15 @@ void ExecutionTest::runCoopVecMulSubtest(
1228012320 // FIXME: This does not capture all cases, but is sufficient for the preview
1228112321 // feature set
1228212322 if (MulProps.MatrixInterpretation == D3D12_LINEAR_ALGEBRA_DATATYPE_SINT8) {
12283- int32_t *InputBiasI32 = (int32_t *)InputBias.getBuffer();
12284- float *InputVectorF32 = (float *)InputVector.getBuffer();
12285-
1228612323 for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) {
12324+ int32_t *InputBiasI32 = InputBias.getVector<int32_t>(0);
1228712325 for (int OutputIdx = 0; OutputIdx < Config.OutputPerThread; ++OutputIdx) {
1228812326 int Acc = 0;
1228912327
1229012328 for (int InputIdx = 0; InputIdx < Config.InputPerThread; ++InputIdx) {
1229112329 int InputElem;
1229212330 if (MulProps.InputType == D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT32) {
12293- InputElem = (int)
12294- InputVectorF32[ThreadIdx * Config.InputPerThread + InputIdx];
12331+ InputElem = (int)InputVector.getVector<float>(ThreadIdx)[InputIdx];
1229512332 } else {
1229612333 InputElem = InputVector.getVector<int8_t>(ThreadIdx)[InputIdx];
1229712334 }
@@ -12315,22 +12352,21 @@ void ExecutionTest::runCoopVecMulSubtest(
1231512352 D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E4M3 ||
1231612353 MulProps.MatrixInterpretation ==
1231712354 D3D12_LINEAR_ALGEBRA_DATATYPE_FLOAT_E5M2) {
12318- DirectX::PackedVector::HALF *InputVectorFP16 =
12319- (DirectX::PackedVector::HALF *)InputVector.getBuffer();
12320- DirectX::PackedVector::HALF *InputBiasFP16 =
12321- (DirectX::PackedVector::HALF *)InputBias.getBuffer();
12322-
1232312355 // The CPU reference matrix is float
1232412356 std::vector<float> InputMatrixFP32(InputMatrix.size() / sizeof(float));
1232512357 std::memcpy(InputMatrixFP32.data(), InputMatrix.data(), InputMatrix.size());
1232612358
1232712359 for (int ThreadIdx = 0; ThreadIdx < Config.NumThreads; ++ThreadIdx) {
12360+ DirectX::PackedVector::HALF *InputVectorFP16 =
12361+ InputVector.getVector<DirectX::PackedVector::HALF>(ThreadIdx);
12362+ DirectX::PackedVector::HALF *InputBiasFP16 =
12363+ InputBias.getVector<DirectX::PackedVector::HALF>(0);
1232812364 for (int OutputIdx = 0; OutputIdx < Config.OutputPerThread; ++OutputIdx) {
1232912365 float Acc = 0;
1233012366
1233112367 for (int InputIdx = 0; InputIdx < Config.InputPerThread; ++InputIdx) {
12332- float const InputElem = ConvertFloat16ToFloat32(
12333- InputVectorFP16[ThreadIdx * Config.InputPerThread + InputIdx]);
12368+ float const InputElem =
12369+ ConvertFloat16ToFloat32( InputVectorFP16[InputIdx]);
1233412370 float const MatrixElem =
1233512371 InputMatrixFP32[OutputIdx * Config.InputPerThread + InputIdx];
1233612372 Acc += InputElem * MatrixElem;
@@ -12365,7 +12401,7 @@ void main(uint threadIdx : SV_GroupThreadID)
1236512401 using namespace dx::linalg;
1236612402
1236712403 uint inputOffset = (threadIdx * INPUT_VECTOR_STRIDE);
12368- vector<INPUT_DATA_TYPE, INPUT_PER_THREAD / INPUT_DIVISOR > input = InputVector.Load<vector<INPUT_DATA_TYPE, INPUT_PER_THREAD / INPUT_DIVISOR > >(inputOffset);
12404+ vector<INPUT_DATA_TYPE, INPUT_VECTOR_NUM_ELEMENTS > input = InputVector.Load<vector<INPUT_DATA_TYPE, INPUT_VECTOR_NUM_ELEMENTS > >(inputOffset);
1236912405
1237012406 MatrixRef<MATRIX_DATA_TYPE_ENUM, OUTPUT_PER_THREAD, INPUT_PER_THREAD, HLSL_MATRIX_LAYOUT, /*transpose*/false> mat = { InputMatrix, 0, STRIDE };
1237112407
@@ -12439,8 +12475,9 @@ void main(uint threadIdx : SV_GroupThreadID)
1243912475 auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride);
1244012476 auto InputDataTypeDefine =
1244112477 CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType);
12442- auto InputDivisorDefine =
12443- CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor);
12478+ auto InputDivisorDefine = CreateDefineFromInt(
12479+ L"INPUT_VECTOR_NUM_ELEMENTS",
12480+ (Config.InputPerThread + InputDivisor - 1) / InputDivisor);
1244412481 auto AccumDataTypeDefine =
1244512482 CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType);
1244612483 auto InputInterpretationEnumDefine = CreateDefineFromString(
@@ -12596,11 +12633,12 @@ void main(uint threadIdx : SV_GroupThreadID)
1259612633 &ConvertInfo.DestInfo);
1259712634 }
1259812635
12636+ int SRVSize = (ConvertInfo.DestInfo.DestSize + 15) / 16 * 16;
12637+
1259912638 // Create resource to hold matrix copy
12600- CreateTestResources(
12601- D3DDevice, CommandList, nullptr, 0,
12602- CD3DX12_RESOURCE_DESC::Buffer(ConvertInfo.DestInfo.DestSize),
12603- &ConvertedMatrixResource, nullptr);
12639+ CreateTestResources(D3DDevice, CommandList, nullptr, SRVSize,
12640+ CD3DX12_RESOURCE_DESC::Buffer(SRVSize),
12641+ &ConvertedMatrixResource, nullptr);
1260412642
1260512643 // Set up data descriptors
1260612644 ConvertInfo.DataDesc.DestVA =
@@ -12613,13 +12651,7 @@ void main(uint threadIdx : SV_GroupThreadID)
1261312651 __uuidof(ID3D12GraphicsCommandList11), (void **)&CommandList11));
1261412652 CommandList11->ConvertLinearAlgebraMatrix(&ConvertInfo, 1);
1261512653
12616- // This increments baseHandle
12617- if ((ConvertInfo.DestInfo.DestSize % 4) != 0) {
12618- WEX::Logging::Log::Error(L"DestSize is not aligned to 4 bytes");
12619- return;
12620- }
12621- CreateRawSRV(D3DDevice, BaseHandle,
12622- ConvertInfo.DestInfo.DestSize / sizeof(int32_t),
12654+ CreateRawSRV(D3DDevice, BaseHandle, SRVSize / sizeof(int32_t),
1262312655 ConvertedMatrixResource);
1262412656 }
1262512657
0 commit comments