Skip to content

Commit 64f7487

Browse files
Support pixel shaders in OuterProduct tests
1 parent 9536047 commit 64f7487

1 file changed

Lines changed: 211 additions & 99 deletions

File tree

tools/clang/unittests/HLSLExec/ExecutionTest.cpp

Lines changed: 211 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ class ExecutionTest {
814814
void runCoopVecOuterProductSubtest(
815815
ID3D12Device *D3DDevice,
816816
D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps,
817-
CoopVecOuterProductSubtestConfig &Config);
817+
CoopVecOuterProductSubtestConfig &Config, bool RunCompute);
818818

819819
#endif // HAVE_COOPVEC_API
820820

@@ -12913,29 +12913,41 @@ void ExecutionTest::runCoopVecOuterProductTestConfig(
1291312913
continue;
1291412914
}
1291512915

12916-
runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config);
12916+
// Run once in compute, then once in graphics (pixel shader)
12917+
runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config, true);
12918+
runCoopVecOuterProductSubtest(D3DDevice, AccumulateProps, Config, false);
1291712919
}
1291812920
}
1291912921

1292012922
void ExecutionTest::runCoopVecOuterProductSubtest(
1292112923
ID3D12Device *D3DDevice,
1292212924
D3D12_COOPERATIVE_VECTOR_PROPERTIES_ACCUMULATE &AccumulateProps,
12923-
CoopVecOuterProductSubtestConfig &Config) {
12925+
CoopVecOuterProductSubtestConfig &Config, bool RunCompute) {
1292412926

1292512927
LogCommentFmt(
12926-
L"Running test for DimM: %d, DimN: %d, NumThreads: %d, MatrixLayout: %s",
12928+
L"Running test for DimM: %d, DimN: %d, NumThreads: %d, MatrixLayout: %s, "
12929+
L"Stage: %s",
1292712930
Config.DimM, Config.DimN, Config.NumThreads,
12928-
CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str());
12931+
CoopVecHelpers::MatrixLayoutToFilterString(Config.MatrixLayout).c_str(),
12932+
RunCompute ? L"Compute" : L"Pixel");
1292912933

1293012934
// Create root signature with a single root entry for all SRVs and UAVs
1293112935
CComPtr<ID3D12RootSignature> RootSignature;
1293212936
{
12933-
CD3DX12_DESCRIPTOR_RANGE ranges[2];
12934-
ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 0,
12935-
0); // InputVector1, InputVector2
12936-
ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0); // AccumMatrix
12937-
CreateRootSignatureFromRanges(D3DDevice, &RootSignature, ranges, 2, nullptr,
12938-
0);
12937+
CD3DX12_DESCRIPTOR_RANGE Ranges[2];
12938+
Ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 0, 0);
12939+
Ranges[1].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0);
12940+
12941+
CD3DX12_ROOT_PARAMETER RootParams[2];
12942+
RootParams[0].InitAsDescriptorTable(_countof(Ranges), Ranges,
12943+
D3D12_SHADER_VISIBILITY_ALL);
12944+
RootParams[1].InitAsUnorderedAccessView(/* register */ 10, /* space */ 0,
12945+
D3D12_SHADER_VISIBILITY_ALL);
12946+
12947+
CD3DX12_ROOT_SIGNATURE_DESC RootSignatureDesc;
12948+
RootSignatureDesc.Init(_countof(RootParams), RootParams, 0, nullptr,
12949+
D3D12_ROOT_SIGNATURE_FLAG_NONE);
12950+
CreateRootSignatureFromDesc(D3DDevice, &RootSignatureDesc, &RootSignature);
1293912951
}
1294012952

1294112953
// Create descriptor heap with space for 3 descriptors: 2 SRVs and 1 UAV
@@ -13027,17 +13039,17 @@ void ExecutionTest::runCoopVecOuterProductSubtest(
1302713039

1302813040
// Create a compute pipeline state object.
1302913041
CComPtr<ID3D12PipelineState> ComputePipelineState;
13030-
{
13031-
std::string ShaderSource = R"(
13042+
13043+
std::string ShaderSource = R"(
1303213044
#include "dx/linalg.h"
1303313045

1303413046
ByteAddressBuffer InputVector1 : register(t0);
1303513047
ByteAddressBuffer InputVector2 : register(t1);
1303613048
RWByteAddressBuffer AccumMatrix : register(u0);
1303713049

13038-
[shader("compute")]
13039-
[numthreads(NUM_THREADS, 1, 1)]
13040-
void main(uint threadIdx : SV_GroupThreadID)
13050+
RWStructuredBuffer<uint> AtomicCounter : register(u10);
13051+
13052+
void RunCoopVecTest(uint threadIdx)
1304113053
{
1304213054
using namespace dx::linalg;
1304313055

@@ -13052,94 +13064,142 @@ void main(uint threadIdx : SV_GroupThreadID)
1305213064

1305313065
OuterProductAccumulate(input1, input2, mat);
1305413066
}
13055-
)";
1305613067

13057-
auto CreateDefineFromInt = [](const wchar_t *Name, int Value) {
13058-
std::wstringstream Stream;
13059-
Stream << L"-D" << Name << L"=" << Value;
13060-
return Stream.str();
13061-
};
13068+
[shader("compute")]
13069+
[numthreads(NUM_THREADS, 1, 1)]
13070+
void main(uint threadIdx : SV_GroupThreadID)
13071+
{
13072+
RunCoopVecTest(threadIdx);
13073+
}
1306213074

13063-
auto CreateDefineFromString = [](const wchar_t *Name,
13064-
const wchar_t *Value) {
13065-
std::wstringstream Stream;
13066-
Stream << L"-D" << Name << L"=" << Value;
13067-
return Stream.str();
13068-
};
13075+
float4 vs_main(uint vid : SV_VertexID) : SV_Position {
13076+
switch (vid) {
13077+
case 0:
13078+
return float4(-1, 1, 0, 0);
13079+
case 1:
13080+
return float4(3, 1, 0, 0);
13081+
case 2:
13082+
return float4(-1, -3, 0, 0);
13083+
}
13084+
return float4(0, 0, 0, 0);
13085+
}
1306913086

13070-
int Stride = 0;
13071-
const std::wstring HlslMatrixLayout =
13072-
CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout);
13073-
int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType(
13074-
AccumulateProps.AccumulationType);
13075-
switch (Config.MatrixLayout) {
13076-
case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR:
13077-
Stride = Config.DimN * StrideMultiplier;
13078-
break;
13079-
case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR:
13080-
Stride = Config.DimM * StrideMultiplier;
13081-
break;
13082-
}
13087+
float4 ps_main() : SV_Target {
13088+
uint threadIdx;
13089+
InterlockedAdd(AtomicCounter[0], 1, threadIdx);
13090+
RunCoopVecTest(threadIdx);
13091+
return float4(1, 1, 1, 1);
13092+
}
13093+
)";
1308313094

13084-
const int InputDivisor =
13085-
CoopVecHelpers::GetNumPackedElementsForInputDataType(
13086-
AccumulateProps.InputType);
13087-
const std::wstring InputDataType =
13088-
CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.InputType);
13089-
const std::wstring AccumDataType =
13090-
CoopVecHelpers::GetHlslDataTypeForDataType(
13091-
AccumulateProps.AccumulationType);
13092-
const std::wstring MatrixDataTypeEnum =
13093-
CoopVecHelpers::GetHlslInterpretationForDataType(
13094-
AccumulateProps.AccumulationType);
13095-
const std::wstring InputInterpretationEnum =
13096-
CoopVecHelpers::GetHlslInterpretationForDataType(
13097-
AccumulateProps.InputType);
13098-
13099-
auto DimMDefine = CreateDefineFromInt(L"DIM_M", Config.DimM);
13100-
auto DimNDefine = CreateDefineFromInt(L"DIM_N", Config.DimN);
13101-
auto NumThreadsDefine =
13102-
CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads);
13103-
auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride);
13104-
auto InputDataTypeDefine =
13105-
CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType.c_str());
13106-
auto InputDivisorDefine =
13107-
CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor);
13108-
auto AccumDataTypeDefine =
13109-
CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType.c_str());
13110-
auto InputInterpretationEnumDefine = CreateDefineFromString(
13111-
L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum.c_str());
13112-
auto HlslMatrixLayoutDefine =
13113-
CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout.c_str());
13114-
auto MatrixDataTypeEnumDefine = CreateDefineFromString(
13115-
L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum.c_str());
13116-
auto InputVector1StrideDefine = CreateDefineFromInt(
13117-
L"INPUT_VECTOR_1_STRIDE", (int)InputVector1.getStride());
13118-
auto InputVector2StrideDefine = CreateDefineFromInt(
13119-
L"INPUT_VECTOR_2_STRIDE", (int)InputVector2.getStride());
13120-
13121-
LPCWSTR Options[] = {
13122-
L"-enable-16bit-types",
13123-
DimMDefine.c_str(),
13124-
DimNDefine.c_str(),
13125-
NumThreadsDefine.c_str(),
13126-
StrideDefine.c_str(),
13127-
InputDataTypeDefine.c_str(),
13128-
InputDivisorDefine.c_str(),
13129-
AccumDataTypeDefine.c_str(),
13130-
InputInterpretationEnumDefine.c_str(),
13131-
HlslMatrixLayoutDefine.c_str(),
13132-
MatrixDataTypeEnumDefine.c_str(),
13133-
InputVector1StrideDefine.c_str(),
13134-
InputVector2StrideDefine.c_str(),
13135-
};
13095+
auto CreateDefineFromInt = [](const wchar_t *Name, int Value) {
13096+
std::wstringstream Stream;
13097+
Stream << L"-D" << Name << L"=" << Value;
13098+
return Stream.str();
13099+
};
13100+
13101+
auto CreateDefineFromString = [](const wchar_t *Name, const wchar_t *Value) {
13102+
std::wstringstream Stream;
13103+
Stream << L"-D" << Name << L"=" << Value;
13104+
return Stream.str();
13105+
};
13106+
13107+
int Stride = 0;
13108+
const std::wstring HlslMatrixLayout =
13109+
CoopVecHelpers::MatrixLayoutToHlslLayoutString(Config.MatrixLayout);
13110+
int StrideMultiplier = CoopVecHelpers::GetStrideMultiplierForMatrixDataType(
13111+
AccumulateProps.AccumulationType);
13112+
switch (Config.MatrixLayout) {
13113+
case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_ROW_MAJOR:
13114+
Stride = Config.DimN * StrideMultiplier;
13115+
break;
13116+
case D3D12_LINEAR_ALGEBRA_MATRIX_LAYOUT_COLUMN_MAJOR:
13117+
Stride = Config.DimM * StrideMultiplier;
13118+
break;
13119+
}
1313613120

13137-
CComPtr<LinAlgHeaderIncludeHandler> IncludeHandler =
13138-
new LinAlgHeaderIncludeHandler(m_support);
13121+
const int InputDivisor = CoopVecHelpers::GetNumPackedElementsForInputDataType(
13122+
AccumulateProps.InputType);
13123+
const std::wstring InputDataType =
13124+
CoopVecHelpers::GetHlslDataTypeForDataType(AccumulateProps.InputType);
13125+
const std::wstring AccumDataType = CoopVecHelpers::GetHlslDataTypeForDataType(
13126+
AccumulateProps.AccumulationType);
13127+
const std::wstring MatrixDataTypeEnum =
13128+
CoopVecHelpers::GetHlslInterpretationForDataType(
13129+
AccumulateProps.AccumulationType);
13130+
const std::wstring InputInterpretationEnum =
13131+
CoopVecHelpers::GetHlslInterpretationForDataType(
13132+
AccumulateProps.InputType);
1313913133

13134+
auto DimMDefine = CreateDefineFromInt(L"DIM_M", Config.DimM);
13135+
auto DimNDefine = CreateDefineFromInt(L"DIM_N", Config.DimN);
13136+
auto NumThreadsDefine =
13137+
CreateDefineFromInt(L"NUM_THREADS", Config.NumThreads);
13138+
auto StrideDefine = CreateDefineFromInt(L"STRIDE", Stride);
13139+
auto InputDataTypeDefine =
13140+
CreateDefineFromString(L"INPUT_DATA_TYPE", InputDataType.c_str());
13141+
auto InputDivisorDefine = CreateDefineFromInt(L"INPUT_DIVISOR", InputDivisor);
13142+
auto AccumDataTypeDefine =
13143+
CreateDefineFromString(L"ACCUM_DATA_TYPE", AccumDataType.c_str());
13144+
auto InputInterpretationEnumDefine = CreateDefineFromString(
13145+
L"INPUT_INTERPRETATION_ENUM", InputInterpretationEnum.c_str());
13146+
auto HlslMatrixLayoutDefine =
13147+
CreateDefineFromString(L"HLSL_MATRIX_LAYOUT", HlslMatrixLayout.c_str());
13148+
auto MatrixDataTypeEnumDefine = CreateDefineFromString(
13149+
L"MATRIX_DATA_TYPE_ENUM", MatrixDataTypeEnum.c_str());
13150+
auto InputVector1StrideDefine = CreateDefineFromInt(
13151+
L"INPUT_VECTOR_1_STRIDE", (int)InputVector1.getStride());
13152+
auto InputVector2StrideDefine = CreateDefineFromInt(
13153+
L"INPUT_VECTOR_2_STRIDE", (int)InputVector2.getStride());
13154+
13155+
LPCWSTR Options[] = {
13156+
L"-enable-16bit-types",
13157+
DimMDefine.c_str(),
13158+
DimNDefine.c_str(),
13159+
NumThreadsDefine.c_str(),
13160+
StrideDefine.c_str(),
13161+
InputDataTypeDefine.c_str(),
13162+
InputDivisorDefine.c_str(),
13163+
AccumDataTypeDefine.c_str(),
13164+
InputInterpretationEnumDefine.c_str(),
13165+
HlslMatrixLayoutDefine.c_str(),
13166+
MatrixDataTypeEnumDefine.c_str(),
13167+
InputVector1StrideDefine.c_str(),
13168+
InputVector2StrideDefine.c_str(),
13169+
};
13170+
13171+
CComPtr<LinAlgHeaderIncludeHandler> IncludeHandler =
13172+
new LinAlgHeaderIncludeHandler(m_support);
13173+
13174+
if (RunCompute) {
1314013175
CreateComputePSO(D3DDevice, RootSignature, ShaderSource.c_str(), L"cs_6_9",
1314113176
&ComputePipelineState, Options, _countof(Options),
1314213177
IncludeHandler);
13178+
} else {
13179+
CComPtr<ID3DBlob> VertexShader;
13180+
CComPtr<ID3DBlob> PixelShader;
13181+
13182+
CompileFromText(ShaderSource.c_str(), L"vs_main", L"vs_6_9", &VertexShader,
13183+
Options, _countof(Options), IncludeHandler);
13184+
CompileFromText(ShaderSource.c_str(), L"ps_main", L"ps_6_9", &PixelShader,
13185+
Options, _countof(Options), IncludeHandler);
13186+
13187+
D3D12_GRAPHICS_PIPELINE_STATE_DESC PsoDesc = {};
13188+
// psoDesc.InputLayout;
13189+
PsoDesc.pRootSignature = RootSignature;
13190+
PsoDesc.VS = CD3DX12_SHADER_BYTECODE(VertexShader);
13191+
PsoDesc.PS = CD3DX12_SHADER_BYTECODE(PixelShader);
13192+
PsoDesc.RasterizerState = CD3DX12_RASTERIZER_DESC(D3D12_DEFAULT);
13193+
PsoDesc.BlendState = CD3DX12_BLEND_DESC(D3D12_DEFAULT);
13194+
PsoDesc.DepthStencilState.DepthEnable = FALSE;
13195+
PsoDesc.DepthStencilState.StencilEnable = FALSE;
13196+
PsoDesc.SampleMask = UINT_MAX;
13197+
PsoDesc.PrimitiveTopologyType = D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE;
13198+
PsoDesc.NumRenderTargets = 1;
13199+
PsoDesc.RTVFormats[0] = DXGI_FORMAT_R8G8B8A8_UNORM;
13200+
PsoDesc.SampleDesc.Count = 1;
13201+
VERIFY_SUCCEEDED(D3DDevice->CreateGraphicsPipelineState(
13202+
&PsoDesc, IID_PPV_ARGS(&ComputePipelineState)));
1314313203
}
1314413204

1314513205
// Create a command list for the compute shader.
@@ -13282,6 +13342,14 @@ void main(uint threadIdx : SV_GroupThreadID)
1328213342
ConvertedMatrixResource);
1328313343
}
1328413344

13345+
// Create resource for atomic counter
13346+
CComPtr<ID3D12Resource> AtomicCounterResource;
13347+
uint32_t AtomicCounterInit = 0;
13348+
CreateTestResources(D3DDevice, CommandList, &AtomicCounterInit,
13349+
sizeof(AtomicCounterInit),
13350+
CD3DX12_RESOURCE_DESC::Buffer(sizeof(AtomicCounterInit)),
13351+
&AtomicCounterResource, nullptr);
13352+
1328513353
CommandList->Close();
1328613354
ExecuteCommandList(CommandQueue, CommandList);
1328713355
WaitForSignal(CommandQueue, FO);
@@ -13293,10 +13361,54 @@ void main(uint threadIdx : SV_GroupThreadID)
1329313361
CD3DX12_GPU_DESCRIPTOR_HANDLE ResHandle(
1329413362
DescriptorHeap->GetGPUDescriptorHandleForHeapStart());
1329513363

13296-
CommandList->SetComputeRootSignature(RootSignature);
13297-
CommandList->SetComputeRootDescriptorTable(0, ResHandle);
13298-
CommandList->SetPipelineState(ComputePipelineState);
13299-
CommandList->Dispatch(1, 1, 1);
13364+
CComPtr<ID3D12DescriptorHeap> RtvHeap;
13365+
CComPtr<ID3D12Resource> RenderTarget;
13366+
CComPtr<ID3D12Resource> RenderTargetRead;
13367+
13368+
if (RunCompute) {
13369+
CommandList->SetComputeRootSignature(RootSignature);
13370+
CommandList->SetComputeRootDescriptorTable(0, ResHandle);
13371+
CommandList->SetPipelineState(ComputePipelineState);
13372+
CommandList->Dispatch(1, 1, 1);
13373+
} else {
13374+
UINT FrameCount = 1;
13375+
UINT RtvDescSize = 0;
13376+
CreateRtvDescriptorHeap(D3DDevice, FrameCount, &RtvHeap, &RtvDescSize);
13377+
CreateRenderTargetAndReadback(D3DDevice, RtvHeap, 100, 100, &RenderTarget,
13378+
&RenderTargetRead);
13379+
13380+
D3D12_RESOURCE_DESC RtDesc = RenderTarget->GetDesc();
13381+
D3D12_VIEWPORT Viewport;
13382+
D3D12_RECT ScissorRect;
13383+
13384+
memset(&Viewport, 0, sizeof(Viewport));
13385+
Viewport.Height = (float)RtDesc.Height;
13386+
Viewport.Width = (float)RtDesc.Width;
13387+
Viewport.MaxDepth = 1.0f;
13388+
memset(&ScissorRect, 0, sizeof(ScissorRect));
13389+
ScissorRect.right = (long)RtDesc.Width;
13390+
ScissorRect.bottom = RtDesc.Height;
13391+
CommandList->SetGraphicsRootSignature(RootSignature);
13392+
CommandList->SetGraphicsRootDescriptorTable(0, ResHandle);
13393+
CommandList->SetGraphicsRootUnorderedAccessView(
13394+
1, AtomicCounterResource->GetGPUVirtualAddress());
13395+
CommandList->RSSetViewports(1, &Viewport);
13396+
CommandList->RSSetScissorRects(1, &ScissorRect);
13397+
13398+
// Indicate that the buffer will be used as a render target.
13399+
RecordTransitionBarrier(CommandList, RenderTarget,
13400+
D3D12_RESOURCE_STATE_COPY_DEST,
13401+
D3D12_RESOURCE_STATE_RENDER_TARGET);
13402+
13403+
CD3DX12_CPU_DESCRIPTOR_HANDLE RtvHandle(
13404+
RtvHeap->GetCPUDescriptorHandleForHeapStart(), 0, RtvDescSize);
13405+
CommandList->OMSetRenderTargets(1, &RtvHandle, FALSE, nullptr);
13406+
13407+
CommandList->ClearRenderTargetView(RtvHandle, ClearColor, 0, nullptr);
13408+
CommandList->IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
13409+
CommandList->DrawInstanced(3, 1, 0, 0);
13410+
}
13411+
1330013412
CommandList->Close();
1330113413
ExecuteCommandList(CommandQueue, CommandList);
1330213414
WaitForSignal(CommandQueue, FO);

0 commit comments

Comments
 (0)