Skip to content

Commit 08e2587

Browse files
committed
Add constant evaluation for clamp() (#3581)
(cherry picked from commit 2bda44f)
1 parent ad955d8 commit 08e2587

2 files changed

Lines changed: 106 additions & 0 deletions

File tree

tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,6 +1460,10 @@ typedef APInt(__cdecl *IntBinaryEvalFuncType)(const APInt &, const APInt &);
14601460
typedef float(__cdecl *FloatBinaryEvalFuncType)(float, float);
14611461
typedef double(__cdecl *DoubleBinaryEvalFuncType)(double, double);
14621462

1463+
typedef APInt(__cdecl *IntTernaryEvalFuncType)(const APInt &, const APInt &, const APInt &);
1464+
typedef float(__cdecl *FloatTernaryEvalFuncType)(float, float, float);
1465+
typedef double(__cdecl *DoubleTernaryEvalFuncType)(double, double, double);
1466+
14631467
Value *EvalUnaryIntrinsic(ConstantFP *fpV, FloatUnaryEvalFuncType floatEvalFunc,
14641468
DoubleUnaryEvalFuncType doubleEvalFunc) {
14651469
llvm::Type *Ty = fpV->getType();
@@ -1510,6 +1514,45 @@ Value *EvalBinaryIntrinsic(Constant *cV0, Constant *cV1,
15101514
return Result;
15111515
}
15121516

1517+
Value *EvalTernaryIntrinsic(Constant *cV0, Constant *cV1, Constant *cV2,
1518+
FloatTernaryEvalFuncType floatEvalFunc,
1519+
DoubleTernaryEvalFuncType doubleEvalFunc,
1520+
IntTernaryEvalFuncType intEvalFunc) {
1521+
llvm::Type *Ty = cV0->getType();
1522+
Value *Result = nullptr;
1523+
if (Ty->isDoubleTy()) {
1524+
ConstantFP *fpV0 = cast<ConstantFP>(cV0);
1525+
ConstantFP *fpV1 = cast<ConstantFP>(cV1);
1526+
ConstantFP *fpV2 = cast<ConstantFP>(cV2);
1527+
double dV0 = fpV0->getValueAPF().convertToDouble();
1528+
double dV1 = fpV1->getValueAPF().convertToDouble();
1529+
double dV2 = fpV2->getValueAPF().convertToDouble();
1530+
Value *dResult = ConstantFP::get(Ty, doubleEvalFunc(dV0, dV1, dV2));
1531+
Result = dResult;
1532+
} else if (Ty->isFloatTy()) {
1533+
ConstantFP *fpV0 = cast<ConstantFP>(cV0);
1534+
ConstantFP *fpV1 = cast<ConstantFP>(cV1);
1535+
ConstantFP *fpV2 = cast<ConstantFP>(cV2);
1536+
float fV0 = fpV0->getValueAPF().convertToFloat();
1537+
float fV1 = fpV1->getValueAPF().convertToFloat();
1538+
float fV2 = fpV2->getValueAPF().convertToFloat();
1539+
Value *dResult = ConstantFP::get(Ty, floatEvalFunc(fV0, fV1, fV2));
1540+
Result = dResult;
1541+
} else {
1542+
DXASSERT_NOMSG(Ty->isIntegerTy());
1543+
DXASSERT_NOMSG(intEvalFunc);
1544+
ConstantInt *ciV0 = cast<ConstantInt>(cV0);
1545+
ConstantInt *ciV1 = cast<ConstantInt>(cV1);
1546+
ConstantInt *ciV2 = cast<ConstantInt>(cV2);
1547+
const APInt &iV0 = ciV0->getValue();
1548+
const APInt &iV1 = ciV1->getValue();
1549+
const APInt &iV2 = ciV2->getValue();
1550+
Value *dResult = ConstantInt::get(Ty, intEvalFunc(iV0, iV1, iV2));
1551+
Result = dResult;
1552+
}
1553+
return Result;
1554+
}
1555+
15131556
Value *EvalUnaryIntrinsic(CallInst *CI, FloatUnaryEvalFuncType floatEvalFunc,
15141557
DoubleUnaryEvalFuncType doubleEvalFunc) {
15151558
Value *V = CI->getArgOperand(0);
@@ -1566,6 +1609,43 @@ Value *EvalBinaryIntrinsic(CallInst *CI, FloatBinaryEvalFuncType floatEvalFunc,
15661609
return Result;
15671610
}
15681611

1612+
Value *EvalTernaryIntrinsic(CallInst *CI, FloatTernaryEvalFuncType floatEvalFunc,
1613+
DoubleTernaryEvalFuncType doubleEvalFunc,
1614+
IntTernaryEvalFuncType intEvalFunc = nullptr) {
1615+
Value *V0 = CI->getArgOperand(0);
1616+
Value *V1 = CI->getArgOperand(1);
1617+
Value *V2 = CI->getArgOperand(2);
1618+
llvm::Type *Ty = CI->getType();
1619+
Value *Result = nullptr;
1620+
if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(Ty)) {
1621+
Result = UndefValue::get(Ty);
1622+
Constant *CV0 = cast<Constant>(V0);
1623+
Constant *CV1 = cast<Constant>(V1);
1624+
Constant *CV2 = cast<Constant>(V2);
1625+
IRBuilder<> Builder(CI);
1626+
for (unsigned i = 0; i < VT->getNumElements(); i++) {
1627+
Constant *cV0 = cast<Constant>(CV0->getAggregateElement(i));
1628+
Constant *cV1 = cast<Constant>(CV1->getAggregateElement(i));
1629+
Constant *cV2 = cast<Constant>(CV2->getAggregateElement(i));
1630+
Value *EltResult = EvalTernaryIntrinsic(cV0, cV1, cV2, floatEvalFunc,
1631+
doubleEvalFunc, intEvalFunc);
1632+
Result = Builder.CreateInsertElement(Result, EltResult, i);
1633+
}
1634+
} else {
1635+
Constant *cV0 = cast<Constant>(V0);
1636+
Constant *cV1 = cast<Constant>(V1);
1637+
Constant *cV2 = cast<Constant>(V2);
1638+
Result = EvalTernaryIntrinsic(cV0, cV1, cV2, floatEvalFunc, doubleEvalFunc,
1639+
intEvalFunc);
1640+
}
1641+
CI->replaceAllUsesWith(Result);
1642+
CI->eraseFromParent();
1643+
return Result;
1644+
1645+
CI->eraseFromParent();
1646+
return Result;
1647+
}
1648+
15691649
void SimpleTransformForHLDXIRInst(Instruction *I, SmallInstSet &deadInsts) {
15701650

15711651
unsigned opcode = I->getOpcode();
@@ -1789,6 +1869,18 @@ Value *TryEvalIntrinsic(CallInst *CI, IntrinsicOp intriOp,
17891869
CI->eraseFromParent();
17901870
return cNan;
17911871
} break;
1872+
case IntrinsicOp::IOP_clamp: {
1873+
auto clampF = [](float a, float b, float c) {
1874+
return a < b ? b : a > c ? c : a;
1875+
};
1876+
auto clampD = [](double a, double b, double c) {
1877+
return a < b ? b : a > c ? c : a;
1878+
};
1879+
auto clampI = [](const APInt &a, const APInt &b, const APInt &c) -> APInt {
1880+
return a.slt(b) ? b : a.sgt(c) ? c : a;
1881+
};
1882+
return EvalTernaryIntrinsic(CI, clampF, clampD, clampI);
1883+
} break;
17921884
default:
17931885
return nullptr;
17941886
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: %dxc -T ps_6_0 %s -E main | %FileCheck %s
2+
// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float 1.000000e+00)
3+
// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float -1.250000e+00)
4+
// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2, float 3.000000e+00)
5+
// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 3, float 2.000000e+00)
6+
7+
[RootSignature("")]
8+
float4 main() : SV_Target {
9+
return float4(
10+
clamp(10, 0, 1),
11+
clamp(-1.0f, -2.5f, -1.25f),
12+
clamp((double)3, (double)-2, (double)5),
13+
clamp(-5LL, 2LL, 5LL));
14+
}

0 commit comments

Comments
 (0)