Skip to content

Commit ba828b2

Browse files
Fold 0 <<,>>,/,% n to 0. Fold a / 1 to a. Fold a % 1 to 0. Fold f % 1.0 to 0.0. Fold 0.0 % f to 0.0 (#6020)
* Fold 0 <<,>>,/,% n to 0. Fold a / 1 to a. Fold a % 1 to 0 * Fold OpFMod (f % 1.0) = 0.0 and (0.0 % f) = 0.0
1 parent bac6ca7 commit ba828b2

2 files changed

Lines changed: 291 additions & 36 deletions

File tree

source/opt/folding_rules.cpp

Lines changed: 131 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2454,20 +2454,51 @@ FoldingRule RedundantFDiv() {
24542454
FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
24552455
FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
24562456

2457-
if (kind0 == FloatConstantKind::Zero) {
2457+
if (kind0 == FloatConstantKind::Zero || kind1 == FloatConstantKind::One) {
24582458
inst->SetOpcode(spv::Op::OpCopyObject);
24592459
inst->SetInOperands(
24602460
{{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
24612461
return true;
24622462
}
24632463

2464-
if (kind1 == FloatConstantKind::One) {
2464+
return false;
2465+
};
2466+
}
2467+
2468+
FoldingRule RedundantFMod() {
2469+
return [](IRContext* context, Instruction* inst,
2470+
const std::vector<const analysis::Constant*>& constants) {
2471+
assert(inst->opcode() == spv::Op::OpFMod &&
2472+
"Wrong opcode. Should be OpFMod.");
2473+
assert(constants.size() == 2);
2474+
2475+
if (!inst->IsFloatingPointFoldingAllowed()) {
2476+
return false;
2477+
}
2478+
2479+
FloatConstantKind kind0 = getFloatConstantKind(constants[0]);
2480+
FloatConstantKind kind1 = getFloatConstantKind(constants[1]);
2481+
2482+
if (kind0 == FloatConstantKind::Zero) {
24652483
inst->SetOpcode(spv::Op::OpCopyObject);
24662484
inst->SetInOperands(
24672485
{{SPV_OPERAND_TYPE_ID, {inst->GetSingleWordInOperand(0)}}});
24682486
return true;
24692487
}
24702488

2489+
if (kind1 == FloatConstantKind::One) {
2490+
auto type = context->get_type_mgr()->GetType(inst->type_id());
2491+
std::vector<uint32_t> zero_words;
2492+
zero_words.resize(ElementWidth(type) / 32);
2493+
auto const_mgr = context->get_constant_mgr();
2494+
auto zero = const_mgr->GetConstant(type, std::move(zero_words));
2495+
auto zero_id = const_mgr->GetDefiningInstruction(zero)->result_id();
2496+
2497+
inst->SetOpcode(spv::Op::OpCopyObject);
2498+
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {zero_id}}});
2499+
return true;
2500+
}
2501+
24712502
return false;
24722503
};
24732504
}
@@ -2507,15 +2538,16 @@ FoldingRule RedundantFMix() {
25072538
};
25082539
}
25092540

2510-
// Returns a folding rule that folds the instruction to the operand not being
2511-
// checked if the operand that is checked is zero.
2512-
FoldingRule RedundantBinaryOpWithZeroOperand(uint32_t arg) {
2513-
return [arg](IRContext* context, Instruction* inst,
2514-
const std::vector<const analysis::Constant*>& constants) {
2541+
// Returns a folding rule that folds the instruction to operand |foldToArg|
2542+
// (0 or 1) if operand |arg| (0 or 1) is a zero constant.
2543+
FoldingRule RedundantBinaryOpWithZeroOperand(uint32_t arg, uint32_t foldToArg) {
2544+
return [arg, foldToArg](
2545+
IRContext* context, Instruction* inst,
2546+
const std::vector<const analysis::Constant*>& constants) {
25152547
assert(constants.size() == 2);
25162548

25172549
if (constants[arg] && constants[arg]->IsZero()) {
2518-
auto operand = inst->GetSingleWordInOperand(1 - arg);
2550+
auto operand = inst->GetSingleWordInOperand(foldToArg);
25192551
auto operand_type = constants[arg]->type();
25202552

25212553
const analysis::Type* inst_type =
@@ -2533,7 +2565,7 @@ FoldingRule RedundantBinaryOpWithZeroOperand(uint32_t arg) {
25332565
}
25342566

25352567
// This rule handles any of RedundantBinaryRhs0Ops with a 0 or vector 0 on the
2536-
// right-hand side.
2568+
// right-hand side (a | 0 => a).
25372569
static const constexpr spv::Op RedundantBinaryRhs0Ops[] = {
25382570
spv::Op::OpBitwiseOr,
25392571
spv::Op::OpBitwiseXor,
@@ -2548,11 +2580,11 @@ FoldingRule RedundantBinaryRhs0(spv::Op op) {
25482580
op) != std::end(RedundantBinaryRhs0Ops) &&
25492581
"Wrong opcode.");
25502582
(void)op;
2551-
return RedundantBinaryOpWithZeroOperand(1);
2583+
return RedundantBinaryOpWithZeroOperand(1, 0);
25522584
}
25532585

25542586
// This rule handles any of RedundantBinaryLhs0Ops with a 0 or vector 0 on the
2555-
// left-hand side.
2587+
// left-hand side (0 | a => a).
25562588
static const constexpr spv::Op RedundantBinaryLhs0Ops[] = {
25572589
spv::Op::OpBitwiseOr, spv::Op::OpBitwiseXor, spv::Op::OpIAdd};
25582590
FoldingRule RedundantBinaryLhs0(spv::Op op) {
@@ -2561,7 +2593,86 @@ FoldingRule RedundantBinaryLhs0(spv::Op op) {
25612593
op) != std::end(RedundantBinaryLhs0Ops) &&
25622594
"Wrong opcode.");
25632595
(void)op;
2564-
return RedundantBinaryOpWithZeroOperand(0);
2596+
return RedundantBinaryOpWithZeroOperand(0, 1);
2597+
}
2598+
2599+
// This rule handles shifts and divisions of 0 or vector 0 by any amount
2600+
// (0 >> a => 0).
2601+
static const constexpr spv::Op RedundantBinaryLhs0To0Ops[] = {
2602+
spv::Op::OpShiftRightLogical,
2603+
spv::Op::OpShiftRightArithmetic,
2604+
spv::Op::OpShiftLeftLogical,
2605+
spv::Op::OpSDiv,
2606+
spv::Op::OpUDiv,
2607+
spv::Op::OpSMod,
2608+
spv::Op::OpUMod};
2609+
FoldingRule RedundantBinaryLhs0To0(spv::Op op) {
2610+
assert(std::find(std::begin(RedundantBinaryLhs0To0Ops),
2611+
std::end(RedundantBinaryLhs0To0Ops),
2612+
op) != std::end(RedundantBinaryLhs0To0Ops) &&
2613+
"Wrong opcode.");
2614+
(void)op;
2615+
return RedundantBinaryOpWithZeroOperand(0, 0);
2616+
}
2617+
2618+
// Returns true if all elements in |c| are 1.
2619+
bool IsAllInt1(const analysis::Constant* c) {
2620+
if (auto composite = c->AsCompositeConstant()) {
2621+
auto& components = composite->GetComponents();
2622+
return std::all_of(std::begin(components), std::end(components), IsAllInt1);
2623+
} else if (c->AsIntConstant()) {
2624+
return c->GetSignExtendedValue() == 1;
2625+
}
2626+
2627+
return false;
2628+
}
2629+
2630+
// This rule handles divisions by 1 or vector 1 (a / 1 => a).
2631+
FoldingRule RedundantSUDiv() {
2632+
return [](IRContext* context, Instruction* inst,
2633+
const std::vector<const analysis::Constant*>& constants) {
2634+
assert(constants.size() == 2);
2635+
assert((inst->opcode() == spv::Op::OpUDiv ||
2636+
inst->opcode() == spv::Op::OpSDiv) &&
2637+
"Wrong opcode.");
2638+
2639+
if (constants[1] && IsAllInt1(constants[1])) {
2640+
auto operand = inst->GetSingleWordInOperand(0);
2641+
auto operand_type = constants[1]->type();
2642+
2643+
const analysis::Type* inst_type =
2644+
context->get_type_mgr()->GetType(inst->type_id());
2645+
if (inst_type->IsSame(operand_type)) {
2646+
inst->SetOpcode(spv::Op::OpCopyObject);
2647+
} else {
2648+
inst->SetOpcode(spv::Op::OpBitcast);
2649+
}
2650+
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {operand}}});
2651+
return true;
2652+
}
2653+
return false;
2654+
};
2655+
}
2656+
2657+
// This rule handles modulo from division by 1 or vector 1 (a % 1 => 0).
2658+
FoldingRule RedundantSUMod() {
2659+
return [](IRContext* context, Instruction* inst,
2660+
const std::vector<const analysis::Constant*>& constants) {
2661+
assert(constants.size() == 2);
2662+
assert((inst->opcode() == spv::Op::OpUMod ||
2663+
inst->opcode() == spv::Op::OpSMod) &&
2664+
"Wrong opcode.");
2665+
2666+
if (constants[1] && IsAllInt1(constants[1])) {
2667+
auto type = context->get_type_mgr()->GetType(inst->type_id());
2668+
auto zero_id = context->get_constant_mgr()->GetNullConstId(type);
2669+
2670+
inst->SetOpcode(spv::Op::OpCopyObject);
2671+
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {zero_id}}});
2672+
return true;
2673+
}
2674+
return false;
2675+
};
25652676
}
25662677

25672678
// This rule look for a dot with a constant vector containing a single 1 and
@@ -2905,6 +3016,12 @@ void FoldingRules::AddFoldingRules() {
29053016
rules_[op].push_back(RedundantBinaryRhs0(op));
29063017
for (auto op : RedundantBinaryLhs0Ops)
29073018
rules_[op].push_back(RedundantBinaryLhs0(op));
3019+
for (auto op : RedundantBinaryLhs0To0Ops)
3020+
rules_[op].push_back(RedundantBinaryLhs0To0(op));
3021+
rules_[spv::Op::OpSDiv].push_back(RedundantSUDiv());
3022+
rules_[spv::Op::OpUDiv].push_back(RedundantSUDiv());
3023+
rules_[spv::Op::OpSMod].push_back(RedundantSUMod());
3024+
rules_[spv::Op::OpUMod].push_back(RedundantSUMod());
29083025

29093026
rules_[spv::Op::OpBitcast].push_back(BitCastScalarOrVector());
29103027

@@ -2937,6 +3054,8 @@ void FoldingRules::AddFoldingRules() {
29373054
rules_[spv::Op::OpFDiv].push_back(MergeDivMulArithmetic());
29383055
rules_[spv::Op::OpFDiv].push_back(MergeDivNegateArithmetic());
29393056

3057+
rules_[spv::Op::OpFMod].push_back(RedundantFMod());
3058+
29403059
rules_[spv::Op::OpFMul].push_back(RedundantFMul());
29413060
rules_[spv::Op::OpFMul].push_back(MergeMulMulArithmetic());
29423061
rules_[spv::Op::OpFMul].push_back(MergeMulDivArithmetic());

0 commit comments

Comments
 (0)