@@ -59,9 +59,12 @@ class Tutorial : public TutorialBase
5959 geomInput.primitive .triangleMesh = mesh;
6060 buildBvh ( geomInput );
6161
62- hiprtDevicePtr geomTemp = nullptr ;
62+ size_t geomTempSize;
63+ hiprtDevicePtr geomTemp;
6364 hiprtBuildOptions options;
6465 options.buildFlags = hiprtBuildFlagBitCustomBvhImport;
66+ CHECK_HIPRT ( hiprtGetGeometryBuildTemporaryBufferSize ( ctxt, geomInput, options, geomTempSize ) );
67+ CHECK_ORO ( oroMalloc ( reinterpret_cast <oroDeviceptr*>( &geomTemp ), geomTempSize ) );
6568
6669 hiprtGeometry geom;
6770 CHECK_HIPRT ( hiprtCreateGeometry ( ctxt, geomInput, options, geom ) );
@@ -97,6 +100,8 @@ class Tutorial : public TutorialBase
97100 CHECK_ORO ( oroFree ( reinterpret_cast <oroDeviceptr>( mesh.triangleIndices ) ) );
98101 CHECK_ORO ( oroFree ( reinterpret_cast <oroDeviceptr>( mesh.vertices ) ) );
99102 CHECK_ORO ( oroFree ( reinterpret_cast <oroDeviceptr>( pixels ) ) );
103+ CHECK_ORO ( oroFree ( reinterpret_cast <oroDeviceptr>( geomInput.nodeList .leafNodes ) ) );
104+ CHECK_ORO ( oroFree ( reinterpret_cast <oroDeviceptr>( geomInput.nodeList .internalNodes ) ) );
100105
101106 CHECK_HIPRT ( hiprtDestroyGeometry ( ctxt, geom ) );
102107 CHECK_HIPRT ( hiprtDestroyContext ( ctxt ) );
@@ -105,10 +110,11 @@ class Tutorial : public TutorialBase
105110
106111void Tutorial::buildBvh ( hiprtGeometryBuildInput& buildInput )
107112{
108- std::vector<hiprtBvhNode> nodes;
113+ std::vector<hiprtInternalNode> internalNodes;
114+ std::vector<Aabb> primBoxes;
109115 if ( buildInput.type == hiprtPrimitiveTypeTriangleMesh )
110116 {
111- std::vector<Aabb> primBoxes ( buildInput.primitive .triangleMesh .triangleCount );
117+ primBoxes. resize ( buildInput.primitive .triangleMesh .triangleCount );
112118 std::vector<uint8_t > verticesRaw (
113119 buildInput.primitive .triangleMesh .vertexCount * buildInput.primitive .triangleMesh .vertexStride );
114120 std::vector<uint8_t > trianglesRaw (
@@ -136,11 +142,11 @@ void Tutorial::buildBvh( hiprtGeometryBuildInput& buildInput )
136142 primBoxes[i].grow ( v1 );
137143 primBoxes[i].grow ( v2 );
138144 }
139- BvhBuilder::build ( buildInput.primitive .triangleMesh .triangleCount , primBoxes, nodes );
145+ BvhBuilder::build ( buildInput.primitive .triangleMesh .triangleCount , primBoxes, internalNodes );
140146 }
141147 else if ( buildInput.type == hiprtPrimitiveTypeAABBList )
142148 {
143- std::vector<Aabb> primBoxes ( buildInput.primitive .aabbList .aabbCount );
149+ primBoxes. resize ( buildInput.primitive .aabbList .aabbCount );
144150 std::vector<uint8_t > primBoxesRaw ( buildInput.primitive .aabbList .aabbCount * buildInput.primitive .aabbList .aabbStride );
145151 CHECK_ORO ( oroMemcpyDtoH (
146152 primBoxesRaw.data (),
@@ -153,13 +159,32 @@ void Tutorial::buildBvh( hiprtGeometryBuildInput& buildInput )
153159 primBoxes[i].m_min = make_float3 ( ptr[0 ] );
154160 primBoxes[i].m_max = make_float3 ( ptr[1 ] );
155161 }
156- BvhBuilder::build ( buildInput.primitive .aabbList .aabbCount , primBoxes, nodes );
162+ BvhBuilder::build ( buildInput.primitive .aabbList .aabbCount , primBoxes, internalNodes );
157163 }
158- CHECK_ORO (
159- oroMalloc ( reinterpret_cast <oroDeviceptr*>( &buildInput.nodeList .nodes ), nodes.size () * sizeof ( hiprtBvhNode ) ) );
164+
165+ std::vector<hiprtLeafNode> leafNodes ( primBoxes.size () );
166+ for ( uint32_t i = 0 ; i < primBoxes.size (); ++i )
167+ {
168+ leafNodes[i].primID = i;
169+ leafNodes[i].aabbMin = primBoxes[i].m_min ;
170+ leafNodes[i].aabbMax = primBoxes[i].m_max ;
171+ }
172+
173+ buildInput.nodeList .nodeCount = static_cast <uint32_t >( leafNodes.size () );
174+
175+ CHECK_ORO ( oroMalloc (
176+ reinterpret_cast <oroDeviceptr*>( &buildInput.nodeList .leafNodes ), leafNodes.size () * sizeof ( hiprtLeafNode ) ) );
177+ CHECK_ORO ( oroMemcpyHtoD (
178+ reinterpret_cast <oroDeviceptr>( buildInput.nodeList .leafNodes ),
179+ leafNodes.data (),
180+ leafNodes.size () * sizeof ( hiprtLeafNode ) ) );
181+ CHECK_ORO ( oroMalloc (
182+ reinterpret_cast <oroDeviceptr*>( &buildInput.nodeList .internalNodes ),
183+ internalNodes.size () * sizeof ( hiprtInternalNode ) ) );
160184 CHECK_ORO ( oroMemcpyHtoD (
161- reinterpret_cast <oroDeviceptr>( buildInput.nodeList .nodes ), nodes.data (), nodes.size () * sizeof ( hiprtBvhNode ) ) );
162- buildInput.nodeList .nodeCount = static_cast <uint32_t >( nodes.size () );
185+ reinterpret_cast <oroDeviceptr>( buildInput.nodeList .internalNodes ),
186+ internalNodes.data (),
187+ internalNodes.size () * sizeof ( hiprtInternalNode ) ) );
163188}
164189
165190int main ( int argc, char ** argv )
0 commit comments