Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions include/internal/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@

namespace cudecomp {

using cudaGraph = uniqueHandle<cudaGraph_t, cudaGraphDestroy>;
using cudaGraphExec = uniqueHandle<cudaGraphExec_t, cudaGraphExecDestroy>;

class graphCache {
using key_type = std::tuple<void*, void*, int, int, cudecompPencilInfo_t, cudecompPencilInfo_t, cudecompDataType_t>;

Expand All @@ -40,12 +43,10 @@ class graphCache {
cudaStream_t startCapture(const key_type& key, cudaStream_t stream) const;
void endCapture(const key_type& key);
bool cached(const key_type& key) const;
void clear();
void clear() noexcept;

private:
void clearNoThrow() noexcept;

std::unordered_map<key_type, cudaGraphExec_t> graph_cache_;
std::unordered_map<key_type, cudaGraphExec> graph_cache_;
cudaStream graph_stream_;
};

Expand Down
38 changes: 38 additions & 0 deletions include/internal/raii_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,44 @@

namespace cudecomp {

template <typename T, auto destroy_fn> class uniqueHandle {
public:
uniqueHandle() = default;
explicit uniqueHandle(T handle) : handle_(handle) {}
~uniqueHandle() noexcept { resetNoThrow(); }

uniqueHandle(const uniqueHandle&) = delete;
uniqueHandle& operator=(const uniqueHandle&) = delete;

uniqueHandle(uniqueHandle&& other) noexcept : handle_(std::exchange(other.handle_, T{})) {}

uniqueHandle& operator=(uniqueHandle&& other) noexcept {
if (this != &other) {
resetNoThrow();
handle_ = std::exchange(other.handle_, T{});
}
return *this;
}

T get() const noexcept { return handle_; }
T* put() noexcept {
resetNoThrow();
return &handle_;
}
T release() noexcept { return std::exchange(handle_, T{}); }
operator T() const noexcept { return handle_; }

private:
void resetNoThrow() noexcept {
if (handle_) {
destroy_fn(handle_);
handle_ = T{};
}
}

T handle_{};
};

template <unsigned int flags> class cudaEventBase {
public:
cudaEventBase() { CHECK_CUDA(cudaEventCreateWithFlags(&event_, flags)); }
Expand Down
32 changes: 16 additions & 16 deletions include/internal/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "internal/cudecomp_kernels.h"
#include "internal/nvtx.h"
#include "internal/performance.h"
#include "internal/raii_wrappers.h"
#include "internal/utils.h"

namespace cudecomp {
Expand Down Expand Up @@ -72,6 +73,10 @@ template <typename T> static inline uint32_t getAlignment(const T* ptr) {
return 1;
}

using cutensorTensorDesc = uniqueHandle<cutensorTensorDescriptor_t, cutensorDestroyTensorDescriptor>;
using cutensorOperationDesc = uniqueHandle<cutensorOperationDescriptor_t, cutensorDestroyOperationDescriptor>;
using cutensorPlan = uniqueHandle<cutensorPlan_t, cutensorDestroyPlan>;

template <typename T>
static void localPermute(const cudecompHandle_t handle, const std::array<int64_t, 3>& extent_in,
const std::array<int32_t, 3>& order_out, const std::array<int64_t, 3>& strides_in,
Expand Down Expand Up @@ -132,28 +137,23 @@ static void localPermute(const cudecompHandle_t handle, const std::array<int64_t
auto strides_in_ptr = anyNonzeros(strides_in) ? strides_in.data() : nullptr;
auto strides_out_ptr = anyNonzeros(strides_out) ? strides_out.data() : nullptr;

cutensorTensorDescriptor_t desc_in;
CHECK_CUTENSOR(cutensorCreateTensorDescriptor(handle->cutensor_handle, &desc_in, 3, extent_in.data(), strides_in_ptr,
cutensor_type, getAlignment(input)));
cutensorTensorDescriptor_t desc_out;
CHECK_CUTENSOR(cutensorCreateTensorDescriptor(handle->cutensor_handle, &desc_out, 3, extent_out.data(),
cutensorTensorDesc desc_in;
CHECK_CUTENSOR(cutensorCreateTensorDescriptor(handle->cutensor_handle, desc_in.put(), 3, extent_in.data(),
strides_in_ptr, cutensor_type, getAlignment(input)));
cutensorTensorDesc desc_out;
CHECK_CUTENSOR(cutensorCreateTensorDescriptor(handle->cutensor_handle, desc_out.put(), 3, extent_out.data(),
strides_out_ptr, cutensor_type, getAlignment(output)));

cutensorOperationDescriptor_t desc_op;
CHECK_CUTENSOR(cutensorCreatePermutation(handle->cutensor_handle, &desc_op, desc_in, order_in.data(),
CUTENSOR_OP_IDENTITY, desc_out, order_out.data(),
cutensorOperationDesc desc_op;
CHECK_CUTENSOR(cutensorCreatePermutation(handle->cutensor_handle, desc_op.put(), desc_in.get(), order_in.data(),
CUTENSOR_OP_IDENTITY, desc_out.get(), order_out.data(),
getCutensorComputeType(cutensor_type)));

cutensorPlan_t plan;
CHECK_CUTENSOR(cutensorCreatePlan(handle->cutensor_handle, &plan, desc_op, handle->cutensor_plan_pref, 0));
cutensorPlan plan;
CHECK_CUTENSOR(cutensorCreatePlan(handle->cutensor_handle, plan.put(), desc_op.get(), handle->cutensor_plan_pref, 0));

T one(1);
CHECK_CUTENSOR(cutensorPermute(handle->cutensor_handle, plan, &one, input, output, stream));

CHECK_CUTENSOR(cutensorDestroyTensorDescriptor(desc_in));
CHECK_CUTENSOR(cutensorDestroyTensorDescriptor(desc_out));
CHECK_CUTENSOR(cutensorDestroyOperationDescriptor(desc_op));
CHECK_CUTENSOR(cutensorDestroyPlan(plan));
CHECK_CUTENSOR(cutensorPermute(handle->cutensor_handle, plan.get(), &one, input, output, stream));
}

#else
Expand Down
42 changes: 37 additions & 5 deletions src/cudecomp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,31 @@ static void cleanupFailedGridDescCreate(cudecompHandle_t handle, cudecompGridDes
releaseUnusedHandleResources(handle, release_streams);
}

#if CUDART_VERSION >= 12030
struct cuMemAllocationGuard {
~cuMemAllocationGuard() noexcept {
if (!active) return;
if (mapped) { cuFnTable.pfn_cuMemUnmap(ptr, size); }
if (address_reserved) { cuFnTable.pfn_cuMemAddressFree(ptr, size); }
if (handle_created) { cuFnTable.pfn_cuMemRelease(handle); }
}

cuMemAllocationGuard() = default;
cuMemAllocationGuard(const cuMemAllocationGuard&) = delete;
cuMemAllocationGuard& operator=(const cuMemAllocationGuard&) = delete;

void release() noexcept { active = false; }

CUmemGenericAllocationHandle handle = 0;
CUdeviceptr ptr = 0;
size_t size = 0;
bool handle_created = false;
bool address_reserved = false;
bool mapped = false;
bool active = true;
};
#endif

} // namespace
} // namespace cudecomp

Expand Down Expand Up @@ -1173,17 +1198,24 @@ cudecompResult_t cudecompMalloc(cudecompHandle_t handle, cudecompGridDesc_t grid
buffer_size_bytes = (buffer_size_bytes + granularity - 1) / granularity * granularity;

// Allocate memory
CUmemGenericAllocationHandle cumem_handle;
CHECK_CUDA_DRV(cuMemCreate(&cumem_handle, buffer_size_bytes, &prop, 0));
CHECK_CUDA_DRV(cuMemAddressReserve((CUdeviceptr*)buffer, buffer_size_bytes, granularity, 0, 0));
CHECK_CUDA_DRV(cuMemMap((CUdeviceptr)*buffer, buffer_size_bytes, 0, cumem_handle, 0));
cuMemAllocationGuard cumem_guard;
cumem_guard.size = buffer_size_bytes;
CHECK_CUDA_DRV(cuMemCreate(&cumem_guard.handle, buffer_size_bytes, &prop, 0));
cumem_guard.handle_created = true;
CHECK_CUDA_DRV(cuMemAddressReserve(&cumem_guard.ptr, buffer_size_bytes, granularity, 0, 0));
cumem_guard.address_reserved = true;
CHECK_CUDA_DRV(cuMemMap(cumem_guard.ptr, buffer_size_bytes, 0, cumem_guard.handle, 0));
cumem_guard.mapped = true;

// Set read/write access
CUmemAccessDesc accessDesc = {};
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = cu_dev;
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CHECK_CUDA_DRV(cuMemSetAccess((CUdeviceptr)*buffer, buffer_size_bytes, &accessDesc, 1));
CHECK_CUDA_DRV(cuMemSetAccess(cumem_guard.ptr, buffer_size_bytes, &accessDesc, 1));

*buffer = reinterpret_cast<void*>(cumem_guard.ptr);
cumem_guard.release();
#endif
} else {
CHECK_CUDA(cudaMalloc(buffer, buffer_size_bytes));
Expand Down
32 changes: 9 additions & 23 deletions src/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <tuple>
#include <unordered_map>
#include <utility>

#include <cuda_runtime.h>

Expand All @@ -27,10 +28,10 @@

namespace cudecomp {

graphCache::~graphCache() noexcept { clearNoThrow(); }
graphCache::~graphCache() noexcept { clear(); }

void graphCache::replay(const graphCache::key_type& key, cudaStream_t stream) const {
CHECK_CUDA(cudaGraphLaunch(graph_cache_.at(key), stream));
CHECK_CUDA(cudaGraphLaunch(graph_cache_.at(key).get(), stream));
}

cudaStream_t graphCache::startCapture(const graphCache::key_type& key, cudaStream_t stream) const {
Expand All @@ -39,31 +40,16 @@ cudaStream_t graphCache::startCapture(const graphCache::key_type& key, cudaStrea
}

void graphCache::endCapture(const graphCache::key_type& key) {
cudaGraph_t graph;
cudaGraphExec_t graph_exec;
CHECK_CUDA(cudaStreamEndCapture(graph_stream_, &graph));
CHECK_CUDA(cudaGraphInstantiate(&graph_exec, graph, nullptr, nullptr, 0));
CHECK_CUDA(cudaGraphDestroy(graph));
cudaGraph graph;
cudaGraphExec graph_exec;
CHECK_CUDA(cudaStreamEndCapture(graph_stream_, graph.put()));
CHECK_CUDA(cudaGraphInstantiate(graph_exec.put(), graph.get(), nullptr, nullptr, 0));

graph_cache_[key] = graph_exec;
graph_cache_[key] = std::move(graph_exec);
}

bool graphCache::cached(const graphCache::key_type& key) const { return graph_cache_.count(key) > 0; }

void graphCache::clear() {
for (auto& entry : graph_cache_) {
CHECK_CUDA(cudaGraphExecDestroy(entry.second));
}

graph_cache_.clear();
}

void graphCache::clearNoThrow() noexcept {
for (auto& entry : graph_cache_) {
cudaGraphExecDestroy(entry.second);
}

graph_cache_.clear();
}
void graphCache::clear() noexcept { graph_cache_.clear(); }

} // namespace cudecomp
Loading