@@ -814,7 +814,7 @@ class ExecutionTest {
814814 void runCoopVecOuterProductSubtest(
815815 ID3D12Device *D3DDevice,
816816 D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps,
817- CoopVecOuterProductSubtestConfig &Config);
817+ CoopVecOuterProductSubtestConfig &Config, bool RunCompute );
818818
819819#endif // HAVE_COOPVEC_API
820820
@@ -12914,29 +12914,41 @@ void ExecutionTest::runCoopVecOuterProductTestConfig(
1291412914 continue;
1291512915 }
1291612916
12917- runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config);
12917+ // Run once in compute, then once in graphics (pixel shader)
12918+ runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config, true);
12919+ runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config, false);
1291812920 }
1291912921}
1292012922
1292112923void ExecutionTest::runCoopVecOuterProductSubtest(
1292212924 ID3D12Device *D3DDevice,
1292312925 D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps,
12924- CoopVecOuterProductSubtestConfig &Config) {
12926+ CoopVecOuterProductSubtestConfig &Config, bool RunCompute ) {
1292512927
1292612928 LogCommentFmt(
12927- L"Running test for DimM: %d, DimN: %d, NumThreads: %d, MatrixLayout: %s",
12929+ L"Running test for DimM: %d, DimN: %d, NumThreads: %d, MatrixLayout: %s, "
12930+ L"Stage: %s",
1292812931 Config.DimM, Config.DimN, Config.NumThreads,
12929- CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str());
12932+ CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str(),
12933+ RunCompute ? L"Compute" : L"Pixel");
1293012934
1293112935 // Create root signature with a single root entry for all SRVs and UAVs
1293212936 CComPtr<ID3D12RootSignature> RootSignature;
1293312937 {
12934- CD3DX12_DESCRIPTOR_RANGE ranges[2];
12935- ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 0,
12936- 0); // InputVector1, InputVector2
12937- ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); // AccumMatrix
12938- CreateRootSignatureFromRanges(D3DDevice, &RootSignature, ranges, 2, nullptr,
12939- 0);
12938+ CD3DX12_DESCRIPTOR_RANGE Ranges[2];
12939+ Ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 0, 0);
12940+ Ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0);
12941+
12942+ CD3DX12_ROOT_PARAMETER RootParams[2];
12943+ RootParams[0].InitAsDescriptorTable(_countof(Ranges), Ranges,
12944+ D3D12_SHADER_VISIBILITY_ALL);
12945+ RootParams[1].InitAsUnorderedAccessView(/* register */ 10, /* space */ 0,
12946+ D3D12_SHADER_VISIBILITY_ALL);
12947+
12948+ CD3DX12_ROOT_SIGNATURE_DESC RootSignatureDesc;
12949+ RootSignatureDesc.Init(_countof(RootParams), RootParams, 0, nullptr,
12950+ D3D12_ROOT_SIGNATURE_FLAG_NONE);
12951+ CreateRootSignatureFromDesc(D3DDevice, &RootSignatureDesc, &RootSignature);
1294012952 }
1294112953
1294212954 // Create descriptor heap with space for 3 descriptors: 2 SRVs and 1 UAV
@@ -13028,17 +13040,17 @@ void ExecutionTest::runCoopVecOuterProductSubtest(
1302813040
1302913041 // Create a compute pipeline state object.
1303013042 CComPtr<ID3D12PipelineState> ComputePipelineState;
13031- {
13032- std::string ShaderSource = R"(
13043+
13044+ std::string ShaderSource = R"(
1303313045#include "dx/linalg.h"
1303413046
1303513047ByteAddressBuffer InputVector1 : register(t0);
1303613048ByteAddressBuffer InputVector2 : register(t1);
1303713049RWByteAddressBuffer AccumMatrix : register(u0);
1303813050
13039- [shader("compute")]
13040- [numthreads(NUM_THREADS, 1, 1)]
13041- void main (uint threadIdx : SV_GroupThreadID )
13051+ RWStructuredBuffer<uint> AtomicCounter : register(u10);
13052+
13053+ void RunCoopVecTest (uint threadIdx)
1304213054{
1304313055 using namespace dx::linalg;
1304413056
@@ -13053,94 +13065,142 @@ void main(uint threadIdx : SV_GroupThreadID)
1305313065
1305413066 OuterProductAccumulate(input1, input2, mat);
1305513067}
13056- )";
1305713068
13058- auto CreateDefineFromInt = [](const wchar_t *Name, int Value) {
13059- std::wstringstream Stream;
13060- Stream << L"-D" << Name << L"=" << Value;
13061- return Stream.str();
13062- };
13069+ [shader("compute")]
13070+ [numthreads(NUM_THREADS, 1, 1)]
13071+ void main(uint threadIdx : SV_GroupThreadID)
13072+ {
13073+ RunCoopVecTest(threadIdx);
13074+ }
1306313075
13064- auto CreateDefineFromString = [](const wchar_t *Name,
13065- const wchar_t *Value) {
13066- std::wstringstream Stream;
13067- Stream << L"-D" << Name << L"=" << Value;
13068- return Stream.str();
13069- };
13076+ float4 vs_main(uint vid : SV_VertexID) : SV_Position {
13077+ switch (vid) {
13078+ case 0:
13079+ return float4(-1, 1, 0, 0);
13080+ case 1:
13081+ return float4(3, 1, 0, 0);
13082+ case 2:
13083+ return float4(-1, -3, 0, 0);
13084+ }
13085+ return float4(0, 0, 0, 0);
13086+ }
1307013087
13071- int Stride = 0;
13072- const std::wstring HlslMatrixLayout =
13073- CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout);
13074- int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType(
13075- AccumulateProps.AccumulationType);
13076- switch (Config.MatrixLayout) {
13077- case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR:
13078- Stride = Config.DimN * StrideMultiplier;
13079- break;
13080- case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR:
13081- Stride = Config.DimM * StrideMultiplier;
13082- break;
13083- }
13088+ float4 ps_main() : SV_Target {
13089+ uint threadIdx;
13090+ InterlockedAdd(AtomicCounter[0], 1, threadIdx);
13091+ RunCoopVecTest(threadIdx);
13092+ return float4(1, 1, 1, 1);
13093+ }
13094+ )";
1308413095
13085- const int InputDivisor =
13086- CoopVecHelpers::GetNumPackedElementsForInputDataType(
13087- AccumulateProps.InputType);
13088- const std::wstring InputDataType =
13089- CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.InputType);
13090- const std::wstring AccumDataType =
13091- CoopVecHelpers::GetHlslDataTypeForDataType(
13092- AccumulateProps.AccumulationType);
13093- const std::wstring MatrixDataTypeEnum =
13094- CoopVecHelpers::GetHlslInterpretationForDataType(
13095- AccumulateProps.AccumulationType);
13096- const std::wstring InputInterpretationEnum =
13097- CoopVecHelpers::GetHlslInterpretationForDataType(
13098- AccumulateProps.InputType);
13099-
13100- auto DimMDefine = CreateDefineFromInt(L"DIM_M", Config.DimM);
13101- auto DimNDefine = CreateDefineFromInt(L"DIM_N", Config.DimN);
13102- auto NumThreadsDefine =
13103- CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads);
13104- auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride);
13105- auto InputDataTypeDefine =
13106- CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType.c_str());
13107- auto InputDivisorDefine =
13108- CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor);
13109- auto AccumDataTypeDefine =
13110- CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType.c_str());
13111- auto InputInterpretationEnumDefine = CreateDefineFromString(
13112- L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum.c_str());
13113- auto HlslMatrixLayoutDefine =
13114- CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout.c_str());
13115- auto MatrixDataTypeEnumDefine = CreateDefineFromString(
13116- L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum.c_str());
13117- auto InputVector1StrideDefine = CreateDefineFromInt(
13118- L"INPUT_VECTOR_1_STRIDE", (int)InputVector1.getStride());
13119- auto InputVector2StrideDefine = CreateDefineFromInt(
13120- L"INPUT_VECTOR_2_STRIDE", (int)InputVector2.getStride());
13121-
13122- LPCWSTR Options[] = {
13123- L"-enable-16bit-types",
13124- DimMDefine.c_str(),
13125- DimNDefine.c_str(),
13126- NumThreadsDefine.c_str(),
13127- StrideDefine.c_str(),
13128- InputDataTypeDefine.c_str(),
13129- InputDivisorDefine.c_str(),
13130- AccumDataTypeDefine.c_str(),
13131- InputInterpretationEnumDefine.c_str(),
13132- HlslMatrixLayoutDefine.c_str(),
13133- MatrixDataTypeEnumDefine.c_str(),
13134- InputVector1StrideDefine.c_str(),
13135- InputVector2StrideDefine.c_str(),
13136- };
13096+ auto CreateDefineFromInt = [](const wchar_t *Name, int Value) {
13097+ std::wstringstream Stream;
13098+ Stream << L"-D" << Name << L"=" << Value;
13099+ return Stream.str();
13100+ };
13101+
13102+ auto CreateDefineFromString = [](const wchar_t *Name, const wchar_t *Value) {
13103+ std::wstringstream Stream;
13104+ Stream << L"-D" << Name << L"=" << Value;
13105+ return Stream.str();
13106+ };
13107+
13108+ int Stride = 0;
13109+ const std::wstring HlslMatrixLayout =
13110+ CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout);
13111+ int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType(
13112+ AccumulateProps.AccumulationType);
13113+ switch (Config.MatrixLayout) {
13114+ case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR:
13115+ Stride = Config.DimN * StrideMultiplier;
13116+ break;
13117+ case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR:
13118+ Stride = Config.DimM * StrideMultiplier;
13119+ break;
13120+ }
1313713121
13138- CComPtr<LinAlgHeaderIncludeHandler> IncludeHandler =
13139- new LinAlgHeaderIncludeHandler(m_support);
13122+ const int InputDivisor = CoopVecHelpers::GetNumPackedElementsForInputDataType(
13123+ AccumulateProps.InputType);
13124+ const std::wstring InputDataType =
13125+ CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.InputType);
13126+ const std::wstring AccumDataType = CoopVecHelpers::GetHlslDataTypeForDataType(
13127+ AccumulateProps.AccumulationType);
13128+ const std::wstring MatrixDataTypeEnum =
13129+ CoopVecHelpers::GetHlslInterpretationForDataType(
13130+ AccumulateProps.AccumulationType);
13131+ const std::wstring InputInterpretationEnum =
13132+ CoopVecHelpers::GetHlslInterpretationForDataType(
13133+ AccumulateProps.InputType);
1314013134
13135+ auto DimMDefine = CreateDefineFromInt(L"DIM_M", Config.DimM);
13136+ auto DimNDefine = CreateDefineFromInt(L"DIM_N", Config.DimN);
13137+ auto NumThreadsDefine =
13138+ CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads);
13139+ auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride);
13140+ auto InputDataTypeDefine =
13141+ CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType.c_str());
13142+ auto InputDivisorDefine = CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor);
13143+ auto AccumDataTypeDefine =
13144+ CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType.c_str());
13145+ auto InputInterpretationEnumDefine = CreateDefineFromString(
13146+ L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum.c_str());
13147+ auto HlslMatrixLayoutDefine =
13148+ CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout.c_str());
13149+ auto MatrixDataTypeEnumDefine = CreateDefineFromString(
13150+ L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum.c_str());
13151+ auto InputVector1StrideDefine = CreateDefineFromInt(
13152+ L"INPUT_VECTOR_1_STRIDE", (int)InputVector1.getStride());
13153+ auto InputVector2StrideDefine = CreateDefineFromInt(
13154+ L"INPUT_VECTOR_2_STRIDE", (int)InputVector2.getStride());
13155+
13156+ LPCWSTR Options[] = {
13157+ L"-enable-16bit-types",
13158+ DimMDefine.c_str(),
13159+ DimNDefine.c_str(),
13160+ NumThreadsDefine.c_str(),
13161+ StrideDefine.c_str(),
13162+ InputDataTypeDefine.c_str(),
13163+ InputDivisorDefine.c_str(),
13164+ AccumDataTypeDefine.c_str(),
13165+ InputInterpretationEnumDefine.c_str(),
13166+ HlslMatrixLayoutDefine.c_str(),
13167+ MatrixDataTypeEnumDefine.c_str(),
13168+ InputVector1StrideDefine.c_str(),
13169+ InputVector2StrideDefine.c_str(),
13170+ };
13171+
13172+ CComPtr<LinAlgHeaderIncludeHandler> IncludeHandler =
13173+ new LinAlgHeaderIncludeHandler(m_support);
13174+
13175+ if (RunCompute) {
1314113176 CreateComputePSO(D3DDevice, RootSignature, ShaderSource.c_str(), L"cs_6_9",
1314213177 &ComputePipelineState, Options, _countof(Options),
1314313178 IncludeHandler);
13179+ } else {
13180+ CComPtr<ID3DBlob> VertexShader;
13181+ CComPtr<ID3DBlob> PixelShader;
13182+
13183+ CompileFromText(ShaderSource.c_str(), L"vs_main", L"vs_6_9", &VertexShader,
13184+ Options, _countof(Options), IncludeHandler);
13185+ CompileFromText(ShaderSource.c_str(), L"ps_main", L"ps_6_9", &PixelShader,
13186+ Options, _countof(Options), IncludeHandler);
13187+
13188+ D3D12_GRAPHICS_PIPELINE_STATE_DESC PsoDesc = {};
13189+ // psoDesc.InputLayout;
13190+ PsoDesc.pRootSignature = RootSignature;
13191+ PsoDesc.VS = CD3DX12_SHADER_BYTECODE(VertexShader);
13192+ PsoDesc.PS = CD3DX12_SHADER_BYTECODE(PixelShader);
13193+ PsoDesc.RasterizerState = CD3DX12_RASTERIZER_DESC(D3D12_DEFAULT);
13194+ PsoDesc.BlendState = CD3DX12_BLEND_DESC(D3D12_DEFAULT);
13195+ PsoDesc.DepthStencilState.DepthEnable = FALSE;
13196+ PsoDesc.DepthStencilState.StencilEnable = FALSE;
13197+ PsoDesc.SampleMask = UINT_MAX;
13198+ PsoDesc.PrimitiveTopologyType = D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE;
13199+ PsoDesc.NumRenderTargets = 1;
13200+ PsoDesc.RTVFormats[0] = DXGI_FORMAT_R8G8B8A8_UNORM;
13201+ PsoDesc.SampleDesc.Count = 1;
13202+ VERIFY_SUCCEEDED(D3DDevice->CreateGraphicsPipelineState(
13203+ &PsoDesc, IID_PPV_ARGS(&ComputePipelineState)));
1314413204 }
1314513205
1314613206 // Create a command list for the compute shader.
@@ -13283,6 +13343,14 @@ void main(uint threadIdx : SV_GroupThreadID)
1328313343 ConvertedMatrixResource);
1328413344 }
1328513345
13346+ // Create resource for atomic counter
13347+ CComPtr<ID3D12Resource> AtomicCounterResource;
13348+ uint32_t AtomicCounterInit = 0;
13349+ CreateTestResources(D3DDevice, CommandList, &AtomicCounterInit,
13350+ sizeof(AtomicCounterInit),
13351+ CD3DX12_RESOURCE_DESC::Buffer(sizeof(AtomicCounterInit)),
13352+ &AtomicCounterResource, nullptr);
13353+
1328613354 CommandList->Close();
1328713355 ExecuteCommandList(CommandQueue, CommandList);
1328813356 WaitForSignal(CommandQueue, FO);
@@ -13294,10 +13362,54 @@ void main(uint threadIdx : SV_GroupThreadID)
1329413362 CD3DX12_GPU_DESCRIPTOR_HANDLE ResHandle(
1329513363 DescriptorHeap->GetGPUDescriptorHandleForHeapStart());
1329613364
13297- CommandList->SetComputeRootSignature(RootSignature);
13298- CommandList->SetComputeRootDescriptorTable(0, ResHandle);
13299- CommandList->SetPipelineState(ComputePipelineState);
13300- CommandList->Dispatch(1, 1, 1);
13365+ CComPtr<ID3D12DescriptorHeap> RtvHeap;
13366+ CComPtr<ID3D12Resource> RenderTarget;
13367+ CComPtr<ID3D12Resource> RenderTargetRead;
13368+
13369+ if (RunCompute) {
13370+ CommandList->SetComputeRootSignature(RootSignature);
13371+ CommandList->SetComputeRootDescriptorTable(0, ResHandle);
13372+ CommandList->SetPipelineState(ComputePipelineState);
13373+ CommandList->Dispatch(1, 1, 1);
13374+ } else {
13375+ UINT FrameCount = 1;
13376+ UINT RtvDescSize = 0;
13377+ CreateRtvDescriptorHeap(D3DDevice, FrameCount, &RtvHeap, &RtvDescSize);
13378+ CreateRenderTargetAndReadback(D3DDevice, RtvHeap, 100, 100, &RenderTarget,
13379+ &RenderTargetRead);
13380+
13381+ D3D12_RESOURCE_DESC RtDesc = RenderTarget->GetDesc();
13382+ D3D12_VIEWPORT Viewport;
13383+ D3D12_RECT ScissorRect;
13384+
13385+ memset(&Viewport, 0, sizeof(Viewport));
13386+ Viewport.Height = (float)RtDesc.Height;
13387+ Viewport.Width = (float)RtDesc.Width;
13388+ Viewport.MaxDepth = 1.0f;
13389+ memset(&ScissorRect, 0, sizeof(ScissorRect));
13390+ ScissorRect.right = (long)RtDesc.Width;
13391+ ScissorRect.bottom = RtDesc.Height;
13392+ CommandList->SetGraphicsRootSignature(RootSignature);
13393+ CommandList->SetGraphicsRootDescriptorTable(0, ResHandle);
13394+ CommandList->SetGraphicsRootUnorderedAccessView(
13395+ 1, AtomicCounterResource->GetGPUVirtualAddress());
13396+ CommandList->RSSetViewports(1, &Viewport);
13397+ CommandList->RSSetScissorRects(1, &ScissorRect);
13398+
13399+ // Indicate that the buffer will be used as a render target.
13400+ RecordTransitionBarrier(CommandList, RenderTarget,
13401+ D3D12_RESOURCE_STATE_COPY_DEST,
13402+ D3D12_RESOURCE_STATE_RENDER_TARGET);
13403+
13404+ CD3DX12_CPU_DESCRIPTOR_HANDLE RtvHandle(
13405+ RtvHeap->GetCPUDescriptorHandleForHeapStart(), 0, RtvDescSize);
13406+ CommandList->OMSetRenderTargets(1, &RtvHandle, FALSE, nullptr);
13407+
13408+ CommandList->ClearRenderTargetView(RtvHandle, ClearColor, 0, nullptr);
13409+ CommandList->IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
13410+ CommandList->DrawInstanced(3, 1, 0, 0);
13411+ }
13412+
1330113413 CommandList->Close();
1330213414 ExecuteCommandList(CommandQueue, CommandList);
1330313415 WaitForSignal(CommandQueue, FO);
0 commit comments