@@ -74,6 +74,7 @@ pprtcResult nvrtc2pp( nvrtcResult a )
7474}
7575
7676#define __PP_FUNC1 ( cuname, hipname ) if ( s_api == API_CUDA ) return cu2pp( cu##cuname ); if ( s_api == API_HIP ) return hip2pp( hip##hipname );
77+ // #define __PP_FUNC1( cuname, hipname ) if( s_api == API_CUDA || API == API_CUDA ) return cu2pp( cu##cuname ); if( s_api == API_HIP || API == API_HIP ) return hip2pp( hip##hipname );
7778#define __PP_FUNC ( name ) if ( s_api == API_CUDA ) return cu2pp( cu##name ); if ( s_api == API_HIP ) return hip2pp( hip##name );
7879#define __PP_CTXT_FUNC ( name ) __PP_FUNC1(Ctx##name, name)
7980// #define __PP_CTXT_FUNC( name ) if( s_api == API_CUDA ) return cu2pp( cuCtx##name ); if( s_api == API_HIP ) return hip2pp( hip##name );
@@ -93,15 +94,16 @@ ppError PPAPI ppGetErrorString(ppError error, const char** pStr)
9394 return ppErrorUnknown;
9495}
9596
97+ template <Api API>
9698ppError PPAPI ppInit (unsigned int Flags)
9799{
98- __PP_FUNC ( Init (Flags) );
99100/*
100- if( s_api == API_CUDA )
101- return cu2pp( cuInit( Flags ) );
102- if( s_api == API_HIP )
103- return hip2pp( hipInit( Flags ) );
101+ if( s_api == API_CUDA || API == API_CUDA )
102+ printf("cuda\n");
103+ if( s_api == API_HIP || API == API_HIP )
104+ printf("hip\n" );
104105*/
106+ __PP_FUNC ( Init (Flags) );
105107 return ppErrorUnknown;
106108}
107109ppError PPAPI ppDriverGetVersion (int * driverVersion)
@@ -183,13 +185,24 @@ ppError PPAPI ppCtxCreate(ppCtx* pctx, unsigned int flags, ppDevice dev)
183185}
184186ppError PPAPI ppCtxDestroy (ppCtx ctx)
185187{
188+ __PP_FUNC1 ( CtxDestroy ( *ppCtx2cu (&ctx) ), CtxDestroy ( *ppCtx2hip (&ctx) ) );
186189 return ppErrorUnknown;
187190}
188191/*
189192ppError PPAPI ppCtxPushCurrent(ppCtx ctx);
190193ppError PPAPI ppCtxPopCurrent(ppCtx* pctx);
191- ppError PPAPI ppCtxSetCurrent(ppCtx ctx);
192- ppError PPAPI ppCtxGetCurrent(ppCtx* pctx);
194+ */
195+ ppError PPAPI ppCtxSetCurrent (ppCtx ctx)
196+ {
197+ __PP_FUNC1 ( CtxSetCurrent ( *ppCtx2cu (&ctx) ), CtxSetCurrent ( *ppCtx2hip (&ctx) ) );
198+ return ppErrorUnknown;
199+ }
200+ ppError PPAPI ppCtxGetCurrent (ppCtx* pctx)
201+ {
202+ __PP_FUNC1 ( CtxGetCurrent ( ppCtx2cu (pctx) ), CtxGetCurrent ( ppCtx2hip (pctx) ) );
203+ return ppErrorUnknown;
204+ }
205+ /*
193206ppError PPAPI ppCtxGetDevice(ppDevice* device);
194207ppError PPAPI ppCtxGetFlags(unsigned int* flags);
195208*/
@@ -222,7 +235,7 @@ ppError PPAPI ppModuleLoadData(ppModule* module, const void* image)
222235 __PP_FUNC1 ( ModuleLoadData ( (CUmodule*)module , image ), ModuleLoadData ( (hipModule_t*)module , image ) );
223236 return ppErrorUnknown;
224237}
225- ppError PPAPI ppModuleLoadDataEx (ppModule* module , const void * image, unsigned int numOptions, hipJitOption * options, void ** optionValues)
238+ ppError PPAPI ppModuleLoadDataEx (ppModule* module , const void * image, unsigned int numOptions, ppJitOption * options, void ** optionValues)
226239{
227240 __PP_FUNC1 ( ModuleLoadDataEx ( (CUmodule*)module , image, numOptions, (CUjit_option*)options, optionValues ),
228241 ModuleLoadDataEx ( (hipModule_t*)module , image, numOptions, (hipJitOption*)options, optionValues ) );
@@ -393,3 +406,4 @@ ppError PPAPI ppStreamCreate(ppStream* stream)
393406 return ppErrorUnknown;
394407}
395408
409+
0 commit comments