Skip to content

Commit 2c1fd89

Browse files
authored
spirv-opt: Handle ID overflow in SplitCombinedImageSamplerPass (#6406)
This pass creates new instructions without checking if the new IDs will overflow. This can lead to a crash if the module has many IDs. This CL adds checks for ID overflow and returns an error if it happens.
1 parent 1944e8d commit 2c1fd89

1 file changed

Lines changed: 34 additions & 12 deletions

File tree

source/opt/split_combined_image_sampler_pass.cpp

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ Instruction* SplitCombinedImageSamplerPass::GetSamplerType() {
142142
analysis::Sampler s;
143143
uint32_t sampler_type_id = type_mgr_->GetTypeInstruction(&s);
144144
sampler_type_ = def_use_mgr_->GetDef(sampler_type_id);
145+
if (sampler_type_ == nullptr) return nullptr;
145146
assert(first_sampled_image_type_);
146147
sampler_type_->InsertBefore(first_sampled_image_type_);
147148
RegisterNewGlobal(sampler_type_->result_id());
@@ -169,6 +170,7 @@ std::pair<Instruction*, Instruction*> SplitCombinedImageSamplerPass::SplitType(
169170
auto* image_type =
170171
def_use_mgr_->GetDef(combined_kind_type.GetSingleWordInOperand(0));
171172
auto* sampler_type = GetSamplerType();
173+
if (!sampler_type) return {nullptr, nullptr};
172174
type_remap_[combined_kind_type.result_id()] = {&combined_kind_type,
173175
image_type, sampler_type};
174176
return {image_type, sampler_type};
@@ -187,7 +189,9 @@ std::pair<Instruction*, Instruction*> SplitCombinedImageSamplerPass::SplitType(
187189
// this defensively.
188190
if (image_pointee && sampler_pointee) {
189191
auto* ptr_image = MakeUniformConstantPointer(image_pointee);
192+
if (!ptr_image) return {nullptr, nullptr};
190193
auto* ptr_sampler = MakeUniformConstantPointer(sampler_pointee);
194+
if (!ptr_sampler) return {nullptr, nullptr};
191195
type_remap_[combined_kind_type.result_id()] = {
192196
&combined_kind_type, ptr_image, ptr_sampler};
193197
return {ptr_image, ptr_sampler};
@@ -207,18 +211,22 @@ std::pair<Instruction*, Instruction*> SplitCombinedImageSamplerPass::SplitType(
207211
analysis::Array array_image_ty(image_ty, array_ty->length_info());
208212
const uint32_t array_image_ty_id =
209213
type_mgr_->GetTypeInstruction(&array_image_ty);
214+
if (array_image_ty_id == 0) return {nullptr, nullptr};
210215
auto* array_image_ty_inst = def_use_mgr_->GetDef(array_image_ty_id);
211216
if (!IsKnownGlobal(array_image_ty_id)) {
212217
array_image_ty_inst->InsertBefore(&combined_kind_type);
213218
RegisterNewGlobal(array_image_ty_id);
214219
// GetTypeInstruction also updated the def-use manager.
215220
}
216221

222+
auto* sampler_ty_inst = GetSamplerType();
223+
if (!sampler_ty_inst) return {nullptr, nullptr};
217224
analysis::Array sampler_array_ty(
218-
type_mgr_->GetType(GetSamplerType()->result_id()),
225+
type_mgr_->GetType(sampler_ty_inst->result_id()),
219226
array_ty->length_info());
220227
const uint32_t array_sampler_ty_id =
221228
type_mgr_->GetTypeInstruction(&sampler_array_ty);
229+
if (array_sampler_ty_id == 0) return {nullptr, nullptr};
222230
auto* array_sampler_ty_inst = def_use_mgr_->GetDef(array_sampler_ty_id);
223231
if (!IsKnownGlobal(array_sampler_ty_id)) {
224232
array_sampler_ty_inst->InsertBefore(&combined_kind_type);
@@ -240,17 +248,21 @@ std::pair<Instruction*, Instruction*> SplitCombinedImageSamplerPass::SplitType(
240248
analysis::RuntimeArray array_image_ty(image_ty);
241249
const uint32_t array_image_ty_id =
242250
type_mgr_->GetTypeInstruction(&array_image_ty);
251+
if (array_image_ty_id == 0) return {nullptr, nullptr};
243252
auto* array_image_ty_inst = def_use_mgr_->GetDef(array_image_ty_id);
244253
if (!IsKnownGlobal(array_image_ty_id)) {
245254
array_image_ty_inst->InsertBefore(&combined_kind_type);
246255
RegisterNewGlobal(array_image_ty_id);
247256
// GetTypeInstruction also updated the def-use manager.
248257
}
249258

259+
auto* sampler_ty_inst = GetSamplerType();
260+
if (!sampler_ty_inst) return {nullptr, nullptr};
250261
analysis::RuntimeArray sampler_array_ty(
251-
type_mgr_->GetType(GetSamplerType()->result_id()));
262+
type_mgr_->GetType(sampler_ty_inst->result_id()));
252263
const uint32_t array_sampler_ty_id =
253264
type_mgr_->GetTypeInstruction(&sampler_array_ty);
265+
if (array_sampler_ty_id == 0) return {nullptr, nullptr};
254266
auto* array_sampler_ty_inst = def_use_mgr_->GetDef(array_sampler_ty_id);
255267
if (!IsKnownGlobal(array_sampler_ty_id)) {
256268
array_sampler_ty_inst->InsertBefore(&combined_kind_type);
@@ -273,14 +285,14 @@ spv_result_t SplitCombinedImageSamplerPass::RemapVar(
273285
// Create an image variable, and a sampler variable.
274286
auto* combined_var_type = def_use_mgr_->GetDef(combined_var->type_id());
275287
auto [ptr_image_ty, ptr_sampler_ty] = SplitType(*combined_var_type);
276-
assert(ptr_image_ty);
277-
assert(ptr_sampler_ty);
278-
// TODO(1841): Handle id overflow.
288+
if (!ptr_image_ty || !ptr_sampler_ty) return SPV_ERROR_INTERNAL;
279289
Instruction* sampler_var = builder.AddVariable(
280290
ptr_sampler_ty->result_id(), SpvStorageClassUniformConstant);
281-
// TODO(1841): Handle id overflow.
291+
if (sampler_var == nullptr) return SPV_ERROR_INTERNAL;
282292
Instruction* image_var = builder.AddVariable(ptr_image_ty->result_id(),
283293
SpvStorageClassUniformConstant);
294+
if (image_var == nullptr) return SPV_ERROR_INTERNAL;
295+
284296
modified_ = true;
285297
return RemapUses(combined_var, image_var, sampler_var);
286298
}
@@ -356,12 +368,12 @@ spv_result_t SplitCombinedImageSamplerPass::RemapUses(
356368

357369
// Create loads for the image part and sampler part.
358370
builder.SetInsertPoint(load);
359-
// TODO(1841): Handle id overflow.
360371
auto* image = builder.AddLoad(PointeeTypeId(use.image_part),
361372
use.image_part->result_id());
362-
// TODO(1841): Handle id overflow.
373+
if (!image) return SPV_ERROR_INTERNAL;
363374
auto* sampler = builder.AddLoad(PointeeTypeId(use.sampler_part),
364375
use.sampler_part->result_id());
376+
if (!sampler) return SPV_ERROR_INTERNAL;
365377

366378
// Move decorations, such as RelaxedPrecision.
367379
auto* deco_mgr = context()->get_decoration_mgr();
@@ -372,6 +384,7 @@ spv_result_t SplitCombinedImageSamplerPass::RemapUses(
372384
// Create a sampled image from the loads of the two parts.
373385
auto* sampled_image = builder.AddSampledImage(
374386
load->type_id(), image->result_id(), sampler->result_id());
387+
if (!sampled_image) return SPV_ERROR_INTERNAL;
375388
// Replace the original sampled image value with the new one.
376389
std::unordered_set<Instruction*> users;
377390
def_use_mgr_->ForEachUse(
@@ -463,14 +476,18 @@ spv_result_t SplitCombinedImageSamplerPass::RemapUses(
463476

464477
auto [result_image_part_ty, result_sampler_part_ty] =
465478
SplitType(*def_use_mgr_->GetDef(original_access_chain->type_id()));
466-
// TODO(1841): Handle id overflow.
479+
if (!result_image_part_ty || !result_sampler_part_ty)
480+
return Fail() << "failed to split type for access chain";
467481
auto* result_image_part = builder.AddOpcodeAccessChain(
468482
use.user->opcode(), result_image_part_ty->result_id(),
469483
use.image_part->result_id(), indices);
470-
// TODO(1841): Handle id overflow.
484+
if (!result_image_part)
485+
return Fail() << "failed to create access chain for image part";
471486
auto* result_sampler_part = builder.AddOpcodeAccessChain(
472487
use.user->opcode(), result_sampler_part_ty->result_id(),
473488
use.sampler_part->result_id(), indices);
489+
if (!result_sampler_part)
490+
return Fail() << "failed to create access chain for sampler part";
474491

475492
// Remap uses of the original access chain.
476493
add_remap(original_access_chain, result_image_part,
@@ -521,8 +538,7 @@ spv_result_t SplitCombinedImageSamplerPass::RemapFunctions() {
521538
if (combined_types_.find(param_ty_id) != combined_types_.end()) {
522539
auto* param_type = def_use_mgr_->GetDef(param_ty_id);
523540
auto [image_type, sampler_type] = SplitType(*param_type);
524-
assert(image_type);
525-
assert(sampler_type);
541+
if (!image_type || !sampler_type) return SPV_ERROR_INTERNAL;
526542
// The image and sampler types must already exist, so there is no
527543
// need to move them to the right spot.
528544
new_params.push_back(type_mgr_->GetType(image_type->result_id()));
@@ -579,6 +595,11 @@ spv_result_t SplitCombinedImageSamplerPass::RemapFunctions() {
579595
auto* combined_inst = param.release();
580596
auto* combined_type = def_use_mgr_->GetDef(combined_inst->type_id());
581597
auto [image_type, sampler_type] = SplitType(*combined_type);
598+
if (!image_type || !sampler_type) {
599+
error = true;
600+
return;
601+
}
602+
582603
uint32_t image_param_id = context()->TakeNextId();
583604
if (image_param_id == 0) {
584605
error = true;
@@ -621,6 +642,7 @@ Instruction* SplitCombinedImageSamplerPass::MakeUniformConstantPointer(
621642
Instruction* pointee) {
622643
uint32_t ptr_id = type_mgr_->FindPointerToType(
623644
pointee->result_id(), spv::StorageClass::UniformConstant);
645+
if (ptr_id == 0) return nullptr;
624646
auto* ptr = def_use_mgr_->GetDef(ptr_id);
625647
if (!IsKnownGlobal(ptr_id)) {
626648
// The pointer type was created at the end. Put it right after the

0 commit comments

Comments
 (0)