@@ -389,6 +389,36 @@ class DXQueue : public offloadtest::Queue {
389389 }
390390};
391391
392+ class DXCommandBuffer : public offloadtest ::CommandBuffer {
393+ public:
394+ static constexpr GPUAPI BackendAPI = GPUAPI::DirectX;
395+
396+ ComPtr<ID3D12CommandAllocator> Allocator;
397+ ComPtr<ID3D12GraphicsCommandList> CmdList;
398+
399+ static llvm::Expected<std::unique_ptr<DXCommandBuffer>>
400+ create (ComPtr<ID3D12Device> Device) {
401+ auto CB = std::unique_ptr<DXCommandBuffer>(new DXCommandBuffer ());
402+ if (auto Err = HR::toError (
403+ Device->CreateCommandAllocator (D3D12_COMMAND_LIST_TYPE_DIRECT,
404+ IID_PPV_ARGS (&CB->Allocator )),
405+ " Failed to create command allocator." ))
406+ return Err;
407+ if (auto Err = HR::toError (
408+ Device->CreateCommandList (0 , D3D12_COMMAND_LIST_TYPE_DIRECT,
409+ CB->Allocator .Get (), nullptr ,
410+ IID_PPV_ARGS (&CB->CmdList )),
411+ " Failed to create command list." ))
412+ return Err;
413+ return CB;
414+ }
415+
416+ ~DXCommandBuffer () override = default ;
417+
418+ private:
419+ DXCommandBuffer () : CommandBuffer(GPUAPI::DirectX) {}
420+ };
421+
392422class DXDevice : public offloadtest ::Device {
393423private:
394424 ComPtr<IDXCoreAdapter> Adapter;
@@ -420,8 +450,7 @@ class DXDevice : public offloadtest::Device {
420450 ComPtr<ID3D12RootSignature> RootSig;
421451 ComPtr<ID3D12DescriptorHeap> DescHeap;
422452 ComPtr<ID3D12PipelineState> PSO;
423- ComPtr<ID3D12CommandAllocator> Allocator;
424- ComPtr<ID3D12GraphicsCommandList> CmdList;
453+ std::unique_ptr<DXCommandBuffer> CB;
425454 std::unique_ptr<offloadtest::Fence> Fence;
426455
427456 // Resources for graphics pipelines.
@@ -683,19 +712,9 @@ class DXDevice : public offloadtest::Device {
683712 return llvm::Error::success ();
684713 }
685714
686- llvm::Error createCommandStructures (InvocationState &IS) {
687- if (auto Err = HR::toError (
688- Device->CreateCommandAllocator (D3D12_COMMAND_LIST_TYPE_DIRECT,
689- IID_PPV_ARGS (&IS.Allocator )),
690- " Failed to create command allocator." ))
691- return Err;
692- if (auto Err = HR::toError (
693- Device->CreateCommandList (0 , D3D12_COMMAND_LIST_TYPE_DIRECT,
694- IS.Allocator .Get (), nullptr ,
695- IID_PPV_ARGS (&IS.CmdList )),
696- " Failed to create command list." ))
697- return Err;
698- return llvm::Error::success ();
715+ llvm::Expected<std::unique_ptr<offloadtest::CommandBuffer>>
716+ createCommandBuffer () override {
717+ return DXCommandBuffer::create (Device);
699718 }
700719
701720 void addResourceUploadCommands (Resource &R, InvocationState &IS,
@@ -712,10 +731,10 @@ class DXDevice : public offloadtest::Device {
712731 const CD3DX12_TEXTURE_COPY_LOCATION DstLoc (Destination.Get (), 0 );
713732 const CD3DX12_TEXTURE_COPY_LOCATION SrcLoc (Source.Get (), Footprint);
714733
715- IS.CmdList ->CopyTextureRegion (&DstLoc, 0 , 0 , 0 , &SrcLoc, nullptr );
734+ IS.CB -> CmdList ->CopyTextureRegion (&DstLoc, 0 , 0 , 0 , &SrcLoc, nullptr );
716735 } else
717- IS.CmdList ->CopyBufferRegion (Destination.Get (), 0 , Source.Get (), 0 ,
718- R.size ());
736+ IS.CB -> CmdList ->CopyBufferRegion (Destination.Get (), 0 , Source.Get (), 0 ,
737+ R.size ());
719738 addUploadEndBarrier (IS, Destination, R.isReadWrite ());
720739 }
721740
@@ -1182,7 +1201,7 @@ class DXDevice : public offloadtest::Device {
11821201 {D3D12_RESOURCE_TRANSITION_BARRIER{
11831202 R.Get (), D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES,
11841203 D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_COPY_DEST}}};
1185- IS.CmdList ->ResourceBarrier (1 , &Barrier);
1204+ IS.CB -> CmdList ->ResourceBarrier (1 , &Barrier);
11861205 }
11871206
11881207 void addUploadEndBarrier (InvocationState &IS, ComPtr<ID3D12Resource> R,
@@ -1195,21 +1214,21 @@ class DXDevice : public offloadtest::Device {
11951214 D3D12_RESOURCE_STATE_COPY_DEST,
11961215 IsUAV ? D3D12_RESOURCE_STATE_UNORDERED_ACCESS
11971216 : D3D12_RESOURCE_STATE_GENERIC_READ}}};
1198- IS.CmdList ->ResourceBarrier (1 , &Barrier);
1217+ IS.CB -> CmdList ->ResourceBarrier (1 , &Barrier);
11991218 }
12001219
12011220 void addReadbackBeginBarrier (InvocationState &IS, ComPtr<ID3D12Resource> R) {
12021221 const D3D12_RESOURCE_BARRIER Barrier = CD3DX12_RESOURCE_BARRIER::Transition (
12031222 R.Get (), D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
12041223 D3D12_RESOURCE_STATE_COPY_SOURCE);
1205- IS.CmdList ->ResourceBarrier (1 , &Barrier);
1224+ IS.CB -> CmdList ->ResourceBarrier (1 , &Barrier);
12061225 }
12071226
12081227 void addReadbackEndBarrier (InvocationState &IS, ComPtr<ID3D12Resource> R) {
12091228 const D3D12_RESOURCE_BARRIER Barrier = CD3DX12_RESOURCE_BARRIER::Transition (
12101229 R.Get (), D3D12_RESOURCE_STATE_COPY_SOURCE,
12111230 D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
1212- IS.CmdList ->ResourceBarrier (1 , &Barrier);
1231+ IS.CB -> CmdList ->ResourceBarrier (1 , &Barrier);
12131232 }
12141233
12151234 llvm::Error waitForSignal (InvocationState &IS) {
@@ -1231,11 +1250,11 @@ class DXDevice : public offloadtest::Device {
12311250 }
12321251
12331252 llvm::Error executeCommandList (InvocationState &IS) {
1234- if (auto Err =
1235- HR::toError (IS. CmdList -> Close (), " Failed to close command list." ))
1253+ if (auto Err = HR::toError (IS. CB -> CmdList -> Close (),
1254+ " Failed to close command list." ))
12361255 return Err;
12371256
1238- ID3D12CommandList *const CmdLists[] = {IS.CmdList .Get ()};
1257+ ID3D12CommandList *const CmdLists[] = {IS.CB -> CmdList .Get ()};
12391258 GraphicsQueue.Queue ->ExecuteCommandLists (1 , CmdLists);
12401259
12411260 return waitForSignal (IS);
@@ -1245,11 +1264,11 @@ class DXDevice : public offloadtest::Device {
12451264 CD3DX12_GPU_DESCRIPTOR_HANDLE Handle;
12461265 if (IS.DescHeap ) {
12471266 ID3D12DescriptorHeap *const Heaps[] = {IS.DescHeap .Get ()};
1248- IS.CmdList ->SetDescriptorHeaps (1 , Heaps);
1267+ IS.CB -> CmdList ->SetDescriptorHeaps (1 , Heaps);
12491268 Handle = IS.DescHeap ->GetGPUDescriptorHandleForHeapStart ();
12501269 }
1251- IS.CmdList ->SetComputeRootSignature (IS.RootSig .Get ());
1252- IS.CmdList ->SetPipelineState (IS.PSO .Get ());
1270+ IS.CB -> CmdList ->SetComputeRootSignature (IS.RootSig .Get ());
1271+ IS.CB -> CmdList ->SetPipelineState (IS.PSO .Get ());
12531272
12541273 const uint32_t Inc = Device->GetDescriptorHandleIncrementSize (
12551274 D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
@@ -1269,14 +1288,15 @@ class DXDevice : public offloadtest::Device {
12691288 " Root constant cannot refer to resource arrays." );
12701289 const uint32_t NumValues =
12711290 Constant.BufferPtr ->size () / sizeof (uint32_t );
1272- IS.CmdList ->SetComputeRoot32BitConstants (
1291+ IS.CB -> CmdList ->SetComputeRoot32BitConstants (
12731292 RootParamIndex++, NumValues,
12741293 Constant.BufferPtr ->Data .back ().get (), ConstantOffset);
12751294 ConstantOffset += NumValues;
12761295 break ;
12771296 }
12781297 case dx::RootParamKind::DescriptorTable:
1279- IS.CmdList ->SetComputeRootDescriptorTable (RootParamIndex++, Handle);
1298+ IS.CB ->CmdList ->SetComputeRootDescriptorTable (RootParamIndex++,
1299+ Handle);
12801300 Handle.Offset (P.Sets [DescriptorTableIndex++].Resources .size (), Inc);
12811301 break ;
12821302 case dx::RootParamKind::RootDescriptor:
@@ -1287,17 +1307,17 @@ class DXDevice : public offloadtest::Device {
12871307 " Root descriptor cannot refer to resource arrays." );
12881308 switch (getDXKind (RootDescIt->first ->Kind )) {
12891309 case SRV:
1290- IS.CmdList ->SetComputeRootShaderResourceView (
1310+ IS.CB -> CmdList ->SetComputeRootShaderResourceView (
12911311 RootParamIndex++,
12921312 RootDescIt->second .back ().Buffer ->GetGPUVirtualAddress ());
12931313 break ;
12941314 case UAV:
1295- IS.CmdList ->SetComputeRootUnorderedAccessView (
1315+ IS.CB -> CmdList ->SetComputeRootUnorderedAccessView (
12961316 RootParamIndex++,
12971317 RootDescIt->second .back ().Buffer ->GetGPUVirtualAddress ());
12981318 break ;
12991319 case CBV:
1300- IS.CmdList ->SetComputeRootConstantBufferView (
1320+ IS.CB -> CmdList ->SetComputeRootConstantBufferView (
13011321 RootParamIndex++,
13021322 RootDescIt->second .back ().Buffer ->GetGPUVirtualAddress ());
13031323 break ;
@@ -1313,15 +1333,15 @@ class DXDevice : public offloadtest::Device {
13131333 // descriptor set layout. This is to make it easier to write tests that
13141334 // don't need complicated root signatures.
13151335 for (uint32_t Idx = 0u ; Idx < P.Sets .size (); ++Idx) {
1316- IS.CmdList ->SetComputeRootDescriptorTable (Idx, Handle);
1336+ IS.CB -> CmdList ->SetComputeRootDescriptorTable (Idx, Handle);
13171337 Handle.Offset (P.Sets [Idx].Resources .size (), Inc);
13181338 }
13191339 }
13201340
13211341 const llvm::ArrayRef<int > DispatchSize =
13221342 llvm::ArrayRef<int >(P.Shaders [0 ].DispatchSize );
13231343
1324- IS.CmdList ->Dispatch (DispatchSize[0 ], DispatchSize[1 ], DispatchSize[2 ]);
1344+ IS.CB -> CmdList ->Dispatch (DispatchSize[0 ], DispatchSize[1 ], DispatchSize[2 ]);
13251345
13261346 auto CopyBackResource = [&IS, this ](ResourcePair &R) {
13271347 if (R.first ->isTexture ()) {
@@ -1338,7 +1358,7 @@ class DXDevice : public offloadtest::Device {
13381358 const CD3DX12_TEXTURE_COPY_LOCATION DstLoc (RS.Readback .Get (),
13391359 Footprint);
13401360 const CD3DX12_TEXTURE_COPY_LOCATION SrcLoc (RS.Buffer .Get (), 0 );
1341- IS.CmdList ->CopyTextureRegion (&DstLoc, 0 , 0 , 0 , &SrcLoc, nullptr );
1361+ IS.CB -> CmdList ->CopyTextureRegion (&DstLoc, 0 , 0 , 0 , &SrcLoc, nullptr );
13421362 addReadbackEndBarrier (IS, RS.Buffer );
13431363 }
13441364 return ;
@@ -1347,7 +1367,7 @@ class DXDevice : public offloadtest::Device {
13471367 if (RS.Readback == nullptr )
13481368 continue ;
13491369 addReadbackBeginBarrier (IS, RS.Buffer );
1350- IS.CmdList ->CopyResource (RS.Readback .Get (), RS.Buffer .Get ());
1370+ IS.CB -> CmdList ->CopyResource (RS.Readback .Get (), RS.Buffer .Get ());
13511371 addReadbackEndBarrier (IS, RS.Buffer );
13521372 }
13531373 };
@@ -1527,8 +1547,8 @@ class DXDevice : public offloadtest::Device {
15271547 VBView.SizeInBytes = static_cast <UINT>(VBSize);
15281548 VBView.StrideInBytes = P.Bindings .getVertexStride ();
15291549
1530- IS.CmdList ->IASetPrimitiveTopology (D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
1531- IS.CmdList ->IASetVertexBuffers (0 , 1 , &VBView);
1550+ IS.CB -> CmdList ->IASetPrimitiveTopology (D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
1551+ IS.CB -> CmdList ->IASetVertexBuffers (0 , 1 , &VBView);
15321552
15331553 return llvm::Error::success ();
15341554 }
@@ -1606,16 +1626,16 @@ class DXDevice : public offloadtest::Device {
16061626 IS.RTVHeap ->GetCPUDescriptorHandleForHeapStart ();
16071627 Device->CreateRenderTargetView (IS.RT .Get (), nullptr , RTVHandle);
16081628
1609- IS.CmdList ->SetGraphicsRootSignature (IS.RootSig .Get ());
1629+ IS.CB -> CmdList ->SetGraphicsRootSignature (IS.RootSig .Get ());
16101630 if (IS.DescHeap ) {
16111631 ID3D12DescriptorHeap *const Heaps[] = {IS.DescHeap .Get ()};
1612- IS.CmdList ->SetDescriptorHeaps (1 , Heaps);
1613- IS.CmdList ->SetGraphicsRootDescriptorTable (
1632+ IS.CB -> CmdList ->SetDescriptorHeaps (1 , Heaps);
1633+ IS.CB -> CmdList ->SetGraphicsRootDescriptorTable (
16141634 0 , IS.DescHeap ->GetGPUDescriptorHandleForHeapStart ());
16151635 }
1616- IS.CmdList ->SetPipelineState (IS.PSO .Get ());
1636+ IS.CB -> CmdList ->SetPipelineState (IS.PSO .Get ());
16171637
1618- IS.CmdList ->OMSetRenderTargets (1 , &RTVHandle, false , nullptr );
1638+ IS.CB -> CmdList ->OMSetRenderTargets (1 , &RTVHandle, false , nullptr );
16191639
16201640 D3D12_VIEWPORT VP = {};
16211641 VP.Width =
@@ -1626,19 +1646,19 @@ class DXDevice : public offloadtest::Device {
16261646 VP.MaxDepth = 1 .0f ;
16271647 VP.TopLeftX = 0 .0f ;
16281648 VP.TopLeftY = 0 .0f ;
1629- IS.CmdList ->RSSetViewports (1 , &VP);
1649+ IS.CB -> CmdList ->RSSetViewports (1 , &VP);
16301650 const D3D12_RECT Scissor = {0 , 0 , static_cast <LONG>(VP.Width ),
16311651 static_cast <LONG>(VP.Height )};
1632- IS.CmdList ->RSSetScissorRects (1 , &Scissor);
1652+ IS.CB -> CmdList ->RSSetScissorRects (1 , &Scissor);
16331653
1634- IS.CmdList ->DrawInstanced (P.Bindings .getVertexCount (), 1 , 0 , 0 );
1654+ IS.CB -> CmdList ->DrawInstanced (P.Bindings .getVertexCount (), 1 , 0 , 0 );
16351655
16361656 // Transition the render target to copy source and copy to the readback
16371657 // buffer.
16381658 const D3D12_RESOURCE_BARRIER Barrier = CD3DX12_RESOURCE_BARRIER::Transition (
16391659 IS.RT .Get (), D3D12_RESOURCE_STATE_RENDER_TARGET,
16401660 D3D12_RESOURCE_STATE_COPY_SOURCE);
1641- IS.CmdList ->ResourceBarrier (1 , &Barrier);
1661+ IS.CB -> CmdList ->ResourceBarrier (1 , &Barrier);
16421662
16431663 const CPUBuffer &B = *P.Bindings .RTargetBufferPtr ;
16441664 const D3D12_PLACED_SUBRESOURCE_FOOTPRINT Footprint{
@@ -1649,7 +1669,7 @@ class DXDevice : public offloadtest::Device {
16491669 const CD3DX12_TEXTURE_COPY_LOCATION DstLoc (IS.RTReadback .Get (), Footprint);
16501670 const CD3DX12_TEXTURE_COPY_LOCATION SrcLoc (IS.RT .Get (), 0 );
16511671
1652- IS.CmdList ->CopyTextureRegion (&DstLoc, 0 , 0 , 0 , &SrcLoc, nullptr );
1672+ IS.CB -> CmdList ->CopyTextureRegion (&DstLoc, 0 , 0 , 0 , &SrcLoc, nullptr );
16531673
16541674 auto CopyBackResource = [&IS, this ](ResourcePair &R) {
16551675 if (R.first ->isTexture ()) {
@@ -1666,7 +1686,7 @@ class DXDevice : public offloadtest::Device {
16661686 const CD3DX12_TEXTURE_COPY_LOCATION DstLoc (RS.Readback .Get (),
16671687 Footprint);
16681688 const CD3DX12_TEXTURE_COPY_LOCATION SrcLoc (RS.Buffer .Get (), 0 );
1669- IS.CmdList ->CopyTextureRegion (&DstLoc, 0 , 0 , 0 , &SrcLoc, nullptr );
1689+ IS.CB -> CmdList ->CopyTextureRegion (&DstLoc, 0 , 0 , 0 , &SrcLoc, nullptr );
16701690 addReadbackEndBarrier (IS, RS.Buffer );
16711691 }
16721692 return ;
@@ -1675,7 +1695,7 @@ class DXDevice : public offloadtest::Device {
16751695 if (RS.Readback == nullptr )
16761696 continue ;
16771697 addReadbackBeginBarrier (IS, RS.Buffer );
1678- IS.CmdList ->CopyResource (RS.Readback .Get (), RS.Buffer .Get ());
1698+ IS.CB -> CmdList ->CopyResource (RS.Readback .Get (), RS.Buffer .Get ());
16791699 addReadbackEndBarrier (IS, RS.Buffer );
16801700 }
16811701 };
@@ -1726,9 +1746,11 @@ class DXDevice : public offloadtest::Device {
17261746 return Err;
17271747 llvm::outs () << " Descriptor heap created.\n " ;
17281748
1729- if (auto Err = createCommandStructures (State))
1730- return Err;
1731- llvm::outs () << " Command structures created.\n " ;
1749+ auto CBOrErr = DXCommandBuffer::create (Device);
1750+ if (!CBOrErr)
1751+ return CBOrErr.takeError ();
1752+ State.CB = std::move (*CBOrErr);
1753+ llvm::outs () << " Command buffer created.\n " ;
17321754
17331755 auto FenceOrErr = createFence (" Fence" );
17341756 if (!FenceOrErr)
0 commit comments