@@ -1460,6 +1460,10 @@ typedef APInt(__cdecl *IntBinaryEvalFuncType)(const APInt &, const APInt &);
14601460typedef float (__cdecl *FloatBinaryEvalFuncType)(float , float );
14611461typedef double (__cdecl *DoubleBinaryEvalFuncType)(double , double );
14621462
1463+ typedef APInt (__cdecl *IntTernaryEvalFuncType)(const APInt &, const APInt &, const APInt &);
1464+ typedef float (__cdecl *FloatTernaryEvalFuncType)(float , float , float );
1465+ typedef double (__cdecl *DoubleTernaryEvalFuncType)(double , double , double );
1466+
14631467Value *EvalUnaryIntrinsic (ConstantFP *fpV, FloatUnaryEvalFuncType floatEvalFunc,
14641468 DoubleUnaryEvalFuncType doubleEvalFunc) {
14651469 llvm::Type *Ty = fpV->getType ();
@@ -1510,6 +1514,45 @@ Value *EvalBinaryIntrinsic(Constant *cV0, Constant *cV1,
15101514 return Result;
15111515}
15121516
1517+ Value *EvalTernaryIntrinsic (Constant *cV0, Constant *cV1, Constant *cV2,
1518+ FloatTernaryEvalFuncType floatEvalFunc,
1519+ DoubleTernaryEvalFuncType doubleEvalFunc,
1520+ IntTernaryEvalFuncType intEvalFunc) {
1521+ llvm::Type *Ty = cV0->getType ();
1522+ Value *Result = nullptr ;
1523+ if (Ty->isDoubleTy ()) {
1524+ ConstantFP *fpV0 = cast<ConstantFP>(cV0);
1525+ ConstantFP *fpV1 = cast<ConstantFP>(cV1);
1526+ ConstantFP *fpV2 = cast<ConstantFP>(cV2);
1527+ double dV0 = fpV0->getValueAPF ().convertToDouble ();
1528+ double dV1 = fpV1->getValueAPF ().convertToDouble ();
1529+ double dV2 = fpV2->getValueAPF ().convertToDouble ();
1530+ Value *dResult = ConstantFP::get (Ty, doubleEvalFunc (dV0, dV1, dV2));
1531+ Result = dResult;
1532+ } else if (Ty->isFloatTy ()) {
1533+ ConstantFP *fpV0 = cast<ConstantFP>(cV0);
1534+ ConstantFP *fpV1 = cast<ConstantFP>(cV1);
1535+ ConstantFP *fpV2 = cast<ConstantFP>(cV2);
1536+ float fV0 = fpV0->getValueAPF ().convertToFloat ();
1537+ float fV1 = fpV1->getValueAPF ().convertToFloat ();
1538+ float fV2 = fpV2->getValueAPF ().convertToFloat ();
1539+ Value *dResult = ConstantFP::get (Ty, floatEvalFunc (fV0 , fV1 , fV2 ));
1540+ Result = dResult;
1541+ } else {
1542+ DXASSERT_NOMSG (Ty->isIntegerTy ());
1543+ DXASSERT_NOMSG (intEvalFunc);
1544+ ConstantInt *ciV0 = cast<ConstantInt>(cV0);
1545+ ConstantInt *ciV1 = cast<ConstantInt>(cV1);
1546+ ConstantInt *ciV2 = cast<ConstantInt>(cV2);
1547+ const APInt &iV0 = ciV0->getValue ();
1548+ const APInt &iV1 = ciV1->getValue ();
1549+ const APInt &iV2 = ciV2->getValue ();
1550+ Value *dResult = ConstantInt::get (Ty, intEvalFunc (iV0, iV1, iV2));
1551+ Result = dResult;
1552+ }
1553+ return Result;
1554+ }
1555+
15131556Value *EvalUnaryIntrinsic (CallInst *CI, FloatUnaryEvalFuncType floatEvalFunc,
15141557 DoubleUnaryEvalFuncType doubleEvalFunc) {
15151558 Value *V = CI->getArgOperand (0 );
@@ -1566,6 +1609,43 @@ Value *EvalBinaryIntrinsic(CallInst *CI, FloatBinaryEvalFuncType floatEvalFunc,
15661609 return Result;
15671610}
15681611
1612+ Value *EvalTernaryIntrinsic (CallInst *CI, FloatTernaryEvalFuncType floatEvalFunc,
1613+ DoubleTernaryEvalFuncType doubleEvalFunc,
1614+ IntTernaryEvalFuncType intEvalFunc = nullptr ) {
1615+ Value *V0 = CI->getArgOperand (0 );
1616+ Value *V1 = CI->getArgOperand (1 );
1617+ Value *V2 = CI->getArgOperand (2 );
1618+ llvm::Type *Ty = CI->getType ();
1619+ Value *Result = nullptr ;
1620+ if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(Ty)) {
1621+ Result = UndefValue::get (Ty);
1622+ Constant *CV0 = cast<Constant>(V0);
1623+ Constant *CV1 = cast<Constant>(V1);
1624+ Constant *CV2 = cast<Constant>(V2);
1625+ IRBuilder<> Builder (CI);
1626+ for (unsigned i = 0 ; i < VT->getNumElements (); i++) {
1627+ Constant *cV0 = cast<Constant>(CV0->getAggregateElement (i));
1628+ Constant *cV1 = cast<Constant>(CV1->getAggregateElement (i));
1629+ Constant *cV2 = cast<Constant>(CV2->getAggregateElement (i));
1630+ Value *EltResult = EvalTernaryIntrinsic (cV0, cV1, cV2, floatEvalFunc,
1631+ doubleEvalFunc, intEvalFunc);
1632+ Result = Builder.CreateInsertElement (Result, EltResult, i);
1633+ }
1634+ } else {
1635+ Constant *cV0 = cast<Constant>(V0);
1636+ Constant *cV1 = cast<Constant>(V1);
1637+ Constant *cV2 = cast<Constant>(V2);
1638+ Result = EvalTernaryIntrinsic (cV0, cV1, cV2, floatEvalFunc, doubleEvalFunc,
1639+ intEvalFunc);
1640+ }
1641+ CI->replaceAllUsesWith (Result);
1642+ CI->eraseFromParent ();
1643+ return Result;
1644+
1645+ CI->eraseFromParent ();
1646+ return Result;
1647+ }
1648+
15691649void SimpleTransformForHLDXIRInst (Instruction *I, SmallInstSet &deadInsts) {
15701650
15711651 unsigned opcode = I->getOpcode ();
@@ -1789,6 +1869,18 @@ Value *TryEvalIntrinsic(CallInst *CI, IntrinsicOp intriOp,
17891869 CI->eraseFromParent ();
17901870 return cNan;
17911871 } break ;
1872+ case IntrinsicOp::IOP_clamp: {
1873+ auto clampF = [](float a, float b, float c) {
1874+ return a < b ? b : a > c ? c : a;
1875+ };
1876+ auto clampD = [](double a, double b, double c) {
1877+ return a < b ? b : a > c ? c : a;
1878+ };
1879+ auto clampI = [](const APInt &a, const APInt &b, const APInt &c) -> APInt {
1880+ return a.slt (b) ? b : a.sgt (c) ? c : a;
1881+ };
1882+ return EvalTernaryIntrinsic (CI, clampF, clampD, clampI);
1883+ } break ;
17921884 default :
17931885 return nullptr ;
17941886 }
0 commit comments