Skip to content

Commit a92eda5

Browse files
Support pixel shaders in OuterProduct tests
1 parent 04ec196 commit a92eda5

1 file changed

Lines changed: 211 additions & 99 deletions

File tree

tools/clang/unittests/HLSLExec/ExecutionTest.cpp

Lines changed: 211 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ class ExecutionTest {
814814
void runCoopVecOuterProductSubtest(
815815
ID3D12Device *D3DDevice,
816816
D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps,
817-
CoopVecOuterProductSubtestConfig &Config);
817+
CoopVecOuterProductSubtestConfig &Config, bool RunCompute);
818818

819819
#endif // HAVE_COOPVEC_API
820820

@@ -12914,29 +12914,41 @@ void ExecutionTest::runCoopVecOuterProductTestConfig(
1291412914
continue;
1291512915
}
1291612916

12917-
runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config);
12917+
// Run once in compute, then once in graphics (pixel shader)
12918+
runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config, true);
12919+
runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config, false);
1291812920
}
1291912921
}
1292012922

1292112923
void ExecutionTest::runCoopVecOuterProductSubtest(
1292212924
ID3D12Device *D3DDevice,
1292312925
D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps,
12924-
CoopVecOuterProductSubtestConfig &Config) {
12926+
CoopVecOuterProductSubtestConfig &Config, bool RunCompute) {
1292512927

1292612928
LogCommentFmt(
12927-
L"Running test for DimM: %d, DimN: %d, NumThreads: %d, MatrixLayout: %s",
12929+
L"Running test for DimM: %d, DimN: %d, NumThreads: %d, MatrixLayout: %s, "
12930+
L"Stage: %s",
1292812931
Config.DimM, Config.DimN, Config.NumThreads,
12929-
CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str());
12932+
CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str(),
12933+
RunCompute ? L"Compute" : L"Pixel");
1293012934

1293112935
// Create root signature with a single root entry for all SRVs and UAVs
1293212936
CComPtr<ID3D12RootSignature> RootSignature;
1293312937
{
12934-
CD3DX12_DESCRIPTOR_RANGE ranges[2];
12935-
ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 0,
12936-
0); // InputVector1, InputVector2
12937-
ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); // AccumMatrix
12938-
CreateRootSignatureFromRanges(D3DDevice, &RootSignature, ranges, 2, nullptr,
12939-
0);
12938+
CD3DX12_DESCRIPTOR_RANGE Ranges[2];
12939+
Ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 0, 0);
12940+
Ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0);
12941+
12942+
CD3DX12_ROOT_PARAMETER RootParams[2];
12943+
RootParams[0].InitAsDescriptorTable(_countof(Ranges), Ranges,
12944+
D3D12_SHADER_VISIBILITY_ALL);
12945+
RootParams[1].InitAsUnorderedAccessView(/* register */ 10, /* space */ 0,
12946+
D3D12_SHADER_VISIBILITY_ALL);
12947+
12948+
CD3DX12_ROOT_SIGNATURE_DESC RootSignatureDesc;
12949+
RootSignatureDesc.Init(_countof(RootParams), RootParams, 0, nullptr,
12950+
D3D12_ROOT_SIGNATURE_FLAG_NONE);
12951+
CreateRootSignatureFromDesc(D3DDevice, &RootSignatureDesc, &RootSignature);
1294012952
}
1294112953

1294212954
// Create descriptor heap with space for 3 descriptors: 2 SRVs and 1 UAV
@@ -13028,17 +13040,17 @@ void ExecutionTest::runCoopVecOuterProductSubtest(
1302813040

1302913041
// Create a compute pipeline state object.
1303013042
CComPtr<ID3D12PipelineState> ComputePipelineState;
13031-
{
13032-
std::string ShaderSource = R"(
13043+
13044+
std::string ShaderSource = R"(
1303313045
#include "dx/linalg.h"
1303413046

1303513047
ByteAddressBuffer InputVector1 : register(t0);
1303613048
ByteAddressBuffer InputVector2 : register(t1);
1303713049
RWByteAddressBuffer AccumMatrix : register(u0);
1303813050

13039-
[shader("compute")]
13040-
[numthreads(NUM_THREADS, 1, 1)]
13041-
void main(uint threadIdx : SV_GroupThreadID)
13051+
RWStructuredBuffer<uint> AtomicCounter : register(u10);
13052+
13053+
void RunCoopVecTest(uint threadIdx)
1304213054
{
1304313055
using namespace dx::linalg;
1304413056

@@ -13053,94 +13065,142 @@ void main(uint threadIdx : SV_GroupThreadID)
1305313065

1305413066
OuterProductAccumulate(input1, input2, mat);
1305513067
}
13056-
)";
1305713068

13058-
auto CreateDefineFromInt = [](const wchar_t *Name, int Value) {
13059-
std::wstringstream Stream;
13060-
Stream << L"-D" << Name << L"=" << Value;
13061-
return Stream.str();
13062-
};
13069+
[shader("compute")]
13070+
[numthreads(NUM_THREADS, 1, 1)]
13071+
void main(uint threadIdx : SV_GroupThreadID)
13072+
{
13073+
RunCoopVecTest(threadIdx);
13074+
}
1306313075

13064-
auto CreateDefineFromString = [](const wchar_t *Name,
13065-
const wchar_t *Value) {
13066-
std::wstringstream Stream;
13067-
Stream << L"-D" << Name << L"=" << Value;
13068-
return Stream.str();
13069-
};
13076+
float4 vs_main(uint vid : SV_VertexID) : SV_Position {
13077+
switch (vid) {
13078+
case 0:
13079+
return float4(-1, 1, 0, 0);
13080+
case 1:
13081+
return float4(3, 1, 0, 0);
13082+
case 2:
13083+
return float4(-1, -3, 0, 0);
13084+
}
13085+
return float4(0, 0, 0, 0);
13086+
}
1307013087

13071-
int Stride = 0;
13072-
const std::wstring HlslMatrixLayout =
13073-
CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout);
13074-
int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType(
13075-
AccumulateProps.AccumulationType);
13076-
switch (Config.MatrixLayout) {
13077-
case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR:
13078-
Stride = Config.DimN * StrideMultiplier;
13079-
break;
13080-
case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR:
13081-
Stride = Config.DimM * StrideMultiplier;
13082-
break;
13083-
}
13088+
float4 ps_main() : SV_Target {
13089+
uint threadIdx;
13090+
InterlockedAdd(AtomicCounter[0], 1, threadIdx);
13091+
RunCoopVecTest(threadIdx);
13092+
return float4(1, 1, 1, 1);
13093+
}
13094+
)";
1308413095

13085-
const int InputDivisor =
13086-
CoopVecHelpers::GetNumPackedElementsForInputDataType(
13087-
AccumulateProps.InputType);
13088-
const std::wstring InputDataType =
13089-
CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.InputType);
13090-
const std::wstring AccumDataType =
13091-
CoopVecHelpers::GetHlslDataTypeForDataType(
13092-
AccumulateProps.AccumulationType);
13093-
const std::wstring MatrixDataTypeEnum =
13094-
CoopVecHelpers::GetHlslInterpretationForDataType(
13095-
AccumulateProps.AccumulationType);
13096-
const std::wstring InputInterpretationEnum =
13097-
CoopVecHelpers::GetHlslInterpretationForDataType(
13098-
AccumulateProps.InputType);
13099-
13100-
auto DimMDefine = CreateDefineFromInt(L"DIM_M", Config.DimM);
13101-
auto DimNDefine = CreateDefineFromInt(L"DIM_N", Config.DimN);
13102-
auto NumThreadsDefine =
13103-
CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads);
13104-
auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride);
13105-
auto InputDataTypeDefine =
13106-
CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType.c_str());
13107-
auto InputDivisorDefine =
13108-
CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor);
13109-
auto AccumDataTypeDefine =
13110-
CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType.c_str());
13111-
auto InputInterpretationEnumDefine = CreateDefineFromString(
13112-
L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum.c_str());
13113-
auto HlslMatrixLayoutDefine =
13114-
CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout.c_str());
13115-
auto MatrixDataTypeEnumDefine = CreateDefineFromString(
13116-
L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum.c_str());
13117-
auto InputVector1StrideDefine = CreateDefineFromInt(
13118-
L"INPUT_VECTOR_1_STRIDE", (int)InputVector1.getStride());
13119-
auto InputVector2StrideDefine = CreateDefineFromInt(
13120-
L"INPUT_VECTOR_2_STRIDE", (int)InputVector2.getStride());
13121-
13122-
LPCWSTR Options[] = {
13123-
L"-enable-16bit-types",
13124-
DimMDefine.c_str(),
13125-
DimNDefine.c_str(),
13126-
NumThreadsDefine.c_str(),
13127-
StrideDefine.c_str(),
13128-
InputDataTypeDefine.c_str(),
13129-
InputDivisorDefine.c_str(),
13130-
AccumDataTypeDefine.c_str(),
13131-
InputInterpretationEnumDefine.c_str(),
13132-
HlslMatrixLayoutDefine.c_str(),
13133-
MatrixDataTypeEnumDefine.c_str(),
13134-
InputVector1StrideDefine.c_str(),
13135-
InputVector2StrideDefine.c_str(),
13136-
};
13096+
auto CreateDefineFromInt = [](const wchar_t *Name, int Value) {
13097+
std::wstringstream Stream;
13098+
Stream << L"-D" << Name << L"=" << Value;
13099+
return Stream.str();
13100+
};
13101+
13102+
auto CreateDefineFromString = [](const wchar_t *Name, const wchar_t *Value) {
13103+
std::wstringstream Stream;
13104+
Stream << L"-D" << Name << L"=" << Value;
13105+
return Stream.str();
13106+
};
13107+
13108+
int Stride = 0;
13109+
const std::wstring HlslMatrixLayout =
13110+
CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout);
13111+
int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType(
13112+
AccumulateProps.AccumulationType);
13113+
switch (Config.MatrixLayout) {
13114+
case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR:
13115+
Stride = Config.DimN * StrideMultiplier;
13116+
break;
13117+
case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR:
13118+
Stride = Config.DimM * StrideMultiplier;
13119+
break;
13120+
}
1313713121

13138-
CComPtr<LinAlgHeaderIncludeHandler> IncludeHandler =
13139-
new LinAlgHeaderIncludeHandler(m_support);
13122+
const int InputDivisor = CoopVecHelpers::GetNumPackedElementsForInputDataType(
13123+
AccumulateProps.InputType);
13124+
const std::wstring InputDataType =
13125+
CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.InputType);
13126+
const std::wstring AccumDataType = CoopVecHelpers::GetHlslDataTypeForDataType(
13127+
AccumulateProps.AccumulationType);
13128+
const std::wstring MatrixDataTypeEnum =
13129+
CoopVecHelpers::GetHlslInterpretationForDataType(
13130+
AccumulateProps.AccumulationType);
13131+
const std::wstring InputInterpretationEnum =
13132+
CoopVecHelpers::GetHlslInterpretationForDataType(
13133+
AccumulateProps.InputType);
1314013134

13135+
auto DimMDefine = CreateDefineFromInt(L"DIM_M", Config.DimM);
13136+
auto DimNDefine = CreateDefineFromInt(L"DIM_N", Config.DimN);
13137+
auto NumThreadsDefine =
13138+
CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads);
13139+
auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride);
13140+
auto InputDataTypeDefine =
13141+
CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType.c_str());
13142+
auto InputDivisorDefine = CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor);
13143+
auto AccumDataTypeDefine =
13144+
CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType.c_str());
13145+
auto InputInterpretationEnumDefine = CreateDefineFromString(
13146+
L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum.c_str());
13147+
auto HlslMatrixLayoutDefine =
13148+
CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout.c_str());
13149+
auto MatrixDataTypeEnumDefine = CreateDefineFromString(
13150+
L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum.c_str());
13151+
auto InputVector1StrideDefine = CreateDefineFromInt(
13152+
L"INPUT_VECTOR_1_STRIDE", (int)InputVector1.getStride());
13153+
auto InputVector2StrideDefine = CreateDefineFromInt(
13154+
L"INPUT_VECTOR_2_STRIDE", (int)InputVector2.getStride());
13155+
13156+
LPCWSTR Options[] = {
13157+
L"-enable-16bit-types",
13158+
DimMDefine.c_str(),
13159+
DimNDefine.c_str(),
13160+
NumThreadsDefine.c_str(),
13161+
StrideDefine.c_str(),
13162+
InputDataTypeDefine.c_str(),
13163+
InputDivisorDefine.c_str(),
13164+
AccumDataTypeDefine.c_str(),
13165+
InputInterpretationEnumDefine.c_str(),
13166+
HlslMatrixLayoutDefine.c_str(),
13167+
MatrixDataTypeEnumDefine.c_str(),
13168+
InputVector1StrideDefine.c_str(),
13169+
InputVector2StrideDefine.c_str(),
13170+
};
13171+
13172+
CComPtr<LinAlgHeaderIncludeHandler> IncludeHandler =
13173+
new LinAlgHeaderIncludeHandler(m_support);
13174+
13175+
if (RunCompute) {
1314113176
CreateComputePSO(D3DDevice, RootSignature, ShaderSource.c_str(), L"cs_6_9",
1314213177
&ComputePipelineState, Options, _countof(Options),
1314313178
IncludeHandler);
13179+
} else {
13180+
CComPtr<ID3DBlob> VertexShader;
13181+
CComPtr<ID3DBlob> PixelShader;
13182+
13183+
CompileFromText(ShaderSource.c_str(), L"vs_main", L"vs_6_9", &VertexShader,
13184+
Options, _countof(Options), IncludeHandler);
13185+
CompileFromText(ShaderSource.c_str(), L"ps_main", L"ps_6_9", &PixelShader,
13186+
Options, _countof(Options), IncludeHandler);
13187+
13188+
D3D12_GRAPHICS_PIPELINE_STATE_DESC PsoDesc = {};
13189+
// psoDesc.InputLayout;
13190+
PsoDesc.pRootSignature = RootSignature;
13191+
PsoDesc.VS = CD3DX12_SHADER_BYTECODE(VertexShader);
13192+
PsoDesc.PS = CD3DX12_SHADER_BYTECODE(PixelShader);
13193+
PsoDesc.RasterizerState = CD3DX12_RASTERIZER_DESC(D3D12_DEFAULT);
13194+
PsoDesc.BlendState = CD3DX12_BLEND_DESC(D3D12_DEFAULT);
13195+
PsoDesc.DepthStencilState.DepthEnable = FALSE;
13196+
PsoDesc.DepthStencilState.StencilEnable = FALSE;
13197+
PsoDesc.SampleMask = UINT_MAX;
13198+
PsoDesc.PrimitiveTopologyType = D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE;
13199+
PsoDesc.NumRenderTargets = 1;
13200+
PsoDesc.RTVFormats[0] = DXGI_FORMAT_R8G8B8A8_UNORM;
13201+
PsoDesc.SampleDesc.Count = 1;
13202+
VERIFY_SUCCEEDED(D3DDevice->CreateGraphicsPipelineState(
13203+
&PsoDesc, IID_PPV_ARGS(&ComputePipelineState)));
1314413204
}
1314513205

1314613206
// Create a command list for the compute shader.
@@ -13283,6 +13343,14 @@ void main(uint threadIdx : SV_GroupThreadID)
1328313343
ConvertedMatrixResource);
1328413344
}
1328513345

13346+
// Create resource for atomic counter
13347+
CComPtr<ID3D12Resource> AtomicCounterResource;
13348+
uint32_t AtomicCounterInit = 0;
13349+
CreateTestResources(D3DDevice, CommandList, &AtomicCounterInit,
13350+
sizeof(AtomicCounterInit),
13351+
CD3DX12_RESOURCE_DESC::Buffer(sizeof(AtomicCounterInit)),
13352+
&AtomicCounterResource, nullptr);
13353+
1328613354
CommandList->Close();
1328713355
ExecuteCommandList(CommandQueue, CommandList);
1328813356
WaitForSignal(CommandQueue, FO);
@@ -13294,10 +13362,54 @@ void main(uint threadIdx : SV_GroupThreadID)
1329413362
CD3DX12_GPU_DESCRIPTOR_HANDLE ResHandle(
1329513363
DescriptorHeap->GetGPUDescriptorHandleForHeapStart());
1329613364

13297-
CommandList->SetComputeRootSignature(RootSignature);
13298-
CommandList->SetComputeRootDescriptorTable(0, ResHandle);
13299-
CommandList->SetPipelineState(ComputePipelineState);
13300-
CommandList->Dispatch(1, 1, 1);
13365+
CComPtr<ID3D12DescriptorHeap> RtvHeap;
13366+
CComPtr<ID3D12Resource> RenderTarget;
13367+
CComPtr<ID3D12Resource> RenderTargetRead;
13368+
13369+
if (RunCompute) {
13370+
CommandList->SetComputeRootSignature(RootSignature);
13371+
CommandList->SetComputeRootDescriptorTable(0, ResHandle);
13372+
CommandList->SetPipelineState(ComputePipelineState);
13373+
CommandList->Dispatch(1, 1, 1);
13374+
} else {
13375+
UINT FrameCount = 1;
13376+
UINT RtvDescSize = 0;
13377+
CreateRtvDescriptorHeap(D3DDevice, FrameCount, &RtvHeap, &RtvDescSize);
13378+
CreateRenderTargetAndReadback(D3DDevice, RtvHeap, 100, 100, &RenderTarget,
13379+
&RenderTargetRead);
13380+
13381+
D3D12_RESOURCE_DESC RtDesc = RenderTarget->GetDesc();
13382+
D3D12_VIEWPORT Viewport;
13383+
D3D12_RECT ScissorRect;
13384+
13385+
memset(&Viewport, 0, sizeof(Viewport));
13386+
Viewport.Height = (float)RtDesc.Height;
13387+
Viewport.Width = (float)RtDesc.Width;
13388+
Viewport.MaxDepth = 1.0f;
13389+
memset(&ScissorRect, 0, sizeof(ScissorRect));
13390+
ScissorRect.right = (long)RtDesc.Width;
13391+
ScissorRect.bottom = RtDesc.Height;
13392+
CommandList->SetGraphicsRootSignature(RootSignature);
13393+
CommandList->SetGraphicsRootDescriptorTable(0, ResHandle);
13394+
CommandList->SetGraphicsRootUnorderedAccessView(
13395+
1, AtomicCounterResource->GetGPUVirtualAddress());
13396+
CommandList->RSSetViewports(1, &Viewport);
13397+
CommandList->RSSetScissorRects(1, &ScissorRect);
13398+
13399+
// Indicate that the buffer will be used as a render target.
13400+
RecordTransitionBarrier(CommandList, RenderTarget,
13401+
D3D12_RESOURCE_STATE_COPY_DEST,
13402+
D3D12_RESOURCE_STATE_RENDER_TARGET);
13403+
13404+
CD3DX12_CPU_DESCRIPTOR_HANDLE RtvHandle(
13405+
RtvHeap->GetCPUDescriptorHandleForHeapStart(), 0, RtvDescSize);
13406+
CommandList->OMSetRenderTargets(1, &RtvHandle, FALSE, nullptr);
13407+
13408+
CommandList->ClearRenderTargetView(RtvHandle, ClearColor, 0, nullptr);
13409+
CommandList->IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
13410+
CommandList->DrawInstanced(3, 1, 0, 0);
13411+
}
13412+
1330113413
CommandList->Close();
1330213414
ExecuteCommandList(CommandQueue, CommandList);
1330313415
WaitForSignal(CommandQueue, FO);

0 commit comments

Comments
 (0)