Skip to content

Commit 4273354

Browse files
authored
[SPIR-V] Fix InterlockedMin/Max codegen (#6609)
RWByteAddressBuffer has overloads for InterlockedMin and InterlockedMax for signed ints that were failing to compile due to mismatched types in the generated SPIR-V instruction. This adds the missing cast if necessary. At the same time, some redundant code is removed from the InterlockedMin/Max intrinsic non-member functions' codegen to modify the opcode. If it was necessary in the past, the frontend has since been fixed and it is no longer necessary. Tests to verify these combinations and the necessary implicit casts have also been added. Fixes #3196 Related to #4189, #6254, #5707
1 parent e61ea50 commit 4273354

3 files changed

Lines changed: 48 additions & 16 deletions

File tree

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3873,11 +3873,24 @@ SpirvInstruction *SpirvEmitter::processRWByteAddressBufferAtomicMethods(
38733873
expr->getArg(3)->getLocStart(), range);
38743874
}
38753875
} else {
3876-
auto *value = doExpr(expr->getArg(1));
3876+
const Expr *value = expr->getArg(1);
3877+
SpirvInstruction *valueInstr = doExpr(expr->getArg(1));
3878+
3879+
// Since a RWAB is represented by an array of 32-bit unsigned integers, the
3880+
// destination pointee type will always be unsigned, and thus the SPIR-V
3881+
// instruction's result type and value type must also be unsigned. The
3882+
// signedness of the opcode is determined correctly by frontend and will
3883+
// correctly determine the signedness of the actual operation, but the
3884+
// necessary argument type cast will not be added by the frontend in the
3885+
// case of a signed value.
3886+
valueInstr =
3887+
castToType(valueInstr, value->getType(), astContext.UnsignedIntTy,
3888+
value->getExprLoc(), range);
3889+
38773890
SpirvInstruction *originalVal = spvBuilder.createAtomicOp(
38783891
translateAtomicHlslOpcodeToSpirvOpcode(opcode),
38793892
astContext.UnsignedIntTy, ptr, spv::Scope::Device,
3880-
spv::MemorySemanticsMask::MaskNone, value,
3893+
spv::MemorySemanticsMask::MaskNone, valueInstr,
38813894
expr->getCallee()->getExprLoc(), range);
38823895
if (expr->getNumArgs() > 2) {
38833896
originalVal = castToType(originalVal, astContext.UnsignedIntTy,
@@ -9203,21 +9216,7 @@ SpirvEmitter::processIntrinsicInterlockedMethod(const CallExpr *expr,
92039216
writeToOutputArg(originalVal, expr, 3);
92049217
} else {
92059218
auto *value = doArg(expr, 1);
9206-
// Since these atomic operations write through the provided pointer, the
9207-
// signed vs. unsigned opcode must be decided based on the pointee type
9208-
// of the first argument. However, the frontend decides the opcode based on
9209-
// the second argument (value). Therefore, the HLSL opcode provided by the
9210-
// frontend may be wrong. Therefore we need the following code to make sure
9211-
// we are using the correct SPIR-V opcode.
92129219
spv::Op atomicOp = translateAtomicHlslOpcodeToSpirvOpcode(opcode);
9213-
if (atomicOp == spv::Op::OpAtomicUMax && baseType->isSignedIntegerType())
9214-
atomicOp = spv::Op::OpAtomicSMax;
9215-
if (atomicOp == spv::Op::OpAtomicSMax && baseType->isUnsignedIntegerType())
9216-
atomicOp = spv::Op::OpAtomicUMax;
9217-
if (atomicOp == spv::Op::OpAtomicUMin && baseType->isSignedIntegerType())
9218-
atomicOp = spv::Op::OpAtomicSMin;
9219-
if (atomicOp == spv::Op::OpAtomicSMin && baseType->isUnsignedIntegerType())
9220-
atomicOp = spv::Op::OpAtomicUMin;
92219220
auto *originalVal = spvBuilder.createAtomicOp(
92229221
atomicOp, baseType, ptr, scope, spv::MemorySemanticsMask::MaskNone,
92239222
value, srcLoc);

tools/clang/test/CodeGenSPIRV/intrinsics.interlocked-methods.cs.hlsl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,22 @@ void main()
4949
// CHECK-NEXT: OpStore %original_i_val [[asmax29]]
5050
InterlockedMax(dest_i, 10, original_i_val);
5151

52+
// CHECK: [[val30:%[0-9]+]] = OpBitcast %uint %int_n5
53+
// CHECK-NEXT: [[aumax:%[0-9]+]] = OpAtomicUMax %uint %dest_u %uint_2 %uint_0 [[val30]]
54+
// CHECK-NEXT: [[res30:%[0-9]+]] = OpBitcast %int [[aumax]]
55+
// CHECK-NEXT: OpStore %original_i_val [[res30]]
56+
InterlockedMax(dest_u, -5, original_i_val);
57+
5258
// CHECK: [[umin30:%[0-9]+]] = OpAtomicUMin %uint %dest_u %uint_2 %uint_0 %uint_10
5359
// CHECK-NEXT: OpStore %original_u_val [[umin30]]
5460
InterlockedMin(dest_u, 10, original_u_val);
5561

62+
// CHECK: [[val31:%[0-9]+]] = OpBitcast %int %uint_5
63+
// CHECK-NEXT: [[asmin:%[0-9]+]] = OpAtomicSMin %int %dest_i %uint_2 %uint_0 [[val31]]
64+
// CHECK-NEXT: [[res31:%[0-9]+]] = OpBitcast %uint [[asmin]]
65+
// CHECK-NEXT: OpStore %original_u_val [[res31]]
66+
InterlockedMin(dest_i, 5u, original_u_val);
67+
5668
// CHECK: [[val2_31:%[0-9]+]] = OpLoad %int %val2
5769
// CHECK-NEXT: [[or31:%[0-9]+]] = OpAtomicOr %int %dest_i %uint_2 %uint_0 [[val2_31]]
5870
// CHECK-NEXT: OpStore %original_i_val [[or31]]

tools/clang/test/CodeGenSPIRV/method.rw-byte-address-buffer.atomic.hlsl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,5 +116,26 @@ float4 main() : SV_Target
116116
// CHECK-NOT: [[val]]
117117
myBuffer.InterlockedCompareStore(/*offset=*/16, /*compare_value=*/30, /*value=*/42);
118118

119+
// CHECK: [[offset_19:%[0-9]+]] = OpShiftRightLogical %uint %uint_0 %uint_2
120+
// CHECK-NEXT: [[ptr_19:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[offset_19]]
121+
// CHECK-NEXT: [[val_19:%[0-9]+]] = OpBitcast %uint %int_n1
122+
// CHECK-NEXT: {{%[0-9]+}} = OpAtomicSMin %uint [[ptr_19]] %uint_1 %uint_0 [[val_19]]
123+
myBuffer.InterlockedMin(0u, -1);
124+
125+
// CHECK: [[offset_20:%[0-9]+]] = OpShiftRightLogical %uint %uint_0 %uint_2
126+
// CHECK-NEXT: [[ptr_20:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[offset_20]]
127+
// CHECK-NEXT: [[val_20:%[0-9]+]] = OpBitcast %uint %int_n1
128+
// CHECK-NEXT: [[res_20:%[0-9]+]] = OpAtomicSMax %uint [[ptr_20]] %uint_1 %uint_0 [[val_20]]
129+
// CHECK-NEXT: OpStore %originalVal [[res_20]]
130+
myBuffer.InterlockedMax(0u, -1, originalVal);
131+
132+
// CHECK: [[offset_21:%[0-9]+]] = OpShiftRightLogical %uint %uint_0 %uint_2
133+
// CHECK-NEXT: [[ptr_21:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[offset_21]]
134+
// CHECK-NEXT: [[val_21:%[0-9]+]] = OpBitcast %uint %int_n1
135+
// CHECK-NEXT: [[res_21:%[0-9]+]] = OpAtomicSMin %uint [[ptr_21]] %uint_1 %uint_0 [[val_21]]
136+
// CHECK-NEXT: [[res_21b:%[0-9]+]] = OpBitcast %int [[res_21]]
137+
// CHECK-NEXT: OpStore %originalValAsInt [[res_21b]]
138+
myBuffer.InterlockedMin(0u, -1, originalValAsInt);
139+
119140
return 1.0;
120141
}

0 commit comments

Comments
 (0)