Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions Orochi/Orochi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -799,20 +799,23 @@ struct ioroCtx_t

struct ioroDevice
{
private:
oroU32 m_api : 4;
oroU32 m_deviceIdx : 16;
static constexpr oroU32 ApiBits = 5;
static constexpr oroU32 ApiMask = ( 1u << ApiBits ) - 1u;
static constexpr oroU32 DeviceBits = 16;
static constexpr oroU32 DeviceMask = ( 1u << DeviceBits ) - 1u;

public:
ioroDevice( int src = 0)
oroU32 m_value = 0;

explicit ioroDevice( oroU32 packed = 0 )
: m_value( packed )
{
((int*)this)[0] = src;
}

oroApi getApi() const { return (oroApi)m_api; }
void setApi(oroApi api) { m_api = api; }
int getDevice() const { return m_deviceIdx; }
void setDevice( int d ) { m_deviceIdx = d; }
oroU32 packed() const { return m_value; }
oroApi getApi() const { return (oroApi)( m_value & ApiMask ); }
void setApi(oroApi api) { m_value = ( m_value & ~ApiMask ) | ( (oroU32)api & ApiMask ); }
int getDevice() const { return (int)( ( m_value >> ApiBits ) & DeviceMask ); }
void setDevice( int d ) { m_value = ( m_value & ~( DeviceMask << ApiBits ) ) | ( ( (oroU32)d & DeviceMask ) << ApiBits ); }
};

inline
Expand Down Expand Up @@ -932,15 +935,16 @@ oroError oroCtxCreateFromRawDestroy( oroCtx ctx )

oroDevice oroGetRawDevice( oroDevice dev )
{
ioroDevice d( dev );
return d.getDevice();
ioroDevice d( (oroU32)dev );
return (oroDevice)d.getDevice();
}

oroDevice oroSetRawDevice( oroApi api, oroDevice dev )
{
ioroDevice d( dev );
ioroDevice d( 0 );
d.setApi( api );
return *(oroDevice*)&d;
d.setDevice( (int)dev );
return (oroDevice)d.packed();
}

//=================================
Expand Down Expand Up @@ -1078,7 +1082,7 @@ oroError OROAPI oroGetDeviceCount(int* count, oroApi iapi)

oroError OROAPI oroGetDeviceProperties(oroDeviceProp_t* props, oroDevice dev)
{
ioroDevice d( dev );
ioroDevice d( (oroU32)dev );
int deviceId = d.getDevice();
oroApi api = d.getApi();
*props = {};
Expand All @@ -1105,7 +1109,7 @@ oroError OROAPI oroDeviceGet(oroDevice* device, int ordinal )
auto e = hipDeviceGet(&t, ordinal);
d.setApi( api );
d.setDevice( t );
*(ioroDevice*)device = d;
*device = (oroDevice)d.packed();
return hip2oro(e);
}
if (api & ORO_API_CUDADRIVER)
Expand All @@ -1115,7 +1119,7 @@ oroError OROAPI oroDeviceGet(oroDevice* device, int ordinal )
auto e = CU4ORO::cuDeviceGet(&t, ordinal);
d.setApi(api);
d.setDevice(t);
*(ioroDevice*)device = d;
*device = (oroDevice)d.packed();
return cu2oro(e);
#endif
}
Expand All @@ -1124,7 +1128,7 @@ oroError OROAPI oroDeviceGet(oroDevice* device, int ordinal )

oroError OROAPI oroDeviceGetName(char* name, int len, oroDevice dev)
{
ioroDevice d( dev );
ioroDevice d( (oroU32)dev );
__ORO_FUNCX( d.getApi(),
CU4ORO::cuDeviceGetName(name, len, d.getDevice() ),
hipDeviceGetName(name, len, d.getDevice() )
Expand All @@ -1136,7 +1140,7 @@ oroError OROAPI oroDeviceGetName(char* name, int len, oroDevice dev)

oroError OROAPI oroDeviceGetAttribute(int* pi, oroDeviceAttribute_t attrib, oroDevice dev)
{
ioroDevice d( dev );
ioroDevice d( (oroU32)dev );
__ORO_FUNCX( d.getApi(),
CU4ORO::cuDeviceGetAttribute( pi, (CU4ORO::CUdevice_attribute)attrib, d.getDevice() ),
hipDeviceGetAttribute( pi, (hipDeviceAttribute_t)attrib, d.getDevice() ) );
Expand All @@ -1145,7 +1149,7 @@ oroError OROAPI oroDeviceGetAttribute(int* pi, oroDeviceAttribute_t attrib, oroD

oroError OROAPI oroCtxCreate(oroCtx* pctx, unsigned int flags, oroDevice dev)
{
ioroDevice d( dev );
ioroDevice d( (oroU32)dev );
ioroCtx_t* ctxt = new ioroCtx_t;
ctxt->setApi( d.getApi() );
(*pctx) = ctxt;
Expand Down