@@ -1146,8 +1146,9 @@ void SpirvEmitter::doStmt(const Stmt *stmt,
11461146 // All cases for expressions used as statements
11471147 SpirvInstruction *result = doExpr(expr);
11481148
1149- if (result && result->getKind() == SpirvInstruction::IK_ExecutionMode &&
1150- !attrs.empty()) {
1149+ if (result && !attrs.empty() &&
1150+ (result->getKind() == SpirvInstruction::IK_ExecutionMode ||
1151+ result->getKind() == SpirvInstruction::IK_ExecutionModeId)) {
11511152 // Handle [[vk::ext_capability(..)]] and [[vk::ext_extension(..)]]
11521153 // attributes for vk::ext_execution_mode[_id](..).
11531154 createSpirvIntrInstExt(
@@ -9161,10 +9162,10 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
91619162 retVal = processRawBufferStore(callExpr);
91629163 break;
91639164 case hlsl::IntrinsicOp::IOP_Vkext_execution_mode:
9164- retVal = processIntrinsicExecutionMode(callExpr, false );
9165+ retVal = processIntrinsicExecutionMode(callExpr);
91659166 break;
91669167 case hlsl::IntrinsicOp::IOP_Vkext_execution_mode_id:
9167- retVal = processIntrinsicExecutionMode (callExpr, true );
9168+ retVal = processIntrinsicExecutionModeId (callExpr);
91689169 break;
91699170 case hlsl::IntrinsicOp::IOP_saturate:
91709171 retVal = processIntrinsicSaturate(callExpr);
@@ -15120,8 +15121,7 @@ SpirvEmitter::processCooperativeMatrixGetLength(const CallExpr *call) {
1512015121}
1512115122
1512215123SpirvInstruction *
15123- SpirvEmitter::processIntrinsicExecutionMode(const CallExpr *expr,
15124- bool useIdParams) {
15124+ SpirvEmitter::processIntrinsicExecutionMode(const CallExpr *expr) {
1512515125 llvm::SmallVector<uint32_t, 2> execModesParams;
1512615126 uint32_t exeMode = 0;
1512715127 const auto args = expr->getArgs();
@@ -15145,9 +15145,38 @@ SpirvEmitter::processIntrinsicExecutionMode(const CallExpr *expr,
1514515145 assert(entryFunction != nullptr);
1514615146 assert(exeMode != 0);
1514715147
15148- return spvBuilder.addExecutionMode(
15149- entryFunction, static_cast<spv::ExecutionMode>(exeMode), execModesParams,
15150- expr->getExprLoc(), useIdParams);
15148+ return spvBuilder.addExecutionMode(entryFunction,
15149+ static_cast<spv::ExecutionMode>(exeMode),
15150+ execModesParams, expr->getExprLoc());
15151+ }
15152+
15153+ SpirvInstruction *
15154+ SpirvEmitter::processIntrinsicExecutionModeId(const CallExpr *expr) {
15155+ assert(expr->getNumArgs() > 0);
15156+ uint32_t exeMode = 0;
15157+ const Expr *modeExpr = expr->getArg(0);
15158+ Expr::EvalResult evalResult;
15159+ if (modeExpr->EvaluateAsRValue(evalResult, astContext) &&
15160+ !evalResult.HasSideEffects && evalResult.Val.isInt()) {
15161+ exeMode = evalResult.Val.getInt().getZExtValue();
15162+ } else {
15163+ emitError("The execution mode must be constant integer",
15164+ expr->getExprLoc());
15165+ return nullptr;
15166+ }
15167+
15168+ llvm::SmallVector<SpirvInstruction *, 2> execModesParams;
15169+ const auto args = expr->getArgs();
15170+ for (uint32_t i = 1; i < expr->getNumArgs(); ++i) {
15171+ const Expr *argExpr = args[i];
15172+ SpirvInstruction *argInst = doExpr(argExpr);
15173+ execModesParams.push_back(argInst);
15174+ }
15175+
15176+ assert(entryFunction != nullptr);
15177+ return spvBuilder.addExecutionModeId(entryFunction,
15178+ static_cast<spv::ExecutionMode>(exeMode),
15179+ execModesParams, expr->getExprLoc());
1515115180}
1515215181
1515315182SpirvInstruction *
@@ -15218,8 +15247,9 @@ bool SpirvEmitter::spirvToolsValidate(std::vector<uint32_t> *mod,
1521815247void SpirvEmitter::addDerivativeGroupExecutionMode() {
1521915248 assert(spvContext.isCS());
1522015249
15221- SpirvExecutionMode *numThreadsEm = spvBuilder.getModule()->findExecutionMode(
15222- entryFunction, spv::ExecutionMode::LocalSize);
15250+ SpirvExecutionMode *numThreadsEm =
15251+ cast<SpirvExecutionMode>(spvBuilder.getModule()->findExecutionMode(
15252+ entryFunction, spv::ExecutionMode::LocalSize));
1522315253 auto numThreads = numThreadsEm->getParams();
1522415254
1522515255 // The layout of the quad is determined by the numer of threads in each
0 commit comments