Skip to content

Commit 9a1bfe5

Browse files
MarijnS95claude
andcommitted
Add CommandBuffer abstraction with backend-specific downcasting
Command buffer creation and management was previously spread across each backend's executeProgram() with no shared interface, making it impossible to manage command buffers from backend-agnostic code. This introduces a CommandBuffer base class on Device so that higher-level code can create and pass around command buffers without knowing the backend. Per-object allocator/pool ownership also prepares for future async execution where multiple command buffers may be in-flight with independent lifetimes. - DX: DXCommandBuffer owns Allocator, CmdList, Fence, Event - VK: VKCommandBuffer owns CmdPool, CmdBuffer; each submission creates a new CommandBuffer for independent lifetime management - MTL: MTLCommandBuffer wraps MTL::CommandBuffer Device::createCommandBuffer() returns Expected<unique_ptr<CommandBuffer>> with a default "not supported" implementation. Co-Authored-By: Claude Opus 4.6 <[email protected]>
1 parent 1522442 commit 9a1bfe5

5 files changed

Lines changed: 266 additions & 137 deletions

File tree

include/API/CommandBuffer.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//===- CommandBuffer.h - Offload Command Buffer API -----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
//
10+
//===----------------------------------------------------------------------===//
11+
12+
#ifndef OFFLOADTEST_API_COMMANDBUFFER_H
13+
#define OFFLOADTEST_API_COMMANDBUFFER_H
14+
15+
#include "API/API.h"
16+
17+
#include <cassert>
18+
19+
namespace offloadtest {
20+
21+
class CommandBuffer {
22+
GPUAPI API;
23+
24+
public:
25+
explicit CommandBuffer(GPUAPI API) : API(API) {}
26+
virtual ~CommandBuffer() = default;
27+
CommandBuffer(const CommandBuffer &) = delete;
28+
CommandBuffer &operator=(const CommandBuffer &) = delete;
29+
30+
GPUAPI getAPI() const { return API; }
31+
32+
template <typename T> T &as() {
33+
assert(API == T::BackendAPI && "CommandBuffer backend mismatch");
34+
return static_cast<T &>(*this);
35+
}
36+
template <typename T> const T &as() const {
37+
assert(API == T::BackendAPI && "CommandBuffer backend mismatch");
38+
return static_cast<const T &>(*this);
39+
}
40+
};
41+
42+
} // namespace offloadtest
43+
44+
#endif // OFFLOADTEST_API_COMMANDBUFFER_H

include/API/Device.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
#include "API/API.h"
1818
#include "API/Capabilities.h"
19+
#include "API/CommandBuffer.h"
1920
#include "llvm/ADT/StringRef.h"
2021
#include "llvm/ADT/iterator_range.h"
22+
#include "llvm/Support/Error.h"
2123

2224
#include <memory>
2325
#include <string>
@@ -99,6 +101,12 @@ class Device {
99101
size_t SizeInBytes) = 0;
100102
virtual void printExtra(llvm::raw_ostream &OS) {}
101103

104+
virtual llvm::Expected<std::unique_ptr<CommandBuffer>> createCommandBuffer() {
105+
return llvm::createStringError(
106+
std::errc::not_supported,
107+
"createCommandBuffer not implemented for this backend");
108+
}
109+
102110
virtual ~Device() = 0;
103111

104112
llvm::StringRef getDescription() const { return Description; }

lib/API/DX/Device.cpp

Lines changed: 76 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,36 @@ class DXQueue : public offloadtest::Queue {
389389
}
390390
};
391391

392+
class DXCommandBuffer : public offloadtest::CommandBuffer {
393+
public:
394+
static constexpr GPUAPI BackendAPI = GPUAPI::DirectX;
395+
396+
ComPtr<ID3D12CommandAllocator> Allocator;
397+
ComPtr<ID3D12GraphicsCommandList> CmdList;
398+
399+
static llvm::Expected<std::unique_ptr<DXCommandBuffer>>
400+
create(ComPtr<ID3D12Device> Device) {
401+
auto CB = std::unique_ptr<DXCommandBuffer>(new DXCommandBuffer());
402+
if (auto Err = HR::toError(
403+
Device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT,
404+
IID_PPV_ARGS(&CB->Allocator)),
405+
"Failed to create command allocator."))
406+
return Err;
407+
if (auto Err = HR::toError(
408+
Device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT,
409+
CB->Allocator.Get(), nullptr,
410+
IID_PPV_ARGS(&CB->CmdList)),
411+
"Failed to create command list."))
412+
return Err;
413+
return CB;
414+
}
415+
416+
~DXCommandBuffer() override = default;
417+
418+
private:
419+
DXCommandBuffer() : CommandBuffer(GPUAPI::DirectX) {}
420+
};
421+
392422
class DXDevice : public offloadtest::Device {
393423
private:
394424
ComPtr<IDXCoreAdapter> Adapter;
@@ -420,8 +450,7 @@ class DXDevice : public offloadtest::Device {
420450
ComPtr<ID3D12RootSignature> RootSig;
421451
ComPtr<ID3D12DescriptorHeap> DescHeap;
422452
ComPtr<ID3D12PipelineState> PSO;
423-
ComPtr<ID3D12CommandAllocator> Allocator;
424-
ComPtr<ID3D12GraphicsCommandList> CmdList;
453+
std::unique_ptr<DXCommandBuffer> CB;
425454
std::unique_ptr<offloadtest::Fence> Fence;
426455

427456
// Resources for graphics pipelines.
@@ -683,19 +712,9 @@ class DXDevice : public offloadtest::Device {
683712
return llvm::Error::success();
684713
}
685714

686-
llvm::Error createCommandStructures(InvocationState &IS) {
687-
if (auto Err = HR::toError(
688-
Device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT,
689-
IID_PPV_ARGS(&IS.Allocator)),
690-
"Failed to create command allocator."))
691-
return Err;
692-
if (auto Err = HR::toError(
693-
Device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT,
694-
IS.Allocator.Get(), nullptr,
695-
IID_PPV_ARGS(&IS.CmdList)),
696-
"Failed to create command list."))
697-
return Err;
698-
return llvm::Error::success();
715+
llvm::Expected<std::unique_ptr<offloadtest::CommandBuffer>>
716+
createCommandBuffer() override {
717+
return DXCommandBuffer::create(Device);
699718
}
700719

701720
void addResourceUploadCommands(Resource &R, InvocationState &IS,
@@ -712,10 +731,10 @@ class DXDevice : public offloadtest::Device {
712731
const CD3DX12_TEXTURE_COPY_LOCATION DstLoc(Destination.Get(), 0);
713732
const CD3DX12_TEXTURE_COPY_LOCATION SrcLoc(Source.Get(), Footprint);
714733

715-
IS.CmdList->CopyTextureRegion(&DstLoc, 0, 0, 0, &SrcLoc, nullptr);
734+
IS.CB->CmdList->CopyTextureRegion(&DstLoc, 0, 0, 0, &SrcLoc, nullptr);
716735
} else
717-
IS.CmdList->CopyBufferRegion(Destination.Get(), 0, Source.Get(), 0,
718-
R.size());
736+
IS.CB->CmdList->CopyBufferRegion(Destination.Get(), 0, Source.Get(), 0,
737+
R.size());
719738
addUploadEndBarrier(IS, Destination, R.isReadWrite());
720739
}
721740

@@ -1182,7 +1201,7 @@ class DXDevice : public offloadtest::Device {
11821201
{D3D12_RESOURCE_TRANSITION_BARRIER{
11831202
R.Get(), D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES,
11841203
D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_COPY_DEST}}};
1185-
IS.CmdList->ResourceBarrier(1, &Barrier);
1204+
IS.CB->CmdList->ResourceBarrier(1, &Barrier);
11861205
}
11871206

11881207
void addUploadEndBarrier(InvocationState &IS, ComPtr<ID3D12Resource> R,
@@ -1195,21 +1214,21 @@ class DXDevice : public offloadtest::Device {
11951214
D3D12_RESOURCE_STATE_COPY_DEST,
11961215
IsUAV ? D3D12_RESOURCE_STATE_UNORDERED_ACCESS
11971216
: D3D12_RESOURCE_STATE_GENERIC_READ}}};
1198-
IS.CmdList->ResourceBarrier(1, &Barrier);
1217+
IS.CB->CmdList->ResourceBarrier(1, &Barrier);
11991218
}
12001219

12011220
void addReadbackBeginBarrier(InvocationState &IS, ComPtr<ID3D12Resource> R) {
12021221
const D3D12_RESOURCE_BARRIER Barrier = CD3DX12_RESOURCE_BARRIER::Transition(
12031222
R.Get(), D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
12041223
D3D12_RESOURCE_STATE_COPY_SOURCE);
1205-
IS.CmdList->ResourceBarrier(1, &Barrier);
1224+
IS.CB->CmdList->ResourceBarrier(1, &Barrier);
12061225
}
12071226

12081227
void addReadbackEndBarrier(InvocationState &IS, ComPtr<ID3D12Resource> R) {
12091228
const D3D12_RESOURCE_BARRIER Barrier = CD3DX12_RESOURCE_BARRIER::Transition(
12101229
R.Get(), D3D12_RESOURCE_STATE_COPY_SOURCE,
12111230
D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
1212-
IS.CmdList->ResourceBarrier(1, &Barrier);
1231+
IS.CB->CmdList->ResourceBarrier(1, &Barrier);
12131232
}
12141233

12151234
llvm::Error waitForSignal(InvocationState &IS) {
@@ -1231,11 +1250,11 @@ class DXDevice : public offloadtest::Device {
12311250
}
12321251

12331252
llvm::Error executeCommandList(InvocationState &IS) {
1234-
if (auto Err =
1235-
HR::toError(IS.CmdList->Close(), "Failed to close command list."))
1253+
if (auto Err = HR::toError(IS.CB->CmdList->Close(),
1254+
"Failed to close command list."))
12361255
return Err;
12371256

1238-
ID3D12CommandList *const CmdLists[] = {IS.CmdList.Get()};
1257+
ID3D12CommandList *const CmdLists[] = {IS.CB->CmdList.Get()};
12391258
GraphicsQueue.Queue->ExecuteCommandLists(1, CmdLists);
12401259

12411260
return waitForSignal(IS);
@@ -1245,11 +1264,11 @@ class DXDevice : public offloadtest::Device {
12451264
CD3DX12_GPU_DESCRIPTOR_HANDLE Handle;
12461265
if (IS.DescHeap) {
12471266
ID3D12DescriptorHeap *const Heaps[] = {IS.DescHeap.Get()};
1248-
IS.CmdList->SetDescriptorHeaps(1, Heaps);
1267+
IS.CB->CmdList->SetDescriptorHeaps(1, Heaps);
12491268
Handle = IS.DescHeap->GetGPUDescriptorHandleForHeapStart();
12501269
}
1251-
IS.CmdList->SetComputeRootSignature(IS.RootSig.Get());
1252-
IS.CmdList->SetPipelineState(IS.PSO.Get());
1270+
IS.CB->CmdList->SetComputeRootSignature(IS.RootSig.Get());
1271+
IS.CB->CmdList->SetPipelineState(IS.PSO.Get());
12531272

12541273
const uint32_t Inc = Device->GetDescriptorHandleIncrementSize(
12551274
D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
@@ -1269,14 +1288,15 @@ class DXDevice : public offloadtest::Device {
12691288
"Root constant cannot refer to resource arrays.");
12701289
const uint32_t NumValues =
12711290
Constant.BufferPtr->size() / sizeof(uint32_t);
1272-
IS.CmdList->SetComputeRoot32BitConstants(
1291+
IS.CB->CmdList->SetComputeRoot32BitConstants(
12731292
RootParamIndex++, NumValues,
12741293
Constant.BufferPtr->Data.back().get(), ConstantOffset);
12751294
ConstantOffset += NumValues;
12761295
break;
12771296
}
12781297
case dx::RootParamKind::DescriptorTable:
1279-
IS.CmdList->SetComputeRootDescriptorTable(RootParamIndex++, Handle);
1298+
IS.CB->CmdList->SetComputeRootDescriptorTable(RootParamIndex++,
1299+
Handle);
12801300
Handle.Offset(P.Sets[DescriptorTableIndex++].Resources.size(), Inc);
12811301
break;
12821302
case dx::RootParamKind::RootDescriptor:
@@ -1287,17 +1307,17 @@ class DXDevice : public offloadtest::Device {
12871307
"Root descriptor cannot refer to resource arrays.");
12881308
switch (getDXKind(RootDescIt->first->Kind)) {
12891309
case SRV:
1290-
IS.CmdList->SetComputeRootShaderResourceView(
1310+
IS.CB->CmdList->SetComputeRootShaderResourceView(
12911311
RootParamIndex++,
12921312
RootDescIt->second.back().Buffer->GetGPUVirtualAddress());
12931313
break;
12941314
case UAV:
1295-
IS.CmdList->SetComputeRootUnorderedAccessView(
1315+
IS.CB->CmdList->SetComputeRootUnorderedAccessView(
12961316
RootParamIndex++,
12971317
RootDescIt->second.back().Buffer->GetGPUVirtualAddress());
12981318
break;
12991319
case CBV:
1300-
IS.CmdList->SetComputeRootConstantBufferView(
1320+
IS.CB->CmdList->SetComputeRootConstantBufferView(
13011321
RootParamIndex++,
13021322
RootDescIt->second.back().Buffer->GetGPUVirtualAddress());
13031323
break;
@@ -1313,15 +1333,15 @@ class DXDevice : public offloadtest::Device {
13131333
// descriptor set layout. This is to make it easier to write tests that
13141334
// don't need complicated root signatures.
13151335
for (uint32_t Idx = 0u; Idx < P.Sets.size(); ++Idx) {
1316-
IS.CmdList->SetComputeRootDescriptorTable(Idx, Handle);
1336+
IS.CB->CmdList->SetComputeRootDescriptorTable(Idx, Handle);
13171337
Handle.Offset(P.Sets[Idx].Resources.size(), Inc);
13181338
}
13191339
}
13201340

13211341
const llvm::ArrayRef<int> DispatchSize =
13221342
llvm::ArrayRef<int>(P.Shaders[0].DispatchSize);
13231343

1324-
IS.CmdList->Dispatch(DispatchSize[0], DispatchSize[1], DispatchSize[2]);
1344+
IS.CB->CmdList->Dispatch(DispatchSize[0], DispatchSize[1], DispatchSize[2]);
13251345

13261346
auto CopyBackResource = [&IS, this](ResourcePair &R) {
13271347
if (R.first->isTexture()) {
@@ -1338,7 +1358,7 @@ class DXDevice : public offloadtest::Device {
13381358
const CD3DX12_TEXTURE_COPY_LOCATION DstLoc(RS.Readback.Get(),
13391359
Footprint);
13401360
const CD3DX12_TEXTURE_COPY_LOCATION SrcLoc(RS.Buffer.Get(), 0);
1341-
IS.CmdList->CopyTextureRegion(&DstLoc, 0, 0, 0, &SrcLoc, nullptr);
1361+
IS.CB->CmdList->CopyTextureRegion(&DstLoc, 0, 0, 0, &SrcLoc, nullptr);
13421362
addReadbackEndBarrier(IS, RS.Buffer);
13431363
}
13441364
return;
@@ -1347,7 +1367,7 @@ class DXDevice : public offloadtest::Device {
13471367
if (RS.Readback == nullptr)
13481368
continue;
13491369
addReadbackBeginBarrier(IS, RS.Buffer);
1350-
IS.CmdList->CopyResource(RS.Readback.Get(), RS.Buffer.Get());
1370+
IS.CB->CmdList->CopyResource(RS.Readback.Get(), RS.Buffer.Get());
13511371
addReadbackEndBarrier(IS, RS.Buffer);
13521372
}
13531373
};
@@ -1527,8 +1547,8 @@ class DXDevice : public offloadtest::Device {
15271547
VBView.SizeInBytes = static_cast<UINT>(VBSize);
15281548
VBView.StrideInBytes = P.Bindings.getVertexStride();
15291549

1530-
IS.CmdList->IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
1531-
IS.CmdList->IASetVertexBuffers(0, 1, &VBView);
1550+
IS.CB->CmdList->IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
1551+
IS.CB->CmdList->IASetVertexBuffers(0, 1, &VBView);
15321552

15331553
return llvm::Error::success();
15341554
}
@@ -1606,16 +1626,16 @@ class DXDevice : public offloadtest::Device {
16061626
IS.RTVHeap->GetCPUDescriptorHandleForHeapStart();
16071627
Device->CreateRenderTargetView(IS.RT.Get(), nullptr, RTVHandle);
16081628

1609-
IS.CmdList->SetGraphicsRootSignature(IS.RootSig.Get());
1629+
IS.CB->CmdList->SetGraphicsRootSignature(IS.RootSig.Get());
16101630
if (IS.DescHeap) {
16111631
ID3D12DescriptorHeap *const Heaps[] = {IS.DescHeap.Get()};
1612-
IS.CmdList->SetDescriptorHeaps(1, Heaps);
1613-
IS.CmdList->SetGraphicsRootDescriptorTable(
1632+
IS.CB->CmdList->SetDescriptorHeaps(1, Heaps);
1633+
IS.CB->CmdList->SetGraphicsRootDescriptorTable(
16141634
0, IS.DescHeap->GetGPUDescriptorHandleForHeapStart());
16151635
}
1616-
IS.CmdList->SetPipelineState(IS.PSO.Get());
1636+
IS.CB->CmdList->SetPipelineState(IS.PSO.Get());
16171637

1618-
IS.CmdList->OMSetRenderTargets(1, &RTVHandle, false, nullptr);
1638+
IS.CB->CmdList->OMSetRenderTargets(1, &RTVHandle, false, nullptr);
16191639

16201640
D3D12_VIEWPORT VP = {};
16211641
VP.Width =
@@ -1626,19 +1646,19 @@ class DXDevice : public offloadtest::Device {
16261646
VP.MaxDepth = 1.0f;
16271647
VP.TopLeftX = 0.0f;
16281648
VP.TopLeftY = 0.0f;
1629-
IS.CmdList->RSSetViewports(1, &VP);
1649+
IS.CB->CmdList->RSSetViewports(1, &VP);
16301650
const D3D12_RECT Scissor = {0, 0, static_cast<LONG>(VP.Width),
16311651
static_cast<LONG>(VP.Height)};
1632-
IS.CmdList->RSSetScissorRects(1, &Scissor);
1652+
IS.CB->CmdList->RSSetScissorRects(1, &Scissor);
16331653

1634-
IS.CmdList->DrawInstanced(P.Bindings.getVertexCount(), 1, 0, 0);
1654+
IS.CB->CmdList->DrawInstanced(P.Bindings.getVertexCount(), 1, 0, 0);
16351655

16361656
// Transition the render target to copy source and copy to the readback
16371657
// buffer.
16381658
const D3D12_RESOURCE_BARRIER Barrier = CD3DX12_RESOURCE_BARRIER::Transition(
16391659
IS.RT.Get(), D3D12_RESOURCE_STATE_RENDER_TARGET,
16401660
D3D12_RESOURCE_STATE_COPY_SOURCE);
1641-
IS.CmdList->ResourceBarrier(1, &Barrier);
1661+
IS.CB->CmdList->ResourceBarrier(1, &Barrier);
16421662

16431663
const CPUBuffer &B = *P.Bindings.RTargetBufferPtr;
16441664
const D3D12_PLACED_SUBRESOURCE_FOOTPRINT Footprint{
@@ -1649,7 +1669,7 @@ class DXDevice : public offloadtest::Device {
16491669
const CD3DX12_TEXTURE_COPY_LOCATION DstLoc(IS.RTReadback.Get(), Footprint);
16501670
const CD3DX12_TEXTURE_COPY_LOCATION SrcLoc(IS.RT.Get(), 0);
16511671

1652-
IS.CmdList->CopyTextureRegion(&DstLoc, 0, 0, 0, &SrcLoc, nullptr);
1672+
IS.CB->CmdList->CopyTextureRegion(&DstLoc, 0, 0, 0, &SrcLoc, nullptr);
16531673

16541674
auto CopyBackResource = [&IS, this](ResourcePair &R) {
16551675
if (R.first->isTexture()) {
@@ -1666,7 +1686,7 @@ class DXDevice : public offloadtest::Device {
16661686
const CD3DX12_TEXTURE_COPY_LOCATION DstLoc(RS.Readback.Get(),
16671687
Footprint);
16681688
const CD3DX12_TEXTURE_COPY_LOCATION SrcLoc(RS.Buffer.Get(), 0);
1669-
IS.CmdList->CopyTextureRegion(&DstLoc, 0, 0, 0, &SrcLoc, nullptr);
1689+
IS.CB->CmdList->CopyTextureRegion(&DstLoc, 0, 0, 0, &SrcLoc, nullptr);
16701690
addReadbackEndBarrier(IS, RS.Buffer);
16711691
}
16721692
return;
@@ -1675,7 +1695,7 @@ class DXDevice : public offloadtest::Device {
16751695
if (RS.Readback == nullptr)
16761696
continue;
16771697
addReadbackBeginBarrier(IS, RS.Buffer);
1678-
IS.CmdList->CopyResource(RS.Readback.Get(), RS.Buffer.Get());
1698+
IS.CB->CmdList->CopyResource(RS.Readback.Get(), RS.Buffer.Get());
16791699
addReadbackEndBarrier(IS, RS.Buffer);
16801700
}
16811701
};
@@ -1726,9 +1746,11 @@ class DXDevice : public offloadtest::Device {
17261746
return Err;
17271747
llvm::outs() << "Descriptor heap created.\n";
17281748

1729-
if (auto Err = createCommandStructures(State))
1730-
return Err;
1731-
llvm::outs() << "Command structures created.\n";
1749+
auto CBOrErr = DXCommandBuffer::create(Device);
1750+
if (!CBOrErr)
1751+
return CBOrErr.takeError();
1752+
State.CB = std::move(*CBOrErr);
1753+
llvm::outs() << "Command buffer created.\n";
17321754

17331755
auto FenceOrErr = createFence("Fence");
17341756
if (!FenceOrErr)

0 commit comments

Comments
 (0)