Skip to content

Commit cfd066f

Browse files
committed
baked kernel compression - improve coding
1 parent 5c0cb32 commit cfd066f

4 files changed

Lines changed: 51 additions & 20 deletions

File tree

Orochi/OrochiUtils.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -794,35 +794,45 @@ void OrochiUtils::launch2D( oroFunction func, int nx, int ny, const void** args,
794794
OROASSERT( e == oroSuccess, 0 );
795795
}
796796

797-
void OrochiUtils::DecompressPrecompiled(std::vector<unsigned char>& out, const unsigned char* compressedInput, size_t compressedInput_sizeByte, size_t uncompressed_sizeByte)
797+
void OrochiUtils::HandlePrecompiled(std::vector<unsigned char>& out, const CompressedBuffer& buffer)
798798
{
799-
if ( uncompressed_sizeByte > 0 ) // if the input data is actually compressed
800-
{
801799
#ifdef ORO_LINK_ZSTD
802-
out.assign(uncompressed_sizeByte,0);
800+
out.assign(buffer.uncompressedSize,0);
803801

804802
size_t decompressedSize = ZSTD_decompress(
805803
out.data(), // final uncompressed buffer
806804
out.size(), // final size
807-
compressedInput, // compressed buffer
808-
compressedInput_sizeByte // compressed buffer - size
805+
buffer.data, // compressed buffer
806+
buffer.size // compressed buffer - size
809807
);
810808

811-
if ( decompressedSize != uncompressed_sizeByte )
809+
if ( decompressedSize != buffer.uncompressedSize )
812810
throw std::runtime_error( "ERROR: ZSTD_decompress FAILED." );
813811
#else
814-
815812
throw std::runtime_error( "ERROR: ZSTD is not part of this build." );
816-
817813
#endif
814+
return;
815+
}
818816

819-
}
820-
else // if the input data is NOT compressed, buypass this decompress process.
821-
{
822-
out = std::vector<unsigned char>(compressedInput, compressedInput + compressedInput_sizeByte );
823-
}
817+
818+
void OrochiUtils::HandlePrecompiled(std::vector<unsigned char>& out, const RawBuffer& buffer)
819+
{
820+
out = std::vector<unsigned char>(buffer.data, buffer.data + buffer.size );
824821
return;
825822
}
826823

827824

825+
void OrochiUtils::HandlePrecompiled(std::vector<unsigned char>& out, const unsigned char* rawData, size_t rawData_sizeByte, std::optional<size_t> uncompressed_sizeByte)
826+
{
827+
if (uncompressed_sizeByte.has_value()) {
828+
// if the input buffer is compressed :
829+
CompressedBuffer buffer{ rawData, rawData_sizeByte, uncompressed_sizeByte.value() };
830+
HandlePrecompiled(out, buffer );
831+
} else {
832+
// if the input buffer is not compressed
833+
RawBuffer buffer{ rawData, rawData_sizeByte };
834+
HandlePrecompiled(out, buffer );
835+
}
836+
}
837+
828838

Orochi/OrochiUtils.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <filesystem>
2828
#include <unordered_map>
2929
#include <vector>
30+
#include <optional>
3031

3132
#if defined( GNUC )
3233
#include <signal.h>
@@ -84,8 +85,19 @@ class OrochiUtils
8485
static void launch1D( oroFunction func, int nx, const void** args, int wgSize = 64, unsigned int sharedMemBytes = 0, oroStream stream = 0 );
8586
static void launch2D( oroFunction func, int nx, int ny, const void** args, int wgSizeX = 8, int wgSizeY = 8, unsigned int sharedMemBytes = 0, oroStream stream = 0 );
8687

87-
// if 'uncompressed_sizeByte' is set to 0, it means the input value is not compressed and this function will output the raw buffer.
88-
static void DecompressPrecompiled(std::vector<unsigned char>& out, const unsigned char* compressedInput, size_t compressedInput_sizeByte, size_t uncompressed_sizeByte);
88+
89+
struct CompressedBuffer {
90+
const unsigned char* data = nullptr; // compressed data
91+
size_t size = 0; // size in byte of 'data'
92+
size_t uncompressedSize = 0; // size of byte of the uncompressed data.
93+
};
94+
struct RawBuffer {
95+
const unsigned char* data = nullptr;
96+
size_t size = 0;
97+
};
98+
static void HandlePrecompiled(std::vector<unsigned char>& out, const CompressedBuffer& buffer);
99+
static void HandlePrecompiled(std::vector<unsigned char>& out, const RawBuffer& buffer);
100+
static void HandlePrecompiled(std::vector<unsigned char>& out, const unsigned char* rawData, size_t rawData_sizeByte, std::optional<size_t> uncompressed_sizeByte=std::nullopt);
89101

90102
template<typename T>
91103
static void malloc( T*& ptr, size_t n )

ParallelPrimitives/RadixSort.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ static const char** RadixSortKernelsIncludes = nullptr;
5555
const unsigned char oro_compiled_kernels_h[] = "";
5656
const size_t oro_compiled_kernels_h_size = 0;
5757
const size_t oro_compiled_kernels_h_size_uncompressed = 0;
58+
const bool oro_compiled_kernels_h_isCompressed = false;
5859
#endif
5960

6061
constexpr uint64_t div_round_up64( uint64_t val, uint64_t divisor ) noexcept { return ( val + divisor - 1 ) / divisor; }
@@ -191,7 +192,7 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string
191192
if constexpr( usePrecompiledAndBakedKernel )
192193
{
193194
std::vector<unsigned char> binary;
194-
OrochiUtils::DecompressPrecompiled(binary, oro_compiled_kernels_h, oro_compiled_kernels_h_size, oro_compiled_kernels_h_size_uncompressed);
195+
OrochiUtils::HandlePrecompiled(binary, oro_compiled_kernels_h, oro_compiled_kernels_h_size, oro_compiled_kernels_h_isCompressed ? std::optional<size_t>{oro_compiled_kernels_h_size_uncompressed} : std::nullopt);
195196
oroFunctions[record.kernelType] = m_oroutils.getFunctionFromPrecompiledBinary_asData(binary.data(), binary.size(), record.kernelName.c_str() );
196197
}
197198
else if constexpr( useBakeKernel )

scripts/convert_binary_to_array.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,17 @@ def binary_to_c_array(bin_file, array_name, size_BeforeCompression, compression_
1010
c_array = f'const unsigned char {array_name}[] = {{\n {hex_array}\n}};\n'
1111
c_array += f'const size_t {array_name}_size = sizeof({array_name}); // {len(binary_data)}\n'
1212

13-
if not compression_activated:
14-
size_BeforeCompression = 0 # set value to 0 if we are not using compression.
15-
c_array += f'const size_t {array_name}_size_uncompressed = {size_BeforeCompression}; // set to 0 if NOT using the ZSTD compression.\n'
13+
c_array += f'const size_t {array_name}_size_uncompressed = '
14+
if compression_activated:
15+
c_array += f'{size_BeforeCompression}; // size of the data in bytes, once it has been uncompressed.\n'
16+
else:
17+
c_array += f'{array_name}_size; // same than raw buffer, because data is not compressed.\n'
18+
19+
c_array += f'const bool {array_name}_isCompressed = '
20+
if compression_activated:
21+
c_array += f'true;\n'
22+
else:
23+
c_array += f'false;\n'
1624
return c_array
1725

1826
if __name__ == "__main__":

0 commit comments

Comments
 (0)