@@ -9484,12 +9484,17 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
94849484 retVal = processIntrinsicPointerCast(callExpr, true);
94859485 break;
94869486 }
9487- INTRINSIC_SPIRV_OP_CASE(ddx, DPdx, true);
9488- INTRINSIC_SPIRV_OP_CASE(ddx_coarse, DPdxCoarse, false);
9489- INTRINSIC_SPIRV_OP_CASE(ddx_fine, DPdxFine, false);
9490- INTRINSIC_SPIRV_OP_CASE(ddy, DPdy, true);
9491- INTRINSIC_SPIRV_OP_CASE(ddy_coarse, DPdyCoarse, false);
9492- INTRINSIC_SPIRV_OP_CASE(ddy_fine, DPdyFine, false);
9487+ case hlsl::IntrinsicOp::IOP_ddx:
9488+ case hlsl::IntrinsicOp::IOP_ddx_coarse:
9489+ case hlsl::IntrinsicOp::IOP_ddx_fine:
9490+ case hlsl::IntrinsicOp::IOP_ddy:
9491+ case hlsl::IntrinsicOp::IOP_ddy_coarse:
9492+ case hlsl::IntrinsicOp::IOP_ddy_fine: {
9493+ retVal = processDerivativeIntrinsic(hlslOpcode, callExpr->getArg(0),
9494+ callExpr->getExprLoc(),
9495+ callExpr->getSourceRange());
9496+ break;
9497+ }
94939498 INTRINSIC_SPIRV_OP_CASE(countbits, BitCount, false);
94949499 INTRINSIC_SPIRV_OP_CASE(fmod, FRem, true);
94959500 INTRINSIC_SPIRV_OP_CASE(fwidth, Fwidth, true);
@@ -9572,6 +9577,77 @@ SpirvEmitter::processIntrinsicFirstbit(const CallExpr *callExpr,
95729577 srcRange);
95739578}
95749579
9580+ SpirvInstruction *SpirvEmitter::processMatrixDerivativeIntrinsic(
9581+ hlsl::IntrinsicOp hlslOpcode, const Expr *arg, SourceLocation loc,
9582+ SourceRange range) {
9583+ const auto actOnEachVec = [this, hlslOpcode, loc, range](
9584+ uint32_t /*index*/, QualType inType,
9585+ QualType outType, SpirvInstruction *curRow) {
9586+ return processDerivativeIntrinsic(hlslOpcode, curRow, loc, range);
9587+ };
9588+
9589+ return processEachVectorInMatrix(arg, arg->getType(), doExpr(arg),
9590+ actOnEachVec, loc, range);
9591+ }
9592+
9593+ SpirvInstruction *
9594+ SpirvEmitter::processDerivativeIntrinsic(hlsl::IntrinsicOp hlslOpcode,
9595+ const Expr *arg, SourceLocation loc,
9596+ SourceRange range) {
9597+ if (isMxNMatrix(arg->getType())) {
9598+ return processMatrixDerivativeIntrinsic(hlslOpcode, arg, loc, range);
9599+ }
9600+ return processDerivativeIntrinsic(hlslOpcode, doExpr(arg), loc, range);
9601+ }
9602+
9603+ SpirvInstruction *SpirvEmitter::processDerivativeIntrinsic(
9604+ hlsl::IntrinsicOp hlslOpcode, SpirvInstruction *arg, SourceLocation loc,
9605+ SourceRange range) {
9606+ QualType returnType = arg->getAstResultType();
9607+ assert(isFloatOrVecOfFloatType(returnType));
9608+
9609+ if (!spvContext.isPS())
9610+ addDerivativeGroupExecutionMode();
9611+ needsLegalization = true;
9612+
9613+ QualType B32Type = astContext.FloatTy;
9614+ uint32_t vectorSize = 0;
9615+ QualType elementType = returnType;
9616+ if (isVectorType(returnType, &elementType, &vectorSize)) {
9617+ B32Type = astContext.getExtVectorType(B32Type, vectorSize);
9618+ }
9619+
9620+ // Derivative operations work on 32-bit floats only. Cast to 32-bit if needed.
9621+ SpirvInstruction *operand = castToType(arg, returnType, B32Type, loc, range);
9622+
9623+ spv::Op opcode = spv::Op::OpNop;
9624+ switch (hlslOpcode) {
9625+ case hlsl::IntrinsicOp::IOP_ddx:
9626+ opcode = spv::Op::OpDPdx;
9627+ break;
9628+ case hlsl::IntrinsicOp::IOP_ddx_coarse:
9629+ opcode = spv::Op::OpDPdxCoarse;
9630+ break;
9631+ case hlsl::IntrinsicOp::IOP_ddx_fine:
9632+ opcode = spv::Op::OpDPdxFine;
9633+ break;
9634+ case hlsl::IntrinsicOp::IOP_ddy:
9635+ opcode = spv::Op::OpDPdy;
9636+ break;
9637+ case hlsl::IntrinsicOp::IOP_ddy_coarse:
9638+ opcode = spv::Op::OpDPdyCoarse;
9639+ break;
9640+ case hlsl::IntrinsicOp::IOP_ddy_fine:
9641+ opcode = spv::Op::OpDPdyFine;
9642+ break;
9643+ };
9644+
9645+ SpirvInstruction *result =
9646+ spvBuilder.createUnaryOp(opcode, B32Type, operand, loc, range);
9647+ result = castToType(result, B32Type, returnType, loc, range);
9648+ return result;
9649+ }
9650+
95759651// Returns true is the given expression can be used as an output parameter.
95769652//
95779653// Warning: this function could return false negatives.
0 commit comments