Skip to content

Commit 41217f3

Browse files
committed
Fix DXIL OuterProductAccmulate param ordering (minterp, mlayout, mstride)
1 parent faa8f8f commit 41217f3

2 files changed

Lines changed: 30 additions & 12 deletions

File tree

include/dxc/DXIL/DxilInstructions.h

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10075,9 +10075,9 @@ struct DxilInst_OuterProductAccumulate {
1007510075
arg_inputVector2 = 2,
1007610076
arg_matrixBuffer = 3,
1007710077
arg_matrixOffset = 4,
10078-
arg_matrixStride = 5,
10079-
arg_matrixIntepretation = 6,
10080-
arg_matrixLayout = 7,
10078+
arg_matrixIntepretation = 5,
10079+
arg_matrixLayout = 6,
10080+
arg_matrixStride = 7,
1008110081
};
1008210082
// Accessors
1008310083
llvm::Value *get_inputVector1() const { return Instr->getOperand(1); }
@@ -10088,12 +10088,30 @@ struct DxilInst_OuterProductAccumulate {
1008810088
void set_matrixBuffer(llvm::Value *val) { Instr->setOperand(3, val); }
1008910089
llvm::Value *get_matrixOffset() const { return Instr->getOperand(4); }
1009010090
void set_matrixOffset(llvm::Value *val) { Instr->setOperand(4, val); }
10091-
llvm::Value *get_matrixStride() const { return Instr->getOperand(5); }
10092-
void set_matrixStride(llvm::Value *val) { Instr->setOperand(5, val); }
10093-
llvm::Value *get_matrixIntepretation() const { return Instr->getOperand(6); }
10094-
void set_matrixIntepretation(llvm::Value *val) { Instr->setOperand(6, val); }
10095-
llvm::Value *get_matrixLayout() const { return Instr->getOperand(7); }
10096-
void set_matrixLayout(llvm::Value *val) { Instr->setOperand(7, val); }
10091+
llvm::Value *get_matrixIntepretation() const { return Instr->getOperand(5); }
10092+
void set_matrixIntepretation(llvm::Value *val) { Instr->setOperand(5, val); }
10093+
int32_t get_matrixIntepretation_val() const {
10094+
return (int32_t)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(5))
10095+
->getZExtValue());
10096+
}
10097+
void set_matrixIntepretation_val(int32_t val) {
10098+
Instr->setOperand(5, llvm::Constant::getIntegerValue(
10099+
llvm::IntegerType::get(Instr->getContext(), 32),
10100+
llvm::APInt(32, (uint64_t)val)));
10101+
}
10102+
llvm::Value *get_matrixLayout() const { return Instr->getOperand(6); }
10103+
void set_matrixLayout(llvm::Value *val) { Instr->setOperand(6, val); }
10104+
int32_t get_matrixLayout_val() const {
10105+
return (int32_t)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(6))
10106+
->getZExtValue());
10107+
}
10108+
void set_matrixLayout_val(int32_t val) {
10109+
Instr->setOperand(6, llvm::Constant::getIntegerValue(
10110+
llvm::IntegerType::get(Instr->getContext(), 32),
10111+
llvm::APInt(32, (uint64_t)val)));
10112+
}
10113+
llvm::Value *get_matrixStride() const { return Instr->getOperand(7); }
10114+
void set_matrixStride(llvm::Value *val) { Instr->setOperand(7, val); }
1009710115
};
1009810116

1009910117
/// This instruction Accumulates the components of a vector component-wise

utils/hct/hctdb.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6413,9 +6413,9 @@ def UFI(name, **mappings):
64136413
db_dxil_param(3, "$x1", "inputVector2", "input vector 2"),
64146414
db_dxil_param(4, "res", "matrixBuffer", "matrix resource"),
64156415
db_dxil_param(5, "i32", "matrixOffset", "matrix offset"),
6416-
db_dxil_param(6, "i32", "matrixStride", "matrix stride"),
6417-
db_dxil_param(7, "i32", "matrixIntepretation", "matrix intepretation"),
6418-
db_dxil_param(8, "i32", "matrixLayout", "matrix layout"),
6416+
db_dxil_param(6, "i32", "matrixIntepretation", "matrix intepretation", is_const=True),
6417+
db_dxil_param(7, "i32", "matrixLayout", "matrix layout", is_const=True),
6418+
db_dxil_param(8, "i32", "matrixStride", "matrix stride"),
64196419
],
64206420
)
64216421
next_op_idx += 1

0 commit comments

Comments
 (0)