Skip to content

Commit d78fb81

Browse files
takahiroharadaKaoCCjammmmehmetoguzderinmeistdan
authored
Feature/oro 0 amdadvtech merge (#43)
* Add gitignore to the repository Signed-off-by: Chih-Chen Kao <[email protected]> * Fix missing CUDA properties. (#16) Signed-off-by: Chih-Chen Kao <[email protected]> * Feature/oro 0 radix sort (#19) * [ORO-0] Working 8 bit radix sort. * [ORO-0] Some optimization. * Create LICENSE * Update README.md (#15) * Feature/oro 0 raw get set (#19) * [ORO-0] Rename setter and getter. * [ORO-0] Fix when there is a dll but no device. * [ORO-0] Deletion function. * [ORO-0] Multi processor count. * [ORO-0] Extended the sort to more than 8 bits. Implemented tests. * [ORO-0] Moved temp buffer allocation out from the sort(). * [ORO-0] README. References. * [ORO-0] Debug flag. * Refactor the code to add the basic constructs to support selecting different scan algorithms. Add different implementation of the scan algorithm: CPU, single WG and all WG . Signed-off-by: Chih-Chen Kao <[email protected]> * Squashed commit of the following: commit 3f32bea2244653d59efb3c3eaa9433018dde5835 Author: takahiroharada <[email protected]> Date: Wed Apr 13 10:48:35 2022 -0700 [ORO-0] Fix nvrtc. * Optimization: Implement the single-pass kernel for GPU parallel scan. Fix a GPU memory bug. Signed-off-by: Chih-Chen Kao <[email protected]> * Feature/oro 0 kernel cache (#4) * [ORO-0] Cache kernel. * [ORO-0] Support newer HIP builds on windows (#22) * [ORO-0] Unit test. (#23) * Fix LDS scan bug. The previous implementation would lead to an error when the wavefront (wrap) size is not equal to the size of a workgroup (block). Since not all threads run simultaneously, for an input arrays larger than the wavefront size, the previous algorithm will not work because it performs the scan in-place on the input array. The results of one wavefront (wrap) will be overwritten by work items (threads) in another wavefront (wrap). Signed-off-by: Chih-Chen Kao <[email protected]> * Optimize the LDS scan algorithm. (#6) * Optimize the LDS scan algorithm. This version does not require a temp buffer and can support a LDS input size up to 2 times the workgroup size. Signed-off-by: Chih-Chen Kao <[email protected]> * Support an input array in LDS that is 2 times the WG size. Signed-off-by: Chih-Chen Kao <[email protected]> * Feature/oro 0 clean up (#7) * Squashed commit of the following: commit 3f32bea2244653d59efb3c3eaa9433018dde5835 Author: takahiroharada <[email protected]> Date: Wed Apr 13 10:48:35 2022 -0700 [ORO-0] Fix nvrtc. * [ORO-0] Clean up. * Feature/oro 0 clean up (#10) * Squashed commit of the following: commit 3f32bea2244653d59efb3c3eaa9433018dde5835 Author: takahiroharada <[email protected]> Date: Wed Apr 13 10:48:35 2022 -0700 [ORO-0] Fix nvrtc. * [ORO-0] Clean up. * [ORO-0] SortKernel1. Less complex. (#8) SortKernel (occupancy: 8) - vgpr: 128 - lds: 6704 SortKernel1 (occupancy: 9) - vgpr: 106 - lds 7720 * [ORO-0] Kernel execution time check. * Fix the memory access pattern and change it to coalesced memory access. (#11) Signed-off-by: Chih-Chen Kao <[email protected]> * [ORO-0] Single kernel sort for small keys. (#12) * Optimize the Count kernel for less LDS usage to achieve full occupancy (#13) * Optimize the Count kernel to let it use less LDS and could achieve full occupancy. Signed-off-by: Chih-Chen Kao <[email protected]> * Remove __threadfence_block() Removes the boundary check in the inner loop. The upper bound is set only once before going into the loop. Signed-off-by: Chih-Chen Kao <[email protected]> * Introduce DRIVER and RTC APIs * Disable enum-variant * Improve paths * Add fields * Update Vulkan test * Define CUDA in terms of DRIVER and RTC * Optimize the sort kernel: single-pass 8bit sort & parallel scan in 4bit sort. (#14) * Fix a minor issue in CountKernel to make it more robust. Implement a single-pass 8-bit local sort. Implement a single-pass 8-bit local sort with shared bins. Signed-off-by: Chih-Chen Kao <[email protected]> * Fix nItemsPerWI and enable the version with shared LDS. Signed-off-by: Chih-Chen Kao <[email protected]> * [ORO-0] Print driver version. * [ORO-0] Repro case. * Fix SORT_WG_SIZE. Fix stable sort order. Signed-off-by: Chih-Chen Kao <[email protected]> * Optimize sort kernel to remove inner boundary check. Adjust nItemsPerWI. Signed-off-by: Chih-Chen Kao <[email protected]> Co-authored-by: takahiroharada <[email protected]> * Merging another merge (#18) * Fix a minor issue in CountKernel to make it more robust. Implement a single-pass 8-bit local sort. Implement a single-pass 8-bit local sort with shared bins. Signed-off-by: Chih-Chen Kao <[email protected]> * Fix nItemsPerWI and enable the version with shared LDS. Signed-off-by: Chih-Chen Kao <[email protected]> * [ORO-0] Print driver version. * [ORO-0] Repro case. * Fix SORT_WG_SIZE. Fix stable sort order. Signed-off-by: Chih-Chen Kao <[email protected]> * Optimize sort kernel to remove inner boundary check. Adjust nItemsPerWI. Signed-off-by: Chih-Chen Kao <[email protected]> * Calculate the number of WGs based on LDS and max-thread-per-WGP. (#15) * Calculate the number of WGs based on LDS and max-thread-per-WGP. Signed-off-by: Chih-Chen Kao <[email protected]> * Add a workaround for CUDA. Signed-off-by: Chih-Chen Kao <[email protected]> * Optimize the sort kernel: single-pass 8bit sort & parallel scan in 4bit sort. (#14) * Fix a minor issue in CountKernel to make it more robust. Implement a single-pass 8-bit local sort. Implement a single-pass 8-bit local sort with shared bins. Signed-off-by: Chih-Chen Kao <[email protected]> * Fix nItemsPerWI and enable the version with shared LDS. Signed-off-by: Chih-Chen Kao <[email protected]> * [ORO-0] Print driver version. * [ORO-0] Repro case. * Fix SORT_WG_SIZE. Fix stable sort order. Signed-off-by: Chih-Chen Kao <[email protected]> * Optimize sort kernel to remove inner boundary check. Adjust nItemsPerWI. Signed-off-by: Chih-Chen Kao <[email protected]> Co-authored-by: takahiroharada <[email protected]> Co-authored-by: takahiroharada <[email protected]> Co-authored-by: Chih-Chen Kao <[email protected]> * Implement key-value pair sorting (#17) * Add gitignore to the repository Signed-off-by: Chih-Chen Kao <[email protected]> * Fix missing CUDA properties. (#16) Signed-off-by: Chih-Chen Kao <[email protected]> * Add basic structure for key-value pair sorting. Fix an error in single pass sort Signed-off-by: Chih-Chen Kao <[email protected]> * Add Value data in the test and sort it according to keys. Signed-off-by: Chih-Chen Kao <[email protected]> * Support Key only sorting. Signed-off-by: Chih-Chen Kao <[email protected]> * [ORO-0] Make single pass kernel non compile time switch. * Support both Key-Only & Key-Value pair sort kernels Signed-off-by: Chih-Chen Kao <[email protected]> * [ORO-0] Test change. * [ORO-0] A bug. * [ORO-0] NVIDIA occupancy computation fix. Test change. Tweak params to use single pass sort as much as possible. Co-authored-by: Takahiro Harada <[email protected]> Co-authored-by: takahiroharada <[email protected]> * [ORO-0] Revert demo code. * Fix missing CUDA properties. (#26) * Update Orochi.cpp * [ORO-0] Clean up. * [ORO-0] OroUtils. (#27) * [ORO-0] OroUtils. * [ORO-0] Linux build fix. * [ORO-0] Forgot to add. * [ORO-0] Linux build fix. * [ORO-0] Clean up. Co-authored-by: Chih-Chen Kao <[email protected]> Co-authored-by: Aaryaman Vasishta <[email protected]> Co-authored-by: Mehmet Oguz Derin <[email protected]> * Add kernel path and include dir to the functions. (#20) Signed-off-by: Chih-Chen Kao <[email protected]> * [ORO-0] BakeKernel. (#21) * [ORO-0] BakeKernel. * Update tools/genArgs.py commented code removal * Update tools/stringify.py commented code removal * Update tools/stringify.py commented code removal * Update tools/stringify.py commented code removal * Update tools/genArgs.py dead code removal * Update tools/stringify.py dead code removal * fix include Signed-off-by: Chih-Chen Kao <[email protected]> * fix script Signed-off-by: Chih-Chen Kao <[email protected]> * fix Signed-off-by: Chih-Chen Kao <[email protected]> Co-authored-by: Chih-Chen Kao <[email protected]> * Fix Orochi CUDA API (#23) Fix Orochi CUDA API Signed-off-by: Chih-Chen Kao <[email protected]> * [ORO-0] Linux build fix. (#22) * [ORO-0] Linux build fix. * Fix Orochi CUDA API Signed-off-by: Chih-Chen Kao <[email protected]> Co-authored-by: Chih-Chen Kao <[email protected]> * Quick fix for old linux gcc which does not support std::exclusive_scan (#24) Quick fix for old linux gcc which does not support std::exclusive_scan Signed-off-by: Chih-Chen Kao <[email protected]> * Fix the kernel cache bug. (#25) Fix the kernel cache bug. The function should not return the oroFunctions that are created previously solely based on the names because they might be invalid. Signed-off-by: Chih-Chen Kao <[email protected]> * [ORO-0] Remove static variables. (#26) * [ORO-0] Remove static variables. * [ORO-0] Applied the suggestions. * [ORO-0] Linux regression fix. * Fix OrochiUtils::getFunctionFromString API (#27) Signed-off-by: Chih-Chen Kao <[email protected]> * Adding missing assert (#28) * Adding missing assert * Adding more asserts * Feature/oro 0 gpuopen merge (#31) * Fix oroGetDeviceProperties in cuda path. * Fix linux crash (#29) * [ORO-0] Added missing file. * [ORO-0] Remove printf from kernelExec and skip compilation of vulkan test on Linux (#31) * [ORO-0] Skip compilation of vulkan test on Linux * [ORO-0] Update kernelExec unit test - remove printf * [ORO-0] Remove cout * [ORO-0] Fix hipGetErrorString (#32) * [ORO-0] Fix hipGetErrorString It was incorrectly importing this API. Import the correct API in hipew. * [ORO-0] Remove printf from kernelExec and skip compilation of vulkan test on Linux (#31) * [ORO-0] Skip compilation of vulkan test on Linux * [ORO-0] Update kernelExec unit test - remove printf * [ORO-0] Remove cout * [ORO-0] Add Orochi error codes mapped to HIP/CUDA (#33) * Add missing path on Apple config. (#34) * [ORO-0] Adding hiprtc+comgr dlls to workaround the regression in 22.7.1 driver (#38) * [ORO-0] Adding hiprtc to workaround the regression in 22.7.1 driver released at 7/26/2022. * [ORO-0] Created win64 subdir. * [ORO-0] Add hiprtc.dll and comgr dll Co-authored-by: takahiroharada <[email protected]> * fix footnote markdown format (#39) * Fix orochi utils issue in unit tests Co-authored-by: Aaryaman Vasishta <[email protected]> Co-authored-by: Chih-Chen Kao <[email protected]> Co-authored-by: NevesLucas <[email protected]> Co-authored-by: PixelClear <[email protected]> Signed-off-by: Chih-Chen Kao <[email protected]> Co-authored-by: Chih-Chen Kao <[email protected]> Co-authored-by: Aaryaman Vasishta <[email protected]> Co-authored-by: Mehmet Oguz Derin <[email protected]> Co-authored-by: Daniel Meister <[email protected]> Co-authored-by: NevesLucas <[email protected]> Co-authored-by: PixelClear <[email protected]>
1 parent 03c4676 commit d78fb81

8 files changed

Lines changed: 357 additions & 18 deletions

File tree

Orochi/OrochiUtils.cpp

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ struct OrochiUtilsImpl
155155
return false;
156156
}
157157

158-
static void getCacheFileName( oroDevice device, const char* moduleName, const char* functionName, const char* options, std::string& binFileName )
158+
static void getCacheFileName( oroDevice device, const char* moduleName, const char* functionName, const char* options, std::string& binFileName, const std::string& cacheDirectory )
159159
{
160160
auto hashBin = []( const char* s, const size_t size )
161161
{
@@ -220,7 +220,7 @@ struct OrochiUtilsImpl
220220
using namespace std::string_literals;
221221

222222
deviceName = deviceName.substr( 0, deviceName.find( ":" ) );
223-
binFileName = OrochiUtils::s_cacheDirectory + "/"s + moduleHash + "-"s + optionHash + ".v."s + deviceName + "."s + driverVersion + "_"s + std::to_string( 8 * sizeof( void* ) ) + ".bin"s;
223+
binFileName = cacheDirectory + "/"s + moduleHash + "-"s + optionHash + ".v."s + deviceName + "."s + driverVersion + "_"s + std::to_string( 8 * sizeof( void* ) ) + ".bin"s;
224224
}
225225
static
226226
bool isFileUpToDate( const char* binaryFileName, const char* srcFileName )
@@ -381,27 +381,47 @@ struct OrochiUtilsImpl
381381
}
382382
};
383383

384-
char* OrochiUtils::s_cacheDirectory = "./cache/";
385-
std::map<std::string, oroFunction> OrochiUtils::s_kernelMap;
384+
OrochiUtils::OrochiUtils()
385+
{
386+
m_cacheDirectory = "./cache/";
387+
}
388+
389+
OrochiUtils::~OrochiUtils()
390+
{
391+
}
386392

387393
oroFunction OrochiUtils::getFunctionFromFile( oroDevice device, const char* path, const char* funcName, std::vector<const char*>* optsIn )
388394
{
389-
std::string cacheName = OrochiUtilsImpl::getCacheName( path, funcName );
390-
if( s_kernelMap.find( cacheName.c_str() ) != s_kernelMap.end() )
395+
const std::string cacheName = OrochiUtilsImpl::getCacheName( path, funcName );
396+
if( m_kernelMap.find( cacheName.c_str() ) != m_kernelMap.end() )
391397
{
392-
return s_kernelMap[ cacheName ];
398+
return m_kernelMap[ cacheName ];
393399
}
394400

395401
std::string source;
396402
if( !OrochiUtilsImpl::readSourceCode( path, source, 0 ) )
397403
return 0;
398404

399405
oroFunction f = getFunction( device, source.c_str(), path, funcName, optsIn );
400-
s_kernelMap[cacheName] = f;
406+
m_kernelMap[cacheName] = f;
401407
return f;
402408
}
403409

404-
oroFunction OrochiUtils::getFunction( oroDevice device, const char* code, const char* path, const char* funcName, std::vector<const char*>* optsIn )
410+
oroFunction OrochiUtils::getFunctionFromString( oroDevice device, const char* source, const char* path, const char* funcName, std::vector<const char*>* optsIn,
411+
int numHeaders, const char** headers, const char** includeNames )
412+
{
413+
const std::string cacheName = OrochiUtilsImpl::getCacheName( path, funcName );
414+
if( m_kernelMap.find( cacheName.c_str() ) != m_kernelMap.end() )
415+
{
416+
return m_kernelMap[cacheName];
417+
}
418+
oroFunction f = getFunction( device, source, path, funcName, optsIn, numHeaders, headers, includeNames );
419+
m_kernelMap[cacheName] = f;
420+
return f;
421+
}
422+
423+
oroFunction OrochiUtils::getFunction( oroDevice device, const char* code, const char* path, const char* funcName, std::vector<const char*>* optsIn,
424+
int numHeaders, const char** headers, const char** includeNames )
405425
{
406426
std::vector<const char*> opts;
407427
opts.push_back( "-std=c++17" );
@@ -422,7 +442,7 @@ oroFunction OrochiUtils::getFunction( oroDevice device, const char* code, const
422442
std::string o;
423443
for(int i=0; i<opts.size(); i++)
424444
o.append( opts[i] );
425-
OrochiUtilsImpl::getCacheFileName( device, path, funcName, o.c_str(), cacheFile );
445+
OrochiUtilsImpl::getCacheFileName( device, path, funcName, o.c_str(), cacheFile, m_cacheDirectory );
426446
}
427447
if( OrochiUtilsImpl::isFileUpToDate( cacheFile.c_str(), path ) )
428448
{
@@ -433,7 +453,8 @@ oroFunction OrochiUtils::getFunction( oroDevice device, const char* code, const
433453
{
434454
orortcProgram prog;
435455
orortcResult e;
436-
e = orortcCreateProgram( &prog, code, path, 0, 0, 0 );
456+
e = orortcCreateProgram( &prog, code, path, numHeaders, headers, includeNames );
457+
OROASSERT( e == ORORTC_SUCCESS, 0 );
437458

438459
e = orortcCompileProgram( prog, opts.size(), opts.data() );
439460
if( e != ORORTC_SUCCESS )
@@ -449,18 +470,23 @@ oroFunction OrochiUtils::getFunction( oroDevice device, const char* code, const
449470
}
450471
size_t codeSize;
451472
e = orortcGetCodeSize( prog, &codeSize );
473+
OROASSERT( e == ORORTC_SUCCESS, 0 );
452474

453475
codec.resize( codeSize );
454476
e = orortcGetCode( prog, codec.data() );
477+
OROASSERT( e == ORORTC_SUCCESS, 0 );
455478
e = orortcDestroyProgram( &prog );
479+
OROASSERT( e == ORORTC_SUCCESS, 0 );
456480

457481
//store cache
458-
OrochiUtilsImpl::createDirectory( s_cacheDirectory );
482+
OrochiUtilsImpl::createDirectory( m_cacheDirectory.c_str() );
459483
OrochiUtilsImpl::cacheBinaryToFile( codec, cacheFile );
460484
}
461485
oroModule module;
462486
oroError ee = oroModuleLoadData( &module, codec.data() );
487+
OROASSERT( ee == oroSuccess, 0 );
463488
ee = oroModuleGetFunction( &function, module, funcName );
489+
OROASSERT( ee == oroSuccess, 0 );
464490

465491
return function;
466492
}

Orochi/OrochiUtils.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22
#include <Orochi/Orochi.h>
33
#include <vector>
4-
#include <map>
4+
#include <unordered_map>
55
#include <string>
66

77
#if defined(_WIN32)
@@ -18,8 +18,14 @@ class OrochiUtils
1818
int x, y, z, w;
1919
};
2020

21-
static oroFunction getFunctionFromFile( oroDevice device, const char* path, const char* funcName, std::vector<const char*>* opts );
22-
static oroFunction getFunction( oroDevice device, const char* code, const char* path, const char* funcName, std::vector<const char*>* opts );
21+
OrochiUtils();
22+
~OrochiUtils();
23+
24+
oroFunction getFunctionFromFile( oroDevice device, const char* path, const char* funcName, std::vector<const char*>* opts );
25+
oroFunction getFunctionFromString( oroDevice device, const char* source, const char* path, const char* funcName, std::vector<const char*>* opts,
26+
int numHeaders, const char** headers, const char** includeNames );
27+
oroFunction getFunction( oroDevice device, const char* code, const char* path, const char* funcName, std::vector<const char*>* opts,
28+
int numHeaders = 0, const char** headers = 0, const char** includeNames = 0 );
2329

2430
static void launch1D( oroFunction func, int nx, const void** args, int wgSize = 64, unsigned int sharedMemBytes = 0 );
2531

@@ -64,6 +70,6 @@ class OrochiUtils
6470
}
6571

6672
public:
67-
static char* s_cacheDirectory;
68-
static std::map<std::string, oroFunction> s_kernelMap;
73+
std::string m_cacheDirectory;
74+
std::unordered_map<std::string, oroFunction> m_kernelMap;
6975
};

Test/Stopwatch.h

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
AMD copyrights (Copyright (c) 2011 Advanced Micro Devices, Inc. All rights reserved)
3+
*/
4+
#pragma once
5+
6+
#if defined(__WINDOWS__)
7+
#define NOMINMAX
8+
#include <windows.h>
9+
10+
#define TIME_TYPE LARGE_INTEGER
11+
#define QUERY_FREQ(f) QueryPerformanceFrequency(&f)
12+
#define RECORD(t) QueryPerformanceCounter(&t)
13+
#define GET_TIME(t) (t).QuadPart*1000.0
14+
#define GET_FREQ(f) (f).QuadPart
15+
#else
16+
#include <sys/time.h>
17+
18+
#define TIME_TYPE timeval
19+
#define QUERY_FREQ(f) f.tv_sec = 1
20+
#define RECORD(t) gettimeofday(&t, 0)
21+
#define GET_TIME(t) ((t).tv_sec*1000.0+(t).tv_usec/1000.0)
22+
#define GET_FREQ(f) 1.0
23+
#endif
24+
25+
26+
class Stopwatch
27+
{
28+
public:
29+
__inline
30+
Stopwatch();
31+
__inline
32+
void init();
33+
__inline
34+
void start();
35+
__inline
36+
void split();
37+
__inline
38+
float getCurrent();
39+
__inline
40+
void stop();
41+
__inline
42+
float getMs();
43+
__inline
44+
void getMs( float* times, int capacity );
45+
46+
private:
47+
enum
48+
{
49+
CAPACITY = 12,
50+
};
51+
int m_idx;
52+
53+
TIME_TYPE m_frequency;
54+
TIME_TYPE m_t[CAPACITY];
55+
};
56+
57+
__inline
58+
Stopwatch::Stopwatch()
59+
{
60+
// QueryPerformanceFrequency( &m_frequency );
61+
QUERY_FREQ( m_frequency );
62+
}
63+
64+
__inline
65+
void Stopwatch::start()
66+
{
67+
m_idx = 0;
68+
// QueryPerformanceCounter(&m_t[m_idx++]);
69+
RECORD( m_t[m_idx++] );
70+
}
71+
72+
__inline
73+
void Stopwatch::split()
74+
{
75+
// QueryPerformanceCounter(&m_t[m_idx++]);
76+
RECORD( m_t[m_idx++] );
77+
}
78+
79+
__inline
80+
float Stopwatch::getCurrent()
81+
{
82+
TIME_TYPE t;
83+
RECORD( t );
84+
return (float)( GET_TIME(t) - GET_TIME(m_t[0]) )/GET_FREQ(m_frequency);
85+
}
86+
87+
__inline
88+
void Stopwatch::stop()
89+
{
90+
split();
91+
}
92+
93+
__inline
94+
float Stopwatch::getMs()
95+
{
96+
// return (float)(1000*(m_t[1].QuadPart - m_t[0].QuadPart))/m_frequency.QuadPart;
97+
return (float)( GET_TIME( m_t[1] ) - GET_TIME( m_t[0] ) )/GET_FREQ( m_frequency );
98+
}
99+
100+
__inline
101+
void Stopwatch::getMs(float* times, int capacity)
102+
{
103+
for(int i=0; i<capacity; i++) times[i] = 0.f;
104+
105+
for(int i=0; i<std::min(capacity, m_idx-1); i++)
106+
{
107+
// times[i] = (float)(1000*(m_t[i+1].QuadPart - m_t[i].QuadPart))/m_frequency.QuadPart;
108+
times[i] = (float)( GET_TIME( m_t[i+1] ) - GET_TIME( m_t[i] ) )/GET_FREQ( m_frequency );
109+
}
110+
}
111+

UnitTest/main.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
#include <Orochi/Orochi.h>
33
#include <Orochi/OrochiUtils.h>
44

5+
#if defined( OROASSERT )
6+
#undef OROASSERT
7+
#endif
58
#define OROASSERT( x ) ASSERT_TRUE( x )
69
#define OROCHECK( x ) { oroError e = x; OROASSERT( e == ORO_SUCCESS ); }
710

@@ -50,11 +53,12 @@ TEST_F( OroTestBase, deviceprops )
5053

5154
TEST_F( OroTestBase, kernelExec )
5255
{
56+
OrochiUtils o;
5357
int a_host = -1;
5458
int* a_device = nullptr;
5559
OROCHECK( oroMalloc( (oroDeviceptr*)&a_device, sizeof( int ) ) );
5660
OROCHECK( oroMemset( (oroDeviceptr)a_device, 0, sizeof( int ) ) );
57-
oroFunction kernel = OrochiUtils::getFunctionFromFile( m_device, "../UnitTest/testKernel.h", "testKernel", 0 );
61+
oroFunction kernel = o.getFunctionFromFile( m_device, "../UnitTest/testKernel.h", "testKernel", 0 );
5862
const void* args[] = { &a_device };
5963
OrochiUtils::launch1D( kernel, 64, args, 64 );
6064
OrochiUtils::waitForCompletion();

tools/bakeKernel.bat

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
echo // automatically generated, don't edit > ParallelPrimitives/cache/Kernels.h
2+
echo // automatically generated, don't edit > ParallelPrimitives/cache/KernelArgs.h
3+
python tools/stringify.py ./ParallelPrimitives/RadixSortKernels.h >> ParallelPrimitives/cache/Kernels.h
4+
python tools/genArgs.py ./ParallelPrimitives/RadixSortKernels.h >> ParallelPrimitives/cache/KernelArgs.h
5+
6+
python tools/stringify.py ./ParallelPrimitives/RadixSortConfigs.h >> ParallelPrimitives/cache/Kernels.h

tools/bakeKernel.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# mkdir hiprt/cache/
2+
echo "// automatically generated, don't edit" > ParallelPrimitives/cache/Kernels.h
3+
echo "// automatically generated, don't edit" > ParallelPrimitives/cache/KernelArgs.h
4+
python tools/stringify.py ./ParallelPrimitives/RadixSortKernels.h >> ParallelPrimitives/cache/Kernels.h
5+
python tools/genArgs.py ./ParallelPrimitives/RadixSortKernels.h >> ParallelPrimitives/cache/KernelArgs.h
6+
7+
python tools/stringify.py ./ParallelPrimitives/RadixSortConfigs.h >> ParallelPrimitives/cache/Kernels.h

tools/genArgs.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#!/usr/bin/env python
2+
from __future__ import print_function
3+
4+
import sys
5+
import os
6+
7+
def genArgs( fileName, api, includes ):
8+
with open(fileName) as f:
9+
iName = os.path.basename( fileName ).split('.')[0]
10+
11+
print( '#if !defined(ORO_PP_LOAD_FROM_STRING)' )
12+
print( ' static const char** '+iName+'Args = 0;' )
13+
print( '#else' )
14+
print( ' static const char* '+iName+'Args[] = {' )
15+
includes += iName +'Includes[] = {'
16+
for line in f.readlines():
17+
a = line.strip('\r\n')
18+
if a.find('#include') == -1:
19+
continue
20+
if a.find('#include') != -1 and a.find('inl.' + api) != -1:
21+
continue
22+
if (api == 'cl' or api == 'metal') and a.find('.cu') != -1:
23+
continue
24+
if (a.find('"') != -1 and a.find('#include') != -1):
25+
continue
26+
27+
filename = os.path.basename(a.split('<')[1].split('>')[0])
28+
includes += '"' + a.split('<')[1].split('>')[0] + '",'
29+
name = filename.split('.' + api)[0]
30+
name = name.split('.h')[0]
31+
name = api + '_'+name
32+
print ( name + ',' )
33+
print( api + '_'+iName+'};' )
34+
print( '#endif' )
35+
return includes
36+
37+
argvs = sys.argv
38+
39+
files = []
40+
if len(argvs) >= 2:
41+
files.append( argvs[1] )
42+
43+
print( '#pragma once' )
44+
45+
46+
api = 'hip'
47+
48+
# Visit each file
49+
print( 'namespace ' + api + ' {')
50+
51+
includes = 'static const char* '
52+
for s in files:
53+
includes = genArgs(s, api, includes)
54+
includes += '};'
55+
print( includes )
56+
print( '}\t//namespace ' + api)

0 commit comments

Comments
 (0)