Skip to content

Commit 8723771

Browse files
committed
Validate SPIR-V once per blob
1 parent 3541a9d commit 8723771

3 files changed

Lines changed: 25 additions & 12 deletions

File tree

examples_tests

include/nbl/asset/utils/ISPIRVEntryPointTrimmer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class NBL_API2 ISPIRVEntryPointTrimmer final : public core::IReferenceCounted
4949

5050
Result trim(const ICPUBuffer* spirvBuffer, const core::set<EntryPoint>& entryPoints, system::logger_opt_ptr logger = nullptr) const;
5151
bool ensureValidated(const ICPUBuffer* spirvBuffer, system::logger_opt_ptr logger = nullptr) const;
52+
void markValidated(const ICPUBuffer* spirvBuffer) const;
5253

5354
inline core::smart_refctd_ptr<const IShader> trim(const IShader* shader, const core::set<EntryPoint>& entryPoints, system::logger_opt_ptr logger = nullptr) const
5455
{

src/nbl/asset/utils/ISPIRVEntryPointTrimmer.cpp

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,17 @@ static bool validate(const uint32_t* binary, uint32_t binarySize, nbl::system::l
6565
return core.Validate(binary, binarySize, validatorOptions);
6666
}
6767

68-
bool ISPIRVEntryPointTrimmer::ensureValidated(const ICPUBuffer* spirvBuffer, system::logger_opt_ptr logger) const
68+
static nbl::core::blake3_hash_t getContentHash(const ICPUBuffer* spirvBuffer)
6969
{
7070
auto contentHash = spirvBuffer->getContentHash();
7171
if (contentHash == ICPUBuffer::INVALID_HASH)
7272
contentHash = spirvBuffer->computeContentHash();
73+
return contentHash;
74+
}
75+
76+
bool ISPIRVEntryPointTrimmer::ensureValidated(const ICPUBuffer* spirvBuffer, system::logger_opt_ptr logger) const
77+
{
78+
const auto contentHash = getContentHash(spirvBuffer);
7379

7480
{
7581
std::lock_guard lock(m_validationCacheMutex);
@@ -84,12 +90,18 @@ bool ISPIRVEntryPointTrimmer::ensureValidated(const ICPUBuffer* spirvBuffer, sys
8490

8591
{
8692
std::lock_guard lock(m_validationCacheMutex);
87-
m_validatedSpirvHashes.insert(contentHash);
93+
m_validatedSpirvHashes.emplace(contentHash);
8894
}
8995

9096
return true;
9197
}
9298

99+
void ISPIRVEntryPointTrimmer::markValidated(const ICPUBuffer* spirvBuffer) const
100+
{
101+
std::lock_guard lock(m_validationCacheMutex);
102+
m_validatedSpirvHashes.emplace(getContentHash(spirvBuffer));
103+
}
104+
93105
ISPIRVEntryPointTrimmer::Result ISPIRVEntryPointTrimmer::trim(const ICPUBuffer* spirvBuffer, const core::set<EntryPoint>& entryPoints, system::logger_opt_ptr logger) const
94106
{
95107
const auto* spirv = static_cast<const uint32_t*>(spirvBuffer->getPointer());
@@ -143,6 +155,15 @@ ISPIRVEntryPointTrimmer::Result ISPIRVEntryPointTrimmer::trim(const ICPUBuffer*
143155
return { length, opcode };
144156
};
145157

158+
if (!ensureValidated(spirvBuffer, logger))
159+
{
160+
logger.log("SPIR-V is not valid", system::ILogger::ELL_ERROR);
161+
return Result{
162+
.spirv = nullptr,
163+
.isSuccess = false,
164+
};
165+
}
166+
146167
{
147168
auto probeOffset = HEADER_SIZE;
148169
auto totalEntryPoints = 0u;
@@ -202,15 +223,6 @@ ISPIRVEntryPointTrimmer::Result ISPIRVEntryPointTrimmer::trim(const ICPUBuffer*
202223
}
203224
}
204225

205-
if (!ensureValidated(spirvBuffer, logger))
206-
{
207-
logger.log("SPIR-V is not valid", system::ILogger::ELL_ERROR);
208-
return Result{
209-
.spirv = nullptr,
210-
.isSuccess = false,
211-
};
212-
}
213-
214226
auto foundEntryPoint = 0;
215227

216228
// Keep in mind about this layout while reading all the code below: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#LogicalLayout

0 commit comments

Comments
 (0)