Skip to content

Commit 6bcb151

Browse files
committed
Add SetShaderTableIndex+LoadLocalRootConstant tests / host code for local constants / simplifications
1 parent 6cb6843 commit 6bcb151

2 files changed

Lines changed: 353 additions & 14 deletions

File tree

tools/clang/unittests/HLSLExec/ExecutionTest.cpp

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,8 @@ class ExecutionTest {
303303
TEST_METHOD(SERGetAttributesTest);
304304
TEST_METHOD(SERTraceHitMissNopTest);
305305
TEST_METHOD(SERIsMissTest);
306+
TEST_METHOD(SERShaderTableIndexTest);
307+
TEST_METHOD(SERLoadLocalRootTableConstantTest);
306308
TEST_METHOD(LifetimeIntrinsicTest)
307309
TEST_METHOD(WaveIntrinsicsTest);
308310
TEST_METHOD(WaveIntrinsicsDDITest);
@@ -2248,11 +2250,12 @@ CComPtr<ID3D12Resource> ExecutionTest::RunDXRTest(
22482250
CComPtr<ID3D12RootSignature> pLocalRootSignature;
22492251
{
22502252
CD3DX12_DESCRIPTOR_RANGE bufferRanges[1];
2251-
CD3DX12_ROOT_PARAMETER rootParameters[1];
2253+
CD3DX12_ROOT_PARAMETER rootParameters[2];
22522254
bufferRanges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 2, 1, 0,
22532255
2); // vertexBuffer(t1), indexBuffer(t2)
22542256
rootParameters[0].InitAsDescriptorTable(
22552257
_countof(bufferRanges), bufferRanges, D3D12_SHADER_VISIBILITY_ALL);
2258+
rootParameters[1].InitAsConstants(4, 1, 0, D3D12_SHADER_VISIBILITY_ALL);
22562259

22572260
CD3DX12_ROOT_SIGNATURE_DESC rootSignatureDesc;
22582261
rootSignatureDesc.Init(_countof(rootParameters), rootParameters, 0, nullptr,
@@ -2316,6 +2319,9 @@ CComPtr<ID3D12Resource> ExecutionTest::RunDXRTest(
23162319
if (useIS) {
23172320
lib->DefineExport(L"intersection");
23182321
}
2322+
if (useMesh && useProceduralGeometry) {
2323+
lib->DefineExport(L"chAABB");
2324+
}
23192325

23202326
const int maxRecursion = 1;
23212327
stateObjectDesc.CreateSubobject<CD3DX12_RAYTRACING_SHADER_CONFIG_SUBOBJECT>()
@@ -2329,6 +2335,10 @@ CComPtr<ID3D12Resource> ExecutionTest::RunDXRTest(
23292335
stateObjectDesc
23302336
.CreateSubobject<CD3DX12_GLOBAL_ROOT_SIGNATURE_SUBOBJECT>();
23312337
globalRootSigSubObj->SetRootSignature(pGlobalRootSignature);
2338+
// Set Local Root Signature subobject.
2339+
stateObjectDesc.CreateSubobject<CD3DX12_LOCAL_ROOT_SIGNATURE_SUBOBJECT>()
2340+
->SetRootSignature(pLocalRootSignature);
2341+
23322342
auto exports = stateObjectDesc.CreateSubobject<
23332343
CD3DX12_SUBOBJECT_TO_EXPORTS_ASSOCIATION_SUBOBJECT>();
23342344
exports->SetSubobjectToAssociate(*globalRootSigSubObj);
@@ -2339,6 +2349,9 @@ CComPtr<ID3D12Resource> ExecutionTest::RunDXRTest(
23392349
if (useIS) {
23402350
exports->AddExport(L"intersection");
23412351
}
2352+
if (useMesh && useProceduralGeometry) {
2353+
exports->AddExport(L"chAABB");
2354+
}
23422355

23432356
auto hitGroup =
23442357
stateObjectDesc.CreateSubobject<CD3DX12_HIT_GROUP_SUBOBJECT>();
@@ -2350,15 +2363,23 @@ CComPtr<ID3D12Resource> ExecutionTest::RunDXRTest(
23502363
}
23512364
hitGroup->SetHitGroupExport(L"HitGroup");
23522365

2366+
if (useMesh && useProceduralGeometry) {
2367+
auto hitGroupAABB =
2368+
stateObjectDesc.CreateSubobject<CD3DX12_HIT_GROUP_SUBOBJECT>();
2369+
hitGroupAABB->SetClosestHitShaderImport(L"chAABB");
2370+
hitGroupAABB->SetAnyHitShaderImport(L"anyhit");
2371+
if (useIS) {
2372+
hitGroup->SetIntersectionShaderImport(L"intersection");
2373+
hitGroup->SetHitGroupType(D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE);
2374+
}
2375+
hitGroupAABB->SetHitGroupExport(L"HitGroupAABB");
2376+
}
2377+
23532378
CComPtr<ID3D12StateObject> pStateObject;
23542379
CComPtr<ID3D12StateObjectProperties> pStateObjectProperties;
23552380
VERIFY_SUCCEEDED(
23562381
pDevice->CreateStateObject(stateObjectDesc, IID_PPV_ARGS(&pStateObject)));
23572382
VERIFY_SUCCEEDED(pStateObject->QueryInterface(&pStateObjectProperties));
2358-
stateObjectDesc.CreateSubobject<CD3DX12_LOCAL_ROOT_SIGNATURE_SUBOBJECT>()
2359-
->SetRootSignature(pLocalRootSignature);
2360-
stateObjectDesc.CreateSubobject<CD3DX12_GLOBAL_ROOT_SIGNATURE_SUBOBJECT>()
2361-
->SetRootSignature(pGlobalRootSignature);
23622383

23632384
// Create SBT
23642385
ShaderTable shaderTable;
@@ -2367,21 +2388,33 @@ CComPtr<ID3D12Resource> ExecutionTest::RunDXRTest(
23672388
1, // miss count
23682389
useMesh && useProceduralGeometry ? 2 : 1, // hit group count
23692390
1, // ray type count
2370-
2 // dwords per root table
2391+
4 // dwords per root table
23712392
);
23722393

2394+
int localRootConsts[4] = {12, 34, 56, 78};
23732395
memcpy(shaderTable.GetRaygenShaderIdPtr(0),
23742396
pStateObjectProperties->GetShaderIdentifier(L"raygen"),
23752397
SHADER_ID_SIZE_IN_BYTES);
2398+
memcpy(shaderTable.GetRaygenRootTablePtr(0), localRootConsts,
2399+
sizeof(localRootConsts));
23762400
memcpy(shaderTable.GetMissShaderIdPtr(0, 0),
23772401
pStateObjectProperties->GetShaderIdentifier(L"miss"),
23782402
SHADER_ID_SIZE_IN_BYTES);
2403+
memcpy(shaderTable.GetMissRootTablePtr(0, 0), localRootConsts,
2404+
sizeof(localRootConsts));
23792405
memcpy(shaderTable.GetHitGroupShaderIdPtr(0, 0),
23802406
pStateObjectProperties->GetShaderIdentifier(L"HitGroup"),
23812407
SHADER_ID_SIZE_IN_BYTES);
2408+
memcpy(shaderTable.GetHitGroupRootTablePtr(0, 0), localRootConsts,
2409+
sizeof(localRootConsts));
2410+
if (useMesh && useProceduralGeometry) {
2411+
memcpy(shaderTable.GetHitGroupShaderIdPtr(0, 1),
2412+
pStateObjectProperties->GetShaderIdentifier(L"HitGroupAABB"),
2413+
SHADER_ID_SIZE_IN_BYTES);
2414+
}
23822415

2383-
auto tbl = pDescriptorHeap->GetGPUDescriptorHandleForHeapStart().ptr;
2384-
memcpy(shaderTable.GetHitGroupRootTablePtr(0, 0), &tbl, 8);
2416+
// auto tbl = pDescriptorHeap->GetGPUDescriptorHandleForHeapStart().ptr;
2417+
// memcpy(shaderTable.GetHitGroupRootTablePtr(0, 0), &tbl, 8);
23852418

23862419
// Create a command allocator and list.
23872420
CComPtr<ID3D12CommandAllocator> pCommandAllocator;
@@ -2521,6 +2554,7 @@ CComPtr<ID3D12Resource> ExecutionTest::RunDXRTest(
25212554
pDevice->GetRaytracingAccelerationStructurePrebuildInfo(&accelInputs,
25222555
&prebuildInfo);
25232556

2557+
scratchResource.Release();
25242558
ReallocScratchResource(pDevice, &scratchResource,
25252559
prebuildInfo.ScratchDataSizeInBytes);
25262560
AllocateBuffer(pDevice, prebuildInfo.ResultDataMaxSizeInBytes,
@@ -2597,6 +2631,7 @@ CComPtr<ID3D12Resource> ExecutionTest::RunDXRTest(
25972631
&prebuildInfo);
25982632

25992633
// Allocate scratch and result buffers for the BLAS
2634+
scratchResource.Release();
26002635
ReallocScratchResource(pDevice, &scratchResource,
26012636
prebuildInfo.ScratchDataSizeInBytes);
26022637
AllocateBuffer(pDevice, prebuildInfo.ResultDataMaxSizeInBytes,
@@ -2654,6 +2689,9 @@ CComPtr<ID3D12Resource> ExecutionTest::RunDXRTest(
26542689
pDevice->GetRaytracingAccelerationStructurePrebuildInfo(&accelInputs,
26552690
&prebuildInfo);
26562691

2692+
scratchResource.Release();
2693+
ReallocScratchResource(pDevice, &scratchResource,
2694+
prebuildInfo.ScratchDataSizeInBytes);
26572695
AllocateBuffer(
26582696
pDevice, prebuildInfo.ResultDataMaxSizeInBytes, &tlasResource, true,
26592697
D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, L"TLAS");
@@ -2691,6 +2729,9 @@ CComPtr<ID3D12Resource> ExecutionTest::RunDXRTest(
26912729
pDevice->GetRaytracingAccelerationStructurePrebuildInfo(&accelInputs,
26922730
&prebuildInfo);
26932731

2732+
scratchResource.Release();
2733+
ReallocScratchResource(pDevice, &scratchResource,
2734+
prebuildInfo.ScratchDataSizeInBytes);
26942735
AllocateBuffer(
26952736
pDevice, prebuildInfo.ResultDataMaxSizeInBytes, &tlasResource, true,
26962737
D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, L"TLAS");
@@ -2710,6 +2751,13 @@ CComPtr<ID3D12Resource> ExecutionTest::RunDXRTest(
27102751
pCommandList->ResourceBarrier(1, (const D3D12_RESOURCE_BARRIER *)&barrier);
27112752
}
27122753

2754+
// Set the local root constants.
2755+
pCommandList->SetComputeRootSignature(pLocalRootSignature);
2756+
pCommandList->SetComputeRoot32BitConstant(1, 12, 0);
2757+
pCommandList->SetComputeRoot32BitConstant(1, 34, 1);
2758+
pCommandList->SetComputeRoot32BitConstant(1, 56, 2);
2759+
pCommandList->SetComputeRoot32BitConstant(1, 78, 3);
2760+
27132761
shaderTable.Upload(pCommandList);
27142762

27152763
ID3D12DescriptorHeap *const pHeaps[1] = {pDescriptorHeap};

0 commit comments

Comments
 (0)