Skip to content

Commit 9f45391

Browse files
committed
Add RAII safeguard: ComputeEncoder destructors call endEncoding()
Makes CommandEncoder::endEncoding() idempotent (split into a non-virtual wrapper plus a protected virtual endEncodingImpl()), then has each backend's ComputeEncoder destructor call endEncoding() so an encoder destroyed without an explicit end still flushes pending barriers and pops its debug group.
1 parent 54a7884 commit 9f45391

4 files changed

Lines changed: 22 additions & 8 deletions

File tree

include/API/Encoder.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ enum class EncoderMode {
3636
class CommandEncoder {
3737
GPUAPI API;
3838
EncoderMode Mode;
39+
bool Ended = false;
40+
41+
protected:
42+
/// Backend-specific cleanup. Called exactly once, either explicitly via
43+
/// endEncoding() or implicitly from the most-derived destructor.
44+
virtual void endEncodingImpl() = 0;
3945

4046
public:
4147
CommandEncoder(GPUAPI API, EncoderMode Mode) : API(API), Mode(Mode) {}
@@ -46,6 +52,7 @@ class CommandEncoder {
4652
GPUAPI getAPI() const { return API; }
4753
EncoderMode getMode() const { return Mode; }
4854
bool isSerial() const { return Mode == EncoderMode::Serial; }
55+
bool isEnded() const { return Ended; }
4956

5057
/// Copy \p Size bytes from \p Src at \p SrcOffset to \p Dst at
5158
/// \p DstOffset.
@@ -69,7 +76,15 @@ class CommandEncoder {
6976
virtual void insertDebugSignpost(llvm::StringRef Label) {}
7077

7178
/// Finish recording. No further commands may be recorded after this call.
72-
virtual void endEncoding() = 0;
79+
/// Idempotent: safe to call more than once. If not called explicitly, the
80+
/// most-derived destructor invokes it as a safeguard against leaked open
81+
/// encoders.
82+
void endEncoding() {
83+
if (Ended)
84+
return;
85+
endEncodingImpl();
86+
Ended = true;
87+
}
7388
};
7489

7590
/// Encoder for recording compute dispatch commands.

lib/API/DX/Device.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ class DXComputeEncoder : public offloadtest::ComputeEncoder {
550550
DXComputeEncoder(DXCommandBuffer &CB, EncoderMode Mode)
551551
: ComputeEncoder(GPUAPI::DirectX, Mode), CB(CB) {}
552552

553-
~DXComputeEncoder() override = default;
553+
~DXComputeEncoder() override { endEncoding(); }
554554

555555
void pushDebugGroup(llvm::StringRef Label) override {
556556
CB.CmdList->BeginEvent(0, Label.data(), Label.size() + 1);
@@ -616,7 +616,7 @@ class DXComputeEncoder : public offloadtest::ComputeEncoder {
616616
CB.flushBarrier();
617617
}
618618

619-
void endEncoding() override { popDebugGroup(); }
619+
void endEncodingImpl() override { popDebugGroup(); }
620620
};
621621

622622
llvm::Expected<std::unique_ptr<offloadtest::ComputeEncoder>>

lib/API/MTL/MTLDevice.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#include "MTLResources.h"
1515
#include "Support/Pipeline.h"
1616

17-
#include "llvm/ADT/ScopeExit.h"
1817
#include "llvm/ADT/SmallString.h"
1918
#include "llvm/Support/Error.h"
2019
#include "llvm/Support/FormatVariadic.h"
@@ -264,7 +263,7 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
264263
: ComputeEncoder(GPUAPI::Metal, Mode), CmdBuffer(CmdBuffer),
265264
ComputeEnc(Encoder) {}
266265

267-
~MTLComputeEncoder() override = default;
266+
~MTLComputeEncoder() override { endEncoding(); }
268267

269268
MTL::ComputeCommandEncoder *getNative() const { return ComputeEnc; }
270269

@@ -361,7 +360,7 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
361360
}
362361
}
363362

364-
void endEncoding() override {
363+
void endEncodingImpl() override {
365364
if (ComputeEnc) {
366365
barrier();
367366
ComputeEnc->popDebugGroup();

lib/API/VK/Device.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ class VKComputeEncoder : public offloadtest::ComputeEncoder {
650650
VKComputeEncoder(VulkanCommandBuffer &CB, EncoderMode Mode)
651651
: ComputeEncoder(GPUAPI::Vulkan, Mode), CB(CB) {}
652652

653-
~VKComputeEncoder() override = default;
653+
~VKComputeEncoder() override { endEncoding(); }
654654

655655
void pushDebugGroup(llvm::StringRef Label) override {
656656
if (!CB.CmdBeginDebugUtilsLabel)
@@ -742,7 +742,7 @@ class VKComputeEncoder : public offloadtest::ComputeEncoder {
742742
VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT);
743743
}
744744

745-
void endEncoding() override { popDebugGroup(); }
745+
void endEncodingImpl() override { popDebugGroup(); }
746746
};
747747

748748
llvm::Expected<std::unique_ptr<offloadtest::ComputeEncoder>>

0 commit comments

Comments
 (0)