@@ -398,6 +398,36 @@ class DXQueue : public offloadtest::Queue {
398398 }
399399};
400400
401+ class DXCommandBuffer : public offloadtest ::CommandBuffer {
402+ public:
403+ static constexpr GPUAPI BackendAPI = GPUAPI::DirectX;
404+
405+ ComPtr<ID3D12CommandAllocator> Allocator;
406+ ComPtr<ID3D12GraphicsCommandList> CmdList;
407+
408+ static llvm::Expected<std::unique_ptr<DXCommandBuffer>>
409+ create (ComPtr<ID3D12Device> Device) {
410+ auto CB = std::unique_ptr<DXCommandBuffer>(new DXCommandBuffer ());
411+ if (auto Err = HR::toError (
412+ Device->CreateCommandAllocator (D3D12_COMMAND_LIST_TYPE_DIRECT,
413+ IID_PPV_ARGS (&CB->Allocator )),
414+ " Failed to create command allocator." ))
415+ return Err;
416+ if (auto Err = HR::toError (
417+ Device->CreateCommandList (0 , D3D12_COMMAND_LIST_TYPE_DIRECT,
418+ CB->Allocator .Get (), nullptr ,
419+ IID_PPV_ARGS (&CB->CmdList )),
420+ " Failed to create command list." ))
421+ return Err;
422+ return CB;
423+ }
424+
425+ ~DXCommandBuffer () override = default ;
426+
427+ private:
428+ DXCommandBuffer () : CommandBuffer(GPUAPI::DirectX) {}
429+ };
430+
401431class DXDevice : public offloadtest ::Device {
402432private:
403433 ComPtr<IDXCoreAdapter> Adapter;
@@ -429,8 +459,7 @@ class DXDevice : public offloadtest::Device {
429459 ComPtr<ID3D12RootSignature> RootSig;
430460 ComPtr<ID3D12DescriptorHeap> DescHeap;
431461 ComPtr<ID3D12PipelineState> PSO;
432- ComPtr<ID3D12CommandAllocator> Allocator;
433- ComPtr<ID3D12GraphicsCommandList> CmdList;
462+ std::unique_ptr<DXCommandBuffer> CB;
434463 std::unique_ptr<offloadtest::Fence> Fence;
435464
436465 // Resources for graphics pipelines.
@@ -692,19 +721,9 @@ class DXDevice : public offloadtest::Device {
692721 return llvm::Error::success ();
693722 }
694723
695- llvm::Error createCommandStructures (InvocationState &IS) {
696- if (auto Err = HR::toError (
697- Device->CreateCommandAllocator (D3D12_COMMAND_LIST_TYPE_DIRECT,
698- IID_PPV_ARGS (&IS.Allocator )),
699- " Failed to create command allocator." ))
700- return Err;
701- if (auto Err = HR::toError (
702- Device->CreateCommandList (0 , D3D12_COMMAND_LIST_TYPE_DIRECT,
703- IS.Allocator .Get (), nullptr ,
704- IID_PPV_ARGS (&IS.CmdList )),
705- " Failed to create command list." ))
706- return Err;
707- return llvm::Error::success ();
724+ llvm::Expected<std::unique_ptr<offloadtest::CommandBuffer>>
725+ createCommandBuffer () override {
726+ return DXCommandBuffer::create (Device);
708727 }
709728
710729 void addResourceUploadCommands (Resource &R, InvocationState &IS,
@@ -721,10 +740,10 @@ class DXDevice : public offloadtest::Device {
721740 const CD3DX12_TEXTURE_COPY_LOCATION DstLoc (Destination.Get (), 0 );
722741 const CD3DX12_TEXTURE_COPY_LOCATION SrcLoc (Source.Get (), Footprint);
723742
724- IS.CmdList ->CopyTextureRegion (&DstLoc, 0 , 0 , 0 , &SrcLoc, nullptr );
743+ IS.CB -> CmdList ->CopyTextureRegion (&DstLoc, 0 , 0 , 0 , &SrcLoc, nullptr );
725744 } else
726- IS.CmdList ->CopyBufferRegion (Destination.Get (), 0 , Source.Get (), 0 ,
727- R.size ());
745+ IS.CB -> CmdList ->CopyBufferRegion (Destination.Get (), 0 , Source.Get (), 0 ,
746+ R.size ());
728747 addUploadEndBarrier (IS, Destination, R.isReadWrite ());
729748 }
730749
@@ -1191,7 +1210,7 @@ class DXDevice : public offloadtest::Device {
11911210 {D3D12_RESOURCE_TRANSITION_BARRIER{
11921211 R.Get (), D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES,
11931212 D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_COPY_DEST}}};
1194- IS.CmdList ->ResourceBarrier (1 , &Barrier);
1213+ IS.CB -> CmdList ->ResourceBarrier (1 , &Barrier);
11951214 }
11961215
11971216 void addUploadEndBarrier (InvocationState &IS, ComPtr<ID3D12Resource> R,
@@ -1204,21 +1223,21 @@ class DXDevice : public offloadtest::Device {
12041223 D3D12_RESOURCE_STATE_COPY_DEST,
12051224 IsUAV ? D3D12_RESOURCE_STATE_UNORDERED_ACCESS
12061225 : D3D12_RESOURCE_STATE_GENERIC_READ}}};
1207- IS.CmdList ->ResourceBarrier (1 , &Barrier);
1226+ IS.CB -> CmdList ->ResourceBarrier (1 , &Barrier);
12081227 }
12091228
12101229 void addReadbackBeginBarrier (InvocationState &IS, ComPtr<ID3D12Resource> R) {
12111230 const D3D12_RESOURCE_BARRIER Barrier = CD3DX12_RESOURCE_BARRIER::Transition (
12121231 R.Get (), D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
12131232 D3D12_RESOURCE_STATE_COPY_SOURCE);
1214- IS.CmdList ->ResourceBarrier (1 , &Barrier);
1233+ IS.CB -> CmdList ->ResourceBarrier (1 , &Barrier);
12151234 }
12161235
12171236 void addReadbackEndBarrier (InvocationState &IS, ComPtr<ID3D12Resource> R) {
12181237 const D3D12_RESOURCE_BARRIER Barrier = CD3DX12_RESOURCE_BARRIER::Transition (
12191238 R.Get (), D3D12_RESOURCE_STATE_COPY_SOURCE,
12201239 D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
1221- IS.CmdList ->ResourceBarrier (1 , &Barrier);
1240+ IS.CB -> CmdList ->ResourceBarrier (1 , &Barrier);
12221241 }
12231242
12241243 llvm::Error waitForSignal (InvocationState &IS) {
@@ -1240,11 +1259,11 @@ class DXDevice : public offloadtest::Device {
12401259 }
12411260
12421261 llvm::Error executeCommandList (InvocationState &IS) {
1243- if (auto Err =
1244- HR::toError (IS. CmdList -> Close (), " Failed to close command list." ))
1262+ if (auto Err = HR::toError (IS. CB -> CmdList -> Close (),
1263+ " Failed to close command list." ))
12451264 return Err;
12461265
1247- ID3D12CommandList *const CmdLists[] = {IS.CmdList .Get ()};
1266+ ID3D12CommandList *const CmdLists[] = {IS.CB -> CmdList .Get ()};
12481267 GraphicsQueue.Queue ->ExecuteCommandLists (1 , CmdLists);
12491268
12501269 return waitForSignal (IS);
@@ -1254,11 +1273,11 @@ class DXDevice : public offloadtest::Device {
12541273 CD3DX12_GPU_DESCRIPTOR_HANDLE Handle;
12551274 if (IS.DescHeap ) {
12561275 ID3D12DescriptorHeap *const Heaps[] = {IS.DescHeap .Get ()};
1257- IS.CmdList ->SetDescriptorHeaps (1 , Heaps);
1276+ IS.CB -> CmdList ->SetDescriptorHeaps (1 , Heaps);
12581277 Handle = IS.DescHeap ->GetGPUDescriptorHandleForHeapStart ();
12591278 }
1260- IS.CmdList ->SetComputeRootSignature (IS.RootSig .Get ());
1261- IS.CmdList ->SetPipelineState (IS.PSO .Get ());
1279+ IS.CB -> CmdList ->SetComputeRootSignature (IS.RootSig .Get ());
1280+ IS.CB -> CmdList ->SetPipelineState (IS.PSO .Get ());
12621281
12631282 const uint32_t Inc = Device->GetDescriptorHandleIncrementSize (
12641283 D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
@@ -1278,14 +1297,15 @@ class DXDevice : public offloadtest::Device {
12781297 " Root constant cannot refer to resource arrays." );
12791298 const uint32_t NumValues =
12801299 Constant.BufferPtr ->size () / sizeof (uint32_t );
1281- IS.CmdList ->SetComputeRoot32BitConstants (
1300+ IS.CB -> CmdList ->SetComputeRoot32BitConstants (
12821301 RootParamIndex++, NumValues,
12831302 Constant.BufferPtr ->Data .back ().get (), ConstantOffset);
12841303 ConstantOffset += NumValues;
12851304 break ;
12861305 }
12871306 case dx::RootParamKind::DescriptorTable:
1288- IS.CmdList ->SetComputeRootDescriptorTable (RootParamIndex++, Handle);
1307+ IS.CB ->CmdList ->SetComputeRootDescriptorTable (RootParamIndex++,
1308+ Handle);
12891309 Handle.Offset (P.Sets [DescriptorTableIndex++].Resources .size (), Inc);
12901310 break ;
12911311 case dx::RootParamKind::RootDescriptor:
@@ -1296,17 +1316,17 @@ class DXDevice : public offloadtest::Device {
12961316 " Root descriptor cannot refer to resource arrays." );
12971317 switch (getDXKind (RootDescIt->first ->Kind )) {
12981318 case SRV:
1299- IS.CmdList ->SetComputeRootShaderResourceView (
1319+ IS.CB -> CmdList ->SetComputeRootShaderResourceView (
13001320 RootParamIndex++,
13011321 RootDescIt->second .back ().Buffer ->GetGPUVirtualAddress ());
13021322 break ;
13031323 case UAV:
1304- IS.CmdList ->SetComputeRootUnorderedAccessView (
1324+ IS.CB -> CmdList ->SetComputeRootUnorderedAccessView (
13051325 RootParamIndex++,
13061326 RootDescIt->second .back ().Buffer ->GetGPUVirtualAddress ());
13071327 break ;
13081328 case CBV:
1309- IS.CmdList ->SetComputeRootConstantBufferView (
1329+ IS.CB -> CmdList ->SetComputeRootConstantBufferView (
13101330 RootParamIndex++,
13111331 RootDescIt->second .back ().Buffer ->GetGPUVirtualAddress ());
13121332 break ;
@@ -1322,15 +1342,15 @@ class DXDevice : public offloadtest::Device {
13221342 // descriptor set layout. This is to make it easier to write tests that
13231343 // don't need complicated root signatures.
13241344 for (uint32_t Idx = 0u ; Idx < P.Sets .size (); ++Idx) {
1325- IS.CmdList ->SetComputeRootDescriptorTable (Idx, Handle);
1345+ IS.CB -> CmdList ->SetComputeRootDescriptorTable (Idx, Handle);
13261346 Handle.Offset (P.Sets [Idx].Resources .size (), Inc);
13271347 }
13281348 }
13291349
13301350 const llvm::ArrayRef<int > DispatchSize =
13311351 llvm::ArrayRef<int >(P.Shaders [0 ].DispatchSize );
13321352
1333- IS.CmdList ->Dispatch (DispatchSize[0 ], DispatchSize[1 ], DispatchSize[2 ]);
1353+ IS.CB -> CmdList ->Dispatch (DispatchSize[0 ], DispatchSize[1 ], DispatchSize[2 ]);
13341354
13351355 auto CopyBackResource = [&IS, this ](ResourcePair &R) {
13361356 if (R.first ->isTexture ()) {
@@ -1347,7 +1367,7 @@ class DXDevice : public offloadtest::Device {
13471367 const CD3DX12_TEXTURE_COPY_LOCATION DstLoc (RS.Readback .Get (),
13481368 Footprint);
13491369 const CD3DX12_TEXTURE_COPY_LOCATION SrcLoc (RS.Buffer .Get (), 0 );
1350- IS.CmdList ->CopyTextureRegion (&DstLoc, 0 , 0 , 0 , &SrcLoc, nullptr );
1370+ IS.CB -> CmdList ->CopyTextureRegion (&DstLoc, 0 , 0 , 0 , &SrcLoc, nullptr );
13511371 addReadbackEndBarrier (IS, RS.Buffer );
13521372 }
13531373 return ;
@@ -1356,7 +1376,7 @@ class DXDevice : public offloadtest::Device {
13561376 if (RS.Readback == nullptr )
13571377 continue ;
13581378 addReadbackBeginBarrier (IS, RS.Buffer );
1359- IS.CmdList ->CopyResource (RS.Readback .Get (), RS.Buffer .Get ());
1379+ IS.CB -> CmdList ->CopyResource (RS.Readback .Get (), RS.Buffer .Get ());
13601380 addReadbackEndBarrier (IS, RS.Buffer );
13611381 }
13621382 };
@@ -1536,8 +1556,8 @@ class DXDevice : public offloadtest::Device {
15361556 VBView.SizeInBytes = static_cast <UINT>(VBSize);
15371557 VBView.StrideInBytes = P.Bindings .getVertexStride ();
15381558
1539- IS.CmdList ->IASetPrimitiveTopology (D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
1540- IS.CmdList ->IASetVertexBuffers (0 , 1 , &VBView);
1559+ IS.CB -> CmdList ->IASetPrimitiveTopology (D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
1560+ IS.CB -> CmdList ->IASetVertexBuffers (0 , 1 , &VBView);
15411561
15421562 return llvm::Error::success ();
15431563 }
@@ -1615,16 +1635,16 @@ class DXDevice : public offloadtest::Device {
16151635 IS.RTVHeap ->GetCPUDescriptorHandleForHeapStart ();
16161636 Device->CreateRenderTargetView (IS.RT .Get (), nullptr , RTVHandle);
16171637
1618- IS.CmdList ->SetGraphicsRootSignature (IS.RootSig .Get ());
1638+ IS.CB -> CmdList ->SetGraphicsRootSignature (IS.RootSig .Get ());
16191639 if (IS.DescHeap ) {
16201640 ID3D12DescriptorHeap *const Heaps[] = {IS.DescHeap .Get ()};
1621- IS.CmdList ->SetDescriptorHeaps (1 , Heaps);
1622- IS.CmdList ->SetGraphicsRootDescriptorTable (
1641+ IS.CB -> CmdList ->SetDescriptorHeaps (1 , Heaps);
1642+ IS.CB -> CmdList ->SetGraphicsRootDescriptorTable (
16231643 0 , IS.DescHeap ->GetGPUDescriptorHandleForHeapStart ());
16241644 }
1625- IS.CmdList ->SetPipelineState (IS.PSO .Get ());
1645+ IS.CB -> CmdList ->SetPipelineState (IS.PSO .Get ());
16261646
1627- IS.CmdList ->OMSetRenderTargets (1 , &RTVHandle, false , nullptr );
1647+ IS.CB -> CmdList ->OMSetRenderTargets (1 , &RTVHandle, false , nullptr );
16281648
16291649 D3D12_VIEWPORT VP = {};
16301650 VP.Width =
@@ -1635,19 +1655,19 @@ class DXDevice : public offloadtest::Device {
16351655 VP.MaxDepth = 1 .0f ;
16361656 VP.TopLeftX = 0 .0f ;
16371657 VP.TopLeftY = 0 .0f ;
1638- IS.CmdList ->RSSetViewports (1 , &VP);
1658+ IS.CB -> CmdList ->RSSetViewports (1 , &VP);
16391659 const D3D12_RECT Scissor = {0 , 0 , static_cast <LONG>(VP.Width ),
16401660 static_cast <LONG>(VP.Height )};
1641- IS.CmdList ->RSSetScissorRects (1 , &Scissor);
1661+ IS.CB -> CmdList ->RSSetScissorRects (1 , &Scissor);
16421662
1643- IS.CmdList ->DrawInstanced (P.Bindings .getVertexCount (), 1 , 0 , 0 );
1663+ IS.CB -> CmdList ->DrawInstanced (P.Bindings .getVertexCount (), 1 , 0 , 0 );
16441664
16451665 // Transition the render target to copy source and copy to the readback
16461666 // buffer.
16471667 const D3D12_RESOURCE_BARRIER Barrier = CD3DX12_RESOURCE_BARRIER::Transition (
16481668 IS.RT .Get (), D3D12_RESOURCE_STATE_RENDER_TARGET,
16491669 D3D12_RESOURCE_STATE_COPY_SOURCE);
1650- IS.CmdList ->ResourceBarrier (1 , &Barrier);
1670+ IS.CB -> CmdList ->ResourceBarrier (1 , &Barrier);
16511671
16521672 const CPUBuffer &B = *P.Bindings .RTargetBufferPtr ;
16531673 const D3D12_PLACED_SUBRESOURCE_FOOTPRINT Footprint{
@@ -1658,7 +1678,7 @@ class DXDevice : public offloadtest::Device {
16581678 const CD3DX12_TEXTURE_COPY_LOCATION DstLoc (IS.RTReadback .Get (), Footprint);
16591679 const CD3DX12_TEXTURE_COPY_LOCATION SrcLoc (IS.RT .Get (), 0 );
16601680
1661- IS.CmdList ->CopyTextureRegion (&DstLoc, 0 , 0 , 0 , &SrcLoc, nullptr );
1681+ IS.CB -> CmdList ->CopyTextureRegion (&DstLoc, 0 , 0 , 0 , &SrcLoc, nullptr );
16621682
16631683 auto CopyBackResource = [&IS, this ](ResourcePair &R) {
16641684 if (R.first ->isTexture ()) {
@@ -1675,7 +1695,7 @@ class DXDevice : public offloadtest::Device {
16751695 const CD3DX12_TEXTURE_COPY_LOCATION DstLoc (RS.Readback .Get (),
16761696 Footprint);
16771697 const CD3DX12_TEXTURE_COPY_LOCATION SrcLoc (RS.Buffer .Get (), 0 );
1678- IS.CmdList ->CopyTextureRegion (&DstLoc, 0 , 0 , 0 , &SrcLoc, nullptr );
1698+ IS.CB -> CmdList ->CopyTextureRegion (&DstLoc, 0 , 0 , 0 , &SrcLoc, nullptr );
16791699 addReadbackEndBarrier (IS, RS.Buffer );
16801700 }
16811701 return ;
@@ -1684,7 +1704,7 @@ class DXDevice : public offloadtest::Device {
16841704 if (RS.Readback == nullptr )
16851705 continue ;
16861706 addReadbackBeginBarrier (IS, RS.Buffer );
1687- IS.CmdList ->CopyResource (RS.Readback .Get (), RS.Buffer .Get ());
1707+ IS.CB -> CmdList ->CopyResource (RS.Readback .Get (), RS.Buffer .Get ());
16881708 addReadbackEndBarrier (IS, RS.Buffer );
16891709 }
16901710 };
@@ -1735,9 +1755,11 @@ class DXDevice : public offloadtest::Device {
17351755 return Err;
17361756 llvm::outs () << " Descriptor heap created.\n " ;
17371757
1738- if (auto Err = createCommandStructures (State))
1739- return Err;
1740- llvm::outs () << " Command structures created.\n " ;
1758+ auto CBOrErr = DXCommandBuffer::create (Device);
1759+ if (!CBOrErr)
1760+ return CBOrErr.takeError ();
1761+ State.CB = std::move (*CBOrErr);
1762+ llvm::outs () << " Command buffer created.\n " ;
17411763
17421764 auto FenceOrErr = createFence (" Fence" );
17431765 if (!FenceOrErr)
0 commit comments