Skip to content

Commit 374044c

Browse files
authored
Fix assert from matrix lower needlessly casting constant to instruction. (#4115) (#4118)
(cherry picked from commit 39dd31f)
1 parent 3360880 commit 374044c

2 files changed

Lines changed: 38 additions & 4 deletions

File tree

lib/HLSL/HLMatrixLowerPass.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,9 +1309,11 @@ static Value *convertScalarOrVector(Value *SrcVal, Type *DstTy, HLCastOpcode Opc
13091309
// Conversions to bools are comparisons
13101310
if (DstTy->getScalarSizeInBits() == 1) {
13111311
// fcmp une is what regular clang uses in C++ for (bool)f;
1312-
return cast<Instruction>(SrcTy->isIntOrIntVectorTy()
1313-
? Builder.CreateICmpNE(SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool")
1314-
: Builder.CreateFCmpUNE(SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool"));
1312+
return SrcTy->isIntOrIntVectorTy()
1313+
? Builder.CreateICmpNE(
1314+
SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool")
1315+
: Builder.CreateFCmpUNE(
1316+
SrcVal, llvm::Constant::getNullValue(SrcTy), "tobool");
13151317
}
13161318

13171319
// Cast necessary
@@ -1321,7 +1323,7 @@ static Value *convertScalarOrVector(Value *SrcVal, Type *DstTy, HLCastOpcode Opc
13211323
Opcode == HLCastOpcode::UnsignedUnsignedCast;
13221324
auto CastOp = static_cast<Instruction::CastOps>(HLModule::GetNumericCastOp(
13231325
SrcTy, SrcIsUnsigned, DstTy, DstIsUnsigned));
1324-
return cast<Instruction>(Builder.CreateCast(CastOp, SrcVal, DstTy));
1326+
return Builder.CreateCast(CastOp, SrcVal, DstTy);
13251327
}
13261328

13271329
Value *HLMatrixLowerPass::lowerHLCast(CallInst *Call, Value *Src, Type *DstTy,
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// RUN: %dxc -E main -T ps_6_2 -enable-16bit-types %s | FileCheck %s
2+
3+
// CHECK: [9 x half] [half 0xH3C00, half 0xH4400, half 0xH4700, half 0xH4000, half 0xH4500, half 0xH4800, half 0xH4200, half 0xH4600, half 0xH4880]
4+
// CHECK: fptoui float
5+
// CHECK: getelementptr [9 x half]
6+
// CHECK: load half
7+
// CHECK: call half @dx.op.tertiary.f16(i32 46, half 0xH4000,
8+
// CHECK: lshr i32 411,
9+
// CHECK: icmp ne
10+
// CHECK: %[[all:[^ ]*]] = uitofp i1 %{{.*}} to float
11+
// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0,
12+
// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1,
13+
// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2,
14+
// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 3, float %[[all]])
15+
16+
struct Foo {
17+
half3x3 hmat;
18+
bool3x3 bmat;
19+
};
20+
21+
Foo fn() {
22+
Foo foo;
23+
foo.hmat = float3x3(1, 2, 3, 4, 5, 6, 7, 8, 9);
24+
foo.bmat = int3x3(1, 2, 0, -5, 14, 3, 0, 0, 21);
25+
return foo;
26+
}
27+
28+
float4 main(float a : A) : SV_Target {
29+
Foo foo = fn();
30+
float3 v = float3(a, a * a, a + a);
31+
return float4(mul(foo.hmat, foo.hmat[a]), all(foo.bmat[a]));
32+
}

0 commit comments

Comments
 (0)