@@ -814,7 +814,7 @@ class ExecutionTest {
814814 void runCoopVecOuterProductSubtest(
815815 ID3D12Device *D3DDevice,
816816 D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps,
817- CoopVecOuterProductSubtestConfig &Config);
817+ CoopVecOuterProductSubtestConfig &Config, bool RunCompute );
818818
819819#endif // HAVE_COOPVEC_API
820820
@@ -12913,29 +12913,41 @@ void ExecutionTest::runCoopVecOuterProductTestConfig(
1291312913 continue;
1291412914 }
1291512915
12916- runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config);
12916+ // Run once in compute, then once in graphics (pixel shader)
12917+ runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config, true);
12918+ runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config, false);
1291712919 }
1291812920}
1291912921
1292012922void ExecutionTest::runCoopVecOuterProductSubtest(
1292112923 ID3D12Device *D3DDevice,
1292212924 D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps,
12923- CoopVecOuterProductSubtestConfig &Config) {
12925+ CoopVecOuterProductSubtestConfig &Config, bool RunCompute ) {
1292412926
1292512927 LogCommentFmt(
12926- L"Running test for DimM: %d, DimN: %d, NumThreads: %d, MatrixLayout: %s",
12928+ L"Running test for DimM: %d, DimN: %d, NumThreads: %d, MatrixLayout: %s, "
12929+ L"Stage: %s",
1292712930 Config.DimM, Config.DimN, Config.NumThreads,
12928- CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str());
12931+ CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str(),
12932+ RunCompute ? L"Compute" : L"Pixel");
1292912933
1293012934 // Create root signature with a single root entry for all SRVs and UAVs
1293112935 CComPtr<ID3D12RootSignature> RootSignature;
1293212936 {
12933- CD3DX12_DESCRIPTOR_RANGE ranges[2];
12934- ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 0,
12935- 0); // InputVector1, InputVector2
12936- ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); // AccumMatrix
12937- CreateRootSignatureFromRanges(D3DDevice, &RootSignature, ranges, 2, nullptr,
12938- 0);
12937+ CD3DX12_DESCRIPTOR_RANGE Ranges[2];
12938+ Ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 0, 0);
12939+ Ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0);
12940+
12941+ CD3DX12_ROOT_PARAMETER RootParams[2];
12942+ RootParams[0].InitAsDescriptorTable(_countof(Ranges), Ranges,
12943+ D3D12_SHADER_VISIBILITY_ALL);
12944+ RootParams[1].InitAsUnorderedAccessView(/* register */ 10, /* space */ 0,
12945+ D3D12_SHADER_VISIBILITY_ALL);
12946+
12947+ CD3DX12_ROOT_SIGNATURE_DESC RootSignatureDesc;
12948+ RootSignatureDesc.Init(_countof(RootParams), RootParams, 0, nullptr,
12949+ D3D12_ROOT_SIGNATURE_FLAG_NONE);
12950+ CreateRootSignatureFromDesc(D3DDevice, &RootSignatureDesc, &RootSignature);
1293912951 }
1294012952
1294112953 // Create descriptor heap with space for 3 descriptors: 2 SRVs and 1 UAV
@@ -13027,17 +13039,17 @@ void ExecutionTest::runCoopVecOuterProductSubtest(
1302713039
1302813040 // Create a compute pipeline state object.
1302913041 CComPtr<ID3D12PipelineState> ComputePipelineState;
13030- {
13031- std::string ShaderSource = R"(
13042+
13043+ std::string ShaderSource = R"(
1303213044#include "dx/linalg.h"
1303313045
1303413046ByteAddressBuffer InputVector1 : register(t0);
1303513047ByteAddressBuffer InputVector2 : register(t1);
1303613048RWByteAddressBuffer AccumMatrix : register(u0);
1303713049
13038- [shader("compute")]
13039- [numthreads(NUM_THREADS, 1, 1)]
13040- void main (uint threadIdx : SV_GroupThreadID )
13050+ RWStructuredBuffer<uint> AtomicCounter : register(u10);
13051+
13052+ void RunCoopVecTest (uint threadIdx)
1304113053{
1304213054 using namespace dx::linalg;
1304313055
@@ -13052,94 +13064,142 @@ void main(uint threadIdx : SV_GroupThreadID)
1305213064
1305313065 OuterProductAccumulate(input1, input2, mat);
1305413066}
13055- )";
1305613067
13057- auto CreateDefineFromInt = [](const wchar_t *Name, int Value) {
13058- std::wstringstream Stream;
13059- Stream << L"-D" << Name << L"=" << Value;
13060- return Stream.str();
13061- };
13068+ [shader("compute")]
13069+ [numthreads(NUM_THREADS, 1, 1)]
13070+ void main(uint threadIdx : SV_GroupThreadID)
13071+ {
13072+ RunCoopVecTest(threadIdx);
13073+ }
1306213074
13063- auto CreateDefineFromString = [](const wchar_t *Name,
13064- const wchar_t *Value) {
13065- std::wstringstream Stream;
13066- Stream << L"-D" << Name << L"=" << Value;
13067- return Stream.str();
13068- };
13075+ float4 vs_main(uint vid : SV_VertexID) : SV_Position {
13076+ switch (vid) {
13077+ case 0:
13078+ return float4(-1, 1, 0, 0);
13079+ case 1:
13080+ return float4(3, 1, 0, 0);
13081+ case 2:
13082+ return float4(-1, -3, 0, 0);
13083+ }
13084+ return float4(0, 0, 0, 0);
13085+ }
1306913086
13070- int Stride = 0;
13071- const std::wstring HlslMatrixLayout =
13072- CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout);
13073- int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType(
13074- AccumulateProps.AccumulationType);
13075- switch (Config.MatrixLayout) {
13076- case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR:
13077- Stride = Config.DimN * StrideMultiplier;
13078- break;
13079- case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR:
13080- Stride = Config.DimM * StrideMultiplier;
13081- break;
13082- }
13087+ float4 ps_main() : SV_Target {
13088+ uint threadIdx;
13089+ InterlockedAdd(AtomicCounter[0], 1, threadIdx);
13090+ RunCoopVecTest(threadIdx);
13091+ return float4(1, 1, 1, 1);
13092+ }
13093+ )";
1308313094
13084- const int InputDivisor =
13085- CoopVecHelpers::GetNumPackedElementsForInputDataType(
13086- AccumulateProps.InputType);
13087- const std::wstring InputDataType =
13088- CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.InputType);
13089- const std::wstring AccumDataType =
13090- CoopVecHelpers::GetHlslDataTypeForDataType(
13091- AccumulateProps.AccumulationType);
13092- const std::wstring MatrixDataTypeEnum =
13093- CoopVecHelpers::GetHlslInterpretationForDataType(
13094- AccumulateProps.AccumulationType);
13095- const std::wstring InputInterpretationEnum =
13096- CoopVecHelpers::GetHlslInterpretationForDataType(
13097- AccumulateProps.InputType);
13098-
13099- auto DimMDefine = CreateDefineFromInt(L"DIM_M", Config.DimM);
13100- auto DimNDefine = CreateDefineFromInt(L"DIM_N", Config.DimN);
13101- auto NumThreadsDefine =
13102- CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads);
13103- auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride);
13104- auto InputDataTypeDefine =
13105- CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType.c_str());
13106- auto InputDivisorDefine =
13107- CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor);
13108- auto AccumDataTypeDefine =
13109- CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType.c_str());
13110- auto InputInterpretationEnumDefine = CreateDefineFromString(
13111- L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum.c_str());
13112- auto HlslMatrixLayoutDefine =
13113- CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout.c_str());
13114- auto MatrixDataTypeEnumDefine = CreateDefineFromString(
13115- L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum.c_str());
13116- auto InputVector1StrideDefine = CreateDefineFromInt(
13117- L"INPUT_VECTOR_1_STRIDE", (int)InputVector1.getStride());
13118- auto InputVector2StrideDefine = CreateDefineFromInt(
13119- L"INPUT_VECTOR_2_STRIDE", (int)InputVector2.getStride());
13120-
13121- LPCWSTR Options[] = {
13122- L"-enable-16bit-types",
13123- DimMDefine.c_str(),
13124- DimNDefine.c_str(),
13125- NumThreadsDefine.c_str(),
13126- StrideDefine.c_str(),
13127- InputDataTypeDefine.c_str(),
13128- InputDivisorDefine.c_str(),
13129- AccumDataTypeDefine.c_str(),
13130- InputInterpretationEnumDefine.c_str(),
13131- HlslMatrixLayoutDefine.c_str(),
13132- MatrixDataTypeEnumDefine.c_str(),
13133- InputVector1StrideDefine.c_str(),
13134- InputVector2StrideDefine.c_str(),
13135- };
13095+ auto CreateDefineFromInt = [](const wchar_t *Name, int Value) {
13096+ std::wstringstream Stream;
13097+ Stream << L"-D" << Name << L"=" << Value;
13098+ return Stream.str();
13099+ };
13100+
13101+ auto CreateDefineFromString = [](const wchar_t *Name, const wchar_t *Value) {
13102+ std::wstringstream Stream;
13103+ Stream << L"-D" << Name << L"=" << Value;
13104+ return Stream.str();
13105+ };
13106+
13107+ int Stride = 0;
13108+ const std::wstring HlslMatrixLayout =
13109+ CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout);
13110+ int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType(
13111+ AccumulateProps.AccumulationType);
13112+ switch (Config.MatrixLayout) {
13113+ case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR:
13114+ Stride = Config.DimN * StrideMultiplier;
13115+ break;
13116+ case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR:
13117+ Stride = Config.DimM * StrideMultiplier;
13118+ break;
13119+ }
1313613120
13137- CComPtr<LinAlgHeaderIncludeHandler> IncludeHandler =
13138- new LinAlgHeaderIncludeHandler(m_support);
13121+ const int InputDivisor = CoopVecHelpers::GetNumPackedElementsForInputDataType(
13122+ AccumulateProps.InputType);
13123+ const std::wstring InputDataType =
13124+ CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.InputType);
13125+ const std::wstring AccumDataType = CoopVecHelpers::GetHlslDataTypeForDataType(
13126+ AccumulateProps.AccumulationType);
13127+ const std::wstring MatrixDataTypeEnum =
13128+ CoopVecHelpers::GetHlslInterpretationForDataType(
13129+ AccumulateProps.AccumulationType);
13130+ const std::wstring InputInterpretationEnum =
13131+ CoopVecHelpers::GetHlslInterpretationForDataType(
13132+ AccumulateProps.InputType);
1313913133
13134+ auto DimMDefine = CreateDefineFromInt(L"DIM_M", Config.DimM);
13135+ auto DimNDefine = CreateDefineFromInt(L"DIM_N", Config.DimN);
13136+ auto NumThreadsDefine =
13137+ CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads);
13138+ auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride);
13139+ auto InputDataTypeDefine =
13140+ CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType.c_str());
13141+ auto InputDivisorDefine = CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor);
13142+ auto AccumDataTypeDefine =
13143+ CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType.c_str());
13144+ auto InputInterpretationEnumDefine = CreateDefineFromString(
13145+ L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum.c_str());
13146+ auto HlslMatrixLayoutDefine =
13147+ CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout.c_str());
13148+ auto MatrixDataTypeEnumDefine = CreateDefineFromString(
13149+ L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum.c_str());
13150+ auto InputVector1StrideDefine = CreateDefineFromInt(
13151+ L"INPUT_VECTOR_1_STRIDE", (int)InputVector1.getStride());
13152+ auto InputVector2StrideDefine = CreateDefineFromInt(
13153+ L"INPUT_VECTOR_2_STRIDE", (int)InputVector2.getStride());
13154+
13155+ LPCWSTR Options[] = {
13156+ L"-enable-16bit-types",
13157+ DimMDefine.c_str(),
13158+ DimNDefine.c_str(),
13159+ NumThreadsDefine.c_str(),
13160+ StrideDefine.c_str(),
13161+ InputDataTypeDefine.c_str(),
13162+ InputDivisorDefine.c_str(),
13163+ AccumDataTypeDefine.c_str(),
13164+ InputInterpretationEnumDefine.c_str(),
13165+ HlslMatrixLayoutDefine.c_str(),
13166+ MatrixDataTypeEnumDefine.c_str(),
13167+ InputVector1StrideDefine.c_str(),
13168+ InputVector2StrideDefine.c_str(),
13169+ };
13170+
13171+ CComPtr<LinAlgHeaderIncludeHandler> IncludeHandler =
13172+ new LinAlgHeaderIncludeHandler(m_support);
13173+
13174+ if (RunCompute) {
1314013175 CreateComputePSO(D3DDevice, RootSignature, ShaderSource.c_str(), L"cs_6_9",
1314113176 &ComputePipelineState, Options, _countof(Options),
1314213177 IncludeHandler);
13178+ } else {
13179+ CComPtr<ID3DBlob> VertexShader;
13180+ CComPtr<ID3DBlob> PixelShader;
13181+
13182+ CompileFromText(ShaderSource.c_str(), L"vs_main", L"vs_6_9", &VertexShader,
13183+ Options, _countof(Options), IncludeHandler);
13184+ CompileFromText(ShaderSource.c_str(), L"ps_main", L"ps_6_9", &PixelShader,
13185+ Options, _countof(Options), IncludeHandler);
13186+
13187+ D3D12_GRAPHICS_PIPELINE_STATE_DESC PsoDesc = {};
13188+ // psoDesc.InputLayout;
13189+ PsoDesc.pRootSignature = RootSignature;
13190+ PsoDesc.VS = CD3DX12_SHADER_BYTECODE(VertexShader);
13191+ PsoDesc.PS = CD3DX12_SHADER_BYTECODE(PixelShader);
13192+ PsoDesc.RasterizerState = CD3DX12_RASTERIZER_DESC(D3D12_DEFAULT);
13193+ PsoDesc.BlendState = CD3DX12_BLEND_DESC(D3D12_DEFAULT);
13194+ PsoDesc.DepthStencilState.DepthEnable = FALSE;
13195+ PsoDesc.DepthStencilState.StencilEnable = FALSE;
13196+ PsoDesc.SampleMask = UINT_MAX;
13197+ PsoDesc.PrimitiveTopologyType = D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE;
13198+ PsoDesc.NumRenderTargets = 1;
13199+ PsoDesc.RTVFormats[0] = DXGI_FORMAT_R8G8B8A8_UNORM;
13200+ PsoDesc.SampleDesc.Count = 1;
13201+ VERIFY_SUCCEEDED(D3DDevice->CreateGraphicsPipelineState(
13202+ &PsoDesc, IID_PPV_ARGS(&ComputePipelineState)));
1314313203 }
1314413204
1314513205 // Create a command list for the compute shader.
@@ -13282,6 +13342,14 @@ void main(uint threadIdx : SV_GroupThreadID)
1328213342 ConvertedMatrixResource);
1328313343 }
1328413344
13345+ // Create resource for atomic counter
13346+ CComPtr<ID3D12Resource> AtomicCounterResource;
13347+ uint32_t AtomicCounterInit = 0;
13348+ CreateTestResources(D3DDevice, CommandList, &AtomicCounterInit,
13349+ sizeof(AtomicCounterInit),
13350+ CD3DX12_RESOURCE_DESC::Buffer(sizeof(AtomicCounterInit)),
13351+ &AtomicCounterResource, nullptr);
13352+
1328513353 CommandList->Close();
1328613354 ExecuteCommandList(CommandQueue, CommandList);
1328713355 WaitForSignal(CommandQueue, FO);
@@ -13293,10 +13361,54 @@ void main(uint threadIdx : SV_GroupThreadID)
1329313361 CD3DX12_GPU_DESCRIPTOR_HANDLE ResHandle(
1329413362 DescriptorHeap->GetGPUDescriptorHandleForHeapStart());
1329513363
13296- CommandList->SetComputeRootSignature(RootSignature);
13297- CommandList->SetComputeRootDescriptorTable(0, ResHandle);
13298- CommandList->SetPipelineState(ComputePipelineState);
13299- CommandList->Dispatch(1, 1, 1);
13364+ CComPtr<ID3D12DescriptorHeap> RtvHeap;
13365+ CComPtr<ID3D12Resource> RenderTarget;
13366+ CComPtr<ID3D12Resource> RenderTargetRead;
13367+
13368+ if (RunCompute) {
13369+ CommandList->SetComputeRootSignature(RootSignature);
13370+ CommandList->SetComputeRootDescriptorTable(0, ResHandle);
13371+ CommandList->SetPipelineState(ComputePipelineState);
13372+ CommandList->Dispatch(1, 1, 1);
13373+ } else {
13374+ UINT FrameCount = 1;
13375+ UINT RtvDescSize = 0;
13376+ CreateRtvDescriptorHeap(D3DDevice, FrameCount, &RtvHeap, &RtvDescSize);
13377+ CreateRenderTargetAndReadback(D3DDevice, RtvHeap, 100, 100, &RenderTarget,
13378+ &RenderTargetRead);
13379+
13380+ D3D12_RESOURCE_DESC RtDesc = RenderTarget->GetDesc();
13381+ D3D12_VIEWPORT Viewport;
13382+ D3D12_RECT ScissorRect;
13383+
13384+ memset(&Viewport, 0, sizeof(Viewport));
13385+ Viewport.Height = (float)RtDesc.Height;
13386+ Viewport.Width = (float)RtDesc.Width;
13387+ Viewport.MaxDepth = 1.0f;
13388+ memset(&ScissorRect, 0, sizeof(ScissorRect));
13389+ ScissorRect.right = (long)RtDesc.Width;
13390+ ScissorRect.bottom = RtDesc.Height;
13391+ CommandList->SetGraphicsRootSignature(RootSignature);
13392+ CommandList->SetGraphicsRootDescriptorTable(0, ResHandle);
13393+ CommandList->SetGraphicsRootUnorderedAccessView(
13394+ 1, AtomicCounterResource->GetGPUVirtualAddress());
13395+ CommandList->RSSetViewports(1, &Viewport);
13396+ CommandList->RSSetScissorRects(1, &ScissorRect);
13397+
13398+ // Indicate that the buffer will be used as a render target.
13399+ RecordTransitionBarrier(CommandList, RenderTarget,
13400+ D3D12_RESOURCE_STATE_COPY_DEST,
13401+ D3D12_RESOURCE_STATE_RENDER_TARGET);
13402+
13403+ CD3DX12_CPU_DESCRIPTOR_HANDLE RtvHandle(
13404+ RtvHeap->GetCPUDescriptorHandleForHeapStart(), 0, RtvDescSize);
13405+ CommandList->OMSetRenderTargets(1, &RtvHandle, FALSE, nullptr);
13406+
13407+ CommandList->ClearRenderTargetView(RtvHandle, ClearColor, 0, nullptr);
13408+ CommandList->IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
13409+ CommandList->DrawInstanced(3, 1, 0, 0);
13410+ }
13411+
1330013412 CommandList->Close();
1330113413 ExecuteCommandList(CommandQueue, CommandList);
1330213414 WaitForSignal(CommandQueue, FO);
0 commit comments