@@ -2705,6 +2705,8 @@ llvm::StringRef OP::ConstructOverloadName(Type *Ty, DXIL::OpCode opCode,
27052705}
27062706
27072707const char *OP::GetOpCodeName (OpCode opCode) {
2708+ DXASSERT (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes,
2709+ " otherwise caller passed OOB index" );
27082710 return m_OpCodeProps[(unsigned )opCode].pOpCodeName ;
27092711}
27102712
@@ -2717,22 +2719,26 @@ const char *OP::GetAtomicOpName(DXIL::AtomicBinOpCode OpCode) {
27172719}
27182720
27192721OP::OpCodeClass OP::GetOpCodeClass (OpCode opCode) {
2722+ DXASSERT (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes,
2723+ " otherwise caller passed OOB index" );
27202724 return m_OpCodeProps[(unsigned )opCode].opCodeClass ;
27212725}
27222726
27232727const char *OP::GetOpCodeClassName (OpCode opCode) {
2728+ DXASSERT (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes,
2729+ " otherwise caller passed OOB index" );
27242730 return m_OpCodeProps[(unsigned )opCode].pOpCodeClassName ;
27252731}
27262732
27272733llvm::Attribute::AttrKind OP::GetMemAccessAttr (OpCode opCode) {
2734+ DXASSERT (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes,
2735+ " otherwise caller passed OOB index" );
27282736 return m_OpCodeProps[(unsigned )opCode].FuncAttr ;
27292737}
27302738
27312739bool OP::IsOverloadLegal (OpCode opCode, Type *pType) {
2732- if (!pType)
2733- return false ;
2734- if (opCode == OpCode::NumOpCodes)
2735- return false ;
2740+ DXASSERT (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes,
2741+ " otherwise caller passed OOB index" );
27362742 unsigned TypeSlot = GetTypeSlot (pType);
27372743 return TypeSlot != UINT_MAX &&
27382744 m_OpCodeProps[(unsigned )opCode].bAllowOverload [TypeSlot];
@@ -2808,13 +2814,8 @@ bool OP::IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode) {
28082814}
28092815
28102816OP::OpCode OP::getOpCode (const llvm::Instruction *I) {
2811- auto *OpConst = llvm::dyn_cast<llvm::ConstantInt>(I->getOperand (0 ));
2812- if (!OpConst)
2813- return OpCode::NumOpCodes;
2814- uint64_t OpCodeVal = OpConst->getZExtValue ();
2815- if (OpCodeVal >= static_cast <uint64_t >(OP::OpCode::NumOpCodes))
2816- return OP::OpCode::NumOpCodes;
2817- return static_cast <OP::OpCode>(OpCodeVal);
2817+ return (OP::OpCode)llvm::cast<llvm::ConstantInt>(I->getOperand (0 ))
2818+ ->getZExtValue ();
28182819}
28192820
28202821OP::OpCode OP::GetDxilOpFuncCallInst (const llvm::Instruction *I) {
@@ -3524,7 +3525,9 @@ void OP::RefreshCache() {
35243525 CallInst *CI = cast<CallInst>(*F.user_begin ());
35253526 OpCode OpCode = OP::GetDxilOpFuncCallInst (CI);
35263527 Type *pOverloadType = OP::GetOverloadType (OpCode, &F);
3527- GetOpFunc (OpCode, pOverloadType);
3528+ Function *OpFunc = GetOpFunc (OpCode, pOverloadType);
3529+ (void )(OpFunc);
3530+ DXASSERT_NOMSG (OpFunc == &F);
35283531 }
35293532 }
35303533}
@@ -3543,15 +3546,13 @@ void OP::FixOverloadNames() {
35433546 CallInst *CI = cast<CallInst>(*F.user_begin ());
35443547 DXIL::OpCode opCode = OP::GetDxilOpFuncCallInst (CI);
35453548 llvm::Type *Ty = OP::GetOverloadType (opCode, &F);
3546- if (!OP::IsOverloadLegal (opCode, Ty))
3547- continue ;
3548- if (!isa<StructType>(Ty) && !isa<PointerType>(Ty))
3549- continue ;
3550-
3551- std::string funcName;
3552- if (OP::ConstructOverloadName (Ty, opCode, funcName)
3553- .compare (F.getName ()) != 0 )
3554- F.setName (funcName);
3549+ if (isa<StructType>(Ty) || isa<PointerType>(Ty)) {
3550+ std::string funcName;
3551+ if (OP::ConstructOverloadName (Ty, opCode, funcName)
3552+ .compare (F.getName ()) != 0 ) {
3553+ F.setName (funcName);
3554+ }
3555+ }
35553556 }
35563557 }
35573558}
@@ -3562,11 +3563,12 @@ void OP::UpdateCache(OpCodeClass opClass, Type *Ty, llvm::Function *F) {
35623563}
35633564
35643565Function *OP::GetOpFunc (OpCode opCode, Type *pOverloadType) {
3565- if (opCode == OpCode::NumOpCodes)
3566- return nullptr ;
3567- if (!IsOverloadLegal (opCode, pOverloadType))
3568- return nullptr ;
3569-
3566+ DXASSERT (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes,
3567+ " otherwise caller passed OOB OpCode" );
3568+ assert (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes);
3569+ DXASSERT (IsOverloadLegal (opCode, pOverloadType),
3570+ " otherwise the caller requested illegal operation overload (eg HLSL "
3571+ " function with unsupported types for mapped intrinsic function)" );
35703572 OpCodeClass opClass = m_OpCodeProps[(unsigned )opCode].opCodeClass ;
35713573 Function *&F =
35723574 m_OpCodeClassCache[(unsigned )opClass].pOverloads [pOverloadType];
@@ -5509,8 +5511,8 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
55095511 // and return values to ensure that ResRetType is constructed in the
55105512 // RefreshCache case.
55115513 if (Function *existF = m_pModule->getFunction (funcName)) {
5512- if (existF->getFunctionType () != pFT)
5513- return nullptr ;
5514+ DXASSERT (existF->getFunctionType () == pFT,
5515+ " existing function must have the expected function type " ) ;
55145516 F = existF;
55155517 UpdateCache (opClass, pOverloadType, F);
55165518 return F;
@@ -5529,6 +5531,9 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
55295531
55305532const SmallMapVector<llvm::Type *, llvm::Function *, 8 > &
55315533OP::GetOpFuncList (OpCode opCode) const {
5534+ DXASSERT (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes,
5535+ " otherwise caller passed OOB OpCode" );
5536+ assert (0 <= (unsigned )opCode && opCode < OpCode::NumOpCodes);
55325537 return m_OpCodeClassCache[(unsigned )m_OpCodeProps[(unsigned )opCode]
55335538 .opCodeClass ]
55345539 .pOverloads ;
@@ -5626,8 +5631,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
56265631 case OpCode::CallShader:
56275632 case OpCode::Pack4x8:
56285633 case OpCode::WaveMatrix_Fill:
5629- if (FT->getNumParams () <= 2 )
5630- return nullptr ;
5634+ DXASSERT_NOMSG (FT->getNumParams () > 2 );
56315635 return FT->getParamType (2 );
56325636 case OpCode::MinPrecXRegStore:
56335637 case OpCode::StoreOutput:
@@ -5637,8 +5641,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
56375641 case OpCode::StoreVertexOutput:
56385642 case OpCode::StorePrimitiveOutput:
56395643 case OpCode::DispatchMesh:
5640- if (FT->getNumParams () <= 4 )
5641- return nullptr ;
5644+ DXASSERT_NOMSG (FT->getNumParams () > 4 );
56425645 return FT->getParamType (4 );
56435646 case OpCode::IsNaN:
56445647 case OpCode::IsInf:
@@ -5656,27 +5659,22 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
56565659 case OpCode::WaveActiveAllEqual:
56575660 case OpCode::CreateHandleForLib:
56585661 case OpCode::WaveMatch:
5659- if (FT->getNumParams () <= 1 )
5660- return nullptr ;
5662+ DXASSERT_NOMSG (FT->getNumParams () > 1 );
56615663 return FT->getParamType (1 );
56625664 case OpCode::TextureStore:
56635665 case OpCode::TextureStoreSample:
5664- if (FT->getNumParams () <= 5 )
5665- return nullptr ;
5666+ DXASSERT_NOMSG (FT->getNumParams () > 5 );
56665667 return FT->getParamType (5 );
56675668 case OpCode::TraceRay:
5668- if (FT->getNumParams () <= 15 )
5669- return nullptr ;
5669+ DXASSERT_NOMSG (FT->getNumParams () > 15 );
56705670 return FT->getParamType (15 );
56715671 case OpCode::ReportHit:
56725672 case OpCode::WaveMatrix_ScalarOp:
5673- if (FT->getNumParams () <= 3 )
5674- return nullptr ;
5673+ DXASSERT_NOMSG (FT->getNumParams () > 3 );
56755674 return FT->getParamType (3 );
56765675 case OpCode::WaveMatrix_LoadGroupShared:
56775676 case OpCode::WaveMatrix_StoreGroupShared:
5678- if (FT->getNumParams () <= 2 )
5679- return nullptr ;
5677+ DXASSERT_NOMSG (FT->getNumParams () > 2 );
56805678 return FT->getParamType (2 )->getPointerElementType ();
56815679 case OpCode::CreateHandle:
56825680 case OpCode::BufferUpdateCounter:
0 commit comments