Skip to content

Commit 4e48d34

Browse files
V-FEXrthekotagithub-actions[bot]
authored
[SM6.9] Lower vector any/all calls to vector op (microsoft#7753)
Fixes microsoft#7687 Implement vector reduction lowing for calls to `any` and `all` --------- Co-authored-by: Helena Kotas <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent c5854cf commit 4e48d34

11 files changed

Lines changed: 302 additions & 114 deletions

File tree

docs/DXIL.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2423,8 +2423,8 @@ ID Name Description
24232423
306 MatVecMulAdd multiplies a MxK dimension matrix and a K sized input vector and adds an M-sized bias vector
24242424
307 OuterProductAccumulate Computes the outer product between column vectors and an MxN matrix is accumulated component-wise atomically (with device scope) in memory
24252425
308 VectorAccumulate Accumulates the components of a vector component-wise atomically (with device scope) to the corresponding elements of an array in memory
2426-
309 ReservedD0 reserved
2427-
310 ReservedD1 reserved
2426+
309 VectorReduceAnd Bitwise AND reduction of the vector returning a scalar
2427+
310 VectorReduceOr Bitwise OR reduction of the vector returning a scalar
24282428
311 FDot computes the n-dimensional vector dot-product
24292429
=== ===================================================== =======================================================================================================================================================================================================================
24302430

include/dxc/DXIL/DxilConstants.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -524,8 +524,6 @@ enum class OpCode : unsigned {
524524
ReservedC7 = 300, // reserved
525525
ReservedC8 = 301, // reserved
526526
ReservedC9 = 302, // reserved
527-
ReservedD0 = 309, // reserved
528-
ReservedD1 = 310, // reserved
529527

530528
// Amplification shader instructions
531529
DispatchMesh = 173, // Amplification shader intrinsic DispatchMesh
@@ -1037,6 +1035,11 @@ enum class OpCode : unsigned {
10371035
Unpack4x8 = 219, // unpacks 4 8-bit signed or unsigned values into int32 or
10381036
// int16 vector
10391037

1038+
// Vector reduce to scalar
1039+
VectorReduceAnd =
1040+
309, // Bitwise AND reduction of the vector returning a scalar
1041+
VectorReduceOr = 310, // Bitwise OR reduction of the vector returning a scalar
1042+
10401043
// Wave
10411044
WaveActiveAllEqual = 115, // returns 1 if all the lanes have the same value
10421045
WaveActiveBallot = 116, // returns a struct with a bit set for each lane where
@@ -1381,6 +1384,9 @@ enum class OpCodeClass : unsigned {
13811384
// Unpacking intrinsics
13821385
Unpack4x8,
13831386

1387+
// Vector reduce to scalar
1388+
VectorReduce,
1389+
13841390
// Wave
13851391
WaveActiveAllEqual,
13861392
WaveActiveBallot,
@@ -1417,7 +1423,7 @@ enum class OpCodeClass : unsigned {
14171423
NumOpClasses_Dxil_1_7 = 153,
14181424
NumOpClasses_Dxil_1_8 = 174,
14191425

1420-
NumOpClasses = 195 // exclusive last value of enumeration
1426+
NumOpClasses = 196 // exclusive last value of enumeration
14211427
};
14221428
// OPCODECLASS-ENUM:END
14231429

include/dxc/DXIL/DxilInstructions.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10149,6 +10149,60 @@ struct DxilInst_VectorAccumulate {
1014910149
void set_arrayOffset(llvm::Value *val) { Instr->setOperand(3, val); }
1015010150
};
1015110151

10152+
/// This instruction Bitwise AND reduction of the vector returning a scalar
10153+
struct DxilInst_VectorReduceAnd {
10154+
llvm::Instruction *Instr;
10155+
// Construction and identification
10156+
DxilInst_VectorReduceAnd(llvm::Instruction *pInstr) : Instr(pInstr) {}
10157+
operator bool() const {
10158+
return hlsl::OP::IsDxilOpFuncCallInst(Instr,
10159+
hlsl::OP::OpCode::VectorReduceAnd);
10160+
}
10161+
// Validation support
10162+
bool isAllowed() const { return true; }
10163+
bool isArgumentListValid() const {
10164+
if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands())
10165+
return false;
10166+
return true;
10167+
}
10168+
// Metadata
10169+
bool requiresUniformInputs() const { return false; }
10170+
// Operand indexes
10171+
enum OperandIdx {
10172+
arg_a = 1,
10173+
};
10174+
// Accessors
10175+
llvm::Value *get_a() const { return Instr->getOperand(1); }
10176+
void set_a(llvm::Value *val) { Instr->setOperand(1, val); }
10177+
};
10178+
10179+
/// This instruction Bitwise OR reduction of the vector returning a scalar
10180+
struct DxilInst_VectorReduceOr {
10181+
llvm::Instruction *Instr;
10182+
// Construction and identification
10183+
DxilInst_VectorReduceOr(llvm::Instruction *pInstr) : Instr(pInstr) {}
10184+
operator bool() const {
10185+
return hlsl::OP::IsDxilOpFuncCallInst(Instr,
10186+
hlsl::OP::OpCode::VectorReduceOr);
10187+
}
10188+
// Validation support
10189+
bool isAllowed() const { return true; }
10190+
bool isArgumentListValid() const {
10191+
if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands())
10192+
return false;
10193+
return true;
10194+
}
10195+
// Metadata
10196+
bool requiresUniformInputs() const { return false; }
10197+
// Operand indexes
10198+
enum OperandIdx {
10199+
arg_a = 1,
10200+
};
10201+
// Accessors
10202+
llvm::Value *get_a() const { return Instr->getOperand(1); }
10203+
void set_a(llvm::Value *val) { Instr->setOperand(1, val); }
10204+
};
10205+
1015210206
/// This instruction computes the n-dimensional vector dot-product
1015310207
struct DxilInst_FDot {
1015410208
llvm::Instruction *Instr;

lib/DXIL/DxilOperations.cpp

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2687,22 +2687,23 @@ const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = {
26872687
{{0x400}},
26882688
{{0x63}}}, // Overloads: <hfwi
26892689

2690-
{OC::ReservedD0,
2691-
"ReservedD0",
2692-
OCC::Reserved,
2693-
"reserved",
2694-
Attribute::None,
2695-
0,
2696-
{},
2697-
{}}, // Overloads: v
2698-
{OC::ReservedD1,
2699-
"ReservedD1",
2700-
OCC::Reserved,
2701-
"reserved",
2702-
Attribute::None,
2703-
0,
2704-
{},
2705-
{}}, // Overloads: v
2690+
// Vector reduce to scalar
2691+
{OC::VectorReduceAnd,
2692+
"VectorReduceAnd",
2693+
OCC::VectorReduce,
2694+
"vectorReduce",
2695+
Attribute::ReadNone,
2696+
1,
2697+
{{0x400}},
2698+
{{0xf8}}}, // Overloads: <18wil
2699+
{OC::VectorReduceOr,
2700+
"VectorReduceOr",
2701+
OCC::VectorReduce,
2702+
"vectorReduce",
2703+
Attribute::ReadNone,
2704+
1,
2705+
{{0x400}},
2706+
{{0xf8}}}, // Overloads: <18wil
27062707

27072708
// Dot
27082709
{OC::FDot,
@@ -6018,14 +6019,16 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
60186019
A(pI32);
60196020
break;
60206021

6021-
//
6022-
case OpCode::ReservedD0:
6023-
A(pV);
6022+
// Vector reduce to scalar
6023+
case OpCode::VectorReduceAnd:
6024+
A(pVecElt);
60246025
A(pI32);
6026+
A(pETy);
60256027
break;
6026-
case OpCode::ReservedD1:
6027-
A(pV);
6028+
case OpCode::VectorReduceOr:
6029+
A(pVecElt);
60286030
A(pI32);
6031+
A(pETy);
60296032
break;
60306033

60316034
// Dot
@@ -6207,6 +6210,8 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
62076210
case OpCode::CreateHandleForLib:
62086211
case OpCode::WaveMatch:
62096212
case OpCode::VectorAccumulate:
6213+
case OpCode::VectorReduceAnd:
6214+
case OpCode::VectorReduceOr:
62106215
case OpCode::FDot:
62116216
if (FT->getNumParams() <= 1)
62126217
return nullptr;
@@ -6324,8 +6329,6 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
63246329
case OpCode::ReservedC7:
63256330
case OpCode::ReservedC8:
63266331
case OpCode::ReservedC9:
6327-
case OpCode::ReservedD0:
6328-
case OpCode::ReservedD1:
63296332
return Type::getVoidTy(Ctx);
63306333
case OpCode::CheckAccessFullyMapped:
63316334
case OpCode::SampleIndex:

lib/HLSL/DxilScalarizeVectorIntrinsics.cpp

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99
// //
1010
///////////////////////////////////////////////////////////////////////////////
1111

12+
#include "dxc/DXIL/DxilConstants.h"
1213
#include "dxc/DXIL/DxilInstructions.h"
1314
#include "dxc/DXIL/DxilModule.h"
1415
#include "dxc/HLSL/DxilGenerationPass.h"
1516

1617
#include "llvm/ADT/StringRef.h"
18+
#include "llvm/IR/Constant.h"
19+
#include "llvm/IR/Constants.h"
1720
#include "llvm/IR/Function.h"
1821
#include "llvm/IR/IRBuilder.h"
1922
#include "llvm/IR/Instructions.h"
@@ -29,6 +32,7 @@ static void scalarizeVectorLoad(hlsl::OP *HlslOP, const DataLayout &DL,
2932
static void scalarizeVectorStore(hlsl::OP *HlslOP, const DataLayout &DL,
3033
CallInst *CI);
3134
static void scalarizeVectorIntrinsic(hlsl::OP *HlslOP, CallInst *CI);
35+
static void scalarizeVectorReduce(hlsl::OP *HlslOP, CallInst *CI);
3236

3337
class DxilScalarizeVectorIntrinsics : public ModulePass {
3438
public:
@@ -53,24 +57,31 @@ class DxilScalarizeVectorIntrinsics : public ModulePass {
5357
for (auto F = M.functions().begin(); F != M.functions().end();) {
5458
Function *Func = &*(F++);
5559
DXIL::OpCodeClass OpClass;
56-
if (HlslOP->GetOpCodeClass(Func, OpClass)) {
60+
if (!HlslOP->GetOpCodeClass(Func, OpClass))
61+
continue;
62+
63+
bool NeedsRewrite = (Func->getReturnType()->isVectorTy() ||
64+
OpClass == DXIL::OpCodeClass::RawBufferVectorLoad ||
65+
OpClass == DXIL::OpCodeClass::RawBufferVectorStore ||
66+
OpClass == DXIL::OpCodeClass::VectorReduce);
67+
if (!NeedsRewrite)
68+
continue;
69+
70+
for (auto U = Func->user_begin(), UE = Func->user_end(); U != UE;) {
71+
CallInst *CI = cast<CallInst>(*(U++));
72+
5773
if (OpClass == DXIL::OpCodeClass::RawBufferVectorLoad)
58-
for (auto U = Func->user_begin(), UE = Func->user_end(); U != UE;) {
59-
CallInst *CI = cast<CallInst>(*(U++));
60-
scalarizeVectorLoad(HlslOP, M.getDataLayout(), CI);
61-
Changed = true;
62-
}
74+
scalarizeVectorLoad(HlslOP, M.getDataLayout(), CI);
6375
else if (OpClass == DXIL::OpCodeClass::RawBufferVectorStore)
64-
for (auto U = Func->user_begin(), UE = Func->user_end(); U != UE;) {
65-
CallInst *CI = cast<CallInst>(*(U++));
66-
scalarizeVectorStore(HlslOP, M.getDataLayout(), CI);
67-
Changed = true;
68-
}
76+
scalarizeVectorStore(HlslOP, M.getDataLayout(), CI);
77+
else if (OpClass == DXIL::OpCodeClass::VectorReduce)
78+
scalarizeVectorReduce(HlslOP, CI);
6979
else if (Func->getReturnType()->isVectorTy())
70-
for (auto U = Func->user_begin(), UE = Func->user_end(); U != UE;) {
71-
CallInst *CI = cast<CallInst>(*(U++));
72-
scalarizeVectorIntrinsic(HlslOP, CI);
73-
}
80+
scalarizeVectorIntrinsic(HlslOP, CI);
81+
else
82+
continue;
83+
84+
Changed = true;
7485
}
7586
}
7687
return Changed;
@@ -226,6 +237,33 @@ static void scalarizeVectorStore(hlsl::OP *HlslOP, const DataLayout &DL,
226237
CI->eraseFromParent();
227238
}
228239

240+
static void scalarizeVectorReduce(hlsl::OP *HlslOP, CallInst *CI) {
241+
IRBuilder<> Builder(CI);
242+
243+
OP::OpCode ReduceOp = OP::getOpCode(CI);
244+
245+
Value *VecArg = CI->getArgOperand(1);
246+
Type *VecTy = VecArg->getType();
247+
248+
Value *Result = Builder.CreateExtractElement(VecArg, (uint64_t)0);
249+
for (unsigned I = 1; I < VecTy->getVectorNumElements(); I++) {
250+
Value *Elt = Builder.CreateExtractElement(VecArg, I);
251+
252+
switch (ReduceOp) {
253+
case OP::OpCode::VectorReduceAnd:
254+
Result = Builder.CreateAnd(Result, Elt);
255+
break;
256+
case OP::OpCode::VectorReduceOr:
257+
Result = Builder.CreateOr(Result, Elt);
258+
break;
259+
default:
260+
assert(false && "Unexpected VectorReduce OpCode");
261+
}
262+
}
263+
264+
CI->replaceAllUsesWith(Result);
265+
}
266+
229267
// Scalarize native vector operation represented by `CI`, generating
230268
// scalar calls for each element of the its vector parameters.
231269
// Use `HlslOP` to retrieve the associated scalar op function.

0 commit comments

Comments
 (0)