Skip to content

Commit 04b6b8f

Browse files
committed
Add RAII safeguard: ComputeEncoder destructors call endEncoding()
Makes CommandEncoder::endEncoding() idempotent (split into a non-virtual wrapper plus a protected virtual endEncodingImpl()), then has each backend's ComputeEncoder destructor call endEncoding() so an encoder destroyed without an explicit end still flushes pending barriers and pops its debug group.
1 parent 014151a commit 04b6b8f

4 files changed

Lines changed: 22 additions & 8 deletions

File tree

include/API/Encoder.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ enum class EncoderMode {
3535
class CommandEncoder {
3636
GPUAPI API;
3737
EncoderMode Mode;
38+
bool Ended = false;
39+
40+
protected:
41+
/// Backend-specific cleanup. Called exactly once, either explicitly via
42+
/// endEncoding() or implicitly from the most-derived destructor.
43+
virtual void endEncodingImpl() = 0;
3844

3945
public:
4046
CommandEncoder(GPUAPI API, EncoderMode Mode) : API(API), Mode(Mode) {}
@@ -45,6 +51,7 @@ class CommandEncoder {
4551
GPUAPI getAPI() const { return API; }
4652
EncoderMode getMode() const { return Mode; }
4753
bool isSerial() const { return Mode == EncoderMode::Serial; }
54+
bool isEnded() const { return Ended; }
4855

4956
/// Copy \p Size bytes from \p Src at \p SrcOffset to \p Dst at
5057
/// \p DstOffset.
@@ -68,7 +75,15 @@ class CommandEncoder {
6875
virtual void insertDebugSignpost(llvm::StringRef Label) {}
6976

7077
/// Finish recording. No further commands may be recorded after this call.
71-
virtual void endEncoding() = 0;
78+
/// Idempotent: safe to call more than once. If not called explicitly, the
79+
/// most-derived destructor invokes it as a safeguard against leaked open
80+
/// encoders.
81+
void endEncoding() {
82+
if (Ended)
83+
return;
84+
endEncodingImpl();
85+
Ended = true;
86+
}
7287
};
7388

7489
/// Encoder for recording compute dispatch commands.

lib/API/DX/Device.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ class DXComputeEncoder : public offloadtest::ComputeEncoder {
551551
DXComputeEncoder(DXCommandBuffer &CB, EncoderMode Mode)
552552
: ComputeEncoder(GPUAPI::DirectX, Mode), CB(CB) {}
553553

554-
~DXComputeEncoder() override = default;
554+
~DXComputeEncoder() override { endEncoding(); }
555555

556556
void pushDebugGroup(llvm::StringRef Label) override {
557557
CB.CmdList->BeginEvent(0, Label.data(), Label.size() + 1);
@@ -617,7 +617,7 @@ class DXComputeEncoder : public offloadtest::ComputeEncoder {
617617
CB.flushBarrier();
618618
}
619619

620-
void endEncoding() override { popDebugGroup(); }
620+
void endEncodingImpl() override { popDebugGroup(); }
621621
};
622622

623623
llvm::Expected<std::unique_ptr<offloadtest::ComputeEncoder>>

lib/API/MTL/MTLDevice.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#include "MTLResources.h"
1515
#include "Support/Pipeline.h"
1616

17-
#include "llvm/ADT/ScopeExit.h"
1817
#include "llvm/ADT/SmallString.h"
1918
#include "llvm/Support/Error.h"
2019
#include "llvm/Support/FormatVariadic.h"
@@ -264,7 +263,7 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
264263
: ComputeEncoder(GPUAPI::Metal, Mode), CmdBuffer(CmdBuffer),
265264
ComputeEnc(Encoder) {}
266265

267-
~MTLComputeEncoder() override = default;
266+
~MTLComputeEncoder() override { endEncoding(); }
268267

269268
MTL::ComputeCommandEncoder *getNative() const { return ComputeEnc; }
270269

@@ -361,7 +360,7 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
361360
}
362361
}
363362

364-
void endEncoding() override {
363+
void endEncodingImpl() override {
365364
if (ComputeEnc) {
366365
barrier();
367366
ComputeEnc->popDebugGroup();

lib/API/VK/Device.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ class VKComputeEncoder : public offloadtest::ComputeEncoder {
653653
VKComputeEncoder(VulkanCommandBuffer &CB, EncoderMode Mode)
654654
: ComputeEncoder(GPUAPI::Vulkan, Mode), CB(CB) {}
655655

656-
~VKComputeEncoder() override = default;
656+
~VKComputeEncoder() override { endEncoding(); }
657657

658658
void pushDebugGroup(llvm::StringRef Label) override {
659659
if (!CB.CmdBeginDebugUtilsLabel)
@@ -745,7 +745,7 @@ class VKComputeEncoder : public offloadtest::ComputeEncoder {
745745
VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT);
746746
}
747747

748-
void endEncoding() override { popDebugGroup(); }
748+
void endEncodingImpl() override { popDebugGroup(); }
749749
};
750750

751751
llvm::Expected<std::unique_ptr<offloadtest::ComputeEncoder>>

0 commit comments

Comments
 (0)